1"""
2Foundation Models - LoRA (Low-Rank Adaptation)
3
4Implements LoRA from scratch using PyTorch.
5Demonstrates parameter-efficient fine-tuning through low-rank decomposition.
6Compares full fine-tuning vs LoRA in terms of trainable parameters and performance.
7
8Requires: PyTorch
9"""
10
11import torch
12import torch.nn as nn
13import torch.nn.functional as F
14import numpy as np
15
16
17class LoRALayer(nn.Module):
18 """
19 LoRA (Low-Rank Adaptation) layer.
20
21 Instead of updating W (d_out × d_in), we add ΔW = BA where:
22 - B: d_out × r
23 - A: r × d_in
24 - r << min(d_out, d_in)
25
26 Forward: h = Wx + BAx = Wx + ΔWx
27 """
28
29 def __init__(self, in_features, out_features, rank=4, alpha=1.0):
30 """
31 Initialize LoRA layer.
32
33 Args:
34 in_features: Input dimension
35 out_features: Output dimension
36 rank: Rank of decomposition (r)
37 alpha: Scaling factor (typically 1.0 or rank)
38 """
39 super().__init__()
40
41 self.in_features = in_features
42 self.out_features = out_features
43 self.rank = rank
44 self.alpha = alpha
45
46 # Original pretrained weight (frozen)
47 self.weight = nn.Parameter(torch.randn(out_features, in_features))
48 self.weight.requires_grad = False
49
50 # LoRA matrices (trainable)
51 # A: Gaussian initialization
52 self.lora_A = nn.Parameter(torch.randn(rank, in_features) / np.sqrt(rank))
53
54 # B: Zero initialization (starts with identity)
55 self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
56
57 # Scaling factor
58 self.scaling = alpha / rank
59
60 def forward(self, x):
61 """
62 Forward pass: h = Wx + (BA)x
63
64 Args:
65 x: Input tensor (batch_size, in_features)
66
67 Returns:
68 Output tensor (batch_size, out_features)
69 """
70 # Original forward pass (frozen)
71 h = F.linear(x, self.weight)
72
73 # LoRA adaptation: x → A → B
74 lora_out = F.linear(x, self.lora_A) # (batch, rank)
75 lora_out = F.linear(lora_out, self.lora_B) # (batch, out_features)
76
77 # Scale and add
78 return h + lora_out * self.scaling
79
80 def merge_weights(self):
81 """Merge LoRA weights into original weight for inference."""
82 with torch.no_grad():
83 # W_merged = W + α/r * BA
84 delta_w = self.lora_B @ self.lora_A * self.scaling
85 self.weight.data += delta_w
86
87 # Zero out LoRA to avoid double counting
88 self.lora_A.zero_()
89 self.lora_B.zero_()
90
91 def get_num_params(self):
92 """Get parameter counts."""
93 total = self.weight.numel()
94 lora_params = self.lora_A.numel() + self.lora_B.numel()
95 trainable = lora_params
96
97 return {
98 'total': total,
99 'trainable': trainable,
100 'lora': lora_params,
101 'frozen': total,
102 }
103
104
105class SimpleLinearModel(nn.Module):
106 """Simple baseline model with standard Linear layers."""
107
108 def __init__(self, input_dim, hidden_dim, output_dim):
109 super().__init__()
110 self.fc1 = nn.Linear(input_dim, hidden_dim)
111 self.fc2 = nn.Linear(hidden_dim, hidden_dim)
112 self.fc3 = nn.Linear(hidden_dim, output_dim)
113
114 def forward(self, x):
115 x = F.relu(self.fc1(x))
116 x = F.relu(self.fc2(x))
117 return self.fc3(x)
118
119
120class LoRAModel(nn.Module):
121 """Model with LoRA layers for parameter-efficient fine-tuning."""
122
123 def __init__(self, input_dim, hidden_dim, output_dim, rank=4):
124 super().__init__()
125 self.fc1 = LoRALayer(input_dim, hidden_dim, rank=rank)
126 self.fc2 = LoRALayer(hidden_dim, hidden_dim, rank=rank)
127 self.fc3 = LoRALayer(hidden_dim, output_dim, rank=rank)
128
129 def forward(self, x):
130 x = F.relu(self.fc1(x))
131 x = F.relu(self.fc2(x))
132 return self.fc3(x)
133
134 def get_num_params(self):
135 """Get total and trainable parameter counts."""
136 stats = {
137 'fc1': self.fc1.get_num_params(),
138 'fc2': self.fc2.get_num_params(),
139 'fc3': self.fc3.get_num_params(),
140 }
141
142 total_params = sum(s['total'] for s in stats.values())
143 trainable_params = sum(s['trainable'] for s in stats.values())
144 lora_params = sum(s['lora'] for s in stats.values())
145
146 return {
147 'total': total_params,
148 'trainable': trainable_params,
149 'lora': lora_params,
150 'layers': stats,
151 }
152
153
154# ============================================================
155# Demonstrations
156# ============================================================
157
158def demo_parameter_efficiency():
159 """Compare parameter counts: full vs LoRA."""
160 print("=" * 60)
161 print("DEMO 1: Parameter Efficiency")
162 print("=" * 60)
163
164 input_dim = 512
165 hidden_dim = 2048
166 output_dim = 128
167
168 # Standard model
169 standard_model = SimpleLinearModel(input_dim, hidden_dim, output_dim)
170 total_standard = sum(p.numel() for p in standard_model.parameters())
171
172 print(f"\nModel architecture: {input_dim} → {hidden_dim} → {hidden_dim} → {output_dim}")
173 print(f"\nStandard model (full fine-tuning):")
174 print(f" Total parameters: {total_standard:,}")
175 print(f" Trainable parameters: {total_standard:,}")
176
177 # LoRA models with different ranks
178 print(f"\nLoRA models:")
179 print("-" * 60)
180
181 for rank in [2, 4, 8, 16, 32]:
182 lora_model = LoRAModel(input_dim, hidden_dim, output_dim, rank=rank)
183 stats = lora_model.get_num_params()
184
185 reduction = (1 - stats['trainable'] / total_standard) * 100
186
187 print(f"\nRank {rank}:")
188 print(f" Total parameters: {stats['total']:,}")
189 print(f" Trainable parameters: {stats['trainable']:,}")
190 print(f" LoRA parameters: {stats['lora']:,}")
191 print(f" Parameter reduction: {reduction:.2f}%")
192 print(f" Compression ratio: {total_standard/stats['trainable']:.1f}x")
193
194
195def demo_lora_layer():
196 """Demonstrate single LoRA layer behavior."""
197 print("\n" + "=" * 60)
198 print("DEMO 2: LoRA Layer Mechanics")
199 print("=" * 60)
200
201 # Small example
202 in_dim = 8
203 out_dim = 4
204 rank = 2
205
206 layer = LoRALayer(in_dim, out_dim, rank=rank, alpha=1.0)
207
208 print(f"\nLayer: {in_dim} → {out_dim}, rank={rank}")
209 print(f"\nWeight matrix W: {layer.weight.shape}")
210 print(f"LoRA matrix A: {layer.lora_A.shape}")
211 print(f"LoRA matrix B: {layer.lora_B.shape}")
212
213 # Check parameter counts
214 stats = layer.get_num_params()
215 print(f"\nParameter counts:")
216 print(f" Original weight W: {stats['frozen']}")
217 print(f" LoRA matrices (A + B): {stats['lora']}")
218 print(f" Reduction: {(1 - stats['lora']/stats['frozen']) * 100:.1f}%")
219
220 # Forward pass
221 batch_size = 3
222 x = torch.randn(batch_size, in_dim)
223
224 with torch.no_grad():
225 output = layer(x)
226
227 print(f"\nForward pass:")
228 print(f" Input shape: {x.shape}")
229 print(f" Output shape: {output.shape}")
230
231
232def demo_training():
233 """Demonstrate training with LoRA."""
234 print("\n" + "=" * 60)
235 print("DEMO 3: Training Comparison")
236 print("=" * 60)
237
238 # Toy regression task
239 input_dim = 64
240 hidden_dim = 256
241 output_dim = 10
242 num_samples = 1000
243
244 # Generate synthetic data
245 X = torch.randn(num_samples, input_dim)
246 y = torch.randn(num_samples, output_dim)
247
248 # Standard model
249 standard_model = SimpleLinearModel(input_dim, hidden_dim, output_dim)
250 optimizer_std = torch.optim.Adam(standard_model.parameters(), lr=0.001)
251
252 # LoRA model
253 lora_model = LoRAModel(input_dim, hidden_dim, output_dim, rank=8)
254 # Only optimize LoRA parameters
255 lora_params = [p for p in lora_model.parameters() if p.requires_grad]
256 optimizer_lora = torch.optim.Adam(lora_params, lr=0.001)
257
258 print(f"\nTraining on {num_samples} samples...")
259
260 # Training loop
261 num_epochs = 50
262 batch_size = 32
263
264 std_losses = []
265 lora_losses = []
266
267 for epoch in range(num_epochs):
268 # Standard model
269 idx = torch.randperm(num_samples)[:batch_size]
270 X_batch, y_batch = X[idx], y[idx]
271
272 optimizer_std.zero_grad()
273 pred = standard_model(X_batch)
274 loss = F.mse_loss(pred, y_batch)
275 loss.backward()
276 optimizer_std.step()
277 std_losses.append(loss.item())
278
279 # LoRA model
280 optimizer_lora.zero_grad()
281 pred = lora_model(X_batch)
282 loss = F.mse_loss(pred, y_batch)
283 loss.backward()
284 optimizer_lora.step()
285 lora_losses.append(loss.item())
286
287 if (epoch + 1) % 10 == 0:
288 print(f"Epoch {epoch+1:3d}: "
289 f"Standard loss = {std_losses[-1]:.4f}, "
290 f"LoRA loss = {lora_losses[-1]:.4f}")
291
292 # Final evaluation
293 with torch.no_grad():
294 std_pred = standard_model(X)
295 lora_pred = lora_model(X)
296
297 std_final_loss = F.mse_loss(std_pred, y).item()
298 lora_final_loss = F.mse_loss(lora_pred, y).item()
299
300 print(f"\nFinal test loss:")
301 print(f" Standard model: {std_final_loss:.4f}")
302 print(f" LoRA model: {lora_final_loss:.4f}")
303
304
305def demo_rank_impact():
306 """Study impact of LoRA rank on capacity."""
307 print("\n" + "=" * 60)
308 print("DEMO 4: Impact of Rank")
309 print("=" * 60)
310
311 input_dim = 128
312 hidden_dim = 512
313 output_dim = 16
314
315 # Training data
316 num_samples = 500
317 X = torch.randn(num_samples, input_dim)
318 y = torch.randn(num_samples, output_dim)
319
320 print(f"\nTraining models with different ranks...")
321 print("-" * 60)
322
323 ranks = [1, 2, 4, 8, 16, 32]
324 results = []
325
326 for rank in ranks:
327 model = LoRAModel(input_dim, hidden_dim, output_dim, rank=rank)
328 lora_params = [p for p in model.parameters() if p.requires_grad]
329 optimizer = torch.optim.Adam(lora_params, lr=0.001)
330
331 # Quick training
332 for _ in range(100):
333 idx = torch.randperm(num_samples)[:64]
334 optimizer.zero_grad()
335 pred = model(X[idx])
336 loss = F.mse_loss(pred, y[idx])
337 loss.backward()
338 optimizer.step()
339
340 # Evaluate
341 with torch.no_grad():
342 pred = model(X)
343 final_loss = F.mse_loss(pred, y).item()
344
345 stats = model.get_num_params()
346 results.append((rank, stats['trainable'], final_loss))
347
348 # Print results
349 print(f"\n{'Rank':<8} {'Params':<12} {'Loss':<10}")
350 print("-" * 60)
351 for rank, params, loss in results:
352 print(f"{rank:<8} {params:<12,} {loss:<10.4f}")
353
354
355def demo_weight_merging():
356 """Demonstrate merging LoRA weights for inference."""
357 print("\n" + "=" * 60)
358 print("DEMO 5: Weight Merging")
359 print("=" * 60)
360
361 # Create layer
362 layer = LoRALayer(512, 256, rank=8)
363
364 # Random input
365 x = torch.randn(4, 512)
366
367 # Output before merging
368 with torch.no_grad():
369 output_before = layer(x)
370
371 print(f"\nBefore merging:")
372 print(f" Weight norm: {layer.weight.norm().item():.4f}")
373 print(f" LoRA A norm: {layer.lora_A.norm().item():.4f}")
374 print(f" LoRA B norm: {layer.lora_B.norm().item():.4f}")
375 print(f" Output sample: {output_before[0, :5]}")
376
377 # Merge weights
378 layer.merge_weights()
379
380 # Output after merging
381 with torch.no_grad():
382 output_after = layer(x)
383
384 print(f"\nAfter merging:")
385 print(f" Weight norm: {layer.weight.norm().item():.4f}")
386 print(f" LoRA A norm: {layer.lora_A.norm().item():.4f}")
387 print(f" LoRA B norm: {layer.lora_B.norm().item():.4f}")
388 print(f" Output sample: {output_after[0, :5]}")
389
390 # Check equivalence
391 diff = (output_before - output_after).abs().max().item()
392 print(f"\nMax difference: {diff:.6e}")
393 print(f"Outputs are equivalent: {diff < 1e-5}")
394
395
396def demo_adapter_composition():
397 """Demonstrate composing multiple LoRA adapters."""
398 print("\n" + "=" * 60)
399 print("DEMO 6: Adapter Composition")
400 print("=" * 60)
401
402 print("\nScenario: Fine-tune same base model for different tasks")
403
404 input_dim = 256
405 output_dim = 128
406
407 # Shared base weight
408 base_weight = torch.randn(output_dim, input_dim)
409
410 # Create adapters for different tasks
411 adapters = {}
412 for task in ['task_A', 'task_B', 'task_C']:
413 layer = LoRALayer(input_dim, output_dim, rank=4)
414 layer.weight.data = base_weight.clone() # Share base
415 adapters[task] = layer
416
417 # Test input
418 x = torch.randn(1, input_dim)
419
420 print(f"\nBase model:")
421 with torch.no_grad():
422 base_output = F.linear(x, base_weight)
423 print(f" Output norm: {base_output.norm().item():.4f}")
424
425 print(f"\nWith task-specific adapters:")
426 for task, adapter in adapters.items():
427 with torch.no_grad():
428 output = adapter(x)
429 diff = (output - base_output).norm().item()
430 print(f" {task}: output norm = {output.norm().item():.4f}, "
431 f"delta = {diff:.4f}")
432
433
434if __name__ == "__main__":
435 print("\n" + "=" * 60)
436 print("Foundation Models: LoRA (Low-Rank Adaptation)")
437 print("=" * 60)
438
439 demo_parameter_efficiency()
440 demo_lora_layer()
441 demo_training()
442 demo_rank_impact()
443 demo_weight_merging()
444 demo_adapter_composition()
445
446 print("\n" + "=" * 60)
447 print("Key Takeaways:")
448 print("=" * 60)
449 print("1. LoRA decomposes weight updates as ΔW = BA (low-rank)")
450 print("2. Reduces trainable params by 100-1000x vs full fine-tuning")
451 print("3. Rank r controls capacity/efficiency tradeoff")
452 print("4. Can merge adapters into base weights for inference")
453 print("5. Enables multi-task learning with shared base model")
454 print("=" * 60)