08_information_theory.py

Download
python 463 lines 15.9 KB
  1"""
  2Information Theory for Machine Learning
  3
  4This script demonstrates key information theory concepts:
  51. Entropy - measure of uncertainty
  62. Cross-entropy and KL divergence
  73. Mutual information
  84. Connection to ML loss functions
  95. ELBO (Evidence Lower Bound) visualization
 10
 11Author: Math for AI Examples
 12"""
 13
 14import numpy as np
 15import matplotlib.pyplot as plt
 16from scipy import stats
 17from typing import List, Tuple
 18
 19
 20def entropy_demo():
 21    """
 22    Demonstrate entropy as a measure of uncertainty.
 23
 24    Entropy: H(X) = -Σ p(x) log p(x)
 25
 26    Properties:
 27    - H(X) >= 0 (non-negative)
 28    - H(X) = 0 when X is deterministic
 29    - H(X) is maximum when X is uniform
 30    """
 31    print("\n" + "="*60)
 32    print("1. Entropy - Measure of Uncertainty")
 33    print("="*60)
 34
 35    def compute_entropy(probs: np.ndarray) -> float:
 36        """Compute Shannon entropy (base 2)."""
 37        # Filter out zero probabilities to avoid log(0)
 38        probs = probs[probs > 0]
 39        return -np.sum(probs * np.log2(probs))
 40
 41    # Example 1: Binary random variable
 42    print("\n--- Binary Random Variable ---")
 43    p_values = np.linspace(0.01, 0.99, 100)
 44    entropies = []
 45
 46    for p in p_values:
 47        probs = np.array([p, 1-p])
 48        H = compute_entropy(probs)
 49        entropies.append(H)
 50
 51    print(f"Entropy when p=0.5 (max): {max(entropies):.4f} bits")
 52    print(f"Entropy when p→0 or p→1 (min): ~0 bits")
 53
 54    # Example 2: Different distributions
 55    print("\n--- Comparing Different Distributions ---")
 56
 57    # Uniform distribution (max entropy)
 58    n = 8
 59    uniform = np.ones(n) / n
 60    H_uniform = compute_entropy(uniform)
 61    print(f"Uniform distribution (n={n}): H = {H_uniform:.4f} bits")
 62
 63    # Peaked distribution (low entropy)
 64    peaked = np.array([0.7, 0.1, 0.05, 0.05, 0.03, 0.03, 0.02, 0.02])
 65    H_peaked = compute_entropy(peaked)
 66    print(f"Peaked distribution: H = {H_peaked:.4f} bits")
 67
 68    # Deterministic (zero entropy)
 69    deterministic = np.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
 70    H_deterministic = compute_entropy(deterministic)
 71    print(f"Deterministic: H = {H_deterministic:.4f} bits")
 72
 73    # Visualization
 74    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
 75
 76    # Plot 1: Binary entropy curve
 77    ax1.plot(p_values, entropies, linewidth=3, color='blue')
 78    ax1.axvline(0.5, color='red', linestyle='--', linewidth=2, label='Maximum at p=0.5')
 79    ax1.set_xlabel('p (probability of outcome 1)', fontsize=12)
 80    ax1.set_ylabel('Entropy (bits)', fontsize=12)
 81    ax1.set_title('Binary Entropy Function', fontsize=13, fontweight='bold')
 82    ax1.grid(True, alpha=0.3)
 83    ax1.legend(fontsize=10)
 84
 85    # Plot 2: Different distributions
 86    x = np.arange(n)
 87    width = 0.25
 88
 89    ax2.bar(x - width, uniform, width, label=f'Uniform (H={H_uniform:.2f})',
 90            alpha=0.7, color='blue')
 91    ax2.bar(x, peaked, width, label=f'Peaked (H={H_peaked:.2f})',
 92            alpha=0.7, color='orange')
 93    ax2.bar(x + width, deterministic, width, label=f'Deterministic (H={H_deterministic:.2f})',
 94            alpha=0.7, color='green')
 95
 96    ax2.set_xlabel('Outcome', fontsize=12)
 97    ax2.set_ylabel('Probability', fontsize=12)
 98    ax2.set_title('Entropy of Different Distributions', fontsize=13, fontweight='bold')
 99    ax2.legend(fontsize=10)
100    ax2.grid(True, alpha=0.3, axis='y')
101
102    plt.tight_layout()
103    plt.savefig('entropy.png', dpi=150, bbox_inches='tight')
104    print("\nSaved entropy visualization to 'entropy.png'")
105
106
107def cross_entropy_kl_divergence():
108    """
109    Demonstrate cross-entropy and KL divergence.
110
111    Cross-Entropy: H(P,Q) = -Σ p(x) log q(x)
112    KL Divergence: D_KL(P||Q) = Σ p(x) log(p(x)/q(x)) = H(P,Q) - H(P)
113
114    KL divergence measures how much Q differs from P.
115    """
116    print("\n" + "="*60)
117    print("2. Cross-Entropy and KL Divergence")
118    print("="*60)
119
120    def cross_entropy(p: np.ndarray, q: np.ndarray) -> float:
121        """Compute cross-entropy H(P,Q)."""
122        return -np.sum(p * np.log(q + 1e-10))
123
124    def kl_divergence(p: np.ndarray, q: np.ndarray) -> float:
125        """Compute KL divergence D_KL(P||Q)."""
126        return np.sum(p * np.log((p + 1e-10) / (q + 1e-10)))
127
128    # True distribution P
129    p = np.array([0.1, 0.2, 0.3, 0.25, 0.1, 0.05])
130
131    # Various approximate distributions Q
132    q1 = np.array([0.15, 0.2, 0.25, 0.25, 0.1, 0.05])  # Close to P
133    q2 = np.array([0.2, 0.2, 0.2, 0.2, 0.1, 0.1])      # Moderately different
134    q3 = np.array([0.05, 0.05, 0.1, 0.1, 0.3, 0.4])    # Very different
135    q_uniform = np.ones(6) / 6                          # Uniform
136
137    distributions = [
138        ('P (true)', p),
139        ('Q1 (close)', q1),
140        ('Q2 (moderate)', q2),
141        ('Q3 (far)', q3),
142        ('Uniform', q_uniform)
143    ]
144
145    H_p = -np.sum(p * np.log(p))
146
147    print(f"\nTrue distribution P: {p}")
148    print(f"Entropy H(P) = {H_p:.4f}")
149    print("\n" + "-"*60)
150
151    results = []
152    for name, q in distributions[1:]:  # Skip P itself
153        ce = cross_entropy(p, q)
154        kl = kl_divergence(p, q)
155        results.append((name, q, ce, kl))
156
157        print(f"\n{name}: {q}")
158        print(f"  Cross-Entropy H(P,Q) = {ce:.4f}")
159        print(f"  KL Divergence D_KL(P||Q) = {kl:.4f}")
160        print(f"  Relationship: H(P,Q) = H(P) + D_KL(P||Q)")
161        print(f"              {ce:.4f} = {H_p:.4f} + {kl:.4f}")
162
163    # Visualization
164    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
165
166    # Plot 1: Distributions
167    x = np.arange(len(p))
168    width = 0.15
169
170    ax1.bar(x - 2*width, p, width, label='P (true)', alpha=0.8, color='blue')
171    for i, (name, q, _, _) in enumerate(results):
172        ax1.bar(x + (i-1)*width, q, width, label=name, alpha=0.7)
173
174    ax1.set_xlabel('Outcome', fontsize=12)
175    ax1.set_ylabel('Probability', fontsize=12)
176    ax1.set_title('True vs Approximate Distributions', fontsize=13, fontweight='bold')
177    ax1.legend(fontsize=9)
178    ax1.grid(True, alpha=0.3, axis='y')
179
180    # Plot 2: Cross-entropy and KL divergence
181    names = [r[0] for r in results]
182    ces = [r[2] for r in results]
183    kls = [r[3] for r in results]
184
185    x_pos = np.arange(len(names))
186    ax2.bar(x_pos - 0.2, ces, 0.4, label='Cross-Entropy H(P,Q)', alpha=0.8, color='orange')
187    ax2.bar(x_pos + 0.2, kls, 0.4, label='KL Divergence D_KL(P||Q)', alpha=0.8, color='red')
188    ax2.axhline(H_p, color='blue', linestyle='--', linewidth=2, label='H(P)')
189
190    ax2.set_xticks(x_pos)
191    ax2.set_xticklabels(names, rotation=15, ha='right')
192    ax2.set_ylabel('Value (nats)', fontsize=12)
193    ax2.set_title('Cross-Entropy and KL Divergence', fontsize=13, fontweight='bold')
194    ax2.legend(fontsize=10)
195    ax2.grid(True, alpha=0.3, axis='y')
196
197    plt.tight_layout()
198    plt.savefig('cross_entropy_kl.png', dpi=150, bbox_inches='tight')
199    print("\nSaved cross-entropy/KL divergence plot to 'cross_entropy_kl.png'")
200
201
202def mutual_information_demo():
203    """
204    Demonstrate mutual information.
205
206    Mutual Information: I(X;Y) = H(X) + H(Y) - H(X,Y)
207                                = H(X) - H(X|Y)
208                                = D_KL(P(X,Y) || P(X)P(Y))
209
210    Measures how much knowing Y reduces uncertainty about X.
211    """
212    print("\n" + "="*60)
213    print("3. Mutual Information")
214    print("="*60)
215
216    def entropy_joint(pxy: np.ndarray) -> float:
217        """Compute joint entropy H(X,Y)."""
218        return -np.sum(pxy * np.log(pxy + 1e-10))
219
220    def mutual_information(pxy: np.ndarray) -> float:
221        """Compute mutual information I(X;Y)."""
222        px = pxy.sum(axis=1)
223        py = pxy.sum(axis=0)
224        px_py = px[:, np.newaxis] * py[np.newaxis, :]
225
226        # I(X;Y) = D_KL(P(X,Y) || P(X)P(Y))
227        return np.sum(pxy * np.log((pxy + 1e-10) / (px_py + 1e-10)))
228
229    # Example 1: Independent variables (I = 0)
230    print("\n--- Independent Variables ---")
231    px = np.array([0.5, 0.5])
232    py = np.array([0.3, 0.7])
233    pxy_indep = px[:, np.newaxis] * py[np.newaxis, :]
234
235    mi_indep = mutual_information(pxy_indep)
236    print("P(X,Y) = P(X)P(Y)  (independent)")
237    print(f"Mutual Information I(X;Y) = {mi_indep:.6f} ≈ 0")
238
239    # Example 2: Perfectly correlated (I = H(X) = H(Y))
240    print("\n--- Perfectly Correlated Variables ---")
241    pxy_perfect = np.array([[0.3, 0.0], [0.0, 0.7]])
242
243    mi_perfect = mutual_information(pxy_perfect)
244    h_x = entropy_joint(pxy_perfect.sum(axis=1, keepdims=True).T)
245    print("X = Y  (perfect correlation)")
246    print(f"Mutual Information I(X;Y) = {mi_perfect:.4f}")
247    print(f"H(X) = {h_x:.4f}  (I(X;Y) = H(X) when perfectly correlated)")
248
249    # Example 3: Partial dependence
250    print("\n--- Partially Dependent Variables ---")
251    pxy_partial = np.array([[0.25, 0.05], [0.15, 0.55]])
252
253    mi_partial = mutual_information(pxy_partial)
254    print(f"Mutual Information I(X;Y) = {mi_partial:.4f}")
255    print("(Between 0 and H(X), indicating partial dependence)")
256
257    # Visualization
258    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
259
260    scenarios = [
261        ('Independent\nI(X;Y) ≈ 0', pxy_indep, mi_indep),
262        ('Partially Dependent\nI(X;Y) = {:.3f}'.format(mi_partial), pxy_partial, mi_partial),
263        ('Perfect Correlation\nI(X;Y) = {:.3f}'.format(mi_perfect), pxy_perfect, mi_perfect)
264    ]
265
266    for ax, (title, pxy, mi) in zip(axes, scenarios):
267        im = ax.imshow(pxy, cmap='YlOrRd', aspect='auto', vmin=0, vmax=0.7)
268        ax.set_xticks([0, 1])
269        ax.set_yticks([0, 1])
270        ax.set_xlabel('Y', fontsize=11)
271        ax.set_ylabel('X', fontsize=11)
272        ax.set_title(title, fontsize=11, fontweight='bold')
273
274        # Annotate cells with probabilities
275        for i in range(2):
276            for j in range(2):
277                ax.text(j, i, f'{pxy[i,j]:.2f}', ha='center', va='center',
278                       color='white' if pxy[i,j] > 0.35 else 'black', fontsize=13)
279
280        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
281
282    plt.tight_layout()
283    plt.savefig('mutual_information.png', dpi=150, bbox_inches='tight')
284    print("\nSaved mutual information plot to 'mutual_information.png'")
285
286
287def ml_loss_functions():
288    """
289    Connect information theory to ML loss functions.
290
291    Binary Classification:
292    - Cross-entropy loss = -[y log(ŷ) + (1-y) log(1-ŷ)]
293    - Minimizing cross-entropy = Maximizing likelihood
294
295    Multi-class Classification:
296    - Cross-entropy loss = -Σ y_c log(ŷ_c)
297    - Same as negative log-likelihood
298    """
299    print("\n" + "="*60)
300    print("4. Connection to ML Loss Functions")
301    print("="*60)
302
303    # Binary classification example
304    print("\n--- Binary Classification ---")
305
306    y_true = np.array([1, 0, 1, 1, 0])  # True labels
307    y_pred = np.array([0.9, 0.1, 0.8, 0.7, 0.2])  # Predicted probabilities
308
309    # Binary cross-entropy loss
310    bce_loss = -np.mean(y_true * np.log(y_pred + 1e-10) +
311                        (1 - y_true) * np.log(1 - y_pred + 1e-10))
312
313    print(f"True labels: {y_true}")
314    print(f"Predictions: {y_pred}")
315    print(f"Binary Cross-Entropy Loss: {bce_loss:.4f}")
316
317    # Multi-class classification example
318    print("\n--- Multi-class Classification ---")
319
320    # 3 samples, 4 classes
321    y_true_mc = np.array([
322        [1, 0, 0, 0],  # Class 0
323        [0, 1, 0, 0],  # Class 1
324        [0, 0, 1, 0]   # Class 2
325    ])
326
327    y_pred_mc = np.array([
328        [0.7, 0.2, 0.05, 0.05],
329        [0.1, 0.6, 0.2, 0.1],
330        [0.1, 0.1, 0.7, 0.1]
331    ])
332
333    # Categorical cross-entropy loss
334    cce_loss = -np.mean(np.sum(y_true_mc * np.log(y_pred_mc + 1e-10), axis=1))
335
336    print(f"Categorical Cross-Entropy Loss: {cce_loss:.4f}")
337    print("\nInterpretation:")
338    print("  Lower loss = predictions closer to true distribution")
339    print("  Minimizing cross-entropy = Maximizing likelihood")
340    print("  Cross-entropy > entropy of true distribution")
341
342    # Visualization: Effect of prediction confidence on loss
343    probs = np.linspace(0.01, 0.99, 100)
344    loss_correct = -np.log(probs)  # When prediction matches true label
345    loss_incorrect = -np.log(1 - probs)  # When prediction doesn't match
346
347    fig, ax = plt.subplots(figsize=(10, 6))
348
349    ax.plot(probs, loss_correct, linewidth=3, label='Correct prediction (y=1, ŷ=p)',
350            color='green')
351    ax.plot(probs, loss_incorrect, linewidth=3, label='Incorrect prediction (y=0, ŷ=p)',
352            color='red')
353
354    ax.set_xlabel('Predicted Probability (ŷ)', fontsize=12)
355    ax.set_ylabel('Loss', fontsize=12)
356    ax.set_title('Binary Cross-Entropy Loss', fontsize=14, fontweight='bold')
357    ax.legend(fontsize=11)
358    ax.grid(True, alpha=0.3)
359    ax.set_ylim(0, 5)
360
361    plt.tight_layout()
362    plt.savefig('ml_loss_functions.png', dpi=150, bbox_inches='tight')
363    print("\nSaved ML loss functions plot to 'ml_loss_functions.png'")
364
365
366def elbo_visualization():
367    """
368    Demonstrate ELBO (Evidence Lower Bound) for variational inference.
369
370    ELBO: log p(x) >= E_q[log p(x,z)] - E_q[log q(z)]
371                    = E_q[log p(x|z)] - D_KL(q(z)||p(z))
372
373    Maximizing ELBO = Minimizing KL(q||p) while maximizing likelihood.
374    """
375    print("\n" + "="*60)
376    print("5. ELBO (Evidence Lower Bound)")
377    print("="*60)
378
379    print("\nVariational Inference Setup:")
380    print("  True posterior: p(z|x) - intractable")
381    print("  Approximate: q(z) ≈ p(z|x)")
382    print("\nELBO Decomposition:")
383    print("  log p(x) = ELBO + D_KL(q(z)||p(z|x))")
384    print("  ELBO = E_q[log p(x|z)] - D_KL(q(z)||p(z))")
385    print("\nMaximizing ELBO:")
386    print("  1. Minimizes KL divergence to true posterior")
387    print("  2. Maximizes expected log-likelihood")
388
389    # Simulate ELBO during training
390    n_iterations = 200
391
392    # ELBO components (simulated)
393    np.random.seed(42)
394    reconstruction_loss = 100 * np.exp(-np.linspace(0, 3, n_iterations))
395    reconstruction_loss += np.random.randn(n_iterations) * 2
396
397    kl_divergence = 50 * (1 - np.exp(-np.linspace(0, 2, n_iterations)))
398    kl_divergence += np.random.randn(n_iterations) * 1.5
399
400    elbo = -(reconstruction_loss + kl_divergence)  # Negative because we're minimizing loss
401
402    # Visualization
403    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10))
404
405    # Plot 1: ELBO components
406    iterations = np.arange(n_iterations)
407
408    ax1.plot(iterations, reconstruction_loss, linewidth=2, label='Reconstruction Loss (negative log p(x|z))',
409             color='red')
410    ax1.plot(iterations, kl_divergence, linewidth=2, label='KL Divergence D_KL(q||p)',
411             color='blue')
412    ax1.plot(iterations, reconstruction_loss + kl_divergence, linewidth=2.5,
413             label='Total Loss (negative ELBO)', color='black', linestyle='--')
414
415    ax1.set_xlabel('Iteration', fontsize=12)
416    ax1.set_ylabel('Loss', fontsize=12)
417    ax1.set_title('ELBO Components During Training', fontsize=14, fontweight='bold')
418    ax1.legend(fontsize=10)
419    ax1.grid(True, alpha=0.3)
420
421    # Plot 2: ELBO (to be maximized)
422    ax2.plot(iterations, elbo, linewidth=3, color='green')
423    ax2.set_xlabel('Iteration', fontsize=12)
424    ax2.set_ylabel('ELBO (to maximize)', fontsize=12)
425    ax2.set_title('Evidence Lower Bound (ELBO)', fontsize=14, fontweight='bold')
426    ax2.grid(True, alpha=0.3)
427
428    # Annotate
429    ax2.text(150, elbo[150], 'ELBO increases →\nbetter approximation',
430             fontsize=11, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
431
432    plt.tight_layout()
433    plt.savefig('elbo_visualization.png', dpi=150, bbox_inches='tight')
434    print("\nSaved ELBO visualization to 'elbo_visualization.png'")
435
436    print(f"\nFinal values:")
437    print(f"  Reconstruction Loss: {reconstruction_loss[-1]:.2f}")
438    print(f"  KL Divergence: {kl_divergence[-1]:.2f}")
439    print(f"  ELBO: {elbo[-1]:.2f}")
440
441
442if __name__ == "__main__":
443    print("="*60)
444    print("Information Theory for Machine Learning")
445    print("="*60)
446
447    # Run demonstrations
448    entropy_demo()
449    cross_entropy_kl_divergence()
450    mutual_information_demo()
451    ml_loss_functions()
452    elbo_visualization()
453
454    print("\n" + "="*60)
455    print("Key Takeaways:")
456    print("="*60)
457    print("1. Entropy: Measures uncertainty/information content")
458    print("2. Cross-Entropy: Used as loss function in classification")
459    print("3. KL Divergence: Measures difference between distributions")
460    print("4. Mutual Information: Measures dependence between variables")
461    print("5. ELBO: Fundamental to variational inference (VAE, etc.)")
462    print("6. Minimizing cross-entropy = Maximizing likelihood")