07_probability_distributions.py

Download
python 368 lines 12.5 KB
  1"""
  2Probability Distributions and Statistical Inference
  3
  4This script demonstrates:
  51. Common probability distributions (Gaussian, Bernoulli, Poisson, Exponential)
  62. Maximum Likelihood Estimation (MLE) for Gaussian parameters
  73. Maximum A Posteriori (MAP) estimation with Gaussian prior
  84. Bayesian update visualization
  9
 10Author: Math for AI Examples
 11"""
 12
 13import numpy as np
 14import matplotlib.pyplot as plt
 15from scipy import stats
 16from typing import Tuple
 17
 18
 19def plot_common_distributions():
 20    """Visualize common probability distributions used in ML."""
 21    print("\n" + "="*60)
 22    print("1. Common Probability Distributions")
 23    print("="*60)
 24
 25    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
 26
 27    # 1. Gaussian (Normal) Distribution
 28    ax = axes[0, 0]
 29    x = np.linspace(-5, 5, 1000)
 30
 31    for mu, sigma in [(0, 0.5), (0, 1), (0, 2), (2, 1)]:
 32        pdf = stats.norm.pdf(x, mu, sigma)
 33        ax.plot(x, pdf, linewidth=2, label=f'μ={mu}, σ={sigma}')
 34
 35    ax.set_xlabel('x', fontsize=11)
 36    ax.set_ylabel('Probability Density', fontsize=11)
 37    ax.set_title('Gaussian Distribution', fontsize=13, fontweight='bold')
 38    ax.legend(fontsize=9)
 39    ax.grid(True, alpha=0.3)
 40
 41    print("Gaussian: f(x|μ,σ²) = (1/√(2πσ²)) exp(-(x-μ)²/(2σ²))")
 42    print("  Use: Continuous data, errors, neural network activations")
 43
 44    # 2. Bernoulli Distribution
 45    ax = axes[0, 1]
 46    x = np.array([0, 1])
 47
 48    for p in [0.2, 0.5, 0.8]:
 49        pmf = np.array([1-p, p])
 50        ax.bar(x + (p-0.5)*0.15, pmf, width=0.12, alpha=0.7, label=f'p={p}')
 51
 52    ax.set_xlabel('x', fontsize=11)
 53    ax.set_ylabel('Probability Mass', fontsize=11)
 54    ax.set_title('Bernoulli Distribution', fontsize=13, fontweight='bold')
 55    ax.set_xticks([0, 1])
 56    ax.legend(fontsize=9)
 57    ax.grid(True, alpha=0.3, axis='y')
 58
 59    print("\nBernoulli: P(X=1) = p, P(X=0) = 1-p")
 60    print("  Use: Binary classification, coin flips")
 61
 62    # 3. Poisson Distribution
 63    ax = axes[1, 0]
 64    x = np.arange(0, 20)
 65
 66    for lambda_ in [1, 4, 10]:
 67        pmf = stats.poisson.pmf(x, lambda_)
 68        ax.plot(x, pmf, 'o-', linewidth=2, markersize=5, label=f'λ={lambda_}')
 69
 70    ax.set_xlabel('k (count)', fontsize=11)
 71    ax.set_ylabel('Probability Mass', fontsize=11)
 72    ax.set_title('Poisson Distribution', fontsize=13, fontweight='bold')
 73    ax.legend(fontsize=9)
 74    ax.grid(True, alpha=0.3)
 75
 76    print("\nPoisson: P(X=k) = (λ^k * e^(-λ)) / k!")
 77    print("  Use: Count data, rare events, arrivals per time period")
 78
 79    # 4. Exponential Distribution
 80    ax = axes[1, 1]
 81    x = np.linspace(0, 5, 1000)
 82
 83    for lambda_ in [0.5, 1, 2]:
 84        pdf = stats.expon.pdf(x, scale=1/lambda_)
 85        ax.plot(x, pdf, linewidth=2, label=f'λ={lambda_}')
 86
 87    ax.set_xlabel('x', fontsize=11)
 88    ax.set_ylabel('Probability Density', fontsize=11)
 89    ax.set_title('Exponential Distribution', fontsize=13, fontweight='bold')
 90    ax.legend(fontsize=9)
 91    ax.grid(True, alpha=0.3)
 92
 93    print("\nExponential: f(x|λ) = λ * e^(-λx) for x >= 0")
 94    print("  Use: Time between events, survival analysis, waiting times")
 95
 96    plt.tight_layout()
 97    plt.savefig('probability_distributions.png', dpi=150, bbox_inches='tight')
 98    print("\nSaved distributions plot to 'probability_distributions.png'")
 99
100
101def mle_gaussian_demo():
102    """
103    Maximum Likelihood Estimation for Gaussian parameters.
104
105    Given data X = {x₁, ..., xₙ} from N(μ, σ²), find MLE for μ and σ².
106
107    Likelihood: L(μ,σ²) = ∏ᵢ (1/√(2πσ²)) exp(-(xᵢ-μ)²/(2σ²))
108    Log-likelihood: ℓ(μ,σ²) = -n/2 log(2π) - n/2 log(σ²) - Σ(xᵢ-μ)²/(2σ²)
109
110    MLE solutions:
111    μ̂ = (1/n) Σxᵢ  (sample mean)
112    σ̂² = (1/n) Σ(xᵢ-μ̂)²  (sample variance)
113    """
114    print("\n" + "="*60)
115    print("2. Maximum Likelihood Estimation (MLE)")
116    print("="*60)
117
118    # True parameters
119    true_mu = 2.0
120    true_sigma = 1.5
121
122    # Generate sample data
123    np.random.seed(42)
124    n_samples = 100
125    data = np.random.normal(true_mu, true_sigma, n_samples)
126
127    # MLE estimates
128    mle_mu = np.mean(data)
129    mle_sigma = np.std(data, ddof=0)  # ddof=0 for MLE (biased estimator)
130
131    print(f"\nTrue parameters: μ = {true_mu}, σ = {true_sigma}")
132    print(f"MLE estimates:   μ̂ = {mle_mu:.4f}, σ̂ = {mle_sigma:.4f}")
133    print(f"Sample size: n = {n_samples}")
134
135    # Plot log-likelihood surface
136    mu_range = np.linspace(0, 4, 100)
137    sigma_range = np.linspace(0.5, 3, 100)
138    MU, SIGMA = np.meshgrid(mu_range, sigma_range)
139
140    def log_likelihood(mu, sigma, data):
141        """Compute log-likelihood for Gaussian."""
142        n = len(data)
143        ll = -n/2 * np.log(2*np.pi) - n * np.log(sigma)
144        ll -= np.sum((data - mu)**2) / (2 * sigma**2)
145        return ll
146
147    # Compute log-likelihood for grid
148    LL = np.zeros_like(MU)
149    for i in range(MU.shape[0]):
150        for j in range(MU.shape[1]):
151            LL[i, j] = log_likelihood(MU[i, j], SIGMA[i, j], data)
152
153    # Visualization
154    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
155
156    # Plot 1: Data histogram with MLE fit
157    ax1.hist(data, bins=20, density=True, alpha=0.6, color='skyblue',
158             edgecolor='black', label='Data')
159
160    x_plot = np.linspace(data.min(), data.max(), 200)
161    ax1.plot(x_plot, stats.norm.pdf(x_plot, mle_mu, mle_sigma),
162             'r-', linewidth=3, label=f'MLE Fit: N({mle_mu:.2f}, {mle_sigma:.2f}²)')
163    ax1.plot(x_plot, stats.norm.pdf(x_plot, true_mu, true_sigma),
164             'g--', linewidth=2, label=f'True: N({true_mu}, {true_sigma}²)')
165
166    ax1.set_xlabel('x', fontsize=12)
167    ax1.set_ylabel('Density', fontsize=12)
168    ax1.set_title('MLE Fit to Data', fontsize=13, fontweight='bold')
169    ax1.legend(fontsize=10)
170    ax1.grid(True, alpha=0.3)
171
172    # Plot 2: Log-likelihood contour
173    contour = ax2.contour(MU, SIGMA, LL, levels=20, cmap='viridis')
174    ax2.clabel(contour, inline=True, fontsize=8)
175
176    ax2.plot(mle_mu, mle_sigma, 'r*', markersize=20, label='MLE Estimate')
177    ax2.plot(true_mu, true_sigma, 'go', markersize=12, label='True Parameters')
178
179    ax2.set_xlabel('μ', fontsize=12)
180    ax2.set_ylabel('σ', fontsize=12)
181    ax2.set_title('Log-Likelihood Surface', fontsize=13, fontweight='bold')
182    ax2.legend(fontsize=10)
183    ax2.grid(True, alpha=0.3)
184
185    plt.tight_layout()
186    plt.savefig('mle_gaussian.png', dpi=150, bbox_inches='tight')
187    print("\nSaved MLE visualization to 'mle_gaussian.png'")
188
189
190def map_estimation_demo():
191    """
192    Maximum A Posteriori (MAP) estimation with Gaussian prior.
193
194    Prior: μ ~ N(μ₀, σ₀²)
195    Likelihood: x | μ ~ N(μ, σ²)
196    Posterior: μ | x ~ N(μₙ, σₙ²)
197
198    where:
199    μₙ = (σ²μ₀ + nσ₀²x̄) / (σ² + nσ₀²)
200    σₙ² = (σ²σ₀²) / (σ² + nσ₀²)
201    """
202    print("\n" + "="*60)
203    print("3. Maximum A Posteriori (MAP) Estimation")
204    print("="*60)
205
206    # True mean
207    true_mu = 5.0
208    data_sigma = 1.0
209
210    # Prior parameters
211    prior_mu = 3.0
212    prior_sigma = 2.0
213
214    print(f"\nPrior belief: μ ~ N({prior_mu}, {prior_sigma}²)")
215    print(f"Data distribution: X ~ N(μ, {data_sigma}²)")
216    print(f"True mean: μ = {true_mu}")
217
218    # Generate increasing amounts of data
219    np.random.seed(42)
220    sample_sizes = [1, 5, 20, 100]
221
222    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
223    axes = axes.flatten()
224
225    for idx, n in enumerate(sample_sizes):
226        ax = axes[idx]
227
228        # Generate data
229        data = np.random.normal(true_mu, data_sigma, n)
230        data_mean = np.mean(data)
231
232        # MLE (just the sample mean)
233        mle_mu = data_mean
234
235        # MAP estimate (posterior mean)
236        posterior_mu = (data_sigma**2 * prior_mu + n * prior_sigma**2 * data_mean) / \
237                      (data_sigma**2 + n * prior_sigma**2)
238        posterior_sigma = np.sqrt((data_sigma**2 * prior_sigma**2) /
239                                  (data_sigma**2 + n * prior_sigma**2))
240
241        print(f"\nn = {n} samples:")
242        print(f"  Sample mean: {data_mean:.4f}")
243        print(f"  MLE:  μ̂ = {mle_mu:.4f}")
244        print(f"  MAP:  μ̂ = {posterior_mu:.4f}")
245        print(f"  Posterior: N({posterior_mu:.4f}, {posterior_sigma:.4f}²)")
246
247        # Plot prior, likelihood, and posterior
248        mu_range = np.linspace(-2, 10, 500)
249
250        # Prior
251        prior_pdf = stats.norm.pdf(mu_range, prior_mu, prior_sigma)
252        ax.plot(mu_range, prior_pdf, 'b--', linewidth=2, label='Prior')
253
254        # Likelihood (as function of μ)
255        likelihood_sigma = data_sigma / np.sqrt(n)
256        likelihood_pdf = stats.norm.pdf(mu_range, data_mean, likelihood_sigma)
257        likelihood_pdf = likelihood_pdf / likelihood_pdf.max() * prior_pdf.max()  # Scale for visualization
258        ax.plot(mu_range, likelihood_pdf, 'g:', linewidth=2, label='Likelihood (scaled)')
259
260        # Posterior
261        posterior_pdf = stats.norm.pdf(mu_range, posterior_mu, posterior_sigma)
262        ax.plot(mu_range, posterior_pdf, 'r-', linewidth=2, label='Posterior')
263
264        # Mark estimates
265        ax.axvline(prior_mu, color='b', linestyle='--', alpha=0.5, label='Prior mean')
266        ax.axvline(mle_mu, color='g', linestyle=':', alpha=0.5, label='MLE')
267        ax.axvline(posterior_mu, color='r', linestyle='-', alpha=0.5, label='MAP')
268        ax.axvline(true_mu, color='k', linestyle='-', linewidth=2, label='True μ')
269
270        ax.set_xlabel('μ', fontsize=11)
271        ax.set_ylabel('Density', fontsize=11)
272        ax.set_title(f'n = {n} samples', fontsize=12, fontweight='bold')
273        ax.legend(fontsize=8, loc='upper left')
274        ax.grid(True, alpha=0.3)
275
276    plt.tight_layout()
277    plt.savefig('map_estimation.png', dpi=150, bbox_inches='tight')
278    print("\nSaved MAP estimation plot to 'map_estimation.png'")
279
280
281def bayesian_update_demo():
282    """Visualize sequential Bayesian updates."""
283    print("\n" + "="*60)
284    print("4. Sequential Bayesian Update")
285    print("="*60)
286
287    # Setup
288    true_mu = 5.0
289    data_sigma = 1.0
290    prior_mu = 2.0
291    prior_sigma = 3.0
292
293    np.random.seed(42)
294    n_updates = 5
295    data_points = np.random.normal(true_mu, data_sigma, n_updates)
296
297    print(f"Prior: μ ~ N({prior_mu}, {prior_sigma}²)")
298    print(f"True mean: {true_mu}")
299    print(f"\nSequential observations: {data_points}")
300
301    fig, ax = plt.subplots(figsize=(12, 7))
302    mu_range = np.linspace(-5, 12, 500)
303
304    # Plot prior
305    current_mu = prior_mu
306    current_sigma = prior_sigma
307    pdf = stats.norm.pdf(mu_range, current_mu, current_sigma)
308    ax.plot(mu_range, pdf, linewidth=3, label=f'Prior: N({current_mu:.2f}, {current_sigma:.2f}²)',
309            color='blue')
310
311    colors = plt.cm.Reds(np.linspace(0.3, 0.9, n_updates))
312
313    # Sequential updates
314    for i, x in enumerate(data_points):
315        # Update: posterior becomes new prior
316        new_mu = (data_sigma**2 * current_mu + current_sigma**2 * x) / \
317                 (data_sigma**2 + current_sigma**2)
318        new_sigma = np.sqrt((data_sigma**2 * current_sigma**2) /
319                           (data_sigma**2 + current_sigma**2))
320
321        pdf = stats.norm.pdf(mu_range, new_mu, new_sigma)
322        ax.plot(mu_range, pdf, linewidth=2.5,
323                label=f'After x_{i+1}={x:.2f}: N({new_mu:.2f}, {new_sigma:.2f}²)',
324                color=colors[i])
325
326        current_mu = new_mu
327        current_sigma = new_sigma
328
329        print(f"\nUpdate {i+1}: observed x = {x:.4f}")
330        print(f"  Posterior: N({new_mu:.4f}, {new_sigma:.4f}²)")
331
332    # Mark true value
333    ax.axvline(true_mu, color='black', linestyle='--', linewidth=2,
334               label=f'True μ = {true_mu}')
335
336    ax.set_xlabel('μ', fontsize=13)
337    ax.set_ylabel('Probability Density', fontsize=13)
338    ax.set_title('Sequential Bayesian Update', fontsize=14, fontweight='bold')
339    ax.legend(fontsize=9, loc='upper left')
340    ax.grid(True, alpha=0.3)
341
342    plt.tight_layout()
343    plt.savefig('bayesian_update.png', dpi=150, bbox_inches='tight')
344    print("\nSaved Bayesian update plot to 'bayesian_update.png'")
345    print(f"\nFinal posterior: N({current_mu:.4f}, {current_sigma:.4f}²)")
346    print(f"Distance from true mean: {abs(current_mu - true_mu):.4f}")
347
348
349if __name__ == "__main__":
350    print("="*60)
351    print("Probability Distributions and Statistical Inference")
352    print("="*60)
353
354    # Run demonstrations
355    plot_common_distributions()
356    mle_gaussian_demo()
357    map_estimation_demo()
358    bayesian_update_demo()
359
360    print("\n" + "="*60)
361    print("Key Takeaways:")
362    print("="*60)
363    print("1. Gaussian: Most common distribution, Central Limit Theorem")
364    print("2. MLE: Find parameters that maximize likelihood of observed data")
365    print("3. MAP: Incorporates prior knowledge, balances prior and likelihood")
366    print("4. Bayesian update: Posterior from step n becomes prior for step n+1")
367    print("5. More data → posterior concentrates around true value")