13_adaptive_lms.py

Download
python 419 lines 15.0 KB
  1#!/usr/bin/env python3
  2"""
  3Adaptive Filters: LMS and NLMS Algorithms
  4==========================================
  5
  6Adaptive filters adjust their coefficients automatically to minimise an error
  7signal.  Unlike fixed FIR/IIR filters, they require no a-priori knowledge of
  8the signal statistics β€” they learn on-line.
  9
 10The Least Mean Squares (LMS) Algorithm
 11---------------------------------------
 12Update rule:
 13    w[n+1] = w[n] + mu * e[n] * x[n]
 14
 15where
 16    x[n]  : input vector (most recent M samples)
 17    d[n]  : desired signal
 18    y[n]  = w[n]^T x[n]   : filter output
 19    e[n]  = d[n] - y[n]   : error signal
 20    mu    : step size (controls speed vs stability)
 21
 22Stability condition:
 23    0 < mu < 2 / (M * max_input_power)
 24
 25Normalised LMS (NLMS)
 26----------------------
 27NLMS divides the step size by the instantaneous input power, making it
 28insensitive to input amplitude variations:
 29
 30    w[n+1] = w[n] + (mu_n / (epsilon + x[n]^T x[n])) * e[n] * x[n]
 31
 32where mu_n is now dimensionless (0 < mu_n < 2) and epsilon prevents
 33division by zero.
 34
 35Applications demonstrated:
 36    1. System identification β€” estimate an unknown FIR filter's coefficients
 37    2. Noise cancellation β€” remove correlated noise from a desired signal
 38
 39Author: Educational example for Signal Processing
 40License: MIT
 41"""
 42
 43import numpy as np
 44import matplotlib.pyplot as plt
 45
 46
 47# ============================================================================
 48# LMS ADAPTIVE FILTER
 49# ============================================================================
 50
 51def lms_filter(x, d, mu, M):
 52    """
 53    Least Mean Squares (LMS) adaptive filter.
 54
 55    Args:
 56        x  (ndarray): Input signal, shape (N,)
 57        d  (ndarray): Desired signal, shape (N,)
 58        mu (float)  : Step size.  Larger mu β†’ faster but less stable.
 59        M  (int)    : Filter length (number of taps).
 60
 61    Returns:
 62        y    (ndarray): Filter output, shape (N,)
 63        e    (ndarray): Error signal e[n] = d[n] - y[n], shape (N,)
 64        W    (ndarray): Weight history, shape (N, M)
 65        mse  (ndarray): Instantaneous squared error |e[n]|^2, shape (N,)
 66    """
 67    N = len(x)
 68    w = np.zeros(M)          # filter weights initialised to zero
 69    y = np.zeros(N)
 70    e = np.zeros(N)
 71    W = np.zeros((N, M))     # weight trajectory (for analysis)
 72
 73    for n in range(N):
 74        # Build input vector: x[n], x[n-1], ..., x[n-M+1]
 75        if n < M:
 76            x_vec = np.concatenate([x[:n+1][::-1], np.zeros(M - n - 1)])
 77        else:
 78            x_vec = x[n:n-M:-1]     # reversed window
 79
 80        y[n] = w @ x_vec             # filter output
 81        e[n] = d[n] - y[n]           # error
 82        w = w + mu * e[n] * x_vec    # weight update (LMS rule)
 83        W[n] = w
 84
 85    mse = e ** 2
 86    return y, e, W, mse
 87
 88
 89# ============================================================================
 90# NLMS ADAPTIVE FILTER
 91# ============================================================================
 92
 93def nlms_filter(x, d, mu_n, M, epsilon=1e-6):
 94    """
 95    Normalised Least Mean Squares (NLMS) adaptive filter.
 96
 97    The step size is normalised by the instantaneous input energy, so the
 98    algorithm is robust to changes in input power.
 99
100    Args:
101        x       (ndarray): Input signal, shape (N,)
102        d       (ndarray): Desired signal, shape (N,)
103        mu_n    (float)  : Normalised step size, 0 < mu_n < 2.
104        M       (int)    : Filter length (number of taps).
105        epsilon (float)  : Small regularisation constant (prevents /0).
106
107    Returns:
108        Same as lms_filter.
109    """
110    N = len(x)
111    w = np.zeros(M)
112    y = np.zeros(N)
113    e = np.zeros(N)
114    W = np.zeros((N, M))
115
116    for n in range(N):
117        if n < M:
118            x_vec = np.concatenate([x[:n+1][::-1], np.zeros(M - n - 1)])
119        else:
120            x_vec = x[n:n-M:-1]
121
122        y[n] = w @ x_vec
123        e[n] = d[n] - y[n]
124
125        # NLMS: divide step size by input power
126        power = x_vec @ x_vec
127        w = w + (mu_n / (epsilon + power)) * e[n] * x_vec
128        W[n] = w
129
130    mse = e ** 2
131    return y, e, W, mse
132
133
134# ============================================================================
135# APPLICATION 1: SYSTEM IDENTIFICATION
136# ============================================================================
137
138def demo_system_identification():
139    """
140    Use LMS/NLMS to identify the impulse response of an unknown FIR system.
141
142    Setup:
143        - 'Unknown system' H(z) is a random FIR filter of length M_true.
144        - The adaptive filter w of the same length learns H iteratively.
145        - After convergence, w β‰ˆ H (system identified).
146    """
147    print("=" * 60)
148    print("APPLICATION 1: System Identification")
149    print("=" * 60)
150
151    rng = np.random.default_rng(42)
152    N = 2000          # number of samples
153    M = 16            # filter length (same for unknown system and adaptive filter)
154
155    # Unknown system impulse response (what we want to identify)
156    h_true = rng.standard_normal(M)
157    h_true /= np.linalg.norm(h_true)   # normalise for numerical convenience
158    print(f"True impulse response (first 8 coefficients): {h_true[:8].round(3)}")
159
160    # White noise excitation signal (wide-band β†’ good identification)
161    x = rng.standard_normal(N)
162
163    # Desired signal = output of the unknown system + small observation noise
164    d = np.convolve(x, h_true, mode='full')[:N]
165    d += 0.05 * rng.standard_normal(N)     # SNR β‰ˆ 26 dB
166
167    # Run LMS with three different step sizes to illustrate the trade-off
168    results = {}
169    for mu in [0.01, 0.05, 0.2]:
170        y, e, W, mse = lms_filter(x, d, mu=mu, M=M)
171        results[f'LMS mu={mu}'] = (W, mse)
172        final_weights = W[-1]
173        err = np.linalg.norm(final_weights - h_true)
174        print(f"  LMS mu={mu:4.2f}  |  final ||w - h_true|| = {err:.4f}")
175
176    # Run NLMS
177    y_n, e_n, W_n, mse_n = nlms_filter(x, d, mu_n=0.5, M=M)
178    final_nlms = W_n[-1]
179    err_nlms = np.linalg.norm(final_nlms - h_true)
180    print(f"  NLMS mu_n=0.5   |  final ||w - h_true|| = {err_nlms:.4f}")
181
182    # ---- Plot ---------------------------------------------------------------
183    fig, axes = plt.subplots(2, 2, figsize=(13, 9))
184    fig.suptitle("System Identification via LMS / NLMS", fontsize=14, fontweight='bold')
185
186    # (a) Learning curves (smoothed MSE)
187    ax = axes[0, 0]
188    smooth = 50   # running average window
189    for label, (_, mse) in results.items():
190        kernel = np.ones(smooth) / smooth
191        mse_smooth = np.convolve(mse, kernel, mode='valid')
192        ax.semilogy(mse_smooth, label=label)
193    mse_smooth_nlms = np.convolve(mse_n, np.ones(smooth) / smooth, mode='valid')
194    ax.semilogy(mse_smooth_nlms, label='NLMS mu_n=0.5', linestyle='--')
195    ax.set_xlabel("Iteration")
196    ax.set_ylabel("MSE (smoothed, log scale)")
197    ax.set_title("(a) Learning Curves β€” Effect of Step Size")
198    ax.legend(fontsize=8)
199    ax.grid(True, alpha=0.3)
200
201    # (b) True vs identified impulse response (best LMS)
202    ax = axes[0, 1]
203    best_W, _ = results['LMS mu=0.05']
204    ax.stem(h_true, linefmt='C0-', markerfmt='C0o', basefmt='k-', label='True h')
205    ax.stem(best_W[-1], linefmt='C1--', markerfmt='C1x', basefmt='k-', label='LMS (mu=0.05)')
206    ax.stem(final_nlms, linefmt='C2:', markerfmt='C2^', basefmt='k-', label='NLMS')
207    ax.set_xlabel("Tap index")
208    ax.set_ylabel("Coefficient value")
209    ax.set_title("(b) Identified vs True Impulse Response")
210    ax.legend(fontsize=8)
211    ax.grid(True, alpha=0.3)
212
213    # (c) Weight evolution over time (LMS mu=0.05, first 4 taps)
214    ax = axes[1, 0]
215    W_mid, _ = results['LMS mu=0.05']
216    for k in range(4):
217        ax.plot(W_mid[:, k], label=f'w[{k}]')
218        ax.axhline(h_true[k], color=f'C{k}', linestyle=':', linewidth=1)
219    ax.set_xlabel("Iteration")
220    ax.set_ylabel("Weight value")
221    ax.set_title("(c) Weight Convergence (LMS mu=0.05, taps 0-3)\nDotted = true value")
222    ax.legend(fontsize=8)
223    ax.grid(True, alpha=0.3)
224
225    # (d) Residual error over time
226    ax = axes[1, 1]
227    for label, (W_hist, _) in results.items():
228        residuals = [np.linalg.norm(W_hist[n] - h_true) for n in range(0, N, 10)]
229        ax.semilogy(range(0, N, 10), residuals, label=label)
230    residuals_nlms = [np.linalg.norm(W_n[n] - h_true) for n in range(0, N, 10)]
231    ax.semilogy(range(0, N, 10), residuals_nlms, label='NLMS mu_n=0.5', linestyle='--')
232    ax.set_xlabel("Iteration")
233    ax.set_ylabel("||w[n] - h_true||")
234    ax.set_title("(d) Weight Error Norm over Time")
235    ax.legend(fontsize=8)
236    ax.grid(True, alpha=0.3)
237
238    plt.tight_layout()
239    plt.savefig("13_system_identification.png", dpi=120)
240    print("  Saved: 13_system_identification.png")
241    plt.show()
242
243
244# ============================================================================
245# APPLICATION 2: NOISE CANCELLATION
246# ============================================================================
247
248def demo_noise_cancellation():
249    """
250    Use NLMS to cancel correlated noise from a desired signal.
251
252    Setup (classic adaptive noise canceller, Widrow 1975):
253        - Primary input   : d[n] = s[n] + v1[n]   (signal + noise)
254        - Reference input : x[n] = v2[n]           (correlated noise, no signal)
255        - The adaptive filter estimates v1 from v2:
256              y[n] β‰ˆ v1[n]
257        - Output: e[n] = d[n] - y[n] β‰ˆ s[n]        (cleaned signal)
258
259    The key requirement is that the reference noise v2 is correlated with
260    the primary noise v1 but uncorrelated with the desired signal s.
261    """
262    print("\n" + "=" * 60)
263    print("APPLICATION 2: Adaptive Noise Cancellation")
264    print("=" * 60)
265
266    rng = np.random.default_rng(7)
267    fs = 1000            # sample rate (Hz)
268    t = np.arange(2000) / fs
269
270    # Desired signal: 50 Hz sinusoid
271    s = np.sin(2 * np.pi * 50 * t)
272
273    # Noise source: band-limited random noise
274    v_source = rng.standard_normal(len(t))
275    # Simulate two sensors picking up the same noise source through different paths
276    noise_path1 = np.array([0.8, 0.3, -0.2, 0.1])   # path to primary sensor
277    noise_path2 = np.array([0.5, -0.4, 0.6])          # path to reference sensor
278    v1 = np.convolve(v_source, noise_path1, mode='full')[:len(t)]
279    v2 = np.convolve(v_source, noise_path2, mode='full')[:len(t)]
280
281    # Primary: signal + noise
282    primary = s + v1
283    reference = v2    # reference: correlated noise only
284
285    # Input SNR before cancellation
286    snr_in = 10 * np.log10(np.var(s) / np.var(v1))
287    print(f"  Input SNR  : {snr_in:.1f} dB")
288
289    # Run NLMS noise canceller
290    M = 12    # adaptive filter length (must capture the path difference)
291    y, e, W, mse = nlms_filter(reference, primary, mu_n=0.8, M=M)
292    # e[n] is the cleaned output β‰ˆ s[n]
293
294    # Output SNR after cancellation
295    # Compare cleaned signal to true s (using later half after convergence)
296    half = len(t) // 2
297    snr_out = 10 * np.log10(np.var(s[half:]) / np.var(e[half:] - s[half:]))
298    print(f"  Output SNR : {snr_out:.1f} dB  (improvement: {snr_out - snr_in:.1f} dB)")
299
300    # ---- Plot ---------------------------------------------------------------
301    fig, axes = plt.subplots(3, 2, figsize=(13, 10))
302    fig.suptitle("Adaptive Noise Cancellation (NLMS)", fontsize=14, fontweight='bold')
303    seg = slice(0, 300)   # show first 0.3 s
304
305    axes[0, 0].plot(t[seg], s[seg], 'C0')
306    axes[0, 0].set_title("(a) Desired Signal s[n] (50 Hz sinusoid)")
307    axes[0, 0].set_ylabel("Amplitude")
308
309    axes[0, 1].plot(t[seg], primary[seg], 'C1')
310    axes[0, 1].set_title(f"(b) Primary Input d[n] = s + noise  (SNR = {snr_in:.1f} dB)")
311    axes[0, 1].set_ylabel("Amplitude")
312
313    axes[1, 0].plot(t[seg], reference[seg], 'C2')
314    axes[1, 0].set_title("(c) Reference Input (correlated noise)")
315    axes[1, 0].set_ylabel("Amplitude")
316
317    axes[1, 1].plot(t[seg], e[seg], 'C3')
318    axes[1, 1].set_title(f"(d) Cleaned Output e[n] β‰ˆ s[n]  (SNR = {snr_out:.1f} dB)")
319    axes[1, 1].set_ylabel("Amplitude")
320
321    # Learning curve
322    smooth = 30
323    kernel = np.ones(smooth) / smooth
324    mse_smooth = np.convolve(mse, kernel, mode='valid')
325    axes[2, 0].semilogy(mse_smooth, 'C4')
326    axes[2, 0].set_title("(e) Learning Curve (MSE)")
327    axes[2, 0].set_xlabel("Iteration")
328    axes[2, 0].set_ylabel("MSE (log)")
329    axes[2, 0].grid(True, alpha=0.3)
330
331    # Overlay comparison (last 200 samples after convergence)
332    seg2 = slice(1700, 2000)
333    axes[2, 1].plot(t[seg2], s[seg2], 'C0', label='True s[n]', linewidth=2)
334    axes[2, 1].plot(t[seg2], e[seg2], 'C3--', label='Cleaned e[n]', linewidth=1.5)
335    axes[2, 1].set_title("(f) True vs Cleaned (after convergence)")
336    axes[2, 1].set_xlabel("Time (s)")
337    axes[2, 1].set_ylabel("Amplitude")
338    axes[2, 1].legend()
339    axes[2, 1].grid(True, alpha=0.3)
340
341    for ax in axes.flat:
342        ax.grid(True, alpha=0.3)
343
344    plt.tight_layout()
345    plt.savefig("13_noise_cancellation.png", dpi=120)
346    print("  Saved: 13_noise_cancellation.png")
347    plt.show()
348
349
350# ============================================================================
351# BONUS: LMS vs NLMS CONVERGENCE COMPARISON
352# ============================================================================
353
354def demo_lms_vs_nlms():
355    """
356    Direct comparison of LMS and NLMS when input amplitude changes mid-way.
357
358    LMS is sensitive to input power (step size must be retuned).
359    NLMS adapts automatically.
360    """
361    print("\n" + "=" * 60)
362    print("COMPARISON: LMS vs NLMS under non-stationary input")
363    print("=" * 60)
364
365    rng = np.random.default_rng(99)
366    N = 3000
367    M = 8
368    h_true = np.array([0.5, 0.3, -0.2, 0.1, 0.05, -0.05, 0.02, 0.01])
369
370    # Input signal: amplitude doubles at n=1500 (non-stationary scenario)
371    x = rng.standard_normal(N)
372    x[1500:] *= 5.0      # sudden power increase
373
374    d = np.convolve(x, h_true, mode='full')[:N]
375    d += 0.02 * rng.standard_normal(N)
376
377    _, _, _, mse_lms = lms_filter(x, d, mu=0.01, M=M)
378    _, _, _, mse_nlms = nlms_filter(x, d, mu_n=0.5, M=M)
379
380    smooth = 40
381    kernel = np.ones(smooth) / smooth
382    mse_lms_s = np.convolve(mse_lms, kernel, mode='valid')
383    mse_nlms_s = np.convolve(mse_nlms, kernel, mode='valid')
384
385    fig, ax = plt.subplots(figsize=(10, 4))
386    ax.semilogy(mse_lms_s, label='LMS (mu=0.01)', color='C0')
387    ax.semilogy(mse_nlms_s, label='NLMS (mu_n=0.5)', color='C1', linestyle='--')
388    ax.axvline(1500, color='red', linestyle=':', linewidth=1.5, label='AmplitudeΓ—5 at n=1500')
389    ax.set_xlabel("Iteration")
390    ax.set_ylabel("MSE (log scale)")
391    ax.set_title("LMS vs NLMS: Non-Stationary Input (amplitude jump at n=1500)\n"
392                 "LMS diverges; NLMS adapts automatically")
393    ax.legend()
394    ax.grid(True, alpha=0.3)
395    plt.tight_layout()
396    plt.savefig("13_lms_vs_nlms.png", dpi=120)
397    print("  Saved: 13_lms_vs_nlms.png")
398    plt.show()
399
400
401# ============================================================================
402# MAIN
403# ============================================================================
404
405if __name__ == "__main__":
406    print("Adaptive Filters: LMS and NLMS")
407    print("=" * 60)
408    print("Key parameters:")
409    print("  mu  (LMS step size) : controls speed vs stability trade-off")
410    print("  mu_n (NLMS)         : normalised step size, 0 < mu_n < 2")
411    print("  M   (filter order)  : must be >= true system order")
412    print()
413
414    demo_system_identification()
415    demo_noise_cancellation()
416    demo_lms_vs_nlms()
417
418    print("\nDone.  Three PNG files saved.")