08_fem_1d.py

Download
python 480 lines 13.5 KB
  1#!/usr/bin/env python3
  2"""
  31D Finite Element Method (FEM)
  4===============================
  5
  6This module demonstrates the Finite Element Method for solving 1D boundary
  7value problems using piecewise linear (hat) basis functions.
  8
  9Problem:
 10    -u''(x) = f(x),  x in [0, 1]
 11    u(0) = 0,  u(1) = 0  (Dirichlet boundary conditions)
 12
 13Method:
 14    - Discretize domain into N elements
 15    - Use piecewise linear hat basis functions
 16    - Assemble global stiffness matrix and load vector
 17    - Solve linear system Au = b
 18    - Compare with analytical solution
 19
 20Key Concepts:
 21    - Weak formulation and Galerkin method
 22    - Element-wise assembly
 23    - Hat basis functions
 24    - Numerical integration (quadrature)
 25
 26Author: Educational example for Numerical Simulation
 27License: MIT
 28"""
 29
 30import numpy as np
 31import matplotlib.pyplot as plt
 32from scipy.sparse import lil_matrix, csr_matrix
 33from scipy.sparse.linalg import spsolve
 34
 35
 36class FEM1D:
 37    """
 38    1D Finite Element Method solver for -u'' = f with Dirichlet BCs.
 39
 40    Attributes:
 41        a (float): Left boundary
 42        b (float): Right boundary
 43        N (int): Number of elements
 44        nodes (ndarray): Node coordinates
 45        h (float): Element size
 46    """
 47
 48    def __init__(self, a, b, N):
 49        """
 50        Initialize the FEM solver.
 51
 52        Args:
 53            a (float): Left boundary
 54            b (float): Right boundary
 55            N (int): Number of elements
 56        """
 57        self.a = a
 58        self.b = b
 59        self.N = N
 60
 61        # Generate uniform mesh
 62        self.nodes = np.linspace(a, b, N + 1)
 63        self.h = (b - a) / N
 64
 65        # Number of nodes (unknowns)
 66        self.n_nodes = N + 1
 67
 68    def hat_function(self, x, i):
 69        """
 70        Piecewise linear hat basis function φ_i(x).
 71
 72        φ_i(x) = 1 at node i, 0 at all other nodes, linear in between.
 73
 74        Args:
 75            x (float or ndarray): Evaluation point(s)
 76            i (int): Node index
 77
 78        Returns:
 79            float or ndarray: Value of φ_i(x)
 80        """
 81        x = np.atleast_1d(x)
 82        phi = np.zeros_like(x)
 83
 84        # Left support: [x_{i-1}, x_i]
 85        if i > 0:
 86            mask = (x >= self.nodes[i-1]) & (x <= self.nodes[i])
 87            phi[mask] = (x[mask] - self.nodes[i-1]) / self.h
 88
 89        # Right support: [x_i, x_{i+1}]
 90        if i < self.N:
 91            mask = (x >= self.nodes[i]) & (x <= self.nodes[i+1])
 92            phi[mask] = (self.nodes[i+1] - x[mask]) / self.h
 93
 94        return phi if len(phi) > 1 else phi[0]
 95
 96    def hat_derivative(self, x, i):
 97        """
 98        Derivative of hat basis function dφ_i/dx.
 99
100        φ_i'(x) = 1/h on [x_{i-1}, x_i], -1/h on [x_i, x_{i+1}], 0 elsewhere.
101
102        Args:
103            x (float or ndarray): Evaluation point(s)
104            i (int): Node index
105
106        Returns:
107            float or ndarray: Value of dφ_i/dx
108        """
109        x = np.atleast_1d(x)
110        dphi = np.zeros_like(x)
111
112        # Left support
113        if i > 0:
114            mask = (x >= self.nodes[i-1]) & (x < self.nodes[i])
115            dphi[mask] = 1.0 / self.h
116
117        # Right support
118        if i < self.N:
119            mask = (x >= self.nodes[i]) & (x < self.nodes[i+1])
120            dphi[mask] = -1.0 / self.h
121
122        return dphi if len(dphi) > 1 else dphi[0]
123
124    def assemble_element_stiffness(self, e):
125        """
126        Assemble local stiffness matrix for element e.
127
128        For -u'', the element stiffness matrix is:
129            K_e[i,j] = ∫_{x_e}^{x_{e+1}} φ_i' φ_j' dx
130
131        For linear elements, this integral can be computed exactly.
132
133        Args:
134            e (int): Element index (0 to N-1)
135
136        Returns:
137            ndarray: 2x2 local stiffness matrix
138        """
139        # For linear hat functions on uniform mesh:
140        # K_local = (1/h) * [[1, -1], [-1, 1]]
141        K_local = (1.0 / self.h) * np.array([
142            [1.0, -1.0],
143            [-1.0, 1.0]
144        ])
145        return K_local
146
147    def assemble_element_load(self, e, f_func):
148        """
149        Assemble local load vector for element e.
150
151        F_e[i] = ∫_{x_e}^{x_{e+1}} f(x) φ_i(x) dx
152
153        Uses 2-point Gauss quadrature for integration.
154
155        Args:
156            e (int): Element index
157            f_func (callable): Right-hand side function f(x)
158
159        Returns:
160            ndarray: 2x1 local load vector
161        """
162        # Element boundaries
163        x_left = self.nodes[e]
164        x_right = self.nodes[e + 1]
165
166        # 2-point Gauss quadrature on reference element [-1, 1]
167        # Gauss points and weights
168        gauss_points = np.array([-1/np.sqrt(3), 1/np.sqrt(3)])
169        gauss_weights = np.array([1.0, 1.0])
170
171        # Map to physical element [x_left, x_right]
172        x_gauss = 0.5 * (x_right - x_left) * gauss_points + 0.5 * (x_right + x_left)
173        jacobian = 0.5 * (x_right - x_left)  # dx/dξ
174
175        # Local load vector
176        F_local = np.zeros(2)
177
178        # Integrate using quadrature
179        for i in range(2):  # Two local nodes
180            for q, (xq, wq) in enumerate(zip(x_gauss, gauss_weights)):
181                # Hat function value at Gauss point
182                # For element e: node 0 is at x_left, node 1 is at x_right
183                if i == 0:
184                    phi_val = (x_right - xq) / self.h
185                else:
186                    phi_val = (xq - x_left) / self.h
187
188                F_local[i] += f_func(xq) * phi_val * wq * jacobian
189
190        return F_local
191
192    def assemble_global_system(self, f_func):
193        """
194        Assemble global stiffness matrix K and load vector F.
195
196        Args:
197            f_func (callable): Right-hand side function f(x)
198
199        Returns:
200            tuple: (K, F) global stiffness matrix and load vector
201        """
202        # Initialize global system (use sparse matrix)
203        K = lil_matrix((self.n_nodes, self.n_nodes))
204        F = np.zeros(self.n_nodes)
205
206        # Loop over elements
207        for e in range(self.N):
208            # Local stiffness and load
209            K_local = self.assemble_element_stiffness(e)
210            F_local = self.assemble_element_load(e, f_func)
211
212            # Global node indices for this element
213            global_indices = [e, e + 1]
214
215            # Add local contributions to global system
216            for i in range(2):
217                for j in range(2):
218                    K[global_indices[i], global_indices[j]] += K_local[i, j]
219                F[global_indices[i]] += F_local[i]
220
221        # Convert to CSR format for efficient solving
222        K = csr_matrix(K)
223
224        return K, F
225
226    def apply_boundary_conditions(self, K, F, u_left=0.0, u_right=0.0):
227        """
228        Apply Dirichlet boundary conditions u(a) = u_left, u(b) = u_right.
229
230        Modify the system to enforce BC by setting diagonal to 1 and RHS to BC value.
231
232        Args:
233            K (sparse matrix): Global stiffness matrix
234            F (ndarray): Global load vector
235            u_left (float): BC at left boundary
236            u_right (float): BC at right boundary
237
238        Returns:
239            tuple: (K_bc, F_bc) modified system
240        """
241        K_bc = K.tolil()  # Convert to lil for modification
242        F_bc = F.copy()
243
244        # Left boundary (node 0)
245        K_bc[0, :] = 0
246        K_bc[0, 0] = 1
247        F_bc[0] = u_left
248
249        # Right boundary (node N)
250        K_bc[self.N, :] = 0
251        K_bc[self.N, self.N] = 1
252        F_bc[self.N] = u_right
253
254        return K_bc.tocsr(), F_bc
255
256    def solve(self, f_func, u_left=0.0, u_right=0.0):
257        """
258        Solve the BVP -u'' = f with Dirichlet BCs.
259
260        Args:
261            f_func (callable): Right-hand side function f(x)
262            u_left (float): BC at left boundary
263            u_right (float): BC at right boundary
264
265        Returns:
266            ndarray: Solution vector at nodes
267        """
268        # Assemble global system
269        K, F = self.assemble_global_system(f_func)
270
271        # Apply boundary conditions
272        K_bc, F_bc = self.apply_boundary_conditions(K, F, u_left, u_right)
273
274        # Solve linear system
275        u = spsolve(K_bc, F_bc)
276
277        return u
278
279    def evaluate_solution(self, u_nodes, x_eval):
280        """
281        Evaluate FEM solution at arbitrary points using basis functions.
282
283        Args:
284            u_nodes (ndarray): Solution coefficients at nodes
285            x_eval (ndarray): Evaluation points
286
287        Returns:
288            ndarray: Solution values at x_eval
289        """
290        u_eval = np.zeros_like(x_eval)
291
292        # Sum over all basis functions
293        for i in range(self.n_nodes):
294            u_eval += u_nodes[i] * self.hat_function(x_eval, i)
295
296        return u_eval
297
298
299def example_1():
300    """
301    Example 1: -u'' = 2, u(0) = 0, u(1) = 0
302
303    Analytical solution: u(x) = x(1 - x)
304    """
305    print("=" * 60)
306    print("Example 1: -u'' = 2 with homogeneous Dirichlet BCs")
307    print("=" * 60)
308
309    # Right-hand side function
310    def f(x):
311        return 2.0
312
313    # Analytical solution
314    def u_exact(x):
315        return x * (1 - x)
316
317    # Solve with different mesh refinements
318    N_values = [4, 8, 16, 32]
319    errors = []
320
321    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
322
323    # Fine grid for plotting exact solution
324    x_fine = np.linspace(0, 1, 200)
325    ax1.plot(x_fine, u_exact(x_fine), 'k-', linewidth=2, label='Exact')
326
327    for N in N_values:
328        # Initialize and solve
329        fem = FEM1D(a=0, b=1, N=N)
330        u_fem = fem.solve(f)
331
332        # Evaluate at fine grid for plotting
333        u_fem_fine = fem.evaluate_solution(u_fem, x_fine)
334
335        # Plot solution
336        ax1.plot(x_fine, u_fem_fine, '--', label=f'FEM N={N}', alpha=0.7)
337        ax1.plot(fem.nodes, u_fem, 'o', markersize=4)
338
339        # Compute error at nodes
340        u_exact_nodes = u_exact(fem.nodes)
341        error = np.linalg.norm(u_fem - u_exact_nodes) / np.linalg.norm(u_exact_nodes)
342        errors.append(error)
343        print(f"N = {N:3d}: Relative L2 error = {error:.6e}")
344
345    ax1.set_xlabel('x')
346    ax1.set_ylabel('u(x)')
347    ax1.set_title('FEM Solution vs Exact Solution')
348    ax1.legend()
349    ax1.grid(True, alpha=0.3)
350
351    # Convergence plot
352    ax2.loglog(N_values, errors, 'bo-', linewidth=2, markersize=8, label='FEM Error')
353    ax2.loglog(N_values, [errors[0] * (N_values[0]/N)**2 for N in N_values],
354               'r--', label='$O(h^2)$ reference')
355    ax2.set_xlabel('Number of elements N')
356    ax2.set_ylabel('Relative L2 Error')
357    ax2.set_title('Convergence Rate')
358    ax2.legend()
359    ax2.grid(True, alpha=0.3)
360
361    plt.tight_layout()
362    plt.savefig('/tmp/fem_example1.png', dpi=150)
363    print("Saved plot to /tmp/fem_example1.png")
364    plt.show()
365
366
367def example_2():
368    """
369    Example 2: -u'' = π^2 sin(πx), u(0) = 0, u(1) = 0
370
371    Analytical solution: u(x) = sin(πx)
372    """
373    print("\n" + "=" * 60)
374    print("Example 2: -u'' = π² sin(πx) with homogeneous Dirichlet BCs")
375    print("=" * 60)
376
377    # Right-hand side function
378    def f(x):
379        return np.pi**2 * np.sin(np.pi * x)
380
381    # Analytical solution
382    def u_exact(x):
383        return np.sin(np.pi * x)
384
385    # Solve
386    N = 32
387    fem = FEM1D(a=0, b=1, N=N)
388    u_fem = fem.solve(f)
389
390    # Evaluation
391    x_fine = np.linspace(0, 1, 200)
392    u_fem_fine = fem.evaluate_solution(u_fem, x_fine)
393    u_exact_fine = u_exact(x_fine)
394
395    # Plot
396    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
397
398    # Solution comparison
399    ax1.plot(x_fine, u_exact_fine, 'k-', linewidth=2, label='Exact')
400    ax1.plot(x_fine, u_fem_fine, 'b--', linewidth=2, label=f'FEM (N={N})')
401    ax1.plot(fem.nodes, u_fem, 'ro', markersize=6, label='FEM nodes')
402    ax1.set_xlabel('x')
403    ax1.set_ylabel('u(x)')
404    ax1.set_title('Solution: -u\'\' = π² sin(πx)')
405    ax1.legend()
406    ax1.grid(True, alpha=0.3)
407
408    # Error plot
409    error_fine = np.abs(u_fem_fine - u_exact_fine)
410    ax2.plot(x_fine, error_fine, 'r-', linewidth=2)
411    ax2.set_xlabel('x')
412    ax2.set_ylabel('|u_FEM - u_exact|')
413    ax2.set_title('Pointwise Absolute Error')
414    ax2.grid(True, alpha=0.3)
415    ax2.set_yscale('log')
416
417    # Compute error at nodes
418    u_exact_nodes = u_exact(fem.nodes)
419    error = np.linalg.norm(u_fem - u_exact_nodes) / np.linalg.norm(u_exact_nodes)
420    print(f"N = {N}: Relative L2 error = {error:.6e}")
421
422    plt.tight_layout()
423    plt.savefig('/tmp/fem_example2.png', dpi=150)
424    print("Saved plot to /tmp/fem_example2.png")
425    plt.show()
426
427
428def example_3_basis_functions():
429    """
430    Visualize hat basis functions.
431    """
432    print("\n" + "=" * 60)
433    print("Example 3: Visualizing Hat Basis Functions")
434    print("=" * 60)
435
436    N = 5
437    fem = FEM1D(a=0, b=1, N=N)
438
439    x_plot = np.linspace(0, 1, 500)
440
441    fig, ax = plt.subplots(figsize=(12, 6))
442
443    # Plot each basis function
444    for i in range(fem.n_nodes):
445        phi_i = fem.hat_function(x_plot, i)
446        ax.plot(x_plot, phi_i, linewidth=2, label=f'φ_{i}')
447        ax.plot(fem.nodes[i], 1.0, 'ko', markersize=8)
448
449    ax.set_xlabel('x')
450    ax.set_ylabel('φ_i(x)')
451    ax.set_title(f'Hat Basis Functions (N = {N} elements)')
452    ax.legend(loc='upper right')
453    ax.grid(True, alpha=0.3)
454    ax.set_ylim([-0.1, 1.2])
455
456    plt.tight_layout()
457    plt.savefig('/tmp/fem_basis_functions.png', dpi=150)
458    print("Saved plot to /tmp/fem_basis_functions.png")
459    plt.show()
460
461
462if __name__ == "__main__":
463    print("1D Finite Element Method")
464    print("=" * 60)
465    print("This script demonstrates FEM for solving -u'' = f with Dirichlet BCs.")
466    print()
467
468    # Run examples
469    example_3_basis_functions()
470    example_1()
471    example_2()
472
473    print("\n" + "=" * 60)
474    print("Key Takeaways:")
475    print("  - FEM uses piecewise polynomial basis functions (hat functions)")
476    print("  - Assembly is done element-by-element (local to global)")
477    print("  - Convergence rate is O(h²) for linear elements")
478    print("  - Sparse matrices enable efficient solution of large systems")
479    print("=" * 60)