1#!/usr/bin/env python3
2"""
31D Finite Element Method (FEM)
4===============================
5
6This module demonstrates the Finite Element Method for solving 1D boundary
7value problems using piecewise linear (hat) basis functions.
8
9Problem:
10 -u''(x) = f(x), x in [0, 1]
11 u(0) = 0, u(1) = 0 (Dirichlet boundary conditions)
12
13Method:
14 - Discretize domain into N elements
15 - Use piecewise linear hat basis functions
16 - Assemble global stiffness matrix and load vector
17 - Solve linear system Au = b
18 - Compare with analytical solution
19
20Key Concepts:
21 - Weak formulation and Galerkin method
22 - Element-wise assembly
23 - Hat basis functions
24 - Numerical integration (quadrature)
25
26Author: Educational example for Numerical Simulation
27License: MIT
28"""
29
30import numpy as np
31import matplotlib.pyplot as plt
32from scipy.sparse import lil_matrix, csr_matrix
33from scipy.sparse.linalg import spsolve
34
35
36class FEM1D:
37 """
38 1D Finite Element Method solver for -u'' = f with Dirichlet BCs.
39
40 Attributes:
41 a (float): Left boundary
42 b (float): Right boundary
43 N (int): Number of elements
44 nodes (ndarray): Node coordinates
45 h (float): Element size
46 """
47
48 def __init__(self, a, b, N):
49 """
50 Initialize the FEM solver.
51
52 Args:
53 a (float): Left boundary
54 b (float): Right boundary
55 N (int): Number of elements
56 """
57 self.a = a
58 self.b = b
59 self.N = N
60
61 # Generate uniform mesh
62 self.nodes = np.linspace(a, b, N + 1)
63 self.h = (b - a) / N
64
65 # Number of nodes (unknowns)
66 self.n_nodes = N + 1
67
68 def hat_function(self, x, i):
69 """
70 Piecewise linear hat basis function φ_i(x).
71
72 φ_i(x) = 1 at node i, 0 at all other nodes, linear in between.
73
74 Args:
75 x (float or ndarray): Evaluation point(s)
76 i (int): Node index
77
78 Returns:
79 float or ndarray: Value of φ_i(x)
80 """
81 x = np.atleast_1d(x)
82 phi = np.zeros_like(x)
83
84 # Left support: [x_{i-1}, x_i]
85 if i > 0:
86 mask = (x >= self.nodes[i-1]) & (x <= self.nodes[i])
87 phi[mask] = (x[mask] - self.nodes[i-1]) / self.h
88
89 # Right support: [x_i, x_{i+1}]
90 if i < self.N:
91 mask = (x >= self.nodes[i]) & (x <= self.nodes[i+1])
92 phi[mask] = (self.nodes[i+1] - x[mask]) / self.h
93
94 return phi if len(phi) > 1 else phi[0]
95
96 def hat_derivative(self, x, i):
97 """
98 Derivative of hat basis function dφ_i/dx.
99
100 φ_i'(x) = 1/h on [x_{i-1}, x_i], -1/h on [x_i, x_{i+1}], 0 elsewhere.
101
102 Args:
103 x (float or ndarray): Evaluation point(s)
104 i (int): Node index
105
106 Returns:
107 float or ndarray: Value of dφ_i/dx
108 """
109 x = np.atleast_1d(x)
110 dphi = np.zeros_like(x)
111
112 # Left support
113 if i > 0:
114 mask = (x >= self.nodes[i-1]) & (x < self.nodes[i])
115 dphi[mask] = 1.0 / self.h
116
117 # Right support
118 if i < self.N:
119 mask = (x >= self.nodes[i]) & (x < self.nodes[i+1])
120 dphi[mask] = -1.0 / self.h
121
122 return dphi if len(dphi) > 1 else dphi[0]
123
124 def assemble_element_stiffness(self, e):
125 """
126 Assemble local stiffness matrix for element e.
127
128 For -u'', the element stiffness matrix is:
129 K_e[i,j] = ∫_{x_e}^{x_{e+1}} φ_i' φ_j' dx
130
131 For linear elements, this integral can be computed exactly.
132
133 Args:
134 e (int): Element index (0 to N-1)
135
136 Returns:
137 ndarray: 2x2 local stiffness matrix
138 """
139 # For linear hat functions on uniform mesh:
140 # K_local = (1/h) * [[1, -1], [-1, 1]]
141 K_local = (1.0 / self.h) * np.array([
142 [1.0, -1.0],
143 [-1.0, 1.0]
144 ])
145 return K_local
146
147 def assemble_element_load(self, e, f_func):
148 """
149 Assemble local load vector for element e.
150
151 F_e[i] = ∫_{x_e}^{x_{e+1}} f(x) φ_i(x) dx
152
153 Uses 2-point Gauss quadrature for integration.
154
155 Args:
156 e (int): Element index
157 f_func (callable): Right-hand side function f(x)
158
159 Returns:
160 ndarray: 2x1 local load vector
161 """
162 # Element boundaries
163 x_left = self.nodes[e]
164 x_right = self.nodes[e + 1]
165
166 # 2-point Gauss quadrature on reference element [-1, 1]
167 # Gauss points and weights
168 gauss_points = np.array([-1/np.sqrt(3), 1/np.sqrt(3)])
169 gauss_weights = np.array([1.0, 1.0])
170
171 # Map to physical element [x_left, x_right]
172 x_gauss = 0.5 * (x_right - x_left) * gauss_points + 0.5 * (x_right + x_left)
173 jacobian = 0.5 * (x_right - x_left) # dx/dξ
174
175 # Local load vector
176 F_local = np.zeros(2)
177
178 # Integrate using quadrature
179 for i in range(2): # Two local nodes
180 for q, (xq, wq) in enumerate(zip(x_gauss, gauss_weights)):
181 # Hat function value at Gauss point
182 # For element e: node 0 is at x_left, node 1 is at x_right
183 if i == 0:
184 phi_val = (x_right - xq) / self.h
185 else:
186 phi_val = (xq - x_left) / self.h
187
188 F_local[i] += f_func(xq) * phi_val * wq * jacobian
189
190 return F_local
191
192 def assemble_global_system(self, f_func):
193 """
194 Assemble global stiffness matrix K and load vector F.
195
196 Args:
197 f_func (callable): Right-hand side function f(x)
198
199 Returns:
200 tuple: (K, F) global stiffness matrix and load vector
201 """
202 # Initialize global system (use sparse matrix)
203 K = lil_matrix((self.n_nodes, self.n_nodes))
204 F = np.zeros(self.n_nodes)
205
206 # Loop over elements
207 for e in range(self.N):
208 # Local stiffness and load
209 K_local = self.assemble_element_stiffness(e)
210 F_local = self.assemble_element_load(e, f_func)
211
212 # Global node indices for this element
213 global_indices = [e, e + 1]
214
215 # Add local contributions to global system
216 for i in range(2):
217 for j in range(2):
218 K[global_indices[i], global_indices[j]] += K_local[i, j]
219 F[global_indices[i]] += F_local[i]
220
221 # Convert to CSR format for efficient solving
222 K = csr_matrix(K)
223
224 return K, F
225
226 def apply_boundary_conditions(self, K, F, u_left=0.0, u_right=0.0):
227 """
228 Apply Dirichlet boundary conditions u(a) = u_left, u(b) = u_right.
229
230 Modify the system to enforce BC by setting diagonal to 1 and RHS to BC value.
231
232 Args:
233 K (sparse matrix): Global stiffness matrix
234 F (ndarray): Global load vector
235 u_left (float): BC at left boundary
236 u_right (float): BC at right boundary
237
238 Returns:
239 tuple: (K_bc, F_bc) modified system
240 """
241 K_bc = K.tolil() # Convert to lil for modification
242 F_bc = F.copy()
243
244 # Left boundary (node 0)
245 K_bc[0, :] = 0
246 K_bc[0, 0] = 1
247 F_bc[0] = u_left
248
249 # Right boundary (node N)
250 K_bc[self.N, :] = 0
251 K_bc[self.N, self.N] = 1
252 F_bc[self.N] = u_right
253
254 return K_bc.tocsr(), F_bc
255
256 def solve(self, f_func, u_left=0.0, u_right=0.0):
257 """
258 Solve the BVP -u'' = f with Dirichlet BCs.
259
260 Args:
261 f_func (callable): Right-hand side function f(x)
262 u_left (float): BC at left boundary
263 u_right (float): BC at right boundary
264
265 Returns:
266 ndarray: Solution vector at nodes
267 """
268 # Assemble global system
269 K, F = self.assemble_global_system(f_func)
270
271 # Apply boundary conditions
272 K_bc, F_bc = self.apply_boundary_conditions(K, F, u_left, u_right)
273
274 # Solve linear system
275 u = spsolve(K_bc, F_bc)
276
277 return u
278
279 def evaluate_solution(self, u_nodes, x_eval):
280 """
281 Evaluate FEM solution at arbitrary points using basis functions.
282
283 Args:
284 u_nodes (ndarray): Solution coefficients at nodes
285 x_eval (ndarray): Evaluation points
286
287 Returns:
288 ndarray: Solution values at x_eval
289 """
290 u_eval = np.zeros_like(x_eval)
291
292 # Sum over all basis functions
293 for i in range(self.n_nodes):
294 u_eval += u_nodes[i] * self.hat_function(x_eval, i)
295
296 return u_eval
297
298
299def example_1():
300 """
301 Example 1: -u'' = 2, u(0) = 0, u(1) = 0
302
303 Analytical solution: u(x) = x(1 - x)
304 """
305 print("=" * 60)
306 print("Example 1: -u'' = 2 with homogeneous Dirichlet BCs")
307 print("=" * 60)
308
309 # Right-hand side function
310 def f(x):
311 return 2.0
312
313 # Analytical solution
314 def u_exact(x):
315 return x * (1 - x)
316
317 # Solve with different mesh refinements
318 N_values = [4, 8, 16, 32]
319 errors = []
320
321 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
322
323 # Fine grid for plotting exact solution
324 x_fine = np.linspace(0, 1, 200)
325 ax1.plot(x_fine, u_exact(x_fine), 'k-', linewidth=2, label='Exact')
326
327 for N in N_values:
328 # Initialize and solve
329 fem = FEM1D(a=0, b=1, N=N)
330 u_fem = fem.solve(f)
331
332 # Evaluate at fine grid for plotting
333 u_fem_fine = fem.evaluate_solution(u_fem, x_fine)
334
335 # Plot solution
336 ax1.plot(x_fine, u_fem_fine, '--', label=f'FEM N={N}', alpha=0.7)
337 ax1.plot(fem.nodes, u_fem, 'o', markersize=4)
338
339 # Compute error at nodes
340 u_exact_nodes = u_exact(fem.nodes)
341 error = np.linalg.norm(u_fem - u_exact_nodes) / np.linalg.norm(u_exact_nodes)
342 errors.append(error)
343 print(f"N = {N:3d}: Relative L2 error = {error:.6e}")
344
345 ax1.set_xlabel('x')
346 ax1.set_ylabel('u(x)')
347 ax1.set_title('FEM Solution vs Exact Solution')
348 ax1.legend()
349 ax1.grid(True, alpha=0.3)
350
351 # Convergence plot
352 ax2.loglog(N_values, errors, 'bo-', linewidth=2, markersize=8, label='FEM Error')
353 ax2.loglog(N_values, [errors[0] * (N_values[0]/N)**2 for N in N_values],
354 'r--', label='$O(h^2)$ reference')
355 ax2.set_xlabel('Number of elements N')
356 ax2.set_ylabel('Relative L2 Error')
357 ax2.set_title('Convergence Rate')
358 ax2.legend()
359 ax2.grid(True, alpha=0.3)
360
361 plt.tight_layout()
362 plt.savefig('/tmp/fem_example1.png', dpi=150)
363 print("Saved plot to /tmp/fem_example1.png")
364 plt.show()
365
366
367def example_2():
368 """
369 Example 2: -u'' = π^2 sin(πx), u(0) = 0, u(1) = 0
370
371 Analytical solution: u(x) = sin(πx)
372 """
373 print("\n" + "=" * 60)
374 print("Example 2: -u'' = π² sin(πx) with homogeneous Dirichlet BCs")
375 print("=" * 60)
376
377 # Right-hand side function
378 def f(x):
379 return np.pi**2 * np.sin(np.pi * x)
380
381 # Analytical solution
382 def u_exact(x):
383 return np.sin(np.pi * x)
384
385 # Solve
386 N = 32
387 fem = FEM1D(a=0, b=1, N=N)
388 u_fem = fem.solve(f)
389
390 # Evaluation
391 x_fine = np.linspace(0, 1, 200)
392 u_fem_fine = fem.evaluate_solution(u_fem, x_fine)
393 u_exact_fine = u_exact(x_fine)
394
395 # Plot
396 fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
397
398 # Solution comparison
399 ax1.plot(x_fine, u_exact_fine, 'k-', linewidth=2, label='Exact')
400 ax1.plot(x_fine, u_fem_fine, 'b--', linewidth=2, label=f'FEM (N={N})')
401 ax1.plot(fem.nodes, u_fem, 'ro', markersize=6, label='FEM nodes')
402 ax1.set_xlabel('x')
403 ax1.set_ylabel('u(x)')
404 ax1.set_title('Solution: -u\'\' = π² sin(πx)')
405 ax1.legend()
406 ax1.grid(True, alpha=0.3)
407
408 # Error plot
409 error_fine = np.abs(u_fem_fine - u_exact_fine)
410 ax2.plot(x_fine, error_fine, 'r-', linewidth=2)
411 ax2.set_xlabel('x')
412 ax2.set_ylabel('|u_FEM - u_exact|')
413 ax2.set_title('Pointwise Absolute Error')
414 ax2.grid(True, alpha=0.3)
415 ax2.set_yscale('log')
416
417 # Compute error at nodes
418 u_exact_nodes = u_exact(fem.nodes)
419 error = np.linalg.norm(u_fem - u_exact_nodes) / np.linalg.norm(u_exact_nodes)
420 print(f"N = {N}: Relative L2 error = {error:.6e}")
421
422 plt.tight_layout()
423 plt.savefig('/tmp/fem_example2.png', dpi=150)
424 print("Saved plot to /tmp/fem_example2.png")
425 plt.show()
426
427
428def example_3_basis_functions():
429 """
430 Visualize hat basis functions.
431 """
432 print("\n" + "=" * 60)
433 print("Example 3: Visualizing Hat Basis Functions")
434 print("=" * 60)
435
436 N = 5
437 fem = FEM1D(a=0, b=1, N=N)
438
439 x_plot = np.linspace(0, 1, 500)
440
441 fig, ax = plt.subplots(figsize=(12, 6))
442
443 # Plot each basis function
444 for i in range(fem.n_nodes):
445 phi_i = fem.hat_function(x_plot, i)
446 ax.plot(x_plot, phi_i, linewidth=2, label=f'φ_{i}')
447 ax.plot(fem.nodes[i], 1.0, 'ko', markersize=8)
448
449 ax.set_xlabel('x')
450 ax.set_ylabel('φ_i(x)')
451 ax.set_title(f'Hat Basis Functions (N = {N} elements)')
452 ax.legend(loc='upper right')
453 ax.grid(True, alpha=0.3)
454 ax.set_ylim([-0.1, 1.2])
455
456 plt.tight_layout()
457 plt.savefig('/tmp/fem_basis_functions.png', dpi=150)
458 print("Saved plot to /tmp/fem_basis_functions.png")
459 plt.show()
460
461
462if __name__ == "__main__":
463 print("1D Finite Element Method")
464 print("=" * 60)
465 print("This script demonstrates FEM for solving -u'' = f with Dirichlet BCs.")
466 print()
467
468 # Run examples
469 example_3_basis_functions()
470 example_1()
471 example_2()
472
473 print("\n" + "=" * 60)
474 print("Key Takeaways:")
475 print(" - FEM uses piecewise polynomial basis functions (hat functions)")
476 print(" - Assembly is done element-by-element (local to global)")
477 print(" - Convergence rate is O(h²) for linear elements")
478 print(" - Sparse matrices enable efficient solution of large systems")
479 print("=" * 60)