08_distillation.py

Download
python 416 lines 13.2 KB
  1"""
  2Foundation Models - Knowledge Distillation
  3
  4Implements knowledge distillation from scratch using PyTorch.
  5Demonstrates student-teacher training with soft labels and temperature.
  6Shows how to compress large models into smaller ones.
  7
  8Requires: PyTorch
  9"""
 10
 11import torch
 12import torch.nn as nn
 13import torch.nn.functional as F
 14import numpy as np
 15
 16
 17class TeacherModel(nn.Module):
 18    """Large teacher model (to be distilled)."""
 19
 20    def __init__(self, input_dim, hidden_dim, num_classes):
 21        super().__init__()
 22        self.fc1 = nn.Linear(input_dim, hidden_dim)
 23        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
 24        self.fc3 = nn.Linear(hidden_dim, hidden_dim // 2)
 25        self.fc4 = nn.Linear(hidden_dim // 2, num_classes)
 26        self.dropout = nn.Dropout(0.3)
 27
 28    def forward(self, x):
 29        x = F.relu(self.fc1(x))
 30        x = self.dropout(x)
 31        x = F.relu(self.fc2(x))
 32        x = self.dropout(x)
 33        x = F.relu(self.fc3(x))
 34        x = self.fc4(x)
 35        return x
 36
 37
 38class StudentModel(nn.Module):
 39    """Small student model (distilled from teacher)."""
 40
 41    def __init__(self, input_dim, hidden_dim, num_classes):
 42        super().__init__()
 43        self.fc1 = nn.Linear(input_dim, hidden_dim)
 44        self.fc2 = nn.Linear(hidden_dim, num_classes)
 45
 46    def forward(self, x):
 47        x = F.relu(self.fc1(x))
 48        x = self.fc2(x)
 49        return x
 50
 51
 52def distillation_loss(student_logits, teacher_logits, labels, temperature=3.0, alpha=0.7):
 53    """
 54    Compute distillation loss.
 55
 56    L = α * L_soft + (1-α) * L_hard
 57
 58    where:
 59    - L_soft: KL divergence between softened student and teacher outputs
 60    - L_hard: Cross-entropy with true labels
 61    - T: Temperature for softening
 62
 63    Args:
 64        student_logits: Student model outputs (before softmax)
 65        teacher_logits: Teacher model outputs (before softmax)
 66        labels: True labels
 67        temperature: Temperature for soft targets
 68        alpha: Weight for soft loss (0-1)
 69
 70    Returns:
 71        Total distillation loss
 72    """
 73    # Soft targets (with temperature)
 74    soft_teacher = F.softmax(teacher_logits / temperature, dim=1)
 75    soft_student = F.log_softmax(student_logits / temperature, dim=1)
 76
 77    # KL divergence loss (scaled by T^2)
 78    soft_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
 79
 80    # Hard targets (standard cross-entropy)
 81    hard_loss = F.cross_entropy(student_logits, labels)
 82
 83    # Combined loss
 84    return alpha * soft_loss + (1 - alpha) * hard_loss
 85
 86
 87def count_parameters(model):
 88    """Count trainable parameters in model."""
 89    return sum(p.numel() for p in model.parameters() if p.requires_grad)
 90
 91
 92def train_teacher(model, X_train, y_train, epochs=100, batch_size=32, lr=0.001):
 93    """Train teacher model on data."""
 94    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
 95    criterion = nn.CrossEntropyLoss()
 96
 97    num_samples = X_train.shape[0]
 98
 99    for epoch in range(epochs):
100        model.train()
101        indices = torch.randperm(num_samples)[:batch_size]
102
103        X_batch = X_train[indices]
104        y_batch = y_train[indices]
105
106        optimizer.zero_grad()
107        outputs = model(X_batch)
108        loss = criterion(outputs, y_batch)
109        loss.backward()
110        optimizer.step()
111
112        if (epoch + 1) % 20 == 0:
113            model.eval()
114            with torch.no_grad():
115                preds = model(X_train).argmax(dim=1)
116                acc = (preds == y_train).float().mean()
117            print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}, Acc = {acc:.4f}")
118
119
120def train_student_standard(model, X_train, y_train, epochs=100, batch_size=32, lr=0.001):
121    """Train student model with standard supervised learning (no distillation)."""
122    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
123    criterion = nn.CrossEntropyLoss()
124
125    num_samples = X_train.shape[0]
126
127    for epoch in range(epochs):
128        model.train()
129        indices = torch.randperm(num_samples)[:batch_size]
130
131        X_batch = X_train[indices]
132        y_batch = y_train[indices]
133
134        optimizer.zero_grad()
135        outputs = model(X_batch)
136        loss = criterion(outputs, y_batch)
137        loss.backward()
138        optimizer.step()
139
140        if (epoch + 1) % 20 == 0:
141            model.eval()
142            with torch.no_grad():
143                preds = model(X_train).argmax(dim=1)
144                acc = (preds == y_train).float().mean()
145            print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}, Acc = {acc:.4f}")
146
147
148def train_student_distillation(student, teacher, X_train, y_train, epochs=100,
149                                batch_size=32, lr=0.001, temperature=3.0, alpha=0.7):
150    """Train student model with knowledge distillation."""
151    optimizer = torch.optim.Adam(student.parameters(), lr=lr)
152
153    num_samples = X_train.shape[0]
154    teacher.eval()
155
156    for epoch in range(epochs):
157        student.train()
158        indices = torch.randperm(num_samples)[:batch_size]
159
160        X_batch = X_train[indices]
161        y_batch = y_train[indices]
162
163        optimizer.zero_grad()
164
165        # Get student outputs
166        student_logits = student(X_batch)
167
168        # Get teacher outputs (no gradients)
169        with torch.no_grad():
170            teacher_logits = teacher(X_batch)
171
172        # Distillation loss
173        loss = distillation_loss(student_logits, teacher_logits, y_batch,
174                                  temperature=temperature, alpha=alpha)
175
176        loss.backward()
177        optimizer.step()
178
179        if (epoch + 1) % 20 == 0:
180            student.eval()
181            with torch.no_grad():
182                preds = student(X_train).argmax(dim=1)
183                acc = (preds == y_train).float().mean()
184            print(f"Epoch {epoch+1:3d}: Loss = {loss.item():.4f}, Acc = {acc:.4f}")
185
186
187# ============================================================
188# Demonstrations
189# ============================================================
190
191def demo_temperature_effect():
192    """Demonstrate effect of temperature on softmax."""
193    print("=" * 60)
194    print("DEMO 1: Temperature Effect on Softmax")
195    print("=" * 60)
196
197    # Logits with clear winner
198    logits = torch.tensor([[2.0, 1.0, 0.5, 0.2]])
199
200    temperatures = [1.0, 2.0, 5.0, 10.0]
201
202    print(f"\nLogits: {logits[0].tolist()}\n")
203
204    for T in temperatures:
205        probs = F.softmax(logits / T, dim=1)
206        entropy = -(probs * torch.log(probs + 1e-10)).sum().item()
207
208        print(f"Temperature {T}:")
209        print(f"  Probabilities: {probs[0].tolist()}")
210        print(f"  Entropy: {entropy:.4f}\n")
211
212    print("Higher temperature → softer (more uniform) distribution")
213    print("This reveals more of the teacher's knowledge")
214
215
216def demo_soft_vs_hard_labels():
217    """Compare soft and hard labels."""
218    print("\n" + "=" * 60)
219    print("DEMO 2: Soft vs Hard Labels")
220    print("=" * 60)
221
222    # Simulate teacher predictions
223    teacher_logits = torch.tensor([
224        [3.0, 1.5, 0.8, 0.5],  # High confidence
225        [2.0, 1.8, 1.5, 1.0],  # Lower confidence
226    ])
227
228    hard_labels = torch.tensor([0, 0])
229
230    print("\nTeacher logits:")
231    print(teacher_logits)
232
233    print("\nHard labels (one-hot):")
234    for label in hard_labels:
235        one_hot = torch.zeros(4)
236        one_hot[label] = 1
237        print(f"  {one_hot.tolist()}")
238
239    print("\nSoft labels (T=3):")
240    soft_labels = F.softmax(teacher_logits / 3.0, dim=1)
241    for i, soft in enumerate(soft_labels):
242        print(f"  {soft.tolist()}")
243
244    print("\nSoft labels encode similarity between classes!")
245
246
247def demo_basic_distillation():
248    """Demonstrate basic knowledge distillation."""
249    print("\n" + "=" * 60)
250    print("DEMO 3: Basic Knowledge Distillation")
251    print("=" * 60)
252
253    # Generate synthetic data
254    torch.manual_seed(42)
255    np.random.seed(42)
256
257    input_dim = 50
258    num_classes = 5
259    num_samples = 500
260
261    X_train = torch.randn(num_samples, input_dim)
262    y_train = torch.randint(0, num_classes, (num_samples,))
263
264    # Create models
265    teacher = TeacherModel(input_dim, hidden_dim=256, num_classes=num_classes)
266    student = StudentModel(input_dim, hidden_dim=64, num_classes=num_classes)
267
268    print(f"\nTeacher parameters: {count_parameters(teacher):,}")
269    print(f"Student parameters: {count_parameters(student):,}")
270    print(f"Compression ratio: {count_parameters(teacher) / count_parameters(student):.2f}x\n")
271
272    # Train teacher
273    print("Training teacher model...")
274    print("-" * 60)
275    train_teacher(teacher, X_train, y_train, epochs=100, batch_size=64)
276
277    # Evaluate teacher
278    teacher.eval()
279    with torch.no_grad():
280        teacher_preds = teacher(X_train).argmax(dim=1)
281        teacher_acc = (teacher_preds == y_train).float().mean()
282    print(f"\nTeacher final accuracy: {teacher_acc:.4f}")
283
284
285def demo_student_comparison():
286    """Compare student trained with and without distillation."""
287    print("\n" + "=" * 60)
288    print("DEMO 4: Student Training Comparison")
289    print("=" * 60)
290
291    # Generate data
292    torch.manual_seed(42)
293    input_dim = 50
294    num_classes = 5
295    num_samples = 500
296
297    X_train = torch.randn(num_samples, input_dim)
298    y_train = torch.randint(0, num_classes, (num_samples,))
299
300    # Train teacher
301    teacher = TeacherModel(input_dim, hidden_dim=256, num_classes=num_classes)
302    print("Training teacher...")
303    train_teacher(teacher, X_train, y_train, epochs=100, batch_size=64, lr=0.001)
304
305    teacher.eval()
306    with torch.no_grad():
307        teacher_acc = (teacher(X_train).argmax(dim=1) == y_train).float().mean()
308    print(f"Teacher accuracy: {teacher_acc:.4f}\n")
309
310    # Student 1: Standard training
311    print("-" * 60)
312    print("Student 1: Standard training (no distillation)")
313    print("-" * 60)
314    student1 = StudentModel(input_dim, hidden_dim=64, num_classes=num_classes)
315    train_student_standard(student1, X_train, y_train, epochs=100, batch_size=64, lr=0.001)
316
317    student1.eval()
318    with torch.no_grad():
319        student1_acc = (student1(X_train).argmax(dim=1) == y_train).float().mean()
320
321    # Student 2: Distillation
322    print("\n" + "-" * 60)
323    print("Student 2: Knowledge distillation (T=3, α=0.7)")
324    print("-" * 60)
325    student2 = StudentModel(input_dim, hidden_dim=64, num_classes=num_classes)
326    train_student_distillation(student2, teacher, X_train, y_train,
327                                epochs=100, batch_size=64, lr=0.001,
328                                temperature=3.0, alpha=0.7)
329
330    student2.eval()
331    with torch.no_grad():
332        student2_acc = (student2(X_train).argmax(dim=1) == y_train).float().mean()
333
334    # Compare
335    print("\n" + "=" * 60)
336    print("Comparison:")
337    print("=" * 60)
338    print(f"Teacher accuracy:              {teacher_acc:.4f}")
339    print(f"Student (standard):            {student1_acc:.4f}")
340    print(f"Student (distillation):        {student2_acc:.4f}")
341    print(f"Improvement from distillation: {(student2_acc - student1_acc):.4f}")
342
343
344def demo_hyperparameter_tuning():
345    """Study effect of temperature and alpha."""
346    print("\n" + "=" * 60)
347    print("DEMO 5: Hyperparameter Tuning")
348    print("=" * 60)
349
350    # Generate data
351    torch.manual_seed(42)
352    input_dim = 40
353    num_classes = 4
354    num_samples = 400
355
356    X_train = torch.randn(num_samples, input_dim)
357    y_train = torch.randint(0, num_classes, (num_samples,))
358
359    # Train teacher
360    teacher = TeacherModel(input_dim, hidden_dim=200, num_classes=num_classes)
361    train_teacher(teacher, X_train, y_train, epochs=80, batch_size=64, lr=0.001)
362
363    print("\n" + "-" * 60)
364    print("Testing different temperatures (α=0.7):")
365    print("-" * 60)
366
367    for T in [1.0, 2.0, 4.0, 8.0]:
368        student = StudentModel(input_dim, hidden_dim=50, num_classes=num_classes)
369        train_student_distillation(student, teacher, X_train, y_train,
370                                    epochs=80, batch_size=64, lr=0.001,
371                                    temperature=T, alpha=0.7)
372
373        student.eval()
374        with torch.no_grad():
375            acc = (student(X_train).argmax(dim=1) == y_train).float().mean()
376        print(f"T={T}: Final accuracy = {acc:.4f}")
377
378    print("\n" + "-" * 60)
379    print("Testing different alpha values (T=3.0):")
380    print("-" * 60)
381
382    for alpha in [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]:
383        student = StudentModel(input_dim, hidden_dim=50, num_classes=num_classes)
384        train_student_distillation(student, teacher, X_train, y_train,
385                                    epochs=80, batch_size=64, lr=0.001,
386                                    temperature=3.0, alpha=alpha)
387
388        student.eval()
389        with torch.no_grad():
390            acc = (student(X_train).argmax(dim=1) == y_train).float().mean()
391        print(f"α={alpha}: Final accuracy = {acc:.4f}")
392
393
394if __name__ == "__main__":
395    print("\n" + "=" * 60)
396    print("Foundation Models: Knowledge Distillation")
397    print("=" * 60)
398
399    demo_temperature_effect()
400    demo_soft_vs_hard_labels()
401    demo_basic_distillation()
402    demo_student_comparison()
403    demo_hyperparameter_tuning()
404
405    print("\n" + "=" * 60)
406    print("Key Takeaways:")
407    print("=" * 60)
408    print("1. Distillation: Compress large model → small model")
409    print("2. Soft labels: Encode class similarities, not just winner")
410    print("3. Temperature: Controls softness of distribution")
411    print("4. Loss: α × L_soft + (1-α) × L_hard")
412    print("5. Typical: T=3-5, α=0.5-0.9")
413    print("6. Student learns from teacher's mistakes and uncertainties")
414    print("7. Can achieve similar performance with much smaller model")
415    print("=" * 60)