1"""
2MCMC Sampling and Advanced Sampling Techniques
3
4This script demonstrates various sampling methods used in probabilistic ML:
5- Rejection sampling
6- Importance sampling
7- Metropolis-Hastings MCMC
8- Reparameterization trick (VAE-style)
9
10These techniques are fundamental for:
11- Bayesian inference
12- Variational Autoencoders (VAE)
13- Generative models
14- Monte Carlo estimation
15"""
16
17import numpy as np
18import matplotlib.pyplot as plt
19from scipy.stats import norm, gamma, multivariate_normal
20import torch
21import torch.nn.functional as F
22
23
24def rejection_sampling(target_pdf, proposal_pdf, proposal_sampler, M, n_samples=10000):
25 """
26 Rejection sampling: sample from target distribution using proposal distribution.
27
28 Args:
29 target_pdf: Target probability density function
30 proposal_pdf: Proposal probability density function
31 proposal_sampler: Function to sample from proposal distribution
32 M: Constant such that target_pdf(x) <= M * proposal_pdf(x) for all x
33 n_samples: Number of samples to generate
34
35 Returns:
36 Array of accepted samples
37 """
38 samples = []
39 n_rejected = 0
40
41 while len(samples) < n_samples:
42 # Sample from proposal distribution
43 x = proposal_sampler()
44
45 # Sample uniform random variable
46 u = np.random.uniform(0, 1)
47
48 # Accept/reject criterion
49 acceptance_prob = target_pdf(x) / (M * proposal_pdf(x))
50
51 if u <= acceptance_prob:
52 samples.append(x)
53 else:
54 n_rejected += 1
55
56 acceptance_rate = n_samples / (n_samples + n_rejected)
57 print(f"Rejection Sampling - Acceptance Rate: {acceptance_rate:.3f}")
58
59 return np.array(samples)
60
61
62def importance_sampling(target_pdf, proposal_pdf, proposal_sampler, n_samples=10000):
63 """
64 Importance sampling: estimate expectations under target distribution
65 using samples from proposal distribution.
66
67 Returns samples and importance weights.
68 """
69 # Sample from proposal
70 samples = np.array([proposal_sampler() for _ in range(n_samples)])
71
72 # Compute importance weights
73 weights = target_pdf(samples) / proposal_pdf(samples)
74
75 # Normalize weights
76 weights = weights / np.sum(weights)
77
78 return samples, weights
79
80
81def metropolis_hastings(target_pdf, initial_state, proposal_std, n_samples=10000, burn_in=1000):
82 """
83 Metropolis-Hastings MCMC sampler.
84
85 Uses symmetric Gaussian proposal distribution.
86
87 Args:
88 target_pdf: Target probability density (unnormalized is OK)
89 initial_state: Starting point for the chain
90 proposal_std: Standard deviation of Gaussian proposal
91 n_samples: Number of samples to generate
92 burn_in: Number of initial samples to discard
93
94 Returns:
95 Array of samples, acceptance rate
96 """
97 samples = []
98 current = initial_state
99 n_accepted = 0
100
101 for i in range(n_samples + burn_in):
102 # Propose new state (symmetric Gaussian)
103 proposed = current + np.random.normal(0, proposal_std, size=current.shape)
104
105 # Compute acceptance ratio
106 acceptance_ratio = target_pdf(proposed) / target_pdf(current)
107
108 # Accept/reject
109 if np.random.uniform(0, 1) < acceptance_ratio:
110 current = proposed
111 n_accepted += 1
112
113 # Store sample after burn-in
114 if i >= burn_in:
115 samples.append(current.copy())
116
117 acceptance_rate = n_accepted / (n_samples + burn_in)
118 print(f"Metropolis-Hastings - Acceptance Rate: {acceptance_rate:.3f}")
119
120 return np.array(samples), acceptance_rate
121
122
123def reparameterization_trick_demo():
124 """
125 Reparameterization trick used in Variational Autoencoders (VAE).
126
127 Instead of sampling z ~ N(mu, sigma^2), we:
128 1. Sample epsilon ~ N(0, 1)
129 2. Compute z = mu + sigma * epsilon
130
131 This allows gradients to flow through mu and sigma.
132 """
133 # VAE latent space parameters
134 mu = torch.tensor([2.0, -1.0], requires_grad=True)
135 log_var = torch.tensor([0.5, 1.0], requires_grad=True)
136
137 # Reparameterization trick
138 def sample_latent(mu, log_var, n_samples=1000):
139 std = torch.exp(0.5 * log_var)
140 eps = torch.randn(n_samples, mu.shape[0])
141 z = mu + std * eps
142 return z
143
144 # Sample latent variables
145 z_samples = sample_latent(mu, log_var)
146
147 # Compute a simple loss (e.g., reconstruction + KL divergence)
148 # KL divergence for N(mu, sigma^2) and N(0, 1)
149 kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
150
151 # Dummy reconstruction loss
152 recon_loss = torch.mean(z_samples.pow(2))
153
154 total_loss = recon_loss + 0.1 * kl_loss
155
156 # Backpropagation works because of reparameterization
157 total_loss.backward()
158
159 print("\nReparameterization Trick (VAE-style):")
160 print(f"Latent mean (mu): {mu.detach().numpy()}")
161 print(f"Latent log_var: {log_var.detach().numpy()}")
162 print(f"Gradient of mu: {mu.grad.numpy()}")
163 print(f"Gradient of log_var: {log_var.grad.numpy()}")
164 print(f"KL divergence: {kl_loss.item():.4f}")
165
166 return z_samples.detach().numpy()
167
168
169def visualize_sampling_comparison():
170 """
171 Compare different sampling methods on a mixture of Gaussians.
172 """
173 # Target: mixture of two Gaussians
174 def target_pdf(x):
175 return 0.6 * norm.pdf(x, loc=-2, scale=0.8) + 0.4 * norm.pdf(x, loc=3, scale=1.2)
176
177 # Proposal: single Gaussian
178 proposal_mean = 0.5
179 proposal_std = 3.0
180
181 def proposal_pdf(x):
182 return norm.pdf(x, loc=proposal_mean, scale=proposal_std)
183
184 def proposal_sampler():
185 return np.random.normal(proposal_mean, proposal_std)
186
187 # Find M for rejection sampling
188 x_range = np.linspace(-6, 8, 1000)
189 M = np.max(target_pdf(x_range) / proposal_pdf(x_range)) * 1.1
190
191 # 1. Rejection Sampling
192 print("\n=== Rejection Sampling ===")
193 rejection_samples = rejection_sampling(target_pdf, proposal_pdf, proposal_sampler, M, n_samples=5000)
194
195 # 2. Importance Sampling
196 print("\n=== Importance Sampling ===")
197 importance_samples, weights = importance_sampling(target_pdf, proposal_pdf, proposal_sampler, n_samples=5000)
198
199 # 3. Metropolis-Hastings
200 print("\n=== Metropolis-Hastings MCMC ===")
201 mh_samples, _ = metropolis_hastings(target_pdf, initial_state=np.array([0.0]),
202 proposal_std=2.0, n_samples=5000, burn_in=500)
203
204 # Visualization
205 fig, axes = plt.subplots(2, 2, figsize=(14, 10))
206
207 # True distribution
208 x = np.linspace(-6, 8, 1000)
209 true_pdf = target_pdf(x)
210
211 # Plot 1: Rejection Sampling
212 axes[0, 0].hist(rejection_samples, bins=50, density=True, alpha=0.6, label='Samples')
213 axes[0, 0].plot(x, true_pdf, 'r-', linewidth=2, label='True PDF')
214 axes[0, 0].set_title('Rejection Sampling')
215 axes[0, 0].set_xlabel('x')
216 axes[0, 0].set_ylabel('Density')
217 axes[0, 0].legend()
218 axes[0, 0].grid(True, alpha=0.3)
219
220 # Plot 2: Importance Sampling
221 axes[0, 1].hist(importance_samples, bins=50, density=True, alpha=0.3, label='Proposal Samples')
222 # Weighted histogram
223 axes[0, 1].hist(importance_samples, bins=50, weights=weights*len(weights),
224 density=True, alpha=0.6, label='Weighted Samples')
225 axes[0, 1].plot(x, true_pdf, 'r-', linewidth=2, label='True PDF')
226 axes[0, 1].set_title('Importance Sampling')
227 axes[0, 1].set_xlabel('x')
228 axes[0, 1].set_ylabel('Density')
229 axes[0, 1].legend()
230 axes[0, 1].grid(True, alpha=0.3)
231
232 # Plot 3: Metropolis-Hastings
233 axes[1, 0].hist(mh_samples.flatten(), bins=50, density=True, alpha=0.6, label='MCMC Samples')
234 axes[1, 0].plot(x, true_pdf, 'r-', linewidth=2, label='True PDF')
235 axes[1, 0].set_title('Metropolis-Hastings MCMC')
236 axes[1, 0].set_xlabel('x')
237 axes[1, 0].set_ylabel('Density')
238 axes[1, 0].legend()
239 axes[1, 0].grid(True, alpha=0.3)
240
241 # Plot 4: MCMC Trace
242 axes[1, 1].plot(mh_samples[:500], alpha=0.7)
243 axes[1, 1].set_title('MCMC Trace (first 500 samples)')
244 axes[1, 1].set_xlabel('Iteration')
245 axes[1, 1].set_ylabel('Sample Value')
246 axes[1, 1].grid(True, alpha=0.3)
247
248 plt.tight_layout()
249 plt.savefig('/tmp/mcmc_sampling_comparison.png', dpi=150, bbox_inches='tight')
250 print("\nPlot saved to /tmp/mcmc_sampling_comparison.png")
251 plt.close()
252
253
254def visualize_reparameterization():
255 """
256 Visualize samples from reparameterization trick.
257 """
258 z_samples = reparameterization_trick_demo()
259
260 fig, axes = plt.subplots(1, 2, figsize=(12, 5))
261
262 # 2D scatter
263 axes[0].scatter(z_samples[:, 0], z_samples[:, 1], alpha=0.3, s=10)
264 axes[0].scatter([2.0], [-1.0], c='red', s=100, marker='x', linewidths=3, label='Mean (μ)')
265 axes[0].set_title('Latent Space Samples (Reparameterization Trick)')
266 axes[0].set_xlabel('z_1')
267 axes[0].set_ylabel('z_2')
268 axes[0].legend()
269 axes[0].grid(True, alpha=0.3)
270
271 # Marginal distributions
272 axes[1].hist(z_samples[:, 0], bins=30, alpha=0.6, label='z_1', density=True)
273 axes[1].hist(z_samples[:, 1], bins=30, alpha=0.6, label='z_2', density=True)
274 axes[1].set_title('Marginal Distributions')
275 axes[1].set_xlabel('Value')
276 axes[1].set_ylabel('Density')
277 axes[1].legend()
278 axes[1].grid(True, alpha=0.3)
279
280 plt.tight_layout()
281 plt.savefig('/tmp/reparameterization_trick.png', dpi=150, bbox_inches='tight')
282 print("Plot saved to /tmp/reparameterization_trick.png")
283 plt.close()
284
285
286if __name__ == "__main__":
287 print("=" * 60)
288 print("MCMC Sampling and Advanced Sampling Techniques")
289 print("=" * 60)
290
291 # Set random seed for reproducibility
292 np.random.seed(42)
293 torch.manual_seed(42)
294
295 # Run sampling comparisons
296 visualize_sampling_comparison()
297
298 # Demonstrate reparameterization trick
299 print("\n" + "=" * 60)
300 visualize_reparameterization()
301
302 print("\n" + "=" * 60)
303 print("Summary:")
304 print("- Rejection sampling: Simple but can be inefficient if M is large")
305 print("- Importance sampling: Useful for expectation estimation")
306 print("- Metropolis-Hastings: MCMC method, eventually converges to target")
307 print("- Reparameterization: Enables gradient-based optimization (VAE)")
308 print("=" * 60)