05_bayesian_basics.py

Download
python 353 lines 11.5 KB
  1"""
  205_bayesian_basics.py
  3
  4Demonstrates basic Bayesian inference concepts:
  5- Bayes' theorem examples
  6- Conjugate priors (Beta-Binomial, Normal-Normal)
  7- Posterior computation
  8- Credible intervals
  9- Comparison with frequentist approach
 10"""
 11
 12import numpy as np
 13from scipy import stats
 14
 15try:
 16    import matplotlib.pyplot as plt
 17    HAS_PLT = True
 18except ImportError:
 19    HAS_PLT = False
 20    print("matplotlib not available; skipping plots\n")
 21
 22
 23def print_section(title):
 24    """Print formatted section header."""
 25    print("\n" + "=" * 70)
 26    print(f"  {title}")
 27    print("=" * 70)
 28
 29
 30def bayes_theorem_discrete():
 31    """Demonstrate Bayes' theorem with discrete example."""
 32    print_section("1. Bayes' Theorem - Discrete Example")
 33
 34    # Medical test example
 35    print("Medical diagnostic test scenario:")
 36    print("  Disease prevalence: 1%")
 37    print("  Test sensitivity (P(+|Disease)): 95%")
 38    print("  Test specificity (P(-|No disease)): 90%")
 39
 40    # Prior
 41    p_disease = 0.01
 42    p_no_disease = 0.99
 43
 44    # Likelihood
 45    p_pos_given_disease = 0.95
 46    p_pos_given_no_disease = 0.10
 47
 48    # Marginal probability of positive test
 49    p_pos = (p_pos_given_disease * p_disease +
 50             p_pos_given_no_disease * p_no_disease)
 51
 52    # Posterior using Bayes' theorem
 53    p_disease_given_pos = (p_pos_given_disease * p_disease) / p_pos
 54
 55    print(f"\nQuestion: If test is positive, what's P(Disease|+)?")
 56    print(f"\nCalculation:")
 57    print(f"  P(+) = P(+|D)P(D) + P(+|¬D)P(¬D)")
 58    print(f"       = {p_pos_given_disease}×{p_disease} + {p_pos_given_no_disease}×{p_no_disease}")
 59    print(f"       = {p_pos:.4f}")
 60    print(f"\n  P(D|+) = P(+|D)P(D) / P(+)")
 61    print(f"         = ({p_pos_given_disease}×{p_disease}) / {p_pos:.4f}")
 62    print(f"         = {p_disease_given_pos:.4f}")
 63
 64    print(f"\nResult: Only {p_disease_given_pos*100:.2f}% chance of disease despite positive test!")
 65    print(f"Reason: Low prevalence (strong prior)")
 66
 67
 68def beta_binomial_conjugate():
 69    """Demonstrate Beta-Binomial conjugate prior."""
 70    print_section("2. Beta-Binomial Conjugate Prior")
 71
 72    print("Estimating coin bias θ (probability of heads)")
 73
 74    # Prior: Beta(α, β)
 75    alpha_prior = 2
 76    beta_prior = 2
 77    print(f"\nPrior: Beta({alpha_prior}, {beta_prior})")
 78    print(f"  Prior mean: {alpha_prior/(alpha_prior + beta_prior):.3f}")
 79    print(f"  Prior expresses weak belief in fairness")
 80
 81    # Data: observed coin flips
 82    n_heads = 7
 83    n_tails = 3
 84    n_total = n_heads + n_tails
 85
 86    print(f"\nObserved data: {n_heads} heads, {n_tails} tails in {n_total} flips")
 87
 88    # Posterior: Beta(α + n_heads, β + n_tails)
 89    alpha_post = alpha_prior + n_heads
 90    beta_post = beta_prior + n_tails
 91
 92    print(f"\nPosterior: Beta({alpha_post}, {beta_post})")
 93    post_mean = alpha_post / (alpha_post + beta_post)
 94    post_var = (alpha_post * beta_post) / ((alpha_post + beta_post)**2 * (alpha_post + beta_post + 1))
 95
 96    print(f"  Posterior mean: {post_mean:.3f}")
 97    print(f"  Posterior std: {np.sqrt(post_var):.3f}")
 98
 99    # Credible interval
100    credible_interval = stats.beta.interval(0.95, alpha_post, beta_post)
101    print(f"  95% Credible Interval: [{credible_interval[0]:.3f}, {credible_interval[1]:.3f}]")
102
103    # MLE for comparison
104    mle = n_heads / n_total
105    print(f"\nFrequentist MLE: {mle:.3f}")
106    print(f"Bayesian posterior mean: {post_mean:.3f}")
107    print(f"  (Bayesian estimate pulled toward prior)")
108
109    if HAS_PLT:
110        theta = np.linspace(0, 1, 200)
111        prior_pdf = stats.beta.pdf(theta, alpha_prior, beta_prior)
112        likelihood = stats.beta.pdf(theta, n_heads + 1, n_tails + 1)  # proportional
113        posterior_pdf = stats.beta.pdf(theta, alpha_post, beta_post)
114
115        plt.figure(figsize=(10, 6))
116        plt.plot(theta, prior_pdf, 'b--', label=f'Prior: Beta({alpha_prior},{beta_prior})', linewidth=2)
117        plt.plot(theta, likelihood / np.max(likelihood) * np.max(posterior_pdf),
118                'g:', label='Likelihood (scaled)', linewidth=2)
119        plt.plot(theta, posterior_pdf, 'r-', label=f'Posterior: Beta({alpha_post},{beta_post})', linewidth=2)
120        plt.axvline(post_mean, color='red', linestyle='--', alpha=0.5, label=f'Posterior mean={post_mean:.3f}')
121        plt.axvline(mle, color='orange', linestyle='--', alpha=0.5, label=f'MLE={mle:.3f}')
122        plt.xlabel('θ (probability of heads)')
123        plt.ylabel('Density')
124        plt.title('Beta-Binomial Conjugate Prior')
125        plt.legend()
126        plt.grid(True, alpha=0.3)
127        plt.savefig('/tmp/beta_binomial.png', dpi=100)
128        print("\n[Plot saved to /tmp/beta_binomial.png]")
129        plt.close()
130
131
132def normal_normal_conjugate():
133    """Demonstrate Normal-Normal conjugate prior."""
134    print_section("3. Normal-Normal Conjugate Prior")
135
136    print("Estimating mean μ of normal distribution (known variance)")
137
138    # True data generation
139    np.random.seed(42)
140    true_mean = 100
141    known_sigma = 15
142    n = 20
143    data = np.random.normal(true_mean, known_sigma, n)
144
145    print(f"\nKnown: σ = {known_sigma}")
146    print(f"Data: n = {n}, sample mean = {np.mean(data):.2f}")
147
148    # Prior: N(μ₀, σ₀²)
149    mu_0 = 110  # Prior belief
150    sigma_0 = 20
151
152    print(f"\nPrior: N({mu_0}, {sigma_0}²)")
153
154    # Posterior: N(μₙ, σₙ²)
155    # Precision formulation
156    tau_0 = 1 / sigma_0**2  # Prior precision
157    tau_likelihood = n / known_sigma**2  # Likelihood precision
158
159    tau_post = tau_0 + tau_likelihood
160    mu_post = (tau_0 * mu_0 + tau_likelihood * np.mean(data)) / tau_post
161    sigma_post = np.sqrt(1 / tau_post)
162
163    print(f"\nPosterior: N({mu_post:.2f}, {sigma_post:.2f}²)")
164    print(f"  Posterior mean: {mu_post:.2f}")
165    print(f"  Posterior std: {sigma_post:.2f}")
166
167    # Credible interval
168    ci = stats.norm.interval(0.95, mu_post, sigma_post)
169    print(f"  95% Credible Interval: [{ci[0]:.2f}, {ci[1]:.2f}]")
170
171    # Compare with frequentist
172    se = known_sigma / np.sqrt(n)
173    freq_ci = stats.norm.interval(0.95, np.mean(data), se)
174
175    print(f"\nFrequentist 95% CI: [{freq_ci[0]:.2f}, {freq_ci[1]:.2f}]")
176
177    print(f"\nPrior influence:")
178    print(f"  Prior mean: {mu_0}")
179    print(f"  Sample mean: {np.mean(data):.2f}")
180    print(f"  Posterior mean: {mu_post:.2f} (weighted average)")
181
182    if HAS_PLT:
183        x = np.linspace(70, 130, 300)
184        prior_pdf = stats.norm.pdf(x, mu_0, sigma_0)
185        likelihood_pdf = stats.norm.pdf(x, np.mean(data), known_sigma / np.sqrt(n))
186        posterior_pdf = stats.norm.pdf(x, mu_post, sigma_post)
187
188        plt.figure(figsize=(10, 6))
189        plt.plot(x, prior_pdf, 'b--', label=f'Prior: N({mu_0},{sigma_0}²)', linewidth=2)
190        plt.plot(x, likelihood_pdf, 'g:', label='Likelihood', linewidth=2)
191        plt.plot(x, posterior_pdf, 'r-', label=f'Posterior: N({mu_post:.1f},{sigma_post:.1f}²)', linewidth=2)
192        plt.axvline(mu_post, color='red', linestyle='--', alpha=0.5)
193        plt.xlabel('μ')
194        plt.ylabel('Density')
195        plt.title('Normal-Normal Conjugate Prior')
196        plt.legend()
197        plt.grid(True, alpha=0.3)
198        plt.savefig('/tmp/normal_normal.png', dpi=100)
199        print("\n[Plot saved to /tmp/normal_normal.png]")
200        plt.close()
201
202
203def prior_influence():
204    """Demonstrate how prior strength affects posterior."""
205    print_section("4. Prior Influence on Posterior")
206
207    print("Comparing weak vs strong priors")
208
209    # Fixed data
210    np.random.seed(123)
211    n = 10
212    data = np.random.binomial(1, 0.7, n)
213    n_successes = np.sum(data)
214    n_failures = n - n_successes
215
216    print(f"\nData: {n_successes} successes in {n} trials")
217
218    # Different priors
219    priors = [
220        ("Weak (uninformative)", 1, 1),
221        ("Moderate", 5, 5),
222        ("Strong (favor 0.5)", 20, 20)
223    ]
224
225    print(f"\nPosterior means with different priors:")
226
227    for name, alpha_0, beta_0 in priors:
228        alpha_post = alpha_0 + n_successes
229        beta_post = beta_0 + n_failures
230        post_mean = alpha_post / (alpha_post + beta_post)
231
232        print(f"\n  {name}: Beta({alpha_0}, {beta_0})")
233        print(f"    Posterior: Beta({alpha_post}, {beta_post})")
234        print(f"    Posterior mean: {post_mean:.3f}")
235
236    mle = n_successes / n
237    print(f"\nMLE (no prior): {mle:.3f}")
238    print(f"\nWith more data, all posteriors converge to MLE")
239
240
241def credible_vs_confidence():
242    """Compare Bayesian credible intervals with frequentist confidence intervals."""
243    print_section("5. Credible vs Confidence Intervals")
244
245    print("Interpretation differences:\n")
246
247    print("Frequentist Confidence Interval:")
248    print("  'If we repeat the experiment many times,")
249    print("   95% of computed intervals will contain the true parameter'")
250    print("  Parameter is FIXED, interval is RANDOM")
251
252    print("\nBayesian Credible Interval:")
253    print("  'The parameter has 95% probability of being in this interval")
254    print("   given the observed data'")
255    print("  Parameter is RANDOM, interval is FIXED (given data)")
256
257    # Example with coin flips
258    np.random.seed(456)
259    n = 50
260    true_theta = 0.6
261    data = np.random.binomial(1, true_theta, n)
262    n_heads = np.sum(data)
263
264    # Frequentist CI
265    p_hat = n_heads / n
266    se = np.sqrt(p_hat * (1 - p_hat) / n)
267    freq_ci = stats.norm.interval(0.95, p_hat, se)
268
269    # Bayesian CI (uniform prior)
270    alpha_post = 1 + n_heads
271    beta_post = 1 + (n - n_heads)
272    bayes_ci = stats.beta.interval(0.95, alpha_post, beta_post)
273
274    print(f"\n\nExample: {n_heads}/{n} heads observed")
275    print(f"True θ = {true_theta}")
276
277    print(f"\nFrequentist 95% CI: [{freq_ci[0]:.3f}, {freq_ci[1]:.3f}]")
278    print(f"  Cannot say 'P(θ in interval) = 0.95'")
279
280    print(f"\nBayesian 95% CI: [{bayes_ci[0]:.3f}, {bayes_ci[1]:.3f}]")
281    print(f"  Can say 'P(θ in [{bayes_ci[0]:.3f}, {bayes_ci[1]:.3f}] | data) = 0.95'")
282
283    print(f"\nBoth intervals contain true value: {freq_ci[0] <= true_theta <= freq_ci[1]}")
284
285
286def sequential_updating():
287    """Demonstrate sequential Bayesian updating."""
288    print_section("6. Sequential Bayesian Updating")
289
290    print("Updating beliefs as data arrives sequentially")
291
292    # Start with prior
293    alpha = 2
294    beta = 2
295    print(f"\nInitial prior: Beta({alpha}, {beta})")
296
297    # Sequential data
298    np.random.seed(789)
299    true_p = 0.7
300    batch_sizes = [5, 10, 20, 50]
301
302    current_alpha = alpha
303    current_beta = beta
304
305    print(f"\nTrue probability: {true_p}")
306    print(f"\nSequential updates:")
307
308    for batch_size in batch_sizes:
309        # New data
310        data = np.random.binomial(1, true_p, batch_size)
311        n_success = np.sum(data)
312        n_fail = batch_size - n_success
313
314        # Update
315        current_alpha += n_success
316        current_beta += n_fail
317
318        post_mean = current_alpha / (current_alpha + current_beta)
319        post_std = np.sqrt((current_alpha * current_beta) /
320                          ((current_alpha + current_beta)**2 * (current_alpha + current_beta + 1)))
321
322        print(f"\n  After {batch_size} more observations:")
323        print(f"    New data: {n_success}/{batch_size} successes")
324        print(f"    Posterior: Beta({current_alpha}, {current_beta})")
325        print(f"    Mean: {post_mean:.4f}, Std: {post_std:.4f}")
326        print(f"    Distance from truth: {abs(post_mean - true_p):.4f}")
327
328    print(f"\n  Posterior converges to truth as data accumulates")
329
330
331def main():
332    """Run all demonstrations."""
333    print("=" * 70)
334    print("  BAYESIAN BASICS DEMONSTRATIONS")
335    print("=" * 70)
336
337    np.random.seed(42)
338
339    bayes_theorem_discrete()
340    beta_binomial_conjugate()
341    normal_normal_conjugate()
342    prior_influence()
343    credible_vs_confidence()
344    sequential_updating()
345
346    print("\n" + "=" * 70)
347    print("  All demonstrations completed successfully!")
348    print("=" * 70)
349
350
351if __name__ == "__main__":
352    main()