06_bayesian_inference.py

Download
python 454 lines 15.2 KB
  1"""
  206_bayesian_inference.py
  3
  4Demonstrates advanced Bayesian inference methods:
  5- MCMC basics (Metropolis-Hastings)
  6- Gibbs sampling simple example
  7- Posterior sampling
  8- Convergence diagnostics (trace plots)
  9- Bayesian linear regression
 10"""
 11
 12import numpy as np
 13from scipy import stats
 14
 15try:
 16    import matplotlib.pyplot as plt
 17    HAS_PLT = True
 18except ImportError:
 19    HAS_PLT = False
 20    print("matplotlib not available; skipping plots\n")
 21
 22
 23def print_section(title):
 24    """Print formatted section header."""
 25    print("\n" + "=" * 70)
 26    print(f"  {title}")
 27    print("=" * 70)
 28
 29
 30def metropolis_hastings_normal():
 31    """Demonstrate Metropolis-Hastings for sampling from normal distribution."""
 32    print_section("1. Metropolis-Hastings Algorithm")
 33
 34    print("Sampling from N(10, 2²) using Metropolis-Hastings")
 35
 36    # Target distribution
 37    target_mean = 10
 38    target_std = 2
 39
 40    def log_target(x):
 41        """Log of target distribution."""
 42        return stats.norm.logpdf(x, target_mean, target_std)
 43
 44    # MCMC parameters
 45    n_samples = 10000
 46    proposal_std = 3
 47    burn_in = 1000
 48
 49    # Initialize
 50    current = 0  # Starting point
 51    samples = []
 52    accepted = 0
 53
 54    print(f"\nMCMC parameters:")
 55    print(f"  Iterations: {n_samples}")
 56    print(f"  Burn-in: {burn_in}")
 57    print(f"  Proposal std: {proposal_std}")
 58
 59    # Run Metropolis-Hastings
 60    for i in range(n_samples):
 61        # Propose new state
 62        proposed = current + np.random.normal(0, proposal_std)
 63
 64        # Acceptance ratio
 65        log_ratio = log_target(proposed) - log_target(current)
 66        accept_prob = min(1, np.exp(log_ratio))
 67
 68        # Accept or reject
 69        if np.random.uniform() < accept_prob:
 70            current = proposed
 71            accepted += 1
 72
 73        samples.append(current)
 74
 75    samples = np.array(samples)
 76    samples_after_burnin = samples[burn_in:]
 77
 78    print(f"\nResults:")
 79    print(f"  Acceptance rate: {accepted/n_samples:.3f}")
 80    print(f"  Sample mean: {np.mean(samples_after_burnin):.3f} (true: {target_mean})")
 81    print(f"  Sample std: {np.std(samples_after_burnin, ddof=1):.3f} (true: {target_std})")
 82
 83    # Effective sample size (simple autocorrelation-based estimate)
 84    autocorr_lag1 = np.corrcoef(samples_after_burnin[:-1], samples_after_burnin[1:])[0, 1]
 85    ess_approx = len(samples_after_burnin) * (1 - autocorr_lag1) / (1 + autocorr_lag1)
 86
 87    print(f"  Autocorrelation (lag 1): {autocorr_lag1:.3f}")
 88    print(f"  Approx. effective sample size: {ess_approx:.0f}")
 89
 90    if HAS_PLT:
 91        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
 92
 93        # Trace plot
 94        axes[0, 0].plot(samples, alpha=0.7, linewidth=0.5)
 95        axes[0, 0].axvline(burn_in, color='red', linestyle='--', label='Burn-in')
 96        axes[0, 0].axhline(target_mean, color='green', linestyle='--', label='True mean')
 97        axes[0, 0].set_xlabel('Iteration')
 98        axes[0, 0].set_ylabel('Value')
 99        axes[0, 0].set_title('Trace Plot')
100        axes[0, 0].legend()
101        axes[0, 0].grid(True, alpha=0.3)
102
103        # Histogram
104        axes[0, 1].hist(samples_after_burnin, bins=50, density=True,
105                       alpha=0.7, edgecolor='black', label='MCMC samples')
106        x = np.linspace(target_mean - 4*target_std, target_mean + 4*target_std, 200)
107        axes[0, 1].plot(x, stats.norm.pdf(x, target_mean, target_std),
108                       'r-', linewidth=2, label='True distribution')
109        axes[0, 1].set_xlabel('Value')
110        axes[0, 1].set_ylabel('Density')
111        axes[0, 1].set_title('Posterior Distribution')
112        axes[0, 1].legend()
113        axes[0, 1].grid(True, alpha=0.3)
114
115        # Running mean
116        running_mean = np.cumsum(samples) / np.arange(1, len(samples) + 1)
117        axes[1, 0].plot(running_mean, alpha=0.7)
118        axes[1, 0].axhline(target_mean, color='red', linestyle='--', label='True mean')
119        axes[1, 0].axvline(burn_in, color='orange', linestyle='--', alpha=0.5)
120        axes[1, 0].set_xlabel('Iteration')
121        axes[1, 0].set_ylabel('Running Mean')
122        axes[1, 0].set_title('Convergence of Mean')
123        axes[1, 0].legend()
124        axes[1, 0].grid(True, alpha=0.3)
125
126        # Autocorrelation
127        max_lag = 50
128        autocorr = [np.corrcoef(samples_after_burnin[:-lag], samples_after_burnin[lag:])[0, 1]
129                   if lag > 0 else 1.0 for lag in range(max_lag)]
130        axes[1, 1].bar(range(max_lag), autocorr, alpha=0.7)
131        axes[1, 1].set_xlabel('Lag')
132        axes[1, 1].set_ylabel('Autocorrelation')
133        axes[1, 1].set_title('Autocorrelation Function')
134        axes[1, 1].grid(True, alpha=0.3)
135
136        plt.tight_layout()
137        plt.savefig('/tmp/mcmc_metropolis.png', dpi=100)
138        print("\n[Plot saved to /tmp/mcmc_metropolis.png]")
139        plt.close()
140
141
142def gibbs_sampling_bivariate():
143    """Demonstrate Gibbs sampling for bivariate normal."""
144    print_section("2. Gibbs Sampling")
145
146    print("Sampling from bivariate normal using Gibbs sampling")
147
148    # Target: bivariate normal with correlation
149    mu = np.array([3, 5])
150    rho = 0.7
151    sigma_x = 1.5
152    sigma_y = 2.0
153    cov_matrix = np.array([
154        [sigma_x**2, rho * sigma_x * sigma_y],
155        [rho * sigma_x * sigma_y, sigma_y**2]
156    ])
157
158    print(f"\nTarget distribution:")
159    print(f"  μ = {mu}")
160    print(f"  ρ = {rho}")
161    print(f"  σ_x = {sigma_x}, σ_y = {sigma_y}")
162
163    # Conditional distributions
164    def sample_x_given_y(y):
165        """Sample x | y from conditional normal."""
166        mu_cond = mu[0] + rho * (sigma_x / sigma_y) * (y - mu[1])
167        sigma_cond = sigma_x * np.sqrt(1 - rho**2)
168        return np.random.normal(mu_cond, sigma_cond)
169
170    def sample_y_given_x(x):
171        """Sample y | x from conditional normal."""
172        mu_cond = mu[1] + rho * (sigma_y / sigma_x) * (x - mu[0])
173        sigma_cond = sigma_y * np.sqrt(1 - rho**2)
174        return np.random.normal(mu_cond, sigma_cond)
175
176    # Gibbs sampling
177    n_samples = 5000
178    burn_in = 500
179
180    x_samples = np.zeros(n_samples)
181    y_samples = np.zeros(n_samples)
182
183    # Initialize
184    x_samples[0] = 0
185    y_samples[0] = 0
186
187    print(f"\nGibbs sampling:")
188    print(f"  Iterations: {n_samples}")
189    print(f"  Burn-in: {burn_in}")
190
191    # Run Gibbs
192    for i in range(1, n_samples):
193        x_samples[i] = sample_x_given_y(y_samples[i-1])
194        y_samples[i] = sample_y_given_x(x_samples[i])
195
196    # Remove burn-in
197    x_final = x_samples[burn_in:]
198    y_final = y_samples[burn_in:]
199
200    print(f"\nResults (after burn-in):")
201    print(f"  Mean x: {np.mean(x_final):.3f} (true: {mu[0]})")
202    print(f"  Mean y: {np.mean(y_final):.3f} (true: {mu[1]})")
203    print(f"  Std x: {np.std(x_final, ddof=1):.3f} (true: {sigma_x})")
204    print(f"  Std y: {np.std(y_final, ddof=1):.3f} (true: {sigma_y})")
205    print(f"  Correlation: {np.corrcoef(x_final, y_final)[0,1]:.3f} (true: {rho})")
206
207    if HAS_PLT:
208        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
209
210        # Trace plots
211        axes[0].plot(x_samples, alpha=0.5, label='x')
212        axes[0].plot(y_samples, alpha=0.5, label='y')
213        axes[0].axvline(burn_in, color='red', linestyle='--', alpha=0.5)
214        axes[0].axhline(mu[0], color='blue', linestyle='--', alpha=0.3)
215        axes[0].axhline(mu[1], color='orange', linestyle='--', alpha=0.3)
216        axes[0].set_xlabel('Iteration')
217        axes[0].set_ylabel('Value')
218        axes[0].set_title('Trace Plots')
219        axes[0].legend()
220        axes[0].grid(True, alpha=0.3)
221
222        # Joint distribution
223        axes[1].scatter(x_final, y_final, alpha=0.3, s=5)
224        axes[1].scatter([mu[0]], [mu[1]], color='red', s=100, marker='x',
225                       label='True mean', zorder=5)
226
227        # Add ellipse for true distribution
228        from matplotlib.patches import Ellipse
229        eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)
230        angle = np.degrees(np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0]))
231        width, height = 2 * np.sqrt(5.991 * eigenvalues)  # 95% confidence ellipse
232        ellipse = Ellipse(mu, width, height, angle=angle,
233                         facecolor='none', edgecolor='red', linewidth=2, label='95% ellipse')
234        axes[1].add_patch(ellipse)
235
236        axes[1].set_xlabel('x')
237        axes[1].set_ylabel('y')
238        axes[1].set_title('Joint Distribution')
239        axes[1].legend()
240        axes[1].grid(True, alpha=0.3)
241        axes[1].axis('equal')
242
243        plt.tight_layout()
244        plt.savefig('/tmp/gibbs_sampling.png', dpi=100)
245        print("\n[Plot saved to /tmp/gibbs_sampling.png]")
246        plt.close()
247
248
249def bayesian_linear_regression():
250    """Demonstrate Bayesian linear regression."""
251    print_section("3. Bayesian Linear Regression")
252
253    # Generate data
254    np.random.seed(42)
255    n = 50
256    x = np.random.uniform(0, 10, n)
257    true_beta_0 = 5
258    true_beta_1 = 2
259    sigma_true = 2
260    y = true_beta_0 + true_beta_1 * x + np.random.normal(0, sigma_true, n)
261
262    print(f"True model: y = {true_beta_0} + {true_beta_1}*x + N(0, {sigma_true}²)")
263    print(f"Sample size: {n}")
264
265    # Design matrix
266    X = np.column_stack([np.ones(n), x])
267
268    # Prior for coefficients: N(0, 100*I) - weakly informative
269    prior_mean = np.zeros(2)
270    prior_cov = 100 * np.eye(2)
271
272    # Likelihood precision (assume known for simplicity)
273    tau = 1 / sigma_true**2
274
275    # Posterior (closed form for normal prior, normal likelihood)
276    # Posterior covariance
277    post_cov = np.linalg.inv(np.linalg.inv(prior_cov) + tau * (X.T @ X))
278
279    # Posterior mean
280    post_mean = post_cov @ (np.linalg.inv(prior_cov) @ prior_mean + tau * (X.T @ y))
281
282    print(f"\nPosterior distribution of coefficients:")
283    print(f"  β₀: mean={post_mean[0]:.3f}, std={np.sqrt(post_cov[0,0]):.3f}")
284    print(f"  β₁: mean={post_mean[1]:.3f}, std={np.sqrt(post_cov[1,1]):.3f}")
285    print(f"  Correlation: {post_cov[0,1] / np.sqrt(post_cov[0,0] * post_cov[1,1]):.3f}")
286
287    # Sample from posterior
288    n_posterior_samples = 2000
289    beta_samples = np.random.multivariate_normal(post_mean, post_cov, n_posterior_samples)
290
291    print(f"\nPosterior samples (n={n_posterior_samples}):")
292    print(f"  β₀: mean={np.mean(beta_samples[:,0]):.3f}")
293    print(f"  β₁: mean={np.mean(beta_samples[:,1]):.3f}")
294
295    # Credible intervals
296    beta_0_ci = np.percentile(beta_samples[:,0], [2.5, 97.5])
297    beta_1_ci = np.percentile(beta_samples[:,1], [2.5, 97.5])
298
299    print(f"\n95% Credible intervals:")
300    print(f"  β₀: [{beta_0_ci[0]:.3f}, {beta_0_ci[1]:.3f}]")
301    print(f"  β₁: [{beta_1_ci[0]:.3f}, {beta_1_ci[1]:.3f}]")
302
303    # Prediction with uncertainty
304    x_new = np.linspace(0, 10, 100)
305    X_new = np.column_stack([np.ones(len(x_new)), x_new])
306
307    # Posterior predictive samples
308    y_pred_samples = []
309    for beta in beta_samples[:500]:  # Use subset for speed
310        y_pred_mean = X_new @ beta
311        y_pred = y_pred_mean + np.random.normal(0, sigma_true, len(x_new))
312        y_pred_samples.append(y_pred)
313
314    y_pred_samples = np.array(y_pred_samples)
315    y_pred_mean = np.mean(y_pred_samples, axis=0)
316    y_pred_lower = np.percentile(y_pred_samples, 2.5, axis=0)
317    y_pred_upper = np.percentile(y_pred_samples, 97.5, axis=0)
318
319    if HAS_PLT:
320        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
321
322        # Posterior samples of coefficients
323        axes[0].scatter(beta_samples[:,0], beta_samples[:,1],
324                       alpha=0.3, s=5, label='Posterior samples')
325        axes[0].scatter([true_beta_0], [true_beta_1], color='red',
326                       s=100, marker='*', label='True values', zorder=5)
327        axes[0].scatter([post_mean[0]], [post_mean[1]], color='green',
328                       s=100, marker='x', label='Posterior mean', zorder=5)
329        axes[0].set_xlabel('β₀ (intercept)')
330        axes[0].set_ylabel('β₁ (slope)')
331        axes[0].set_title('Posterior Distribution of Coefficients')
332        axes[0].legend()
333        axes[0].grid(True, alpha=0.3)
334
335        # Predictions
336        axes[1].scatter(x, y, alpha=0.6, s=30, label='Data')
337        axes[1].plot(x_new, y_pred_mean, 'r-', linewidth=2, label='Posterior mean')
338        axes[1].fill_between(x_new, y_pred_lower, y_pred_upper,
339                            alpha=0.3, color='red', label='95% Prediction interval')
340        axes[1].set_xlabel('x')
341        axes[1].set_ylabel('y')
342        axes[1].set_title('Bayesian Linear Regression Predictions')
343        axes[1].legend()
344        axes[1].grid(True, alpha=0.3)
345
346        plt.tight_layout()
347        plt.savefig('/tmp/bayesian_regression.png', dpi=100)
348        print("\n[Plot saved to /tmp/bayesian_regression.png]")
349        plt.close()
350
351
352def convergence_diagnostics():
353    """Demonstrate MCMC convergence diagnostics."""
354    print_section("4. Convergence Diagnostics")
355
356    print("Multiple chains for convergence assessment")
357
358    # Target: N(5, 1.5²)
359    target_mean = 5
360    target_std = 1.5
361
362    def log_target(x):
363        return stats.norm.logpdf(x, target_mean, target_std)
364
365    # Run multiple chains
366    n_chains = 4
367    n_samples = 3000
368    burn_in = 500
369    proposal_std = 2
370
371    chains = []
372
373    print(f"\nRunning {n_chains} chains:")
374    print(f"  Samples per chain: {n_samples}")
375    print(f"  Burn-in: {burn_in}")
376
377    # Different starting points
378    starting_points = [-5, 0, 10, 15]
379
380    for chain_id, start in enumerate(starting_points):
381        current = start
382        samples = []
383
384        for i in range(n_samples):
385            proposed = current + np.random.normal(0, proposal_std)
386            log_ratio = log_target(proposed) - log_target(current)
387
388            if np.random.uniform() < min(1, np.exp(log_ratio)):
389                current = proposed
390
391            samples.append(current)
392
393        chains.append(np.array(samples))
394        print(f"  Chain {chain_id+1}: start={start:5.1f}, mean={np.mean(samples[burn_in:]):.3f}")
395
396    # Gelman-Rubin diagnostic (simple version)
397    chains_after_burnin = [c[burn_in:] for c in chains]
398
399    # Within-chain variance
400    W = np.mean([np.var(c, ddof=1) for c in chains_after_burnin])
401
402    # Between-chain variance
403    chain_means = [np.mean(c) for c in chains_after_burnin]
404    B = np.var(chain_means, ddof=1) * len(chains_after_burnin[0])
405
406    # R-hat
407    var_plus = ((len(chains_after_burnin[0]) - 1) * W + B) / len(chains_after_burnin[0])
408    R_hat = np.sqrt(var_plus / W)
409
410    print(f"\nGelman-Rubin R̂ statistic: {R_hat:.4f}")
411    if R_hat < 1.1:
412        print(f"  Chains have converged (R̂ < 1.1)")
413    else:
414        print(f"  Chains may not have converged (R̂ ≥ 1.1)")
415
416    if HAS_PLT:
417        plt.figure(figsize=(12, 5))
418
419        for i, chain in enumerate(chains):
420            plt.plot(chain, alpha=0.7, label=f'Chain {i+1}')
421
422        plt.axvline(burn_in, color='red', linestyle='--', alpha=0.5, label='Burn-in end')
423        plt.axhline(target_mean, color='black', linestyle='--', alpha=0.5, label='True mean')
424        plt.xlabel('Iteration')
425        plt.ylabel('Value')
426        plt.title(f'Multiple Chains (R̂={R_hat:.3f})')
427        plt.legend()
428        plt.grid(True, alpha=0.3)
429        plt.savefig('/tmp/convergence_chains.png', dpi=100)
430        print("\n[Plot saved to /tmp/convergence_chains.png]")
431        plt.close()
432
433
434def main():
435    """Run all demonstrations."""
436    print("=" * 70)
437    print("  BAYESIAN INFERENCE DEMONSTRATIONS")
438    print("=" * 70)
439
440    np.random.seed(42)
441
442    metropolis_hastings_normal()
443    gibbs_sampling_bivariate()
444    bayesian_linear_regression()
445    convergence_diagnostics()
446
447    print("\n" + "=" * 70)
448    print("  All demonstrations completed successfully!")
449    print("=" * 70)
450
451
452if __name__ == "__main__":
453    main()