1#!/usr/bin/env python3
2"""
3Fourier Spectral Method for PDEs
4=================================
5
6This module demonstrates the Fourier spectral method for solving partial
7differential equations (PDEs) using Fast Fourier Transform (FFT).
8
9Examples:
10 1. 1D Heat Equation: u_t = nu * u_xx
11 2. 1D Burgers Equation: u_t + u * u_x = nu * u_xx (nonlinear)
12
13Key Concepts:
14 - Spectral differentiation using FFT
15 - Time integration with RK4
16 - Dealiasing using the 3/2 rule for nonlinear terms
17 - High-order accuracy in space
18
19Author: Educational example for Numerical Simulation
20License: MIT
21"""
22
23import numpy as np
24from scipy.fft import fft, ifft, fftfreq
25import matplotlib.pyplot as plt
26from matplotlib.animation import FuncAnimation
27
28
29class FourierSpectralSolver:
30 """
31 Fourier spectral method solver for 1D PDEs with periodic boundary conditions.
32
33 Attributes:
34 N (int): Number of spatial grid points
35 L (float): Domain length [0, L]
36 nu (float): Viscosity/diffusion coefficient
37 x (ndarray): Spatial grid points
38 k (ndarray): Wavenumber array for spectral differentiation
39 """
40
41 def __init__(self, N, L, nu):
42 """
43 Initialize the spectral solver.
44
45 Args:
46 N (int): Number of grid points (should be even for dealiasing)
47 L (float): Domain length
48 nu (float): Viscosity coefficient
49 """
50 self.N = N
51 self.L = L
52 self.nu = nu
53
54 # Spatial grid (periodic, exclude endpoint)
55 self.x = np.linspace(0, L, N, endpoint=False)
56
57 # Wavenumbers for FFT (properly ordered for fft)
58 self.k = 2 * np.pi * fftfreq(N, d=L/N)
59
60 def spectral_derivative(self, u_hat, order=1):
61 """
62 Compute spectral derivative in Fourier space.
63
64 For a function u(x) with Fourier transform û(k):
65 d^n u/dx^n <--> (ik)^n û(k)
66
67 Args:
68 u_hat (ndarray): Fourier coefficients of u
69 order (int): Derivative order
70
71 Returns:
72 ndarray: Fourier coefficients of du/dx^order
73 """
74 return (1j * self.k)**order * u_hat
75
76 def dealias_3_2(self, u_hat):
77 """
78 Apply 3/2 rule dealiasing to prevent aliasing errors in nonlinear terms.
79
80 The 3/2 rule: zero out the middle third of Fourier modes when computing
81 nonlinear products.
82
83 Args:
84 u_hat (ndarray): Fourier coefficients
85
86 Returns:
87 ndarray: Dealiased Fourier coefficients
88 """
89 u_hat_dealiased = u_hat.copy()
90 # Zero out modes in middle third
91 N = len(u_hat)
92 u_hat_dealiased[N//3:(2*N)//3] = 0
93 return u_hat_dealiased
94
95 def heat_equation_rhs(self, u_hat):
96 """
97 Right-hand side for heat equation: u_t = nu * u_xx
98
99 In Fourier space: û_t = -nu * k^2 * û
100
101 Args:
102 u_hat (ndarray): Fourier coefficients of u
103
104 Returns:
105 ndarray: Time derivative in Fourier space
106 """
107 return -self.nu * self.k**2 * u_hat
108
109 def burgers_equation_rhs(self, u_hat, use_dealiasing=True):
110 """
111 Right-hand side for Burgers equation: u_t + u * u_x = nu * u_xx
112
113 Args:
114 u_hat (ndarray): Fourier coefficients of u
115 use_dealiasing (bool): Whether to apply 3/2 rule dealiasing
116
117 Returns:
118 ndarray: Time derivative in Fourier space
119 """
120 # Linear term: diffusion
121 diffusion = -self.nu * self.k**2 * u_hat
122
123 # Nonlinear term: -u * u_x
124 # Transform back to physical space
125 u = ifft(u_hat).real
126
127 # Compute u * du/dx in physical space
128 u_x_hat = self.spectral_derivative(u_hat, order=1)
129 u_x = ifft(u_x_hat).real
130 nonlinear = -u * u_x
131
132 # Transform to Fourier space and apply dealiasing
133 nonlinear_hat = fft(nonlinear)
134 if use_dealiasing:
135 nonlinear_hat = self.dealias_3_2(nonlinear_hat)
136
137 return diffusion + nonlinear_hat
138
139 def rk4_step(self, u_hat, dt, equation='heat'):
140 """
141 Fourth-order Runge-Kutta time integration step.
142
143 Args:
144 u_hat (ndarray): Current Fourier coefficients
145 dt (float): Time step
146 equation (str): 'heat' or 'burgers'
147
148 Returns:
149 ndarray: Updated Fourier coefficients
150 """
151 if equation == 'heat':
152 rhs = self.heat_equation_rhs
153 elif equation == 'burgers':
154 rhs = self.burgers_equation_rhs
155 else:
156 raise ValueError("equation must be 'heat' or 'burgers'")
157
158 # RK4 stages
159 k1 = rhs(u_hat)
160 k2 = rhs(u_hat + 0.5 * dt * k1)
161 k3 = rhs(u_hat + 0.5 * dt * k2)
162 k4 = rhs(u_hat + dt * k3)
163
164 # Update
165 u_hat_new = u_hat + (dt / 6.0) * (k1 + 2*k2 + 2*k3 + k4)
166
167 return u_hat_new
168
169 def solve(self, u0, T, dt, equation='heat'):
170 """
171 Solve the PDE from t=0 to t=T.
172
173 Args:
174 u0 (ndarray): Initial condition in physical space
175 T (float): Final time
176 dt (float): Time step
177 equation (str): 'heat' or 'burgers'
178
179 Returns:
180 tuple: (time_points, solution_history)
181 """
182 # Number of time steps
183 Nt = int(T / dt)
184
185 # Initialize
186 u_hat = fft(u0)
187
188 # Storage
189 t_history = [0]
190 u_history = [u0.copy()]
191
192 # Time integration
193 for n in range(Nt):
194 u_hat = self.rk4_step(u_hat, dt, equation)
195
196 # Store (every 10 steps to save memory)
197 if (n + 1) % 10 == 0 or n == Nt - 1:
198 u = ifft(u_hat).real
199 t_history.append((n + 1) * dt)
200 u_history.append(u.copy())
201
202 return np.array(t_history), np.array(u_history)
203
204
205def example_heat_equation():
206 """
207 Solve the 1D heat equation with a Gaussian initial condition.
208
209 PDE: u_t = nu * u_xx, x in [0, 2π], periodic BCs
210 IC: u(x, 0) = exp(-10 * (x - π)^2)
211
212 Analytical solution exists for comparison.
213 """
214 print("=" * 60)
215 print("Example 1: Heat Equation")
216 print("=" * 60)
217
218 # Parameters
219 N = 128 # Grid points
220 L = 2 * np.pi # Domain length
221 nu = 0.1 # Diffusion coefficient
222 T = 2.0 # Final time
223 dt = 0.01 # Time step
224
225 # Initialize solver
226 solver = FourierSpectralSolver(N, L, nu)
227
228 # Initial condition: Gaussian
229 u0 = np.exp(-10 * (solver.x - np.pi)**2)
230
231 # Solve
232 print(f"Solving heat equation with N={N} points, dt={dt}")
233 t_history, u_history = solver.solve(u0, T, dt, equation='heat')
234 print(f"Computed {len(t_history)} time snapshots")
235
236 # Analytical solution for comparison (Gaussian diffusion)
237 def analytical_heat(x, t, nu):
238 sigma0_sq = 1.0 / (2 * 10) # Initial variance
239 sigma_t_sq = sigma0_sq + 2 * nu * t
240 return np.sqrt(sigma0_sq / sigma_t_sq) * np.exp(-((x - np.pi)**2) / (2 * sigma_t_sq))
241
242 # Plot results
243 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
244
245 # Left: Evolution over time
246 snapshot_indices = [0, len(t_history)//3, 2*len(t_history)//3, -1]
247 for idx in snapshot_indices:
248 ax1.plot(solver.x, u_history[idx], label=f't = {t_history[idx]:.2f}')
249 ax1.set_xlabel('x')
250 ax1.set_ylabel('u(x, t)')
251 ax1.set_title('Heat Equation Evolution (Spectral Method)')
252 ax1.legend()
253 ax1.grid(True, alpha=0.3)
254
255 # Right: Comparison with analytical solution at final time
256 u_analytical = analytical_heat(solver.x, t_history[-1], nu)
257 ax2.plot(solver.x, u_history[-1], 'b-', label='Spectral', linewidth=2)
258 ax2.plot(solver.x, u_analytical, 'r--', label='Analytical', linewidth=2)
259 ax2.set_xlabel('x')
260 ax2.set_ylabel('u(x, T)')
261 ax2.set_title(f'Comparison at t = {T}')
262 ax2.legend()
263 ax2.grid(True, alpha=0.3)
264
265 # Compute error
266 error = np.linalg.norm(u_history[-1] - u_analytical) / np.linalg.norm(u_analytical)
267 print(f"Relative L2 error: {error:.2e}")
268
269 plt.tight_layout()
270 plt.savefig('/tmp/heat_equation_spectral.png', dpi=150)
271 print("Saved plot to /tmp/heat_equation_spectral.png")
272 plt.show()
273
274
275def example_burgers_equation():
276 """
277 Solve the 1D viscous Burgers equation.
278
279 PDE: u_t + u * u_x = nu * u_xx, x in [0, 2π], periodic BCs
280 IC: u(x, 0) = sin(x) + 0.5 * sin(2x)
281
282 This demonstrates shock formation and dissipation.
283 """
284 print("\n" + "=" * 60)
285 print("Example 2: Burgers Equation")
286 print("=" * 60)
287
288 # Parameters
289 N = 256 # Grid points (more for shock resolution)
290 L = 2 * np.pi # Domain length
291 nu = 0.05 # Viscosity
292 T = 3.0 # Final time
293 dt = 0.005 # Time step (smaller for stability)
294
295 # Initialize solver
296 solver = FourierSpectralSolver(N, L, nu)
297
298 # Initial condition: smooth wave
299 u0 = np.sin(solver.x) + 0.5 * np.sin(2 * solver.x)
300
301 # Solve
302 print(f"Solving Burgers equation with N={N} points, dt={dt}")
303 t_history, u_history = solver.solve(u0, T, dt, equation='burgers')
304 print(f"Computed {len(t_history)} time snapshots")
305
306 # Plot results
307 fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
308
309 # Top: Evolution over time
310 snapshot_indices = [0, len(t_history)//4, len(t_history)//2, 3*len(t_history)//4, -1]
311 colors = plt.cm.viridis(np.linspace(0, 1, len(snapshot_indices)))
312
313 for i, idx in enumerate(snapshot_indices):
314 ax1.plot(solver.x, u_history[idx], color=colors[i],
315 label=f't = {t_history[idx]:.2f}', linewidth=2)
316 ax1.set_xlabel('x')
317 ax1.set_ylabel('u(x, t)')
318 ax1.set_title('Burgers Equation: Shock Formation and Dissipation')
319 ax1.legend()
320 ax1.grid(True, alpha=0.3)
321
322 # Bottom: Space-time contour plot
323 X, T_grid = np.meshgrid(solver.x, t_history)
324 contour = ax2.contourf(X, T_grid, u_history, levels=20, cmap='RdBu_r')
325 plt.colorbar(contour, ax=ax2, label='u(x, t)')
326 ax2.set_xlabel('x')
327 ax2.set_ylabel('t')
328 ax2.set_title('Space-Time Evolution')
329
330 plt.tight_layout()
331 plt.savefig('/tmp/burgers_equation_spectral.png', dpi=150)
332 print("Saved plot to /tmp/burgers_equation_spectral.png")
333 plt.show()
334
335
336if __name__ == "__main__":
337 print("Fourier Spectral Method for PDEs")
338 print("=" * 60)
339 print("This script demonstrates spectral methods using FFT for:")
340 print(" 1. Heat equation (linear diffusion)")
341 print(" 2. Burgers equation (nonlinear advection + diffusion)")
342 print()
343
344 # Run examples
345 example_heat_equation()
346 example_burgers_equation()
347
348 print("\n" + "=" * 60)
349 print("Key Takeaways:")
350 print(" - Spectral methods provide exponential convergence for smooth solutions")
351 print(" - FFT enables O(N log N) computation of derivatives")
352 print(" - Dealiasing (3/2 rule) prevents aliasing in nonlinear terms")
353 print(" - RK4 time integration maintains high temporal accuracy")
354 print("=" * 60)