03_lora.py

Download
python 455 lines 13.3 KB
  1"""
  2Foundation Models - LoRA (Low-Rank Adaptation)
  3
  4Implements LoRA from scratch using PyTorch.
  5Demonstrates parameter-efficient fine-tuning through low-rank decomposition.
  6Compares full fine-tuning vs LoRA in terms of trainable parameters and performance.
  7
  8Requires: PyTorch
  9"""
 10
 11import torch
 12import torch.nn as nn
 13import torch.nn.functional as F
 14import numpy as np
 15
 16
 17class LoRALayer(nn.Module):
 18    """
 19    LoRA (Low-Rank Adaptation) layer.
 20
 21    Instead of updating W (d_out × d_in), we add ΔW = BA where:
 22    - B: d_out × r
 23    - A: r × d_in
 24    - r << min(d_out, d_in)
 25
 26    Forward: h = Wx + BAx = Wx + ΔWx
 27    """
 28
 29    def __init__(self, in_features, out_features, rank=4, alpha=1.0):
 30        """
 31        Initialize LoRA layer.
 32
 33        Args:
 34            in_features: Input dimension
 35            out_features: Output dimension
 36            rank: Rank of decomposition (r)
 37            alpha: Scaling factor (typically 1.0 or rank)
 38        """
 39        super().__init__()
 40
 41        self.in_features = in_features
 42        self.out_features = out_features
 43        self.rank = rank
 44        self.alpha = alpha
 45
 46        # Original pretrained weight (frozen)
 47        self.weight = nn.Parameter(torch.randn(out_features, in_features))
 48        self.weight.requires_grad = False
 49
 50        # LoRA matrices (trainable)
 51        # A: Gaussian initialization
 52        self.lora_A = nn.Parameter(torch.randn(rank, in_features) / np.sqrt(rank))
 53
 54        # B: Zero initialization (starts with identity)
 55        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
 56
 57        # Scaling factor
 58        self.scaling = alpha / rank
 59
 60    def forward(self, x):
 61        """
 62        Forward pass: h = Wx + (BA)x
 63
 64        Args:
 65            x: Input tensor (batch_size, in_features)
 66
 67        Returns:
 68            Output tensor (batch_size, out_features)
 69        """
 70        # Original forward pass (frozen)
 71        h = F.linear(x, self.weight)
 72
 73        # LoRA adaptation: x → A → B
 74        lora_out = F.linear(x, self.lora_A)  # (batch, rank)
 75        lora_out = F.linear(lora_out, self.lora_B)  # (batch, out_features)
 76
 77        # Scale and add
 78        return h + lora_out * self.scaling
 79
 80    def merge_weights(self):
 81        """Merge LoRA weights into original weight for inference."""
 82        with torch.no_grad():
 83            # W_merged = W + α/r * BA
 84            delta_w = self.lora_B @ self.lora_A * self.scaling
 85            self.weight.data += delta_w
 86
 87            # Zero out LoRA to avoid double counting
 88            self.lora_A.zero_()
 89            self.lora_B.zero_()
 90
 91    def get_num_params(self):
 92        """Get parameter counts."""
 93        total = self.weight.numel()
 94        lora_params = self.lora_A.numel() + self.lora_B.numel()
 95        trainable = lora_params
 96
 97        return {
 98            'total': total,
 99            'trainable': trainable,
100            'lora': lora_params,
101            'frozen': total,
102        }
103
104
105class SimpleLinearModel(nn.Module):
106    """Simple baseline model with standard Linear layers."""
107
108    def __init__(self, input_dim, hidden_dim, output_dim):
109        super().__init__()
110        self.fc1 = nn.Linear(input_dim, hidden_dim)
111        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
112        self.fc3 = nn.Linear(hidden_dim, output_dim)
113
114    def forward(self, x):
115        x = F.relu(self.fc1(x))
116        x = F.relu(self.fc2(x))
117        return self.fc3(x)
118
119
120class LoRAModel(nn.Module):
121    """Model with LoRA layers for parameter-efficient fine-tuning."""
122
123    def __init__(self, input_dim, hidden_dim, output_dim, rank=4):
124        super().__init__()
125        self.fc1 = LoRALayer(input_dim, hidden_dim, rank=rank)
126        self.fc2 = LoRALayer(hidden_dim, hidden_dim, rank=rank)
127        self.fc3 = LoRALayer(hidden_dim, output_dim, rank=rank)
128
129    def forward(self, x):
130        x = F.relu(self.fc1(x))
131        x = F.relu(self.fc2(x))
132        return self.fc3(x)
133
134    def get_num_params(self):
135        """Get total and trainable parameter counts."""
136        stats = {
137            'fc1': self.fc1.get_num_params(),
138            'fc2': self.fc2.get_num_params(),
139            'fc3': self.fc3.get_num_params(),
140        }
141
142        total_params = sum(s['total'] for s in stats.values())
143        trainable_params = sum(s['trainable'] for s in stats.values())
144        lora_params = sum(s['lora'] for s in stats.values())
145
146        return {
147            'total': total_params,
148            'trainable': trainable_params,
149            'lora': lora_params,
150            'layers': stats,
151        }
152
153
154# ============================================================
155# Demonstrations
156# ============================================================
157
158def demo_parameter_efficiency():
159    """Compare parameter counts: full vs LoRA."""
160    print("=" * 60)
161    print("DEMO 1: Parameter Efficiency")
162    print("=" * 60)
163
164    input_dim = 512
165    hidden_dim = 2048
166    output_dim = 128
167
168    # Standard model
169    standard_model = SimpleLinearModel(input_dim, hidden_dim, output_dim)
170    total_standard = sum(p.numel() for p in standard_model.parameters())
171
172    print(f"\nModel architecture: {input_dim}{hidden_dim}{hidden_dim}{output_dim}")
173    print(f"\nStandard model (full fine-tuning):")
174    print(f"  Total parameters: {total_standard:,}")
175    print(f"  Trainable parameters: {total_standard:,}")
176
177    # LoRA models with different ranks
178    print(f"\nLoRA models:")
179    print("-" * 60)
180
181    for rank in [2, 4, 8, 16, 32]:
182        lora_model = LoRAModel(input_dim, hidden_dim, output_dim, rank=rank)
183        stats = lora_model.get_num_params()
184
185        reduction = (1 - stats['trainable'] / total_standard) * 100
186
187        print(f"\nRank {rank}:")
188        print(f"  Total parameters: {stats['total']:,}")
189        print(f"  Trainable parameters: {stats['trainable']:,}")
190        print(f"  LoRA parameters: {stats['lora']:,}")
191        print(f"  Parameter reduction: {reduction:.2f}%")
192        print(f"  Compression ratio: {total_standard/stats['trainable']:.1f}x")
193
194
195def demo_lora_layer():
196    """Demonstrate single LoRA layer behavior."""
197    print("\n" + "=" * 60)
198    print("DEMO 2: LoRA Layer Mechanics")
199    print("=" * 60)
200
201    # Small example
202    in_dim = 8
203    out_dim = 4
204    rank = 2
205
206    layer = LoRALayer(in_dim, out_dim, rank=rank, alpha=1.0)
207
208    print(f"\nLayer: {in_dim}{out_dim}, rank={rank}")
209    print(f"\nWeight matrix W: {layer.weight.shape}")
210    print(f"LoRA matrix A: {layer.lora_A.shape}")
211    print(f"LoRA matrix B: {layer.lora_B.shape}")
212
213    # Check parameter counts
214    stats = layer.get_num_params()
215    print(f"\nParameter counts:")
216    print(f"  Original weight W: {stats['frozen']}")
217    print(f"  LoRA matrices (A + B): {stats['lora']}")
218    print(f"  Reduction: {(1 - stats['lora']/stats['frozen']) * 100:.1f}%")
219
220    # Forward pass
221    batch_size = 3
222    x = torch.randn(batch_size, in_dim)
223
224    with torch.no_grad():
225        output = layer(x)
226
227    print(f"\nForward pass:")
228    print(f"  Input shape: {x.shape}")
229    print(f"  Output shape: {output.shape}")
230
231
232def demo_training():
233    """Demonstrate training with LoRA."""
234    print("\n" + "=" * 60)
235    print("DEMO 3: Training Comparison")
236    print("=" * 60)
237
238    # Toy regression task
239    input_dim = 64
240    hidden_dim = 256
241    output_dim = 10
242    num_samples = 1000
243
244    # Generate synthetic data
245    X = torch.randn(num_samples, input_dim)
246    y = torch.randn(num_samples, output_dim)
247
248    # Standard model
249    standard_model = SimpleLinearModel(input_dim, hidden_dim, output_dim)
250    optimizer_std = torch.optim.Adam(standard_model.parameters(), lr=0.001)
251
252    # LoRA model
253    lora_model = LoRAModel(input_dim, hidden_dim, output_dim, rank=8)
254    # Only optimize LoRA parameters
255    lora_params = [p for p in lora_model.parameters() if p.requires_grad]
256    optimizer_lora = torch.optim.Adam(lora_params, lr=0.001)
257
258    print(f"\nTraining on {num_samples} samples...")
259
260    # Training loop
261    num_epochs = 50
262    batch_size = 32
263
264    std_losses = []
265    lora_losses = []
266
267    for epoch in range(num_epochs):
268        # Standard model
269        idx = torch.randperm(num_samples)[:batch_size]
270        X_batch, y_batch = X[idx], y[idx]
271
272        optimizer_std.zero_grad()
273        pred = standard_model(X_batch)
274        loss = F.mse_loss(pred, y_batch)
275        loss.backward()
276        optimizer_std.step()
277        std_losses.append(loss.item())
278
279        # LoRA model
280        optimizer_lora.zero_grad()
281        pred = lora_model(X_batch)
282        loss = F.mse_loss(pred, y_batch)
283        loss.backward()
284        optimizer_lora.step()
285        lora_losses.append(loss.item())
286
287        if (epoch + 1) % 10 == 0:
288            print(f"Epoch {epoch+1:3d}: "
289                  f"Standard loss = {std_losses[-1]:.4f}, "
290                  f"LoRA loss = {lora_losses[-1]:.4f}")
291
292    # Final evaluation
293    with torch.no_grad():
294        std_pred = standard_model(X)
295        lora_pred = lora_model(X)
296
297        std_final_loss = F.mse_loss(std_pred, y).item()
298        lora_final_loss = F.mse_loss(lora_pred, y).item()
299
300    print(f"\nFinal test loss:")
301    print(f"  Standard model: {std_final_loss:.4f}")
302    print(f"  LoRA model: {lora_final_loss:.4f}")
303
304
305def demo_rank_impact():
306    """Study impact of LoRA rank on capacity."""
307    print("\n" + "=" * 60)
308    print("DEMO 4: Impact of Rank")
309    print("=" * 60)
310
311    input_dim = 128
312    hidden_dim = 512
313    output_dim = 16
314
315    # Training data
316    num_samples = 500
317    X = torch.randn(num_samples, input_dim)
318    y = torch.randn(num_samples, output_dim)
319
320    print(f"\nTraining models with different ranks...")
321    print("-" * 60)
322
323    ranks = [1, 2, 4, 8, 16, 32]
324    results = []
325
326    for rank in ranks:
327        model = LoRAModel(input_dim, hidden_dim, output_dim, rank=rank)
328        lora_params = [p for p in model.parameters() if p.requires_grad]
329        optimizer = torch.optim.Adam(lora_params, lr=0.001)
330
331        # Quick training
332        for _ in range(100):
333            idx = torch.randperm(num_samples)[:64]
334            optimizer.zero_grad()
335            pred = model(X[idx])
336            loss = F.mse_loss(pred, y[idx])
337            loss.backward()
338            optimizer.step()
339
340        # Evaluate
341        with torch.no_grad():
342            pred = model(X)
343            final_loss = F.mse_loss(pred, y).item()
344
345        stats = model.get_num_params()
346        results.append((rank, stats['trainable'], final_loss))
347
348    # Print results
349    print(f"\n{'Rank':<8} {'Params':<12} {'Loss':<10}")
350    print("-" * 60)
351    for rank, params, loss in results:
352        print(f"{rank:<8} {params:<12,} {loss:<10.4f}")
353
354
355def demo_weight_merging():
356    """Demonstrate merging LoRA weights for inference."""
357    print("\n" + "=" * 60)
358    print("DEMO 5: Weight Merging")
359    print("=" * 60)
360
361    # Create layer
362    layer = LoRALayer(512, 256, rank=8)
363
364    # Random input
365    x = torch.randn(4, 512)
366
367    # Output before merging
368    with torch.no_grad():
369        output_before = layer(x)
370
371    print(f"\nBefore merging:")
372    print(f"  Weight norm: {layer.weight.norm().item():.4f}")
373    print(f"  LoRA A norm: {layer.lora_A.norm().item():.4f}")
374    print(f"  LoRA B norm: {layer.lora_B.norm().item():.4f}")
375    print(f"  Output sample: {output_before[0, :5]}")
376
377    # Merge weights
378    layer.merge_weights()
379
380    # Output after merging
381    with torch.no_grad():
382        output_after = layer(x)
383
384    print(f"\nAfter merging:")
385    print(f"  Weight norm: {layer.weight.norm().item():.4f}")
386    print(f"  LoRA A norm: {layer.lora_A.norm().item():.4f}")
387    print(f"  LoRA B norm: {layer.lora_B.norm().item():.4f}")
388    print(f"  Output sample: {output_after[0, :5]}")
389
390    # Check equivalence
391    diff = (output_before - output_after).abs().max().item()
392    print(f"\nMax difference: {diff:.6e}")
393    print(f"Outputs are equivalent: {diff < 1e-5}")
394
395
396def demo_adapter_composition():
397    """Demonstrate composing multiple LoRA adapters."""
398    print("\n" + "=" * 60)
399    print("DEMO 6: Adapter Composition")
400    print("=" * 60)
401
402    print("\nScenario: Fine-tune same base model for different tasks")
403
404    input_dim = 256
405    output_dim = 128
406
407    # Shared base weight
408    base_weight = torch.randn(output_dim, input_dim)
409
410    # Create adapters for different tasks
411    adapters = {}
412    for task in ['task_A', 'task_B', 'task_C']:
413        layer = LoRALayer(input_dim, output_dim, rank=4)
414        layer.weight.data = base_weight.clone()  # Share base
415        adapters[task] = layer
416
417    # Test input
418    x = torch.randn(1, input_dim)
419
420    print(f"\nBase model:")
421    with torch.no_grad():
422        base_output = F.linear(x, base_weight)
423        print(f"  Output norm: {base_output.norm().item():.4f}")
424
425    print(f"\nWith task-specific adapters:")
426    for task, adapter in adapters.items():
427        with torch.no_grad():
428            output = adapter(x)
429            diff = (output - base_output).norm().item()
430            print(f"  {task}: output norm = {output.norm().item():.4f}, "
431                  f"delta = {diff:.4f}")
432
433
434if __name__ == "__main__":
435    print("\n" + "=" * 60)
436    print("Foundation Models: LoRA (Low-Rank Adaptation)")
437    print("=" * 60)
438
439    demo_parameter_efficiency()
440    demo_lora_layer()
441    demo_training()
442    demo_rank_impact()
443    demo_weight_merging()
444    demo_adapter_composition()
445
446    print("\n" + "=" * 60)
447    print("Key Takeaways:")
448    print("=" * 60)
449    print("1. LoRA decomposes weight updates as ΔW = BA (low-rank)")
450    print("2. Reduces trainable params by 100-1000x vs full fine-tuning")
451    print("3. Rank r controls capacity/efficiency tradeoff")
452    print("4. Can merge adapters into base weights for inference")
453    print("5. Enables multi-task learning with shared base model")
454    print("=" * 60)