1"""
2Information Theory for Machine Learning
3
4This script demonstrates key information theory concepts:
51. Entropy - measure of uncertainty
62. Cross-entropy and KL divergence
73. Mutual information
84. Connection to ML loss functions
95. ELBO (Evidence Lower Bound) visualization
10
11Author: Math for AI Examples
12"""
13
14import numpy as np
15import matplotlib.pyplot as plt
16from scipy import stats
17from typing import List, Tuple
18
19
20def entropy_demo():
21 """
22 Demonstrate entropy as a measure of uncertainty.
23
24 Entropy: H(X) = -Σ p(x) log p(x)
25
26 Properties:
27 - H(X) >= 0 (non-negative)
28 - H(X) = 0 when X is deterministic
29 - H(X) is maximum when X is uniform
30 """
31 print("\n" + "="*60)
32 print("1. Entropy - Measure of Uncertainty")
33 print("="*60)
34
35 def compute_entropy(probs: np.ndarray) -> float:
36 """Compute Shannon entropy (base 2)."""
37 # Filter out zero probabilities to avoid log(0)
38 probs = probs[probs > 0]
39 return -np.sum(probs * np.log2(probs))
40
41 # Example 1: Binary random variable
42 print("\n--- Binary Random Variable ---")
43 p_values = np.linspace(0.01, 0.99, 100)
44 entropies = []
45
46 for p in p_values:
47 probs = np.array([p, 1-p])
48 H = compute_entropy(probs)
49 entropies.append(H)
50
51 print(f"Entropy when p=0.5 (max): {max(entropies):.4f} bits")
52 print(f"Entropy when p→0 or p→1 (min): ~0 bits")
53
54 # Example 2: Different distributions
55 print("\n--- Comparing Different Distributions ---")
56
57 # Uniform distribution (max entropy)
58 n = 8
59 uniform = np.ones(n) / n
60 H_uniform = compute_entropy(uniform)
61 print(f"Uniform distribution (n={n}): H = {H_uniform:.4f} bits")
62
63 # Peaked distribution (low entropy)
64 peaked = np.array([0.7, 0.1, 0.05, 0.05, 0.03, 0.03, 0.02, 0.02])
65 H_peaked = compute_entropy(peaked)
66 print(f"Peaked distribution: H = {H_peaked:.4f} bits")
67
68 # Deterministic (zero entropy)
69 deterministic = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
70 H_deterministic = compute_entropy(deterministic)
71 print(f"Deterministic: H = {H_deterministic:.4f} bits")
72
73 # Visualization
74 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
75
76 # Plot 1: Binary entropy curve
77 ax1.plot(p_values, entropies, linewidth=3, color='blue')
78 ax1.axvline(0.5, color='red', linestyle='--', linewidth=2, label='Maximum at p=0.5')
79 ax1.set_xlabel('p (probability of outcome 1)', fontsize=12)
80 ax1.set_ylabel('Entropy (bits)', fontsize=12)
81 ax1.set_title('Binary Entropy Function', fontsize=13, fontweight='bold')
82 ax1.grid(True, alpha=0.3)
83 ax1.legend(fontsize=10)
84
85 # Plot 2: Different distributions
86 x = np.arange(n)
87 width = 0.25
88
89 ax2.bar(x - width, uniform, width, label=f'Uniform (H={H_uniform:.2f})',
90 alpha=0.7, color='blue')
91 ax2.bar(x, peaked, width, label=f'Peaked (H={H_peaked:.2f})',
92 alpha=0.7, color='orange')
93 ax2.bar(x + width, deterministic, width, label=f'Deterministic (H={H_deterministic:.2f})',
94 alpha=0.7, color='green')
95
96 ax2.set_xlabel('Outcome', fontsize=12)
97 ax2.set_ylabel('Probability', fontsize=12)
98 ax2.set_title('Entropy of Different Distributions', fontsize=13, fontweight='bold')
99 ax2.legend(fontsize=10)
100 ax2.grid(True, alpha=0.3, axis='y')
101
102 plt.tight_layout()
103 plt.savefig('entropy.png', dpi=150, bbox_inches='tight')
104 print("\nSaved entropy visualization to 'entropy.png'")
105
106
107def cross_entropy_kl_divergence():
108 """
109 Demonstrate cross-entropy and KL divergence.
110
111 Cross-Entropy: H(P,Q) = -Σ p(x) log q(x)
112 KL Divergence: D_KL(P||Q) = Σ p(x) log(p(x)/q(x)) = H(P,Q) - H(P)
113
114 KL divergence measures how much Q differs from P.
115 """
116 print("\n" + "="*60)
117 print("2. Cross-Entropy and KL Divergence")
118 print("="*60)
119
120 def cross_entropy(p: np.ndarray, q: np.ndarray) -> float:
121 """Compute cross-entropy H(P,Q)."""
122 return -np.sum(p * np.log(q + 1e-10))
123
124 def kl_divergence(p: np.ndarray, q: np.ndarray) -> float:
125 """Compute KL divergence D_KL(P||Q)."""
126 return np.sum(p * np.log((p + 1e-10) / (q + 1e-10)))
127
128 # True distribution P
129 p = np.array([0.1, 0.2, 0.3, 0.25, 0.1, 0.05])
130
131 # Various approximate distributions Q
132 q1 = np.array([0.15, 0.2, 0.25, 0.25, 0.1, 0.05]) # Close to P
133 q2 = np.array([0.2, 0.2, 0.2, 0.2, 0.1, 0.1]) # Moderately different
134 q3 = np.array([0.05, 0.05, 0.1, 0.1, 0.3, 0.4]) # Very different
135 q_uniform = np.ones(6) / 6 # Uniform
136
137 distributions = [
138 ('P (true)', p),
139 ('Q1 (close)', q1),
140 ('Q2 (moderate)', q2),
141 ('Q3 (far)', q3),
142 ('Uniform', q_uniform)
143 ]
144
145 H_p = -np.sum(p * np.log(p))
146
147 print(f"\nTrue distribution P: {p}")
148 print(f"Entropy H(P) = {H_p:.4f}")
149 print("\n" + "-"*60)
150
151 results = []
152 for name, q in distributions[1:]: # Skip P itself
153 ce = cross_entropy(p, q)
154 kl = kl_divergence(p, q)
155 results.append((name, q, ce, kl))
156
157 print(f"\n{name}: {q}")
158 print(f" Cross-Entropy H(P,Q) = {ce:.4f}")
159 print(f" KL Divergence D_KL(P||Q) = {kl:.4f}")
160 print(f" Relationship: H(P,Q) = H(P) + D_KL(P||Q)")
161 print(f" {ce:.4f} = {H_p:.4f} + {kl:.4f}")
162
163 # Visualization
164 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
165
166 # Plot 1: Distributions
167 x = np.arange(len(p))
168 width = 0.15
169
170 ax1.bar(x - 2*width, p, width, label='P (true)', alpha=0.8, color='blue')
171 for i, (name, q, _, _) in enumerate(results):
172 ax1.bar(x + (i-1)*width, q, width, label=name, alpha=0.7)
173
174 ax1.set_xlabel('Outcome', fontsize=12)
175 ax1.set_ylabel('Probability', fontsize=12)
176 ax1.set_title('True vs Approximate Distributions', fontsize=13, fontweight='bold')
177 ax1.legend(fontsize=9)
178 ax1.grid(True, alpha=0.3, axis='y')
179
180 # Plot 2: Cross-entropy and KL divergence
181 names = [r[0] for r in results]
182 ces = [r[2] for r in results]
183 kls = [r[3] for r in results]
184
185 x_pos = np.arange(len(names))
186 ax2.bar(x_pos - 0.2, ces, 0.4, label='Cross-Entropy H(P,Q)', alpha=0.8, color='orange')
187 ax2.bar(x_pos + 0.2, kls, 0.4, label='KL Divergence D_KL(P||Q)', alpha=0.8, color='red')
188 ax2.axhline(H_p, color='blue', linestyle='--', linewidth=2, label='H(P)')
189
190 ax2.set_xticks(x_pos)
191 ax2.set_xticklabels(names, rotation=15, ha='right')
192 ax2.set_ylabel('Value (nats)', fontsize=12)
193 ax2.set_title('Cross-Entropy and KL Divergence', fontsize=13, fontweight='bold')
194 ax2.legend(fontsize=10)
195 ax2.grid(True, alpha=0.3, axis='y')
196
197 plt.tight_layout()
198 plt.savefig('cross_entropy_kl.png', dpi=150, bbox_inches='tight')
199 print("\nSaved cross-entropy/KL divergence plot to 'cross_entropy_kl.png'")
200
201
202def mutual_information_demo():
203 """
204 Demonstrate mutual information.
205
206 Mutual Information: I(X;Y) = H(X) + H(Y) - H(X,Y)
207 = H(X) - H(X|Y)
208 = D_KL(P(X,Y) || P(X)P(Y))
209
210 Measures how much knowing Y reduces uncertainty about X.
211 """
212 print("\n" + "="*60)
213 print("3. Mutual Information")
214 print("="*60)
215
216 def entropy_joint(pxy: np.ndarray) -> float:
217 """Compute joint entropy H(X,Y)."""
218 return -np.sum(pxy * np.log(pxy + 1e-10))
219
220 def mutual_information(pxy: np.ndarray) -> float:
221 """Compute mutual information I(X;Y)."""
222 px = pxy.sum(axis=1)
223 py = pxy.sum(axis=0)
224 px_py = px[:, np.newaxis] * py[np.newaxis, :]
225
226 # I(X;Y) = D_KL(P(X,Y) || P(X)P(Y))
227 return np.sum(pxy * np.log((pxy + 1e-10) / (px_py + 1e-10)))
228
229 # Example 1: Independent variables (I = 0)
230 print("\n--- Independent Variables ---")
231 px = np.array([0.5, 0.5])
232 py = np.array([0.3, 0.7])
233 pxy_indep = px[:, np.newaxis] * py[np.newaxis, :]
234
235 mi_indep = mutual_information(pxy_indep)
236 print("P(X,Y) = P(X)P(Y) (independent)")
237 print(f"Mutual Information I(X;Y) = {mi_indep:.6f} ≈ 0")
238
239 # Example 2: Perfectly correlated (I = H(X) = H(Y))
240 print("\n--- Perfectly Correlated Variables ---")
241 pxy_perfect = np.array([[0.3, 0.0], [0.0, 0.7]])
242
243 mi_perfect = mutual_information(pxy_perfect)
244 h_x = entropy_joint(pxy_perfect.sum(axis=1, keepdims=True).T)
245 print("X = Y (perfect correlation)")
246 print(f"Mutual Information I(X;Y) = {mi_perfect:.4f}")
247 print(f"H(X) = {h_x:.4f} (I(X;Y) = H(X) when perfectly correlated)")
248
249 # Example 3: Partial dependence
250 print("\n--- Partially Dependent Variables ---")
251 pxy_partial = np.array([[0.25, 0.05], [0.15, 0.55]])
252
253 mi_partial = mutual_information(pxy_partial)
254 print(f"Mutual Information I(X;Y) = {mi_partial:.4f}")
255 print("(Between 0 and H(X), indicating partial dependence)")
256
257 # Visualization
258 fig, axes = plt.subplots(1, 3, figsize=(15, 5))
259
260 scenarios = [
261 ('Independent\nI(X;Y) ≈ 0', pxy_indep, mi_indep),
262 ('Partially Dependent\nI(X;Y) = {:.3f}'.format(mi_partial), pxy_partial, mi_partial),
263 ('Perfect Correlation\nI(X;Y) = {:.3f}'.format(mi_perfect), pxy_perfect, mi_perfect)
264 ]
265
266 for ax, (title, pxy, mi) in zip(axes, scenarios):
267 im = ax.imshow(pxy, cmap='YlOrRd', aspect='auto', vmin=0, vmax=0.7)
268 ax.set_xticks([0, 1])
269 ax.set_yticks([0, 1])
270 ax.set_xlabel('Y', fontsize=11)
271 ax.set_ylabel('X', fontsize=11)
272 ax.set_title(title, fontsize=11, fontweight='bold')
273
274 # Annotate cells with probabilities
275 for i in range(2):
276 for j in range(2):
277 ax.text(j, i, f'{pxy[i,j]:.2f}', ha='center', va='center',
278 color='white' if pxy[i,j] > 0.35 else 'black', fontsize=13)
279
280 plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
281
282 plt.tight_layout()
283 plt.savefig('mutual_information.png', dpi=150, bbox_inches='tight')
284 print("\nSaved mutual information plot to 'mutual_information.png'")
285
286
287def ml_loss_functions():
288 """
289 Connect information theory to ML loss functions.
290
291 Binary Classification:
292 - Cross-entropy loss = -[y log(ŷ) + (1-y) log(1-ŷ)]
293 - Minimizing cross-entropy = Maximizing likelihood
294
295 Multi-class Classification:
296 - Cross-entropy loss = -Σ y_c log(ŷ_c)
297 - Same as negative log-likelihood
298 """
299 print("\n" + "="*60)
300 print("4. Connection to ML Loss Functions")
301 print("="*60)
302
303 # Binary classification example
304 print("\n--- Binary Classification ---")
305
306 y_true = np.array([1, 0, 1, 1, 0]) # True labels
307 y_pred = np.array([0.9, 0.1, 0.8, 0.7, 0.2]) # Predicted probabilities
308
309 # Binary cross-entropy loss
310 bce_loss = -np.mean(y_true * np.log(y_pred + 1e-10) +
311 (1 - y_true) * np.log(1 - y_pred + 1e-10))
312
313 print(f"True labels: {y_true}")
314 print(f"Predictions: {y_pred}")
315 print(f"Binary Cross-Entropy Loss: {bce_loss:.4f}")
316
317 # Multi-class classification example
318 print("\n--- Multi-class Classification ---")
319
320 # 3 samples, 4 classes
321 y_true_mc = np.array([
322 [1, 0, 0, 0], # Class 0
323 [0, 1, 0, 0], # Class 1
324 [0, 0, 1, 0] # Class 2
325 ])
326
327 y_pred_mc = np.array([
328 [0.7, 0.2, 0.05, 0.05],
329 [0.1, 0.6, 0.2, 0.1],
330 [0.1, 0.1, 0.7, 0.1]
331 ])
332
333 # Categorical cross-entropy loss
334 cce_loss = -np.mean(np.sum(y_true_mc * np.log(y_pred_mc + 1e-10), axis=1))
335
336 print(f"Categorical Cross-Entropy Loss: {cce_loss:.4f}")
337 print("\nInterpretation:")
338 print(" Lower loss = predictions closer to true distribution")
339 print(" Minimizing cross-entropy = Maximizing likelihood")
340 print(" Cross-entropy > entropy of true distribution")
341
342 # Visualization: Effect of prediction confidence on loss
343 probs = np.linspace(0.01, 0.99, 100)
344 loss_correct = -np.log(probs) # When prediction matches true label
345 loss_incorrect = -np.log(1 - probs) # When prediction doesn't match
346
347 fig, ax = plt.subplots(figsize=(10, 6))
348
349 ax.plot(probs, loss_correct, linewidth=3, label='Correct prediction (y=1, ŷ=p)',
350 color='green')
351 ax.plot(probs, loss_incorrect, linewidth=3, label='Incorrect prediction (y=0, ŷ=p)',
352 color='red')
353
354 ax.set_xlabel('Predicted Probability (ŷ)', fontsize=12)
355 ax.set_ylabel('Loss', fontsize=12)
356 ax.set_title('Binary Cross-Entropy Loss', fontsize=14, fontweight='bold')
357 ax.legend(fontsize=11)
358 ax.grid(True, alpha=0.3)
359 ax.set_ylim(0, 5)
360
361 plt.tight_layout()
362 plt.savefig('ml_loss_functions.png', dpi=150, bbox_inches='tight')
363 print("\nSaved ML loss functions plot to 'ml_loss_functions.png'")
364
365
366def elbo_visualization():
367 """
368 Demonstrate ELBO (Evidence Lower Bound) for variational inference.
369
370 ELBO: log p(x) >= E_q[log p(x,z)] - E_q[log q(z)]
371 = E_q[log p(x|z)] - D_KL(q(z)||p(z))
372
373 Maximizing ELBO = Minimizing KL(q||p) while maximizing likelihood.
374 """
375 print("\n" + "="*60)
376 print("5. ELBO (Evidence Lower Bound)")
377 print("="*60)
378
379 print("\nVariational Inference Setup:")
380 print(" True posterior: p(z|x) - intractable")
381 print(" Approximate: q(z) ≈ p(z|x)")
382 print("\nELBO Decomposition:")
383 print(" log p(x) = ELBO + D_KL(q(z)||p(z|x))")
384 print(" ELBO = E_q[log p(x|z)] - D_KL(q(z)||p(z))")
385 print("\nMaximizing ELBO:")
386 print(" 1. Minimizes KL divergence to true posterior")
387 print(" 2. Maximizes expected log-likelihood")
388
389 # Simulate ELBO during training
390 n_iterations = 200
391
392 # ELBO components (simulated)
393 np.random.seed(42)
394 reconstruction_loss = 100 * np.exp(-np.linspace(0, 3, n_iterations))
395 reconstruction_loss += np.random.randn(n_iterations) * 2
396
397 kl_divergence = 50 * (1 - np.exp(-np.linspace(0, 2, n_iterations)))
398 kl_divergence += np.random.randn(n_iterations) * 1.5
399
400 elbo = -(reconstruction_loss + kl_divergence) # Negative because we're minimizing loss
401
402 # Visualization
403 fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
404
405 # Plot 1: ELBO components
406 iterations = np.arange(n_iterations)
407
408 ax1.plot(iterations, reconstruction_loss, linewidth=2, label='Reconstruction Loss (negative log p(x|z))',
409 color='red')
410 ax1.plot(iterations, kl_divergence, linewidth=2, label='KL Divergence D_KL(q||p)',
411 color='blue')
412 ax1.plot(iterations, reconstruction_loss + kl_divergence, linewidth=2.5,
413 label='Total Loss (negative ELBO)', color='black', linestyle='--')
414
415 ax1.set_xlabel('Iteration', fontsize=12)
416 ax1.set_ylabel('Loss', fontsize=12)
417 ax1.set_title('ELBO Components During Training', fontsize=14, fontweight='bold')
418 ax1.legend(fontsize=10)
419 ax1.grid(True, alpha=0.3)
420
421 # Plot 2: ELBO (to be maximized)
422 ax2.plot(iterations, elbo, linewidth=3, color='green')
423 ax2.set_xlabel('Iteration', fontsize=12)
424 ax2.set_ylabel('ELBO (to maximize)', fontsize=12)
425 ax2.set_title('Evidence Lower Bound (ELBO)', fontsize=14, fontweight='bold')
426 ax2.grid(True, alpha=0.3)
427
428 # Annotate
429 ax2.text(150, elbo[150], 'ELBO increases →\nbetter approximation',
430 fontsize=11, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
431
432 plt.tight_layout()
433 plt.savefig('elbo_visualization.png', dpi=150, bbox_inches='tight')
434 print("\nSaved ELBO visualization to 'elbo_visualization.png'")
435
436 print(f"\nFinal values:")
437 print(f" Reconstruction Loss: {reconstruction_loss[-1]:.2f}")
438 print(f" KL Divergence: {kl_divergence[-1]:.2f}")
439 print(f" ELBO: {elbo[-1]:.2f}")
440
441
442if __name__ == "__main__":
443 print("="*60)
444 print("Information Theory for Machine Learning")
445 print("="*60)
446
447 # Run demonstrations
448 entropy_demo()
449 cross_entropy_kl_divergence()
450 mutual_information_demo()
451 ml_loss_functions()
452 elbo_visualization()
453
454 print("\n" + "="*60)
455 print("Key Takeaways:")
456 print("="*60)
457 print("1. Entropy: Measures uncertainty/information content")
458 print("2. Cross-Entropy: Used as loss function in classification")
459 print("3. KL Divergence: Measures difference between distributions")
460 print("4. Mutual Information: Measures dependence between variables")
461 print("5. ELBO: Fundamental to variational inference (VAE, etc.)")
462 print("6. Minimizing cross-entropy = Maximizing likelihood")