1"""
206_bayesian_inference.py
3
4Demonstrates advanced Bayesian inference methods:
5- MCMC basics (Metropolis-Hastings)
6- Gibbs sampling simple example
7- Posterior sampling
8- Convergence diagnostics (trace plots)
9- Bayesian linear regression
10"""
11
12import numpy as np
13from scipy import stats
14
15try:
16 import matplotlib.pyplot as plt
17 HAS_PLT = True
18except ImportError:
19 HAS_PLT = False
20 print("matplotlib not available; skipping plots\n")
21
22
23def print_section(title):
24 """Print formatted section header."""
25 print("\n" + "=" * 70)
26 print(f" {title}")
27 print("=" * 70)
28
29
30def metropolis_hastings_normal():
31 """Demonstrate Metropolis-Hastings for sampling from normal distribution."""
32 print_section("1. Metropolis-Hastings Algorithm")
33
34 print("Sampling from N(10, 2²) using Metropolis-Hastings")
35
36 # Target distribution
37 target_mean = 10
38 target_std = 2
39
40 def log_target(x):
41 """Log of target distribution."""
42 return stats.norm.logpdf(x, target_mean, target_std)
43
44 # MCMC parameters
45 n_samples = 10000
46 proposal_std = 3
47 burn_in = 1000
48
49 # Initialize
50 current = 0 # Starting point
51 samples = []
52 accepted = 0
53
54 print(f"\nMCMC parameters:")
55 print(f" Iterations: {n_samples}")
56 print(f" Burn-in: {burn_in}")
57 print(f" Proposal std: {proposal_std}")
58
59 # Run Metropolis-Hastings
60 for i in range(n_samples):
61 # Propose new state
62 proposed = current + np.random.normal(0, proposal_std)
63
64 # Acceptance ratio
65 log_ratio = log_target(proposed) - log_target(current)
66 accept_prob = min(1, np.exp(log_ratio))
67
68 # Accept or reject
69 if np.random.uniform() < accept_prob:
70 current = proposed
71 accepted += 1
72
73 samples.append(current)
74
75 samples = np.array(samples)
76 samples_after_burnin = samples[burn_in:]
77
78 print(f"\nResults:")
79 print(f" Acceptance rate: {accepted/n_samples:.3f}")
80 print(f" Sample mean: {np.mean(samples_after_burnin):.3f} (true: {target_mean})")
81 print(f" Sample std: {np.std(samples_after_burnin, ddof=1):.3f} (true: {target_std})")
82
83 # Effective sample size (simple autocorrelation-based estimate)
84 autocorr_lag1 = np.corrcoef(samples_after_burnin[:-1], samples_after_burnin[1:])[0, 1]
85 ess_approx = len(samples_after_burnin) * (1 - autocorr_lag1) / (1 + autocorr_lag1)
86
87 print(f" Autocorrelation (lag 1): {autocorr_lag1:.3f}")
88 print(f" Approx. effective sample size: {ess_approx:.0f}")
89
90 if HAS_PLT:
91 fig, axes = plt.subplots(2, 2, figsize=(12, 10))
92
93 # Trace plot
94 axes[0, 0].plot(samples, alpha=0.7, linewidth=0.5)
95 axes[0, 0].axvline(burn_in, color='red', linestyle='--', label='Burn-in')
96 axes[0, 0].axhline(target_mean, color='green', linestyle='--', label='True mean')
97 axes[0, 0].set_xlabel('Iteration')
98 axes[0, 0].set_ylabel('Value')
99 axes[0, 0].set_title('Trace Plot')
100 axes[0, 0].legend()
101 axes[0, 0].grid(True, alpha=0.3)
102
103 # Histogram
104 axes[0, 1].hist(samples_after_burnin, bins=50, density=True,
105 alpha=0.7, edgecolor='black', label='MCMC samples')
106 x = np.linspace(target_mean - 4*target_std, target_mean + 4*target_std, 200)
107 axes[0, 1].plot(x, stats.norm.pdf(x, target_mean, target_std),
108 'r-', linewidth=2, label='True distribution')
109 axes[0, 1].set_xlabel('Value')
110 axes[0, 1].set_ylabel('Density')
111 axes[0, 1].set_title('Posterior Distribution')
112 axes[0, 1].legend()
113 axes[0, 1].grid(True, alpha=0.3)
114
115 # Running mean
116 running_mean = np.cumsum(samples) / np.arange(1, len(samples) + 1)
117 axes[1, 0].plot(running_mean, alpha=0.7)
118 axes[1, 0].axhline(target_mean, color='red', linestyle='--', label='True mean')
119 axes[1, 0].axvline(burn_in, color='orange', linestyle='--', alpha=0.5)
120 axes[1, 0].set_xlabel('Iteration')
121 axes[1, 0].set_ylabel('Running Mean')
122 axes[1, 0].set_title('Convergence of Mean')
123 axes[1, 0].legend()
124 axes[1, 0].grid(True, alpha=0.3)
125
126 # Autocorrelation
127 max_lag = 50
128 autocorr = [np.corrcoef(samples_after_burnin[:-lag], samples_after_burnin[lag:])[0, 1]
129 if lag > 0 else 1.0 for lag in range(max_lag)]
130 axes[1, 1].bar(range(max_lag), autocorr, alpha=0.7)
131 axes[1, 1].set_xlabel('Lag')
132 axes[1, 1].set_ylabel('Autocorrelation')
133 axes[1, 1].set_title('Autocorrelation Function')
134 axes[1, 1].grid(True, alpha=0.3)
135
136 plt.tight_layout()
137 plt.savefig('/tmp/mcmc_metropolis.png', dpi=100)
138 print("\n[Plot saved to /tmp/mcmc_metropolis.png]")
139 plt.close()
140
141
142def gibbs_sampling_bivariate():
143 """Demonstrate Gibbs sampling for bivariate normal."""
144 print_section("2. Gibbs Sampling")
145
146 print("Sampling from bivariate normal using Gibbs sampling")
147
148 # Target: bivariate normal with correlation
149 mu = np.array([3, 5])
150 rho = 0.7
151 sigma_x = 1.5
152 sigma_y = 2.0
153 cov_matrix = np.array([
154 [sigma_x**2, rho * sigma_x * sigma_y],
155 [rho * sigma_x * sigma_y, sigma_y**2]
156 ])
157
158 print(f"\nTarget distribution:")
159 print(f" μ = {mu}")
160 print(f" ρ = {rho}")
161 print(f" σ_x = {sigma_x}, σ_y = {sigma_y}")
162
163 # Conditional distributions
164 def sample_x_given_y(y):
165 """Sample x | y from conditional normal."""
166 mu_cond = mu[0] + rho * (sigma_x / sigma_y) * (y - mu[1])
167 sigma_cond = sigma_x * np.sqrt(1 - rho**2)
168 return np.random.normal(mu_cond, sigma_cond)
169
170 def sample_y_given_x(x):
171 """Sample y | x from conditional normal."""
172 mu_cond = mu[1] + rho * (sigma_y / sigma_x) * (x - mu[0])
173 sigma_cond = sigma_y * np.sqrt(1 - rho**2)
174 return np.random.normal(mu_cond, sigma_cond)
175
176 # Gibbs sampling
177 n_samples = 5000
178 burn_in = 500
179
180 x_samples = np.zeros(n_samples)
181 y_samples = np.zeros(n_samples)
182
183 # Initialize
184 x_samples[0] = 0
185 y_samples[0] = 0
186
187 print(f"\nGibbs sampling:")
188 print(f" Iterations: {n_samples}")
189 print(f" Burn-in: {burn_in}")
190
191 # Run Gibbs
192 for i in range(1, n_samples):
193 x_samples[i] = sample_x_given_y(y_samples[i-1])
194 y_samples[i] = sample_y_given_x(x_samples[i])
195
196 # Remove burn-in
197 x_final = x_samples[burn_in:]
198 y_final = y_samples[burn_in:]
199
200 print(f"\nResults (after burn-in):")
201 print(f" Mean x: {np.mean(x_final):.3f} (true: {mu[0]})")
202 print(f" Mean y: {np.mean(y_final):.3f} (true: {mu[1]})")
203 print(f" Std x: {np.std(x_final, ddof=1):.3f} (true: {sigma_x})")
204 print(f" Std y: {np.std(y_final, ddof=1):.3f} (true: {sigma_y})")
205 print(f" Correlation: {np.corrcoef(x_final, y_final)[0,1]:.3f} (true: {rho})")
206
207 if HAS_PLT:
208 fig, axes = plt.subplots(1, 2, figsize=(12, 5))
209
210 # Trace plots
211 axes[0].plot(x_samples, alpha=0.5, label='x')
212 axes[0].plot(y_samples, alpha=0.5, label='y')
213 axes[0].axvline(burn_in, color='red', linestyle='--', alpha=0.5)
214 axes[0].axhline(mu[0], color='blue', linestyle='--', alpha=0.3)
215 axes[0].axhline(mu[1], color='orange', linestyle='--', alpha=0.3)
216 axes[0].set_xlabel('Iteration')
217 axes[0].set_ylabel('Value')
218 axes[0].set_title('Trace Plots')
219 axes[0].legend()
220 axes[0].grid(True, alpha=0.3)
221
222 # Joint distribution
223 axes[1].scatter(x_final, y_final, alpha=0.3, s=5)
224 axes[1].scatter([mu[0]], [mu[1]], color='red', s=100, marker='x',
225 label='True mean', zorder=5)
226
227 # Add ellipse for true distribution
228 from matplotlib.patches import Ellipse
229 eigenvalues, eigenvectors = np.linalg.eig(cov_matrix)
230 angle = np.degrees(np.arctan2(eigenvectors[1, 0], eigenvectors[0, 0]))
231 width, height = 2 * np.sqrt(5.991 * eigenvalues) # 95% confidence ellipse
232 ellipse = Ellipse(mu, width, height, angle=angle,
233 facecolor='none', edgecolor='red', linewidth=2, label='95% ellipse')
234 axes[1].add_patch(ellipse)
235
236 axes[1].set_xlabel('x')
237 axes[1].set_ylabel('y')
238 axes[1].set_title('Joint Distribution')
239 axes[1].legend()
240 axes[1].grid(True, alpha=0.3)
241 axes[1].axis('equal')
242
243 plt.tight_layout()
244 plt.savefig('/tmp/gibbs_sampling.png', dpi=100)
245 print("\n[Plot saved to /tmp/gibbs_sampling.png]")
246 plt.close()
247
248
249def bayesian_linear_regression():
250 """Demonstrate Bayesian linear regression."""
251 print_section("3. Bayesian Linear Regression")
252
253 # Generate data
254 np.random.seed(42)
255 n = 50
256 x = np.random.uniform(0, 10, n)
257 true_beta_0 = 5
258 true_beta_1 = 2
259 sigma_true = 2
260 y = true_beta_0 + true_beta_1 * x + np.random.normal(0, sigma_true, n)
261
262 print(f"True model: y = {true_beta_0} + {true_beta_1}*x + N(0, {sigma_true}²)")
263 print(f"Sample size: {n}")
264
265 # Design matrix
266 X = np.column_stack([np.ones(n), x])
267
268 # Prior for coefficients: N(0, 100*I) - weakly informative
269 prior_mean = np.zeros(2)
270 prior_cov = 100 * np.eye(2)
271
272 # Likelihood precision (assume known for simplicity)
273 tau = 1 / sigma_true**2
274
275 # Posterior (closed form for normal prior, normal likelihood)
276 # Posterior covariance
277 post_cov = np.linalg.inv(np.linalg.inv(prior_cov) + tau * (X.T @ X))
278
279 # Posterior mean
280 post_mean = post_cov @ (np.linalg.inv(prior_cov) @ prior_mean + tau * (X.T @ y))
281
282 print(f"\nPosterior distribution of coefficients:")
283 print(f" β₀: mean={post_mean[0]:.3f}, std={np.sqrt(post_cov[0,0]):.3f}")
284 print(f" β₁: mean={post_mean[1]:.3f}, std={np.sqrt(post_cov[1,1]):.3f}")
285 print(f" Correlation: {post_cov[0,1] / np.sqrt(post_cov[0,0] * post_cov[1,1]):.3f}")
286
287 # Sample from posterior
288 n_posterior_samples = 2000
289 beta_samples = np.random.multivariate_normal(post_mean, post_cov, n_posterior_samples)
290
291 print(f"\nPosterior samples (n={n_posterior_samples}):")
292 print(f" β₀: mean={np.mean(beta_samples[:,0]):.3f}")
293 print(f" β₁: mean={np.mean(beta_samples[:,1]):.3f}")
294
295 # Credible intervals
296 beta_0_ci = np.percentile(beta_samples[:,0], [2.5, 97.5])
297 beta_1_ci = np.percentile(beta_samples[:,1], [2.5, 97.5])
298
299 print(f"\n95% Credible intervals:")
300 print(f" β₀: [{beta_0_ci[0]:.3f}, {beta_0_ci[1]:.3f}]")
301 print(f" β₁: [{beta_1_ci[0]:.3f}, {beta_1_ci[1]:.3f}]")
302
303 # Prediction with uncertainty
304 x_new = np.linspace(0, 10, 100)
305 X_new = np.column_stack([np.ones(len(x_new)), x_new])
306
307 # Posterior predictive samples
308 y_pred_samples = []
309 for beta in beta_samples[:500]: # Use subset for speed
310 y_pred_mean = X_new @ beta
311 y_pred = y_pred_mean + np.random.normal(0, sigma_true, len(x_new))
312 y_pred_samples.append(y_pred)
313
314 y_pred_samples = np.array(y_pred_samples)
315 y_pred_mean = np.mean(y_pred_samples, axis=0)
316 y_pred_lower = np.percentile(y_pred_samples, 2.5, axis=0)
317 y_pred_upper = np.percentile(y_pred_samples, 97.5, axis=0)
318
319 if HAS_PLT:
320 fig, axes = plt.subplots(1, 2, figsize=(12, 5))
321
322 # Posterior samples of coefficients
323 axes[0].scatter(beta_samples[:,0], beta_samples[:,1],
324 alpha=0.3, s=5, label='Posterior samples')
325 axes[0].scatter([true_beta_0], [true_beta_1], color='red',
326 s=100, marker='*', label='True values', zorder=5)
327 axes[0].scatter([post_mean[0]], [post_mean[1]], color='green',
328 s=100, marker='x', label='Posterior mean', zorder=5)
329 axes[0].set_xlabel('β₀ (intercept)')
330 axes[0].set_ylabel('β₁ (slope)')
331 axes[0].set_title('Posterior Distribution of Coefficients')
332 axes[0].legend()
333 axes[0].grid(True, alpha=0.3)
334
335 # Predictions
336 axes[1].scatter(x, y, alpha=0.6, s=30, label='Data')
337 axes[1].plot(x_new, y_pred_mean, 'r-', linewidth=2, label='Posterior mean')
338 axes[1].fill_between(x_new, y_pred_lower, y_pred_upper,
339 alpha=0.3, color='red', label='95% Prediction interval')
340 axes[1].set_xlabel('x')
341 axes[1].set_ylabel('y')
342 axes[1].set_title('Bayesian Linear Regression Predictions')
343 axes[1].legend()
344 axes[1].grid(True, alpha=0.3)
345
346 plt.tight_layout()
347 plt.savefig('/tmp/bayesian_regression.png', dpi=100)
348 print("\n[Plot saved to /tmp/bayesian_regression.png]")
349 plt.close()
350
351
352def convergence_diagnostics():
353 """Demonstrate MCMC convergence diagnostics."""
354 print_section("4. Convergence Diagnostics")
355
356 print("Multiple chains for convergence assessment")
357
358 # Target: N(5, 1.5²)
359 target_mean = 5
360 target_std = 1.5
361
362 def log_target(x):
363 return stats.norm.logpdf(x, target_mean, target_std)
364
365 # Run multiple chains
366 n_chains = 4
367 n_samples = 3000
368 burn_in = 500
369 proposal_std = 2
370
371 chains = []
372
373 print(f"\nRunning {n_chains} chains:")
374 print(f" Samples per chain: {n_samples}")
375 print(f" Burn-in: {burn_in}")
376
377 # Different starting points
378 starting_points = [-5, 0, 10, 15]
379
380 for chain_id, start in enumerate(starting_points):
381 current = start
382 samples = []
383
384 for i in range(n_samples):
385 proposed = current + np.random.normal(0, proposal_std)
386 log_ratio = log_target(proposed) - log_target(current)
387
388 if np.random.uniform() < min(1, np.exp(log_ratio)):
389 current = proposed
390
391 samples.append(current)
392
393 chains.append(np.array(samples))
394 print(f" Chain {chain_id+1}: start={start:5.1f}, mean={np.mean(samples[burn_in:]):.3f}")
395
396 # Gelman-Rubin diagnostic (simple version)
397 chains_after_burnin = [c[burn_in:] for c in chains]
398
399 # Within-chain variance
400 W = np.mean([np.var(c, ddof=1) for c in chains_after_burnin])
401
402 # Between-chain variance
403 chain_means = [np.mean(c) for c in chains_after_burnin]
404 B = np.var(chain_means, ddof=1) * len(chains_after_burnin[0])
405
406 # R-hat
407 var_plus = ((len(chains_after_burnin[0]) - 1) * W + B) / len(chains_after_burnin[0])
408 R_hat = np.sqrt(var_plus / W)
409
410 print(f"\nGelman-Rubin R̂ statistic: {R_hat:.4f}")
411 if R_hat < 1.1:
412 print(f" Chains have converged (R̂ < 1.1)")
413 else:
414 print(f" Chains may not have converged (R̂ ≥ 1.1)")
415
416 if HAS_PLT:
417 plt.figure(figsize=(12, 5))
418
419 for i, chain in enumerate(chains):
420 plt.plot(chain, alpha=0.7, label=f'Chain {i+1}')
421
422 plt.axvline(burn_in, color='red', linestyle='--', alpha=0.5, label='Burn-in end')
423 plt.axhline(target_mean, color='black', linestyle='--', alpha=0.5, label='True mean')
424 plt.xlabel('Iteration')
425 plt.ylabel('Value')
426 plt.title(f'Multiple Chains (R̂={R_hat:.3f})')
427 plt.legend()
428 plt.grid(True, alpha=0.3)
429 plt.savefig('/tmp/convergence_chains.png', dpi=100)
430 print("\n[Plot saved to /tmp/convergence_chains.png]")
431 plt.close()
432
433
434def main():
435 """Run all demonstrations."""
436 print("=" * 70)
437 print(" BAYESIAN INFERENCE DEMONSTRATIONS")
438 print("=" * 70)
439
440 np.random.seed(42)
441
442 metropolis_hastings_normal()
443 gibbs_sampling_bivariate()
444 bayesian_linear_regression()
445 convergence_diagnostics()
446
447 print("\n" + "=" * 70)
448 print(" All demonstrations completed successfully!")
449 print("=" * 70)
450
451
452if __name__ == "__main__":
453 main()