29_stable_diffusion_finetune.py

Download
python 422 lines 14.1 KB
  1"""
  2Stable Diffusion Fine-tuning with LoRA (Low-Rank Adaptation)
  3
  4This script demonstrates how to fine-tune a diffusion model using LoRA,
  5a parameter-efficient fine-tuning technique that adds trainable low-rank
  6matrices to attention layers while keeping the base model frozen.
  7
  8Key Concepts:
  9- LoRA (Low-Rank Adaptation): Inject trainable rank decomposition matrices
 10- Diffusion Models: Iterative denoising process for image generation
 11- Parameter-Efficient Fine-tuning: Train only a small subset of parameters
 12- UNet Architecture: Core denoising network in diffusion models
 13- Text-to-Image: Conditioning image generation with text prompts
 14
 15Requirements:
 16    pip install torch torchvision diffusers transformers accelerate safetensors
 17"""
 18
 19import argparse
 20import os
 21from typing import Dict, Optional, Tuple
 22import math
 23
 24import torch
 25import torch.nn as nn
 26import torch.nn.functional as F
 27from torch.utils.data import Dataset, DataLoader
 28import torchvision.transforms as transforms
 29from PIL import Image
 30
 31try:
 32    from diffusers import StableDiffusionPipeline, UNet2DConditionModel, DDPMScheduler
 33    from transformers import CLIPTextModel, CLIPTokenizer
 34except ImportError:
 35    print("Please install required packages:")
 36    print("pip install diffusers transformers accelerate")
 37    exit(1)
 38
 39
 40class LoRALayer(nn.Module):
 41    """
 42    Low-Rank Adaptation layer that can be injected into linear layers.
 43
 44    LoRA decomposes weight updates into two low-rank matrices:
 45    W' = W + BA, where B is (out_features, rank) and A is (rank, in_features)
 46    """
 47
 48    def __init__(self, in_features: int, out_features: int, rank: int = 4, alpha: float = 1.0):
 49        super().__init__()
 50        self.rank = rank
 51        self.alpha = alpha
 52        self.scaling = alpha / rank
 53
 54        # Low-rank matrices
 55        self.lora_A = nn.Parameter(torch.zeros(rank, in_features))
 56        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))
 57
 58        # Initialize A with Kaiming uniform, B with zeros
 59        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
 60        nn.init.zeros_(self.lora_B)
 61
 62    def forward(self, x: torch.Tensor) -> torch.Tensor:
 63        """
 64        Apply LoRA transformation: (BA) * x
 65
 66        Args:
 67            x: Input tensor of shape (..., in_features)
 68
 69        Returns:
 70            LoRA output scaled by alpha/rank
 71        """
 72        # Compute low-rank transformation
 73        result = F.linear(x, self.lora_A)  # (..., rank)
 74        result = F.linear(result, self.lora_B)  # (..., out_features)
 75        return result * self.scaling
 76
 77
 78class LinearWithLoRA(nn.Module):
 79    """Linear layer with optional LoRA adaptation."""
 80
 81    def __init__(self, linear: nn.Linear, rank: int = 4, alpha: float = 1.0):
 82        super().__init__()
 83        self.linear = linear
 84        self.lora = LoRALayer(
 85            linear.in_features, linear.out_features, rank, alpha
 86        )
 87
 88        # Freeze original linear layer
 89        for param in self.linear.parameters():
 90            param.requires_grad = False
 91
 92    def forward(self, x: torch.Tensor) -> torch.Tensor:
 93        # Original linear transformation + LoRA adaptation
 94        return self.linear(x) + self.lora(x)
 95
 96
 97def inject_lora_to_attention(unet: UNet2DConditionModel, rank: int = 4,
 98                             alpha: float = 1.0) -> nn.Module:
 99    """
100    Inject LoRA layers into the attention modules of UNet.
101
102    In Stable Diffusion, attention modules have Q, K, V projection layers
103    and an output projection layer. We add LoRA to all of these.
104
105    Args:
106        unet: UNet2DConditionModel from diffusers
107        rank: Rank of LoRA matrices
108        alpha: Scaling factor for LoRA
109
110    Returns:
111        Modified UNet with LoRA layers
112    """
113    # Freeze all parameters first
114    for param in unet.parameters():
115        param.requires_grad = False
116
117    lora_count = 0
118
119    # Iterate through all attention modules
120    for name, module in unet.named_modules():
121        if "attn" in name and isinstance(module, nn.Linear):
122            # Determine parent module and attribute name
123            parent_name = ".".join(name.split(".")[:-1])
124            attr_name = name.split(".")[-1]
125
126            parent = unet
127            for part in parent_name.split("."):
128                if part:
129                    parent = getattr(parent, part)
130
131            # Replace Linear with LinearWithLoRA
132            original_linear = getattr(parent, attr_name)
133            setattr(parent, attr_name, LinearWithLoRA(original_linear, rank, alpha))
134            lora_count += 1
135
136    print(f"Injected LoRA into {lora_count} linear layers")
137    return unet
138
139
140def get_lora_parameters(model: nn.Module) -> list:
141    """Extract only LoRA parameters for training."""
142    lora_params = []
143    for name, param in model.named_parameters():
144        if "lora" in name and param.requires_grad:
145            lora_params.append(param)
146    return lora_params
147
148
149def save_lora_weights(model: nn.Module, save_path: str):
150    """Save only LoRA weights (not the full model)."""
151    lora_state_dict = {}
152    for name, param in model.named_parameters():
153        if "lora" in name:
154            lora_state_dict[name] = param.cpu()
155
156    torch.save(lora_state_dict, save_path)
157    print(f"LoRA weights saved to {save_path}")
158    print(f"File size: {os.path.getsize(save_path) / 1024 / 1024:.2f} MB")
159
160
161def load_lora_weights(model: nn.Module, load_path: str):
162    """Load LoRA weights into the model."""
163    lora_state_dict = torch.load(load_path)
164
165    # Load weights
166    model.load_state_dict(lora_state_dict, strict=False)
167    print(f"LoRA weights loaded from {load_path}")
168
169
170class DummyDataset(Dataset):
171    """
172    Dummy dataset for demonstration purposes.
173
174    In practice, you would use:
175    - Custom image-caption pairs
176    - Datasets from HuggingFace (e.g., lambdalabs/pokemon-blip-captions)
177    - Your own domain-specific data
178    """
179
180    def __init__(self, size: int = 100, img_size: int = 512):
181        self.size = size
182        self.img_size = img_size
183
184        # Dummy prompts
185        self.prompts = [
186            "a beautiful landscape with mountains",
187            "a cute cat sitting on a chair",
188            "abstract digital art",
189            "a futuristic city at night",
190            "a serene beach at sunset"
191        ]
192
193    def __len__(self) -> int:
194        return self.size
195
196    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
197        # Generate random image (in practice, load real images)
198        image = torch.randn(3, self.img_size, self.img_size)
199
200        # Random prompt
201        prompt = self.prompts[idx % len(self.prompts)]
202
203        return image, prompt
204
205
206def train_step(unet: nn.Module, noise_scheduler: DDPMScheduler,
207               text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer,
208               optimizer: torch.optim.Optimizer, batch: Tuple,
209               device: torch.device) -> float:
210    """
211    Single training step for diffusion model fine-tuning.
212
213    The training objective is to predict the noise added to the image.
214    """
215    images, prompts = batch
216    images = images.to(device)
217
218    # Encode text prompts
219    text_inputs = tokenizer(
220        prompts, padding="max_length", max_length=77,
221        truncation=True, return_tensors="pt"
222    )
223    text_embeddings = text_encoder(text_inputs.input_ids.to(device))[0]
224
225    # Sample random timesteps
226    batch_size = images.shape[0]
227    timesteps = torch.randint(
228        0, noise_scheduler.config.num_train_timesteps,
229        (batch_size,), device=device
230    ).long()
231
232    # Add noise to images
233    noise = torch.randn_like(images)
234    noisy_images = noise_scheduler.add_noise(images, noise, timesteps)
235
236    # Predict noise with UNet
237    model_output = unet(noisy_images, timesteps, text_embeddings).sample
238
239    # Compute loss (MSE between predicted and actual noise)
240    loss = F.mse_loss(model_output, noise)
241
242    # Backward pass (only LoRA parameters are updated)
243    optimizer.zero_grad()
244    loss.backward()
245    optimizer.step()
246
247    return loss.item()
248
249
250def generate_image(pipeline: StableDiffusionPipeline, prompt: str,
251                  save_path: str, num_inference_steps: int = 50):
252    """Generate image with the fine-tuned model."""
253    with torch.no_grad():
254        image = pipeline(
255            prompt,
256            num_inference_steps=num_inference_steps,
257            guidance_scale=7.5
258        ).images[0]
259
260    image.save(save_path)
261    print(f"Generated image saved to {save_path}")
262
263
264def count_parameters(model: nn.Module, trainable_only: bool = True) -> int:
265    """Count model parameters."""
266    if trainable_only:
267        return sum(p.numel() for p in model.parameters() if p.requires_grad)
268    return sum(p.numel() for p in model.parameters())
269
270
271def main(args):
272    # Set device
273    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
274    print(f"Using device: {device}")
275
276    # Load pretrained Stable Diffusion model
277    print("\nLoading pretrained Stable Diffusion model...")
278    print("Note: First run will download ~4GB of model weights")
279
280    model_id = "runwayml/stable-diffusion-v1-5"
281
282    # Load components separately for fine-tuning
283    tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
284    text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder")
285    unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
286    noise_scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
287
288    # Freeze text encoder (we only fine-tune UNet)
289    text_encoder.requires_grad_(False)
290    text_encoder.to(device)
291
292    # Inject LoRA into UNet attention layers
293    print("\nInjecting LoRA layers...")
294    unet = inject_lora_to_attention(unet, rank=args.lora_rank, alpha=args.lora_alpha)
295    unet.to(device)
296
297    # Print parameter statistics
298    total_params = count_parameters(unet, trainable_only=False)
299    trainable_params = count_parameters(unet, trainable_only=True)
300    print(f"\nUNet Parameters:")
301    print(f"  Total: {total_params:,}")
302    print(f"  Trainable (LoRA): {trainable_params:,}")
303    print(f"  Percentage trainable: {100 * trainable_params / total_params:.2f}%")
304
305    # Setup training
306    dataset = DummyDataset(size=args.num_samples, img_size=512)
307    dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
308
309    lora_params = get_lora_parameters(unet)
310    optimizer = torch.optim.AdamW(lora_params, lr=args.lr)
311
312    # Training loop
313    print("\n" + "="*60)
314    print("Starting LoRA Fine-tuning")
315    print("="*60)
316    print("Note: This is a demonstration with dummy data.")
317    print("For real fine-tuning, replace DummyDataset with actual image-caption pairs.")
318
319    unet.train()
320    for epoch in range(args.epochs):
321        epoch_loss = 0.0
322
323        for batch_idx, batch in enumerate(dataloader):
324            loss = train_step(
325                unet, noise_scheduler, text_encoder,
326                tokenizer, optimizer, batch, device
327            )
328            epoch_loss += loss
329
330            if (batch_idx + 1) % 10 == 0:
331                print(f"Epoch [{epoch+1}/{args.epochs}], "
332                      f"Batch [{batch_idx+1}/{len(dataloader)}], "
333                      f"Loss: {loss:.4f}")
334
335        avg_loss = epoch_loss / len(dataloader)
336        print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")
337
338    # Save LoRA weights
339    os.makedirs(args.output_dir, exist_ok=True)
340    lora_path = os.path.join(args.output_dir, "lora_weights.pt")
341    save_lora_weights(unet, lora_path)
342
343    # Generate sample images with fine-tuned model
344    print("\n" + "="*60)
345    print("Generating Sample Images")
346    print("="*60)
347
348    # Create pipeline with fine-tuned UNet
349    pipeline = StableDiffusionPipeline.from_pretrained(
350        model_id,
351        text_encoder=text_encoder,
352        tokenizer=tokenizer,
353        unet=unet,
354        safety_checker=None  # Disable for faster inference
355    )
356    pipeline = pipeline.to(device)
357    pipeline.set_progress_bar_config(disable=True)
358
359    # Test prompts
360    test_prompts = [
361        "a beautiful landscape with mountains",
362        "a cute cat sitting on a chair",
363        "abstract digital art with vibrant colors"
364    ]
365
366    for idx, prompt in enumerate(test_prompts):
367        save_path = os.path.join(args.output_dir, f"generated_{idx+1}.png")
368        print(f"\nGenerating: '{prompt}'")
369        generate_image(pipeline, prompt, save_path, num_inference_steps=30)
370
371    # Demonstrate loading LoRA weights
372    print("\n" + "="*60)
373    print("Demonstrating LoRA Weight Loading")
374    print("="*60)
375
376    # Create a fresh UNet and load LoRA weights
377    unet_fresh = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
378    unet_fresh = inject_lora_to_attention(unet_fresh, rank=args.lora_rank, alpha=args.lora_alpha)
379    load_lora_weights(unet_fresh, lora_path)
380
381    print("\nFine-tuning complete!")
382    print(f"LoRA weights saved to: {lora_path}")
383    print(f"Generated images saved to: {args.output_dir}")
384
385
386if __name__ == "__main__":
387    parser = argparse.ArgumentParser(
388        description="Fine-tune Stable Diffusion with LoRA"
389    )
390
391    # Training arguments
392    parser.add_argument("--epochs", type=int, default=5,
393                       help="Number of training epochs")
394    parser.add_argument("--batch_size", type=int, default=1,
395                       help="Batch size (use 1 for limited VRAM)")
396    parser.add_argument("--lr", type=float, default=1e-4,
397                       help="Learning rate")
398    parser.add_argument("--num_samples", type=int, default=50,
399                       help="Number of training samples (dummy data)")
400
401    # LoRA arguments
402    parser.add_argument("--lora_rank", type=int, default=4,
403                       help="Rank of LoRA matrices (lower = fewer parameters)")
404    parser.add_argument("--lora_alpha", type=float, default=1.0,
405                       help="LoRA scaling factor")
406
407    # Output arguments
408    parser.add_argument("--output_dir", type=str, default="./lora_output",
409                       help="Directory to save LoRA weights and generated images")
410
411    args = parser.parse_args()
412
413    # Validate CUDA availability for Stable Diffusion
414    if not torch.cuda.is_available():
415        print("\nWARNING: CUDA not available. Stable Diffusion is extremely slow on CPU.")
416        print("This script is designed for GPU usage. Proceed at your own risk.")
417        response = input("Continue anyway? (y/n): ")
418        if response.lower() != 'y':
419            exit(0)
420
421    main(args)