1"""
2Foundation Models - Knowledge Distillation
3
4Implements knowledge distillation from scratch using PyTorch.
5Demonstrates student-teacher training with soft labels and temperature.
6Shows how to compress large models into smaller ones.
7
8Requires: PyTorch
9"""
10
11import torch
12import torch.nn as nn
13import torch.nn.functional as F
14import numpy as np
15
16
17class TeacherModel(nn.Module):
18 """Large teacher model (to be distilled)."""
19
20 def __init__(self, input_dim, hidden_dim, num_classes):
21 super().__init__()
22 self.fc1 = nn.Linear(input_dim, hidden_dim)
23 self.fc2 = nn.Linear(hidden_dim, hidden_dim)
24 self.fc3 = nn.Linear(hidden_dim, hidden_dim // 2)
25 self.fc4 = nn.Linear(hidden_dim // 2, num_classes)
26 self.dropout = nn.Dropout(0.3)
27
28 def forward(self, x):
29 x = F.relu(self.fc1(x))
30 x = self.dropout(x)
31 x = F.relu(self.fc2(x))
32 x = self.dropout(x)
33 x = F.relu(self.fc3(x))
34 x = self.fc4(x)
35 return x
36
37
38class StudentModel(nn.Module):
39 """Small student model (distilled from teacher)."""
40
41 def __init__(self, input_dim, hidden_dim, num_classes):
42 super().__init__()
43 self.fc1 = nn.Linear(input_dim, hidden_dim)
44 self.fc2 = nn.Linear(hidden_dim, num_classes)
45
46 def forward(self, x):
47 x = F.relu(self.fc1(x))
48 x = self.fc2(x)
49 return x
50
51
52def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.7):
53 """
54 Compute distillation loss.
55
56 L = α * L_soft + (1-α) * L_hard
57
58 where:
59 - L_soft: KL divergence between softened student and teacher outputs
60 - L_hard: Cross-entropy with true labels
61 - T: Temperature for softening
62
63 Args:
64 student_logits: Student model outputs (before softmax)
65 teacher_logits: Teacher model outputs (before softmax)
66 labels: True labels
67 temperature: Temperature for soft targets
68 alpha: Weight for soft loss (0-1)
69
70 Returns:
71 Total distillation loss
72 """
73 # Soft targets (with temperature)
74 soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
75 soft_student = F.log_softmax(student_logits / temperature, dim=1)
76
77 # KL divergence loss (scaled by T^2)
78 soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
79
80 # Hard targets (standard cross-entropy)
81 hard_loss = F.cross_entropy(student_logits, labels)
82
83 # Combined loss
84 return alpha * soft_loss + (1 - alpha) * hard_loss
85
86
87def count_parameters(model):
88 """Count trainable parameters in model."""
89 return sum(p.numel() for p in model.parameters() if p.requires_grad)
90
91
92def train_teacher(model, X_train, y_train, epochs=100, batch_size=32, lr=0.001):
93 """Train teacher model on data."""
94 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
95 criterion = nn.CrossEntropyLoss()
96
97 num_samples = X_train.shape[0]
98
99 for epoch in range(epochs):
100 model.train()
101 indices = torch.randperm(num_samples)[:batch_size]
102
103 X_batch = X_train[indices]
104 y_batch = y_train[indices]
105
106 optimizer.zero_grad()
107 outputs = model(X_batch)
108 loss = criterion(outputs, y_batch)
109 loss.backward()
110 optimizer.step()
111
112 if (epoch + 1) % 20 == 0:
113 model.eval()
114 with torch.no_grad():
115 preds = model(X_train).argmax(dim=1)
116 acc = (preds == y_train).float().mean()
117 print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}, Acc = {acc:.4f}")
118
119
120def train_student_standard(model, X_train, y_train, epochs=100, batch_size=32, lr=0.001):
121 """Train student model with standard supervised learning (no distillation)."""
122 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
123 criterion = nn.CrossEntropyLoss()
124
125 num_samples = X_train.shape[0]
126
127 for epoch in range(epochs):
128 model.train()
129 indices = torch.randperm(num_samples)[:batch_size]
130
131 X_batch = X_train[indices]
132 y_batch = y_train[indices]
133
134 optimizer.zero_grad()
135 outputs = model(X_batch)
136 loss = criterion(outputs, y_batch)
137 loss.backward()
138 optimizer.step()
139
140 if (epoch + 1) % 20 == 0:
141 model.eval()
142 with torch.no_grad():
143 preds = model(X_train).argmax(dim=1)
144 acc = (preds == y_train).float().mean()
145 print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}, Acc = {acc:.4f}")
146
147
148def train_student_distillation(student, teacher, X_train, y_train, epochs=100,
149 batch_size=32, lr=0.001, temperature=3.0, alpha=0.7):
150 """Train student model with knowledge distillation."""
151 optimizer = torch.optim.Adam(student.parameters(), lr=lr)
152
153 num_samples = X_train.shape[0]
154 teacher.eval()
155
156 for epoch in range(epochs):
157 student.train()
158 indices = torch.randperm(num_samples)[:batch_size]
159
160 X_batch = X_train[indices]
161 y_batch = y_train[indices]
162
163 optimizer.zero_grad()
164
165 # Get student outputs
166 student_logits = student(X_batch)
167
168 # Get teacher outputs (no gradients)
169 with torch.no_grad():
170 teacher_logits = teacher(X_batch)
171
172 # Distillation loss
173 loss = distillation_loss(student_logits, teacher_logits, y_batch,
174 temperature=temperature, alpha=alpha)
175
176 loss.backward()
177 optimizer.step()
178
179 if (epoch + 1) % 20 == 0:
180 student.eval()
181 with torch.no_grad():
182 preds = student(X_train).argmax(dim=1)
183 acc = (preds == y_train).float().mean()
184 print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}, Acc = {acc:.4f}")
185
186
187# ============================================================
188# Demonstrations
189# ============================================================
190
191def demo_temperature_effect():
192 """Demonstrate effect of temperature on softmax."""
193 print("=" * 60)
194 print("DEMO 1: Temperature Effect on Softmax")
195 print("=" * 60)
196
197 # Logits with clear winner
198 logits = torch.tensor([[2.0, 1.0, 0.5, 0.2]])
199
200 temperatures = [1.0, 2.0, 5.0, 10.0]
201
202 print(f"\nLogits: {logits[0].tolist()}\n")
203
204 for T in temperatures:
205 probs = F.softmax(logits / T, dim=1)
206 entropy = -(probs * torch.log(probs + 1e-10)).sum().item()
207
208 print(f"Temperature {T}:")
209 print(f" Probabilities: {probs[0].tolist()}")
210 print(f" Entropy: {entropy:.4f}\n")
211
212 print("Higher temperature → softer (more uniform) distribution")
213 print("This reveals more of the teacher's knowledge")
214
215
216def demo_soft_vs_hard_labels():
217 """Compare soft and hard labels."""
218 print("\n" + "=" * 60)
219 print("DEMO 2: Soft vs Hard Labels")
220 print("=" * 60)
221
222 # Simulate teacher predictions
223 teacher_logits = torch.tensor([
224 [3.0, 1.5, 0.8, 0.5], # High confidence
225 [2.0, 1.8, 1.5, 1.0], # Lower confidence
226 ])
227
228 hard_labels = torch.tensor([0, 0])
229
230 print("\nTeacher logits:")
231 print(teacher_logits)
232
233 print("\nHard labels (one-hot):")
234 for label in hard_labels:
235 one_hot = torch.zeros(4)
236 one_hot[label] = 1
237 print(f" {one_hot.tolist()}")
238
239 print("\nSoft labels (T=3):")
240 soft_labels = F.softmax(teacher_logits / 3.0, dim=1)
241 for i, soft in enumerate(soft_labels):
242 print(f" {soft.tolist()}")
243
244 print("\nSoft labels encode similarity between classes!")
245
246
247def demo_basic_distillation():
248 """Demonstrate basic knowledge distillation."""
249 print("\n" + "=" * 60)
250 print("DEMO 3: Basic Knowledge Distillation")
251 print("=" * 60)
252
253 # Generate synthetic data
254 torch.manual_seed(42)
255 np.random.seed(42)
256
257 input_dim = 50
258 num_classes = 5
259 num_samples = 500
260
261 X_train = torch.randn(num_samples, input_dim)
262 y_train = torch.randint(0, num_classes, (num_samples,))
263
264 # Create models
265 teacher = TeacherModel(input_dim, hidden_dim=256, num_classes=num_classes)
266 student = StudentModel(input_dim, hidden_dim=64, num_classes=num_classes)
267
268 print(f"\nTeacher parameters: {count_parameters(teacher):,}")
269 print(f"Student parameters: {count_parameters(student):,}")
270 print(f"Compression ratio: {count_parameters(teacher) / count_parameters(student):.2f}x\n")
271
272 # Train teacher
273 print("Training teacher model...")
274 print("-" * 60)
275 train_teacher(teacher, X_train, y_train, epochs=100, batch_size=64)
276
277 # Evaluate teacher
278 teacher.eval()
279 with torch.no_grad():
280 teacher_preds = teacher(X_train).argmax(dim=1)
281 teacher_acc = (teacher_preds == y_train).float().mean()
282 print(f"\nTeacher final accuracy: {teacher_acc:.4f}")
283
284
285def demo_student_comparison():
286 """Compare student trained with and without distillation."""
287 print("\n" + "=" * 60)
288 print("DEMO 4: Student Training Comparison")
289 print("=" * 60)
290
291 # Generate data
292 torch.manual_seed(42)
293 input_dim = 50
294 num_classes = 5
295 num_samples = 500
296
297 X_train = torch.randn(num_samples, input_dim)
298 y_train = torch.randint(0, num_classes, (num_samples,))
299
300 # Train teacher
301 teacher = TeacherModel(input_dim, hidden_dim=256, num_classes=num_classes)
302 print("Training teacher...")
303 train_teacher(teacher, X_train, y_train, epochs=100, batch_size=64, lr=0.001)
304
305 teacher.eval()
306 with torch.no_grad():
307 teacher_acc = (teacher(X_train).argmax(dim=1) == y_train).float().mean()
308 print(f"Teacher accuracy: {teacher_acc:.4f}\n")
309
310 # Student 1: Standard training
311 print("-" * 60)
312 print("Student 1: Standard training (no distillation)")
313 print("-" * 60)
314 student1 = StudentModel(input_dim, hidden_dim=64, num_classes=num_classes)
315 train_student_standard(student1, X_train, y_train, epochs=100, batch_size=64, lr=0.001)
316
317 student1.eval()
318 with torch.no_grad():
319 student1_acc = (student1(X_train).argmax(dim=1) == y_train).float().mean()
320
321 # Student 2: Distillation
322 print("\n" + "-" * 60)
323 print("Student 2: Knowledge distillation (T=3, α=0.7)")
324 print("-" * 60)
325 student2 = StudentModel(input_dim, hidden_dim=64, num_classes=num_classes)
326 train_student_distillation(student2, teacher, X_train, y_train,
327 epochs=100, batch_size=64, lr=0.001,
328 temperature=3.0, alpha=0.7)
329
330 student2.eval()
331 with torch.no_grad():
332 student2_acc = (student2(X_train).argmax(dim=1) == y_train).float().mean()
333
334 # Compare
335 print("\n" + "=" * 60)
336 print("Comparison:")
337 print("=" * 60)
338 print(f"Teacher accuracy: {teacher_acc:.4f}")
339 print(f"Student (standard): {student1_acc:.4f}")
340 print(f"Student (distillation): {student2_acc:.4f}")
341 print(f"Improvement from distillation: {(student2_acc - student1_acc):.4f}")
342
343
344def demo_hyperparameter_tuning():
345 """Study effect of temperature and alpha."""
346 print("\n" + "=" * 60)
347 print("DEMO 5: Hyperparameter Tuning")
348 print("=" * 60)
349
350 # Generate data
351 torch.manual_seed(42)
352 input_dim = 40
353 num_classes = 4
354 num_samples = 400
355
356 X_train = torch.randn(num_samples, input_dim)
357 y_train = torch.randint(0, num_classes, (num_samples,))
358
359 # Train teacher
360 teacher = TeacherModel(input_dim, hidden_dim=200, num_classes=num_classes)
361 train_teacher(teacher, X_train, y_train, epochs=80, batch_size=64, lr=0.001)
362
363 print("\n" + "-" * 60)
364 print("Testing different temperatures (α=0.7):")
365 print("-" * 60)
366
367 for T in [1.0, 2.0, 4.0, 8.0]:
368 student = StudentModel(input_dim, hidden_dim=50, num_classes=num_classes)
369 train_student_distillation(student, teacher, X_train, y_train,
370 epochs=80, batch_size=64, lr=0.001,
371 temperature=T, alpha=0.7)
372
373 student.eval()
374 with torch.no_grad():
375 acc = (student(X_train).argmax(dim=1) == y_train).float().mean()
376 print(f"T={T}: Final accuracy = {acc:.4f}")
377
378 print("\n" + "-" * 60)
379 print("Testing different alpha values (T=3.0):")
380 print("-" * 60)
381
382 for alpha in [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]:
383 student = StudentModel(input_dim, hidden_dim=50, num_classes=num_classes)
384 train_student_distillation(student, teacher, X_train, y_train,
385 epochs=80, batch_size=64, lr=0.001,
386 temperature=3.0, alpha=alpha)
387
388 student.eval()
389 with torch.no_grad():
390 acc = (student(X_train).argmax(dim=1) == y_train).float().mean()
391 print(f"α={alpha}: Final accuracy = {acc:.4f}")
392
393
394if __name__ == "__main__":
395 print("\n" + "=" * 60)
396 print("Foundation Models: Knowledge Distillation")
397 print("=" * 60)
398
399 demo_temperature_effect()
400 demo_soft_vs_hard_labels()
401 demo_basic_distillation()
402 demo_student_comparison()
403 demo_hyperparameter_tuning()
404
405 print("\n" + "=" * 60)
406 print("Key Takeaways:")
407 print("=" * 60)
408 print("1. Distillation: Compress large model → small model")
409 print("2. Soft labels: Encode class similarities, not just winner")
410 print("3. Temperature: Controls softness of distribution")
411 print("4. Loss: α × L_soft + (1-α) × L_hard")
412 print("5. Typical: T=3-5, α=0.5-0.9")
413 print("6. Student learns from teacher's mistakes and uncertainties")
414 print("7. Can achieve similar performance with much smaller model")
415 print("=" * 60)