1"""
2Stable Diffusion Fine-tuning with LoRA (Low-Rank Adaptation)
3
4This script demonstrates how to fine-tune a diffusion model using LoRA,
5a parameter-efficient fine-tuning technique that adds trainable low-rank
6matrices to attention layers while keeping the base model frozen.
7
8Key Concepts:
9- LoRA (Low-Rank Adaptation): Inject trainable rank decomposition matrices
10- Diffusion Models: Iterative denoising process for image generation
11- Parameter-Efficient Fine-tuning: Train only a small subset of parameters
12- UNet Architecture: Core denoising network in diffusion models
13- Text-to-Image: Conditioning image generation with text prompts
14
15Requirements:
16 pip install torch torchvision diffusers transformers accelerate safetensors
17"""
18
19import argparse
20import os
21from typing import Dict, Optional, Tuple
22import math
23
24import torch
25import torch.nn as nn
26import torch.nn.functional as F
27from torch.utils.data import Dataset, DataLoader
28import torchvision.transforms as transforms
29from PIL import Image
30
31try:
32 from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DDPMScheduler
33 from transformers import CLIPTextModel, CLIPTokenizer
34except ImportError:
35 print("Please install required packages:")
36 print("pip install diffusers transformers accelerate")
37 exit(1)
38
39
40class LoRALayer(nn.Module):
41 """
42 Low-Rank Adaptation layer that can be injected into linear layers.
43
44 LoRA decomposes weight updates into two low-rank matrices:
45 W' = W + BA, where B is (out_features, rank) and A is (rank, in_features)
46 """
47
48 def __init__(self, in_features: int, out_features: int, rank: int = 4, alpha: float = 1.0):
49 super().__init__()
50 self.rank = rank
51 self.alpha = alpha
52 self.scaling = alpha / rank
53
54 # Low-rank matrices
55 self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
56 self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
57
58 # Initialize A with Kaiming uniform, B with zeros
59 nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
60 nn.init.zeros_(self.lora_B)
61
62 def forward(self, x: torch.Tensor) -> torch.Tensor:
63 """
64 Apply LoRA transformation: (BA) * x
65
66 Args:
67 x: Input tensor of shape (..., in_features)
68
69 Returns:
70 LoRA output scaled by alpha/rank
71 """
72 # Compute low-rank transformation
73 result = F.linear(x, self.lora_A) # (..., rank)
74 result = F.linear(result, self.lora_B) # (..., out_features)
75 return result * self.scaling
76
77
78class LinearWithLoRA(nn.Module):
79 """Linear layer with optional LoRA adaptation."""
80
81 def __init__(self, linear: nn.Linear, rank: int = 4, alpha: float = 1.0):
82 super().__init__()
83 self.linear = linear
84 self.lora = LoRALayer(
85 linear.in_features, linear.out_features, rank, alpha
86 )
87
88 # Freeze original linear layer
89 for param in self.linear.parameters():
90 param.requires_grad = False
91
92 def forward(self, x: torch.Tensor) -> torch.Tensor:
93 # Original linear transformation + LoRA adaptation
94 return self.linear(x) + self.lora(x)
95
96
97def inject_lora_to_attention(unet: UNet2DConditionModel, rank: int = 4,
98 alpha: float = 1.0) -> nn.Module:
99 """
100 Inject LoRA layers into the attention modules of UNet.
101
102 In Stable Diffusion, attention modules have Q, K, V projection layers
103 and an output projection layer. We add LoRA to all of these.
104
105 Args:
106 unet: UNet2DConditionModel from diffusers
107 rank: Rank of LoRA matrices
108 alpha: Scaling factor for LoRA
109
110 Returns:
111 Modified UNet with LoRA layers
112 """
113 # Freeze all parameters first
114 for param in unet.parameters():
115 param.requires_grad = False
116
117 lora_count = 0
118
119 # Iterate through all attention modules
120 for name, module in unet.named_modules():
121 if "attn" in name and isinstance(module, nn.Linear):
122 # Determine parent module and attribute name
123 parent_name = ".".join(name.split(".")[:-1])
124 attr_name = name.split(".")[-1]
125
126 parent = unet
127 for part in parent_name.split("."):
128 if part:
129 parent = getattr(parent, part)
130
131 # Replace Linear with LinearWithLoRA
132 original_linear = getattr(parent, attr_name)
133 setattr(parent, attr_name, LinearWithLoRA(original_linear, rank, alpha))
134 lora_count += 1
135
136 print(f"Injected LoRA into {lora_count} linear layers")
137 return unet
138
139
140def get_lora_parameters(model: nn.Module) -> list:
141 """Extract only LoRA parameters for training."""
142 lora_params = []
143 for name, param in model.named_parameters():
144 if "lora" in name and param.requires_grad:
145 lora_params.append(param)
146 return lora_params
147
148
149def save_lora_weights(model: nn.Module, save_path: str):
150 """Save only LoRA weights (not the full model)."""
151 lora_state_dict = {}
152 for name, param in model.named_parameters():
153 if "lora" in name:
154 lora_state_dict[name] = param.cpu()
155
156 torch.save(lora_state_dict, save_path)
157 print(f"LoRA weights saved to {save_path}")
158 print(f"File size: {os.path.getsize(save_path) / 1024 / 1024:.2f} MB")
159
160
161def load_lora_weights(model: nn.Module, load_path: str):
162 """Load LoRA weights into the model."""
163 lora_state_dict = torch.load(load_path)
164
165 # Load weights
166 model.load_state_dict(lora_state_dict, strict=False)
167 print(f"LoRA weights loaded from {load_path}")
168
169
170class DummyDataset(Dataset):
171 """
172 Dummy dataset for demonstration purposes.
173
174 In practice, you would use:
175 - Custom image-caption pairs
176 - Datasets from HuggingFace (e.g., lambdalabs/pokemon-blip-captions)
177 - Your own domain-specific data
178 """
179
180 def __init__(self, size: int = 100, img_size: int = 512):
181 self.size = size
182 self.img_size = img_size
183
184 # Dummy prompts
185 self.prompts = [
186 "a beautiful landscape with mountains",
187 "a cute cat sitting on a chair",
188 "abstract digital art",
189 "a futuristic city at night",
190 "a serene beach at sunset"
191 ]
192
193 def __len__(self) -> int:
194 return self.size
195
196 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
197 # Generate random image (in practice, load real images)
198 image = torch.randn(3, self.img_size, self.img_size)
199
200 # Random prompt
201 prompt = self.prompts[idx % len(self.prompts)]
202
203 return image, prompt
204
205
206def train_step(unet: nn.Module, noise_scheduler: DDPMScheduler,
207 text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer,
208 optimizer: torch.optim.Optimizer, batch: Tuple,
209 device: torch.device) -> float:
210 """
211 Single training step for diffusion model fine-tuning.
212
213 The training objective is to predict the noise added to the image.
214 """
215 images, prompts = batch
216 images = images.to(device)
217
218 # Encode text prompts
219 text_inputs = tokenizer(
220 prompts, padding="max_length", max_length=77,
221 truncation=True, return_tensors="pt"
222 )
223 text_embeddings = text_encoder(text_inputs.input_ids.to(device))[0]
224
225 # Sample random timesteps
226 batch_size = images.shape[0]
227 timesteps = torch.randint(
228 0, noise_scheduler.config.num_train_timesteps,
229 (batch_size,), device=device
230 ).long()
231
232 # Add noise to images
233 noise = torch.randn_like(images)
234 noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
235
236 # Predict noise with UNet
237 model_output = unet(noisy_images, timesteps, text_embeddings).sample
238
239 # Compute loss (MSE between predicted and actual noise)
240 loss = F.mse_loss(model_output, noise)
241
242 # Backward pass (only LoRA parameters are updated)
243 optimizer.zero_grad()
244 loss.backward()
245 optimizer.step()
246
247 return loss.item()
248
249
250def generate_image(pipeline: StableDiffusionPipeline, prompt: str,
251 save_path: str, num_inference_steps: int = 50):
252 """Generate image with the fine-tuned model."""
253 with torch.no_grad():
254 image = pipeline(
255 prompt,
256 num_inference_steps=num_inference_steps,
257 guidance_scale=7.5
258 ).images[0]
259
260 image.save(save_path)
261 print(f"Generated image saved to {save_path}")
262
263
264def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
265 """Count model parameters."""
266 if trainable_only:
267 return sum(p.numel() for p in model.parameters() if p.requires_grad)
268 return sum(p.numel() for p in model.parameters())
269
270
271def main(args):
272 # Set device
273 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
274 print(f"Using device: {device}")
275
276 # Load pretrained Stable Diffusion model
277 print("\nLoading pretrained Stable Diffusion model...")
278 print("Note: First run will download ~4GB of model weights")
279
280 model_id = "runwayml/stable-diffusion-v1-5"
281
282 # Load components separately for fine-tuning
283 tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
284 text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
285 unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
286 noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
287
288 # Freeze text encoder (we only fine-tune UNet)
289 text_encoder.requires_grad_(False)
290 text_encoder.to(device)
291
292 # Inject LoRA into UNet attention layers
293 print("\nInjecting LoRA layers...")
294 unet = inject_lora_to_attention(unet, rank=args.lora_rank, alpha=args.lora_alpha)
295 unet.to(device)
296
297 # Print parameter statistics
298 total_params = count_parameters(unet, trainable_only=False)
299 trainable_params = count_parameters(unet, trainable_only=True)
300 print(f"\nUNet Parameters:")
301 print(f" Total: {total_params:,}")
302 print(f" Trainable (LoRA): {trainable_params:,}")
303 print(f" Percentage trainable: {100 * trainable_params / total_params:.2f}%")
304
305 # Setup training
306 dataset = DummyDataset(size=args.num_samples, img_size=512)
307 dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
308
309 lora_params = get_lora_parameters(unet)
310 optimizer = torch.optim.AdamW(lora_params, lr=args.lr)
311
312 # Training loop
313 print("\n" + "="*60)
314 print("Starting LoRA Fine-tuning")
315 print("="*60)
316 print("Note: This is a demonstration with dummy data.")
317 print("For real fine-tuning, replace DummyDataset with actual image-caption pairs.")
318
319 unet.train()
320 for epoch in range(args.epochs):
321 epoch_loss = 0.0
322
323 for batch_idx, batch in enumerate(dataloader):
324 loss = train_step(
325 unet, noise_scheduler, text_encoder,
326 tokenizer, optimizer, batch, device
327 )
328 epoch_loss += loss
329
330 if (batch_idx + 1) % 10 == 0:
331 print(f"Epoch [{epoch+1}/{args.epochs}], "
332 f"Batch [{batch_idx+1}/{len(dataloader)}], "
333 f"Loss: {loss:.4f}")
334
335 avg_loss = epoch_loss / len(dataloader)
336 print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")
337
338 # Save LoRA weights
339 os.makedirs(args.output_dir, exist_ok=True)
340 lora_path = os.path.join(args.output_dir, "lora_weights.pt")
341 save_lora_weights(unet, lora_path)
342
343 # Generate sample images with fine-tuned model
344 print("\n" + "="*60)
345 print("Generating Sample Images")
346 print("="*60)
347
348 # Create pipeline with fine-tuned UNet
349 pipeline = StableDiffusionPipeline.from_pretrained(
350 model_id,
351 text_encoder=text_encoder,
352 tokenizer=tokenizer,
353 unet=unet,
354 safety_checker=None # Disable for faster inference
355 )
356 pipeline = pipeline.to(device)
357 pipeline.set_progress_bar_config(disable=True)
358
359 # Test prompts
360 test_prompts = [
361 "a beautiful landscape with mountains",
362 "a cute cat sitting on a chair",
363 "abstract digital art with vibrant colors"
364 ]
365
366 for idx, prompt in enumerate(test_prompts):
367 save_path = os.path.join(args.output_dir, f"generated_{idx+1}.png")
368 print(f"\nGenerating: '{prompt}'")
369 generate_image(pipeline, prompt, save_path, num_inference_steps=30)
370
371 # Demonstrate loading LoRA weights
372 print("\n" + "="*60)
373 print("Demonstrating LoRA Weight Loading")
374 print("="*60)
375
376 # Create a fresh UNet and load LoRA weights
377 unet_fresh = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
378 unet_fresh = inject_lora_to_attention(unet_fresh, rank=args.lora_rank, alpha=args.lora_alpha)
379 load_lora_weights(unet_fresh, lora_path)
380
381 print("\nFine-tuning complete!")
382 print(f"LoRA weights saved to: {lora_path}")
383 print(f"Generated images saved to: {args.output_dir}")
384
385
386if __name__ == "__main__":
387 parser = argparse.ArgumentParser(
388 description="Fine-tune Stable Diffusion with LoRA"
389 )
390
391 # Training arguments
392 parser.add_argument("--epochs", type=int, default=5,
393 help="Number of training epochs")
394 parser.add_argument("--batch_size", type=int, default=1,
395 help="Batch size (use 1 for limited VRAM)")
396 parser.add_argument("--lr", type=float, default=1e-4,
397 help="Learning rate")
398 parser.add_argument("--num_samples", type=int, default=50,
399 help="Number of training samples (dummy data)")
400
401 # LoRA arguments
402 parser.add_argument("--lora_rank", type=int, default=4,
403 help="Rank of LoRA matrices (lower = fewer parameters)")
404 parser.add_argument("--lora_alpha", type=float, default=1.0,
405 help="LoRA scaling factor")
406
407 # Output arguments
408 parser.add_argument("--output_dir", type=str, default="./lora_output",
409 help="Directory to save LoRA weights and generated images")
410
411 args = parser.parse_args()
412
413 # Validate CUDA availability for Stable Diffusion
414 if not torch.cuda.is_available():
415 print("\nWARNING: CUDA not available. Stable Diffusion is extremely slow on CPU.")
416 print("This script is designed for GPU usage. Proceed at your own risk.")
417 response = input("Continue anyway? (y/n): ")
418 if response.lower() != 'y':
419 exit(0)
420
421 main(args)