09_mcmc_sampling.py

Download
python 309 lines 9.9 KB
  1"""
  2MCMC Sampling and Advanced Sampling Techniques
  3
  4This script demonstrates various sampling methods used in probabilistic ML:
  5- Rejection sampling
  6- Importance sampling
  7- Metropolis-Hastings MCMC
  8- Reparameterization trick (VAE-style)
  9
 10These techniques are fundamental for:
 11- Bayesian inference
 12- Variational Autoencoders (VAE)
 13- Generative models
 14- Monte Carlo estimation
 15"""
 16
 17import numpy as np
 18import matplotlib.pyplot as plt
 19from scipy.stats import norm, gamma, multivariate_normal
 20import torch
 21import torch.nn.functional as F
 22
 23
 24def rejection_sampling(target_pdf, proposal_pdf, proposal_sampler, M, n_samples=10000):
 25    """
 26    Rejection sampling: sample from target distribution using proposal distribution.
 27
 28    Args:
 29        target_pdf: Target probability density function
 30        proposal_pdf: Proposal probability density function
 31        proposal_sampler: Function to sample from proposal distribution
 32        M: Constant such that target_pdf(x) <= M * proposal_pdf(x) for all x
 33        n_samples: Number of samples to generate
 34
 35    Returns:
 36        Array of accepted samples
 37    """
 38    samples = []
 39    n_rejected = 0
 40
 41    while len(samples) < n_samples:
 42        # Sample from proposal distribution
 43        x = proposal_sampler()
 44
 45        # Sample uniform random variable
 46        u = np.random.uniform(0, 1)
 47
 48        # Accept/reject criterion
 49        acceptance_prob = target_pdf(x) / (M * proposal_pdf(x))
 50
 51        if u <= acceptance_prob:
 52            samples.append(x)
 53        else:
 54            n_rejected += 1
 55
 56    acceptance_rate = n_samples / (n_samples + n_rejected)
 57    print(f"Rejection Sampling - Acceptance Rate: {acceptance_rate:.3f}")
 58
 59    return np.array(samples)
 60
 61
 62def importance_sampling(target_pdf, proposal_pdf, proposal_sampler, n_samples=10000):
 63    """
 64    Importance sampling: estimate expectations under target distribution
 65    using samples from proposal distribution.
 66
 67    Returns samples and importance weights.
 68    """
 69    # Sample from proposal
 70    samples = np.array([proposal_sampler() for _ in range(n_samples)])
 71
 72    # Compute importance weights
 73    weights = target_pdf(samples) / proposal_pdf(samples)
 74
 75    # Normalize weights
 76    weights = weights / np.sum(weights)
 77
 78    return samples, weights
 79
 80
 81def metropolis_hastings(target_pdf, initial_state, proposal_std, n_samples=10000, burn_in=1000):
 82    """
 83    Metropolis-Hastings MCMC sampler.
 84
 85    Uses symmetric Gaussian proposal distribution.
 86
 87    Args:
 88        target_pdf: Target probability density (unnormalized is OK)
 89        initial_state: Starting point for the chain
 90        proposal_std: Standard deviation of Gaussian proposal
 91        n_samples: Number of samples to generate
 92        burn_in: Number of initial samples to discard
 93
 94    Returns:
 95        Array of samples, acceptance rate
 96    """
 97    samples = []
 98    current = initial_state
 99    n_accepted = 0
100
101    for i in range(n_samples + burn_in):
102        # Propose new state (symmetric Gaussian)
103        proposed = current + np.random.normal(0, proposal_std, size=current.shape)
104
105        # Compute acceptance ratio
106        acceptance_ratio = target_pdf(proposed) / target_pdf(current)
107
108        # Accept/reject
109        if np.random.uniform(0, 1) < acceptance_ratio:
110            current = proposed
111            n_accepted += 1
112
113        # Store sample after burn-in
114        if i >= burn_in:
115            samples.append(current.copy())
116
117    acceptance_rate = n_accepted / (n_samples + burn_in)
118    print(f"Metropolis-Hastings - Acceptance Rate: {acceptance_rate:.3f}")
119
120    return np.array(samples), acceptance_rate
121
122
123def reparameterization_trick_demo():
124    """
125    Reparameterization trick used in Variational Autoencoders (VAE).
126
127    Instead of sampling z ~ N(mu, sigma^2), we:
128    1. Sample epsilon ~ N(0, 1)
129    2. Compute z = mu + sigma * epsilon
130
131    This allows gradients to flow through mu and sigma.
132    """
133    # VAE latent space parameters
134    mu = torch.tensor([2.0, -1.0], requires_grad=True)
135    log_var = torch.tensor([0.5, 1.0], requires_grad=True)
136
137    # Reparameterization trick
138    def sample_latent(mu, log_var, n_samples=1000):
139        std = torch.exp(0.5 * log_var)
140        eps = torch.randn(n_samples, mu.shape[0])
141        z = mu + std * eps
142        return z
143
144    # Sample latent variables
145    z_samples = sample_latent(mu, log_var)
146
147    # Compute a simple loss (e.g., reconstruction + KL divergence)
148    # KL divergence for N(mu, sigma^2) and N(0, 1)
149    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
150
151    # Dummy reconstruction loss
152    recon_loss = torch.mean(z_samples.pow(2))
153
154    total_loss = recon_loss + 0.1 * kl_loss
155
156    # Backpropagation works because of reparameterization
157    total_loss.backward()
158
159    print("\nReparameterization Trick (VAE-style):")
160    print(f"Latent mean (mu): {mu.detach().numpy()}")
161    print(f"Latent log_var: {log_var.detach().numpy()}")
162    print(f"Gradient of mu: {mu.grad.numpy()}")
163    print(f"Gradient of log_var: {log_var.grad.numpy()}")
164    print(f"KL divergence: {kl_loss.item():.4f}")
165
166    return z_samples.detach().numpy()
167
168
169def visualize_sampling_comparison():
170    """
171    Compare different sampling methods on a mixture of Gaussians.
172    """
173    # Target: mixture of two Gaussians
174    def target_pdf(x):
175        return 0.6 * norm.pdf(x, loc=-2, scale=0.8) + 0.4 * norm.pdf(x, loc=3, scale=1.2)
176
177    # Proposal: single Gaussian
178    proposal_mean = 0.5
179    proposal_std = 3.0
180
181    def proposal_pdf(x):
182        return norm.pdf(x, loc=proposal_mean, scale=proposal_std)
183
184    def proposal_sampler():
185        return np.random.normal(proposal_mean, proposal_std)
186
187    # Find M for rejection sampling
188    x_range = np.linspace(-6, 8, 1000)
189    M = np.max(target_pdf(x_range) / proposal_pdf(x_range)) * 1.1
190
191    # 1. Rejection Sampling
192    print("\n=== Rejection Sampling ===")
193    rejection_samples = rejection_sampling(target_pdf, proposal_pdf, proposal_sampler, M, n_samples=5000)
194
195    # 2. Importance Sampling
196    print("\n=== Importance Sampling ===")
197    importance_samples, weights = importance_sampling(target_pdf, proposal_pdf, proposal_sampler, n_samples=5000)
198
199    # 3. Metropolis-Hastings
200    print("\n=== Metropolis-Hastings MCMC ===")
201    mh_samples, _ = metropolis_hastings(target_pdf, initial_state=np.array([0.0]),
202                                        proposal_std=2.0, n_samples=5000, burn_in=500)
203
204    # Visualization
205    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
206
207    # True distribution
208    x = np.linspace(-6, 8, 1000)
209    true_pdf = target_pdf(x)
210
211    # Plot 1: Rejection Sampling
212    axes[0, 0].hist(rejection_samples, bins=50, density=True, alpha=0.6, label='Samples')
213    axes[0, 0].plot(x, true_pdf, 'r-', linewidth=2, label='True PDF')
214    axes[0, 0].set_title('Rejection Sampling')
215    axes[0, 0].set_xlabel('x')
216    axes[0, 0].set_ylabel('Density')
217    axes[0, 0].legend()
218    axes[0, 0].grid(True, alpha=0.3)
219
220    # Plot 2: Importance Sampling
221    axes[0, 1].hist(importance_samples, bins=50, density=True, alpha=0.3, label='Proposal Samples')
222    # Weighted histogram
223    axes[0, 1].hist(importance_samples, bins=50, weights=weights*len(weights),
224                    density=True, alpha=0.6, label='Weighted Samples')
225    axes[0, 1].plot(x, true_pdf, 'r-', linewidth=2, label='True PDF')
226    axes[0, 1].set_title('Importance Sampling')
227    axes[0, 1].set_xlabel('x')
228    axes[0, 1].set_ylabel('Density')
229    axes[0, 1].legend()
230    axes[0, 1].grid(True, alpha=0.3)
231
232    # Plot 3: Metropolis-Hastings
233    axes[1, 0].hist(mh_samples.flatten(), bins=50, density=True, alpha=0.6, label='MCMC Samples')
234    axes[1, 0].plot(x, true_pdf, 'r-', linewidth=2, label='True PDF')
235    axes[1, 0].set_title('Metropolis-Hastings MCMC')
236    axes[1, 0].set_xlabel('x')
237    axes[1, 0].set_ylabel('Density')
238    axes[1, 0].legend()
239    axes[1, 0].grid(True, alpha=0.3)
240
241    # Plot 4: MCMC Trace
242    axes[1, 1].plot(mh_samples[:500], alpha=0.7)
243    axes[1, 1].set_title('MCMC Trace (first 500 samples)')
244    axes[1, 1].set_xlabel('Iteration')
245    axes[1, 1].set_ylabel('Sample Value')
246    axes[1, 1].grid(True, alpha=0.3)
247
248    plt.tight_layout()
249    plt.savefig('/tmp/mcmc_sampling_comparison.png', dpi=150, bbox_inches='tight')
250    print("\nPlot saved to /tmp/mcmc_sampling_comparison.png")
251    plt.close()
252
253
254def visualize_reparameterization():
255    """
256    Visualize samples from reparameterization trick.
257    """
258    z_samples = reparameterization_trick_demo()
259
260    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
261
262    # 2D scatter
263    axes[0].scatter(z_samples[:, 0], z_samples[:, 1], alpha=0.3, s=10)
264    axes[0].scatter([2.0], [-1.0], c='red', s=100, marker='x', linewidths=3, label='Mean (μ)')
265    axes[0].set_title('Latent Space Samples (Reparameterization Trick)')
266    axes[0].set_xlabel('z_1')
267    axes[0].set_ylabel('z_2')
268    axes[0].legend()
269    axes[0].grid(True, alpha=0.3)
270
271    # Marginal distributions
272    axes[1].hist(z_samples[:, 0], bins=30, alpha=0.6, label='z_1', density=True)
273    axes[1].hist(z_samples[:, 1], bins=30, alpha=0.6, label='z_2', density=True)
274    axes[1].set_title('Marginal Distributions')
275    axes[1].set_xlabel('Value')
276    axes[1].set_ylabel('Density')
277    axes[1].legend()
278    axes[1].grid(True, alpha=0.3)
279
280    plt.tight_layout()
281    plt.savefig('/tmp/reparameterization_trick.png', dpi=150, bbox_inches='tight')
282    print("Plot saved to /tmp/reparameterization_trick.png")
283    plt.close()
284
285
286if __name__ == "__main__":
287    print("=" * 60)
288    print("MCMC Sampling and Advanced Sampling Techniques")
289    print("=" * 60)
290
291    # Set random seed for reproducibility
292    np.random.seed(42)
293    torch.manual_seed(42)
294
295    # Run sampling comparisons
296    visualize_sampling_comparison()
297
298    # Demonstrate reparameterization trick
299    print("\n" + "=" * 60)
300    visualize_reparameterization()
301
302    print("\n" + "=" * 60)
303    print("Summary:")
304    print("- Rejection sampling: Simple but can be inefficient if M is large")
305    print("- Importance sampling: Useful for expectation estimation")
306    print("- Metropolis-Hastings: MCMC method, eventually converges to target")
307    print("- Reparameterization: Enables gradient-based optimization (VAE)")
308    print("=" * 60)