16. Advanced Vision-Language

16. Advanced Vision-Language

Overview

Vision-Language Models (VLMs) are models that understand both images and text together. This lesson covers state-of-the-art VLM architectures like LLaVA and Qwen-VL, as well as Visual Instruction Tuning techniques.


1. VLM Paradigm

1.1 Evolution

┌──────────────────────────────────────────────────────────────────┐
│                    VLM Evolution                                 │
├──────────────────────────────────────────────────────────────────┤
│                                                                  │
│  2021: CLIP                                                      │
│  - Image-Text contrastive learning                              │
│  - Zero-shot classification                                     │
│                                                                  │
│  2022: Flamingo                                                  │
│  - Inject visual tokens into LLM                                │
│  - Few-shot vision-language learning                            │
│                                                                  │
│  2023: LLaVA                                                     │
│  - Visual Instruction Tuning                                    │
│  - Open-source GPT-4V alternative                               │
│                                                                  │
│  2024: LLaVA-NeXT, Qwen-VL, Phi-3-Vision                        │
│  - High resolution, multi-image, video                          │
│  - Commercial-grade performance                                 │
│                                                                  │
└──────────────────────────────────────────────────────────────────┘

1.2 Architecture Comparison

Model Vision Encoder LLM Connection Method
LLaVA CLIP ViT-L Vicuna/LLaMA Linear Projection
Qwen-VL ViT-G Qwen Cross-Attention
InternVL InternViT InternLM MLP
Phi-3-Vision CLIP ViT Phi-3 Linear
GPT-4V Unknown GPT-4 Unknown

2. LLaVA (Large Language and Vision Assistant)

2.1 Architecture

LLaVA Structure:

Image  CLIP ViT-L/14  Visual Features (576 tokens)
                
         Linear Projection
                
         Visual Tokens
                
[System] [Visual Tokens] [User Query]  LLaMA/Vicuna  Response

Training Stages:
1. Pre-training: Image-Text alignment (CC3M)
2. Fine-tuning: Visual Instruction Tuning (158K)

2.2 Implementation

import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer

class LLaVAModel(nn.Module):
    """LLaVA-style Vision-Language Model"""

    def __init__(
        self,
        vision_encoder: str = "openai/clip-vit-large-patch14",
        llm: str = "lmsys/vicuna-7b-v1.5",
        freeze_vision: bool = True,
        freeze_llm: bool = False
    ):
        super().__init__()

        # Vision Encoder
        self.vision_encoder = CLIPVisionModel.from_pretrained(vision_encoder)
        self.vision_hidden_size = self.vision_encoder.config.hidden_size

        # Language Model
        self.llm = LlamaForCausalLM.from_pretrained(llm)
        self.llm_hidden_size = self.llm.config.hidden_size

        # Vision-Language Projection
        self.vision_projection = nn.Linear(
            self.vision_hidden_size,
            self.llm_hidden_size
        )

        # Freeze encoders
        if freeze_vision:
            for param in self.vision_encoder.parameters():
                param.requires_grad = False

        if freeze_llm:
            for param in self.llm.parameters():
                param.requires_grad = False

    def encode_images(self, images: torch.Tensor) -> torch.Tensor:
        """
        Image encoding

        Args:
            images: (B, C, H, W)

        Returns:
            visual_tokens: (B, num_patches, llm_hidden_size)
        """
        # CLIP encoding
        vision_outputs = self.vision_encoder(images)
        image_features = vision_outputs.last_hidden_state  # (B, 257, 1024)

        # Exclude [CLS] token
        image_features = image_features[:, 1:, :]  # (B, 256, 1024)

        # Project to LLM space
        visual_tokens = self.vision_projection(image_features)

        return visual_tokens

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        images: torch.Tensor = None,
        image_positions: torch.Tensor = None,
        labels: torch.Tensor = None
    ):
        """
        Forward pass

        Args:
            input_ids: (B, seq_len) text tokens
            attention_mask: (B, seq_len)
            images: (B, C, H, W) images
            image_positions: positions where image tokens should be inserted
            labels: (B, seq_len) for training
        """
        B, seq_len = input_ids.shape

        # Text embeddings
        text_embeds = self.llm.model.embed_tokens(input_ids)

        # Image embeddings
        if images is not None:
            visual_tokens = self.encode_images(images)  # (B, num_patches, hidden)

            # Interleave visual tokens with text
            # Simplified: add images before text
            combined_embeds = torch.cat([visual_tokens, text_embeds], dim=1)

            # Adjust attention mask
            visual_mask = torch.ones(B, visual_tokens.shape[1], device=attention_mask.device)
            combined_mask = torch.cat([visual_mask, attention_mask], dim=1)
        else:
            combined_embeds = text_embeds
            combined_mask = attention_mask

        # LLM forward
        outputs = self.llm(
            inputs_embeds=combined_embeds,
            attention_mask=combined_mask,
            labels=labels,
            return_dict=True
        )

        return outputs


class VisualInstructionDataset:
    """Visual Instruction Tuning Dataset"""

    INSTRUCTION_TEMPLATES = [
        "Describe this image in detail.",
        "What can you see in this image?",
        "Explain what is happening in this picture.",
        "<question>",  # VQA
    ]

    def __init__(self, data_path: str):
        """
        Data format:
        {
            "image": "path/to/image.jpg",
            "conversations": [
                {"from": "human", "value": "<image>\nDescribe this image."},
                {"from": "gpt", "value": "This image shows..."}
            ]
        }
        """
        import json
        with open(data_path, 'r') as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Load image
        from PIL import Image
        image = Image.open(item['image']).convert('RGB')

        # Construct conversation
        conversations = item['conversations']
        human_input = conversations[0]['value']
        assistant_output = conversations[1]['value']

        return {
            'image': image,
            'human': human_input,
            'assistant': assistant_output
        }

2.3 LLaVA-NeXT Improvements

class LLaVANeXTConfig:
    """
    LLaVA-NeXT Improvements

    1. High-resolution support (AnyRes)
    2. Better Vision Encoder (SigLIP)
    3. Larger LLM (Llama 3, Qwen 2)
    """

    # AnyRes: handle various resolutions
    SUPPORTED_RESOLUTIONS = [
        (336, 336),
        (672, 336),
        (336, 672),
        (672, 672),
        (1008, 336),
        (336, 1008),
    ]

    @staticmethod
    def select_best_resolution(image_size: tuple, resolutions: list):
        """Select best resolution for image"""
        img_h, img_w = image_size
        img_ratio = img_w / img_h

        best_res = None
        best_ratio_diff = float('inf')

        for res in resolutions:
            res_ratio = res[1] / res[0]
            ratio_diff = abs(img_ratio - res_ratio)

            if ratio_diff < best_ratio_diff:
                best_ratio_diff = ratio_diff
                best_res = res

        return best_res


def anyres_processing(image, base_resolution=336):
    """
    AnyRes image processing

    Split high-resolution image into base resolution tiles
    + downscaled full image
    """
    from PIL import Image
    import torch

    # 1. Resize full image (global context)
    global_image = image.resize((base_resolution, base_resolution))

    # 2. Split into tiles (local details)
    W, H = image.size
    num_tiles_w = (W + base_resolution - 1) // base_resolution
    num_tiles_h = (H + base_resolution - 1) // base_resolution

    tiles = []
    for i in range(num_tiles_h):
        for j in range(num_tiles_w):
            left = j * base_resolution
            top = i * base_resolution
            right = min(left + base_resolution, W)
            bottom = min(top + base_resolution, H)

            tile = image.crop((left, top, right, bottom))
            # Padding
            padded_tile = Image.new('RGB', (base_resolution, base_resolution))
            padded_tile.paste(tile, (0, 0))
            tiles.append(padded_tile)

    # [global_image] + [tile1, tile2, ...]
    all_images = [global_image] + tiles

    return all_images

3. Qwen-VL

3.1 Architecture

Qwen-VL Features:

1. Vision Encoder: ViT-bigG (1.9B params)
2. High resolution: 448×448 (variable)
3. Grounding support: bounding box output
4. OCR strength: excellent text recognition

Input format:
<img>image_path</img> User question
<ref>object name</ref><box>(x1,y1),(x2,y2)</box>

3.2 Usage Example

from transformers import AutoModelForCausalLM, AutoTokenizer

def use_qwen_vl():
    """Using Qwen-VL"""

    model = AutoModelForCausalLM.from_pretrained(
        "Qwen/Qwen-VL-Chat",
        trust_remote_code=True,
        torch_dtype=torch.float16
    ).to("cuda")

    tokenizer = AutoTokenizer.from_pretrained(
        "Qwen/Qwen-VL-Chat",
        trust_remote_code=True
    )

    # Basic VQA
    query = tokenizer.from_list_format([
        {'image': 'path/to/image.jpg'},
        {'text': 'What is in this image?'},
    ])

    response, history = model.chat(tokenizer, query=query, history=None)
    print(response)

    # Grounding (find object locations)
    query = tokenizer.from_list_format([
        {'image': 'path/to/image.jpg'},
        {'text': 'Find all the cats in this image and output their bounding boxes.'},
    ])

    response, history = model.chat(tokenizer, query=query, history=None)
    # Output: <ref>cat</ref><box>(100,200),(300,400)</box>

    # Multiple images
    query = tokenizer.from_list_format([
        {'image': 'image1.jpg'},
        {'image': 'image2.jpg'},
        {'text': 'What is the difference between these two images?'},
    ])

    response, history = model.chat(tokenizer, query=query, history=None)

    return response

4. Visual Instruction Tuning

4.1 Data Generation

class VisualInstructionGenerator:
    """Visual Instruction data generator"""

    def __init__(self, teacher_model="gpt-4-vision-preview"):
        from openai import OpenAI
        self.client = OpenAI()
        self.teacher_model = teacher_model

    def generate_conversation(
        self,
        image_path: str,
        task_type: str = "detailed_description"
    ):
        """Generate training data with GPT-4V"""
        import base64

        # Encode image
        with open(image_path, "rb") as f:
            image_data = base64.b64encode(f.read()).decode()

        task_prompts = {
            "detailed_description": "Describe this image in detail.",
            "reasoning": "What conclusions can you draw from this image? Explain your reasoning.",
            "conversation": "Generate a multi-turn conversation about this image.",
            "creative": "Write a creative story inspired by this image."
        }

        prompt = task_prompts.get(task_type, task_prompts["detailed_description"])

        response = self.client.chat.completions.create(
            model=self.teacher_model,
            messages=[
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": prompt},
                        {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}}
                    ]
                }
            ],
            max_tokens=1024
        )

        return {
            "image": image_path,
            "task": task_type,
            "question": prompt,
            "answer": response.choices[0].message.content
        }

    def generate_dataset(
        self,
        image_paths: list,
        output_path: str,
        tasks: list = None
    ):
        """Generate large-scale dataset"""
        import json
        from tqdm import tqdm

        if tasks is None:
            tasks = ["detailed_description", "reasoning", "conversation"]

        dataset = []

        for image_path in tqdm(image_paths):
            for task in tasks:
                try:
                    data = self.generate_conversation(image_path, task)
                    dataset.append(data)
                except Exception as e:
                    print(f"Error processing {image_path}: {e}")

        with open(output_path, 'w') as f:
            json.dump(dataset, f, indent=2)

        return dataset

4.2 Training Strategy

from transformers import Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model

def finetune_vlm():
    """VLM Fine-tuning"""

    # Load model
    model = LLaVAModel(
        freeze_vision=True,  # Freeze vision encoder
        freeze_llm=False     # Fine-tune LLM
    )

    # Apply LoRA (efficient training)
    lora_config = LoraConfig(
        r=16,
        lora_alpha=32,
        target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
        lora_dropout=0.05,
    )

    model.llm = get_peft_model(model.llm, lora_config)

    # Training setup
    training_args = TrainingArguments(
        output_dir="./llava-finetuned",
        per_device_train_batch_size=4,
        gradient_accumulation_steps=8,
        num_train_epochs=1,
        learning_rate=2e-5,
        warmup_ratio=0.03,
        lr_scheduler_type="cosine",
        bf16=True,
        logging_steps=10,
        save_steps=500,
        dataloader_num_workers=4,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=vlm_data_collator,
    )

    trainer.train()


def vlm_data_collator(features):
    """VLM data collator"""
    batch = {
        'input_ids': torch.stack([f['input_ids'] for f in features]),
        'attention_mask': torch.stack([f['attention_mask'] for f in features]),
        'images': torch.stack([f['image'] for f in features]),
        'labels': torch.stack([f['labels'] for f in features]),
    }
    return batch

5. Evaluation Benchmarks

5.1 Major Benchmarks

VLM Evaluation Benchmarks:

1. VQA-v2: General Visual QA
2. GQA: Structural reasoning QA
3. TextVQA: Understanding text in images
4. POPE: Hallucination evaluation
5. MME: 14 subtask suite
6. MMBench: 20 capability evaluation
7. SEED-Bench: 19K multiple choice problems

5.2 Evaluation Code

def evaluate_vlm(model, dataset_name: str = "vqav2"):
    """VLM evaluation"""

    if dataset_name == "vqav2":
        return evaluate_vqa_v2(model)
    elif dataset_name == "textvqa":
        return evaluate_textvqa(model)
    elif dataset_name == "pope":
        return evaluate_pope(model)


def evaluate_pope(model):
    """
    POPE: Polling-based Object Probing Evaluation

    Hallucination evaluation: "Is there a [object] in the image?"
    """
    from datasets import load_dataset

    dataset = load_dataset("lmms-lab/POPE")

    correct = 0
    total = 0

    for item in dataset['test']:
        image = item['image']
        question = item['question']  # "Is there a dog in the image?"
        answer = item['answer']      # "yes" or "no"

        # Model prediction
        prediction = model.generate(image, question)
        pred_answer = "yes" if "yes" in prediction.lower() else "no"

        if pred_answer == answer:
            correct += 1
        total += 1

    accuracy = correct / total
    print(f"POPE Accuracy: {accuracy:.4f}")

    return accuracy

6. Practical Applications

6.1 Document Understanding

def document_understanding():
    """Document understanding application"""

    model = load_qwen_vl()  # OCR strength

    # PDF page analysis
    def analyze_document_page(image_path: str, questions: list):
        results = []

        for question in questions:
            query = f"<img>{image_path}</img>{question}"
            answer = model.generate(query)
            results.append({
                'question': question,
                'answer': answer
            })

        return results

    # Example questions
    questions = [
        "What is the title of this document?",
        "Summarize the main points.",
        "Extract all dates mentioned.",
        "What tables are present? Describe their contents.",
    ]

    results = analyze_document_page("document_page.png", questions)


def chart_understanding():
    """Chart/graph understanding"""

    prompts = [
        "What type of chart is this?",
        "What is the trend shown in this chart?",
        "What are the maximum and minimum values?",
        "Describe the relationship between X and Y.",
    ]

    # Chart analysis with VLM
    for prompt in prompts:
        response = model.generate(chart_image, prompt)
        print(f"Q: {prompt}")
        print(f"A: {response}\n")

References

Papers

  • Liu et al. (2023). "Visual Instruction Tuning" (LLaVA)
  • Liu et al. (2024). "LLaVA-NeXT: Improved reasoning, OCR, and world knowledge"
  • Bai et al. (2023). "Qwen-VL: A Versatile Vision-Language Model"

Models

to navigate between lessons