1"""
2Probability Distributions and Statistical Inference
3
4This script demonstrates:
51. Common probability distributions (Gaussian, Bernoulli, Poisson, Exponential)
62. Maximum Likelihood Estimation (MLE) for Gaussian parameters
73. Maximum A Posteriori (MAP) estimation with Gaussian prior
84. Bayesian update visualization
9
10Author: Math for AI Examples
11"""
12
13import numpy as np
14import matplotlib.pyplot as plt
15from scipy import stats
16from typing import Tuple
17
18
19def plot_common_distributions():
20 """Visualize common probability distributions used in ML."""
21 print("\n" + "="*60)
22 print("1. Common Probability Distributions")
23 print("="*60)
24
25 fig, axes = plt.subplots(2, 2, figsize=(14, 10))
26
27 # 1. Gaussian (Normal) Distribution
28 ax = axes[0, 0]
29 x = np.linspace(-5, 5, 1000)
30
31 for mu, sigma in [(0, 0.5), (0, 1), (0, 2), (2, 1)]:
32 pdf = stats.norm.pdf(x, mu, sigma)
33 ax.plot(x, pdf, linewidth=2, label=f'μ={mu}, σ={sigma}')
34
35 ax.set_xlabel('x', fontsize=11)
36 ax.set_ylabel('Probability Density', fontsize=11)
37 ax.set_title('Gaussian Distribution', fontsize=13, fontweight='bold')
38 ax.legend(fontsize=9)
39 ax.grid(True, alpha=0.3)
40
41 print("Gaussian: f(x|μ,σ²) = (1/√(2πσ²)) exp(-(x-μ)²/(2σ²))")
42 print(" Use: Continuous data, errors, neural network activations")
43
44 # 2. Bernoulli Distribution
45 ax = axes[0, 1]
46 x = np.array([0, 1])
47
48 for p in [0.2, 0.5, 0.8]:
49 pmf = np.array([1-p, p])
50 ax.bar(x + (p-0.5)*0.15, pmf, width=0.12, alpha=0.7, label=f'p={p}')
51
52 ax.set_xlabel('x', fontsize=11)
53 ax.set_ylabel('Probability Mass', fontsize=11)
54 ax.set_title('Bernoulli Distribution', fontsize=13, fontweight='bold')
55 ax.set_xticks([0, 1])
56 ax.legend(fontsize=9)
57 ax.grid(True, alpha=0.3, axis='y')
58
59 print("\nBernoulli: P(X=1) = p, P(X=0) = 1-p")
60 print(" Use: Binary classification, coin flips")
61
62 # 3. Poisson Distribution
63 ax = axes[1, 0]
64 x = np.arange(0, 20)
65
66 for lambda_ in [1, 4, 10]:
67 pmf = stats.poisson.pmf(x, lambda_)
68 ax.plot(x, pmf, 'o-', linewidth=2, markersize=5, label=f'λ={lambda_}')
69
70 ax.set_xlabel('k (count)', fontsize=11)
71 ax.set_ylabel('Probability Mass', fontsize=11)
72 ax.set_title('Poisson Distribution', fontsize=13, fontweight='bold')
73 ax.legend(fontsize=9)
74 ax.grid(True, alpha=0.3)
75
76 print("\nPoisson: P(X=k) = (λ^k * e^(-λ)) / k!")
77 print(" Use: Count data, rare events, arrivals per time period")
78
79 # 4. Exponential Distribution
80 ax = axes[1, 1]
81 x = np.linspace(0, 5, 1000)
82
83 for lambda_ in [0.5, 1, 2]:
84 pdf = stats.expon.pdf(x, scale=1/lambda_)
85 ax.plot(x, pdf, linewidth=2, label=f'λ={lambda_}')
86
87 ax.set_xlabel('x', fontsize=11)
88 ax.set_ylabel('Probability Density', fontsize=11)
89 ax.set_title('Exponential Distribution', fontsize=13, fontweight='bold')
90 ax.legend(fontsize=9)
91 ax.grid(True, alpha=0.3)
92
93 print("\nExponential: f(x|λ) = λ * e^(-λx) for x >= 0")
94 print(" Use: Time between events, survival analysis, waiting times")
95
96 plt.tight_layout()
97 plt.savefig('probability_distributions.png', dpi=150, bbox_inches='tight')
98 print("\nSaved distributions plot to 'probability_distributions.png'")
99
100
101def mle_gaussian_demo():
102 """
103 Maximum Likelihood Estimation for Gaussian parameters.
104
105 Given data X = {x₁, ..., xₙ} from N(μ, σ²), find MLE for μ and σ².
106
107 Likelihood: L(μ,σ²) = ∏ᵢ (1/√(2πσ²)) exp(-(xᵢ-μ)²/(2σ²))
108 Log-likelihood: ℓ(μ,σ²) = -n/2 log(2π) - n/2 log(σ²) - Σ(xᵢ-μ)²/(2σ²)
109
110 MLE solutions:
111 μ̂ = (1/n) Σxᵢ (sample mean)
112 σ̂² = (1/n) Σ(xᵢ-μ̂)² (sample variance)
113 """
114 print("\n" + "="*60)
115 print("2. Maximum Likelihood Estimation (MLE)")
116 print("="*60)
117
118 # True parameters
119 true_mu = 2.0
120 true_sigma = 1.5
121
122 # Generate sample data
123 np.random.seed(42)
124 n_samples = 100
125 data = np.random.normal(true_mu, true_sigma, n_samples)
126
127 # MLE estimates
128 mle_mu = np.mean(data)
129 mle_sigma = np.std(data, ddof=0) # ddof=0 for MLE (biased estimator)
130
131 print(f"\nTrue parameters: μ = {true_mu}, σ = {true_sigma}")
132 print(f"MLE estimates: μ̂ = {mle_mu:.4f}, σ̂ = {mle_sigma:.4f}")
133 print(f"Sample size: n = {n_samples}")
134
135 # Plot log-likelihood surface
136 mu_range = np.linspace(0, 4, 100)
137 sigma_range = np.linspace(0.5, 3, 100)
138 MU, SIGMA = np.meshgrid(mu_range, sigma_range)
139
140 def log_likelihood(mu, sigma, data):
141 """Compute log-likelihood for Gaussian."""
142 n = len(data)
143 ll = -n/2 * np.log(2*np.pi) - n * np.log(sigma)
144 ll -= np.sum((data - mu)**2) / (2 * sigma**2)
145 return ll
146
147 # Compute log-likelihood for grid
148 LL = np.zeros_like(MU)
149 for i in range(MU.shape[0]):
150 for j in range(MU.shape[1]):
151 LL[i, j] = log_likelihood(MU[i, j], SIGMA[i, j], data)
152
153 # Visualization
154 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
155
156 # Plot 1: Data histogram with MLE fit
157 ax1.hist(data, bins=20, density=True, alpha=0.6, color='skyblue',
158 edgecolor='black', label='Data')
159
160 x_plot = np.linspace(data.min(), data.max(), 200)
161 ax1.plot(x_plot, stats.norm.pdf(x_plot, mle_mu, mle_sigma),
162 'r-', linewidth=3, label=f'MLE Fit: N({mle_mu:.2f}, {mle_sigma:.2f}²)')
163 ax1.plot(x_plot, stats.norm.pdf(x_plot, true_mu, true_sigma),
164 'g--', linewidth=2, label=f'True: N({true_mu}, {true_sigma}²)')
165
166 ax1.set_xlabel('x', fontsize=12)
167 ax1.set_ylabel('Density', fontsize=12)
168 ax1.set_title('MLE Fit to Data', fontsize=13, fontweight='bold')
169 ax1.legend(fontsize=10)
170 ax1.grid(True, alpha=0.3)
171
172 # Plot 2: Log-likelihood contour
173 contour = ax2.contour(MU, SIGMA, LL, levels=20, cmap='viridis')
174 ax2.clabel(contour, inline=True, fontsize=8)
175
176 ax2.plot(mle_mu, mle_sigma, 'r*', markersize=20, label='MLE Estimate')
177 ax2.plot(true_mu, true_sigma, 'go', markersize=12, label='True Parameters')
178
179 ax2.set_xlabel('μ', fontsize=12)
180 ax2.set_ylabel('σ', fontsize=12)
181 ax2.set_title('Log-Likelihood Surface', fontsize=13, fontweight='bold')
182 ax2.legend(fontsize=10)
183 ax2.grid(True, alpha=0.3)
184
185 plt.tight_layout()
186 plt.savefig('mle_gaussian.png', dpi=150, bbox_inches='tight')
187 print("\nSaved MLE visualization to 'mle_gaussian.png'")
188
189
190def map_estimation_demo():
191 """
192 Maximum A Posteriori (MAP) estimation with Gaussian prior.
193
194 Prior: μ ~ N(μ₀, σ₀²)
195 Likelihood: x | μ ~ N(μ, σ²)
196 Posterior: μ | x ~ N(μₙ, σₙ²)
197
198 where:
199 μₙ = (σ²μ₀ + nσ₀²x̄) / (σ² + nσ₀²)
200 σₙ² = (σ²σ₀²) / (σ² + nσ₀²)
201 """
202 print("\n" + "="*60)
203 print("3. Maximum A Posteriori (MAP) Estimation")
204 print("="*60)
205
206 # True mean
207 true_mu = 5.0
208 data_sigma = 1.0
209
210 # Prior parameters
211 prior_mu = 3.0
212 prior_sigma = 2.0
213
214 print(f"\nPrior belief: μ ~ N({prior_mu}, {prior_sigma}²)")
215 print(f"Data distribution: X ~ N(μ, {data_sigma}²)")
216 print(f"True mean: μ = {true_mu}")
217
218 # Generate increasing amounts of data
219 np.random.seed(42)
220 sample_sizes = [1, 5, 20, 100]
221
222 fig, axes = plt.subplots(2, 2, figsize=(14, 10))
223 axes = axes.flatten()
224
225 for idx, n in enumerate(sample_sizes):
226 ax = axes[idx]
227
228 # Generate data
229 data = np.random.normal(true_mu, data_sigma, n)
230 data_mean = np.mean(data)
231
232 # MLE (just the sample mean)
233 mle_mu = data_mean
234
235 # MAP estimate (posterior mean)
236 posterior_mu = (data_sigma**2 * prior_mu + n * prior_sigma**2 * data_mean) / \
237 (data_sigma**2 + n * prior_sigma**2)
238 posterior_sigma = np.sqrt((data_sigma**2 * prior_sigma**2) /
239 (data_sigma**2 + n * prior_sigma**2))
240
241 print(f"\nn = {n} samples:")
242 print(f" Sample mean: {data_mean:.4f}")
243 print(f" MLE: μ̂ = {mle_mu:.4f}")
244 print(f" MAP: μ̂ = {posterior_mu:.4f}")
245 print(f" Posterior: N({posterior_mu:.4f}, {posterior_sigma:.4f}²)")
246
247 # Plot prior, likelihood, and posterior
248 mu_range = np.linspace(-2, 10, 500)
249
250 # Prior
251 prior_pdf = stats.norm.pdf(mu_range, prior_mu, prior_sigma)
252 ax.plot(mu_range, prior_pdf, 'b--', linewidth=2, label='Prior')
253
254 # Likelihood (as function of μ)
255 likelihood_sigma = data_sigma / np.sqrt(n)
256 likelihood_pdf = stats.norm.pdf(mu_range, data_mean, likelihood_sigma)
257 likelihood_pdf = likelihood_pdf / likelihood_pdf.max() * prior_pdf.max() # Scale for visualization
258 ax.plot(mu_range, likelihood_pdf, 'g:', linewidth=2, label='Likelihood (scaled)')
259
260 # Posterior
261 posterior_pdf = stats.norm.pdf(mu_range, posterior_mu, posterior_sigma)
262 ax.plot(mu_range, posterior_pdf, 'r-', linewidth=2, label='Posterior')
263
264 # Mark estimates
265 ax.axvline(prior_mu, color='b', linestyle='--', alpha=0.5, label='Prior mean')
266 ax.axvline(mle_mu, color='g', linestyle=':', alpha=0.5, label='MLE')
267 ax.axvline(posterior_mu, color='r', linestyle='-', alpha=0.5, label='MAP')
268 ax.axvline(true_mu, color='k', linestyle='-', linewidth=2, label='True μ')
269
270 ax.set_xlabel('μ', fontsize=11)
271 ax.set_ylabel('Density', fontsize=11)
272 ax.set_title(f'n = {n} samples', fontsize=12, fontweight='bold')
273 ax.legend(fontsize=8, loc='upper left')
274 ax.grid(True, alpha=0.3)
275
276 plt.tight_layout()
277 plt.savefig('map_estimation.png', dpi=150, bbox_inches='tight')
278 print("\nSaved MAP estimation plot to 'map_estimation.png'")
279
280
281def bayesian_update_demo():
282 """Visualize sequential Bayesian updates."""
283 print("\n" + "="*60)
284 print("4. Sequential Bayesian Update")
285 print("="*60)
286
287 # Setup
288 true_mu = 5.0
289 data_sigma = 1.0
290 prior_mu = 2.0
291 prior_sigma = 3.0
292
293 np.random.seed(42)
294 n_updates = 5
295 data_points = np.random.normal(true_mu, data_sigma, n_updates)
296
297 print(f"Prior: μ ~ N({prior_mu}, {prior_sigma}²)")
298 print(f"True mean: {true_mu}")
299 print(f"\nSequential observations: {data_points}")
300
301 fig, ax = plt.subplots(figsize=(12, 7))
302 mu_range = np.linspace(-5, 12, 500)
303
304 # Plot prior
305 current_mu = prior_mu
306 current_sigma = prior_sigma
307 pdf = stats.norm.pdf(mu_range, current_mu, current_sigma)
308 ax.plot(mu_range, pdf, linewidth=3, label=f'Prior: N({current_mu:.2f}, {current_sigma:.2f}²)',
309 color='blue')
310
311 colors = plt.cm.Reds(np.linspace(0.3, 0.9, n_updates))
312
313 # Sequential updates
314 for i, x in enumerate(data_points):
315 # Update: posterior becomes new prior
316 new_mu = (data_sigma**2 * current_mu + current_sigma**2 * x) / \
317 (data_sigma**2 + current_sigma**2)
318 new_sigma = np.sqrt((data_sigma**2 * current_sigma**2) /
319 (data_sigma**2 + current_sigma**2))
320
321 pdf = stats.norm.pdf(mu_range, new_mu, new_sigma)
322 ax.plot(mu_range, pdf, linewidth=2.5,
323 label=f'After x_{i+1}={x:.2f}: N({new_mu:.2f}, {new_sigma:.2f}²)',
324 color=colors[i])
325
326 current_mu = new_mu
327 current_sigma = new_sigma
328
329 print(f"\nUpdate {i+1}: observed x = {x:.4f}")
330 print(f" Posterior: N({new_mu:.4f}, {new_sigma:.4f}²)")
331
332 # Mark true value
333 ax.axvline(true_mu, color='black', linestyle='--', linewidth=2,
334 label=f'True μ = {true_mu}')
335
336 ax.set_xlabel('μ', fontsize=13)
337 ax.set_ylabel('Probability Density', fontsize=13)
338 ax.set_title('Sequential Bayesian Update', fontsize=14, fontweight='bold')
339 ax.legend(fontsize=9, loc='upper left')
340 ax.grid(True, alpha=0.3)
341
342 plt.tight_layout()
343 plt.savefig('bayesian_update.png', dpi=150, bbox_inches='tight')
344 print("\nSaved Bayesian update plot to 'bayesian_update.png'")
345 print(f"\nFinal posterior: N({current_mu:.4f}, {current_sigma:.4f}²)")
346 print(f"Distance from true mean: {abs(current_mu - true_mu):.4f}")
347
348
349if __name__ == "__main__":
350 print("="*60)
351 print("Probability Distributions and Statistical Inference")
352 print("="*60)
353
354 # Run demonstrations
355 plot_common_distributions()
356 mle_gaussian_demo()
357 map_estimation_demo()
358 bayesian_update_demo()
359
360 print("\n" + "="*60)
361 print("Key Takeaways:")
362 print("="*60)
363 print("1. Gaussian: Most common distribution, Central Limit Theorem")
364 print("2. MLE: Find parameters that maximize likelihood of observed data")
365 print("3. MAP: Incorporates prior knowledge, balances prior and likelihood")
366 print("4. Bayesian update: Posterior from step n becomes prior for step n+1")
367 print("5. More data → posterior concentrates around true value")