16. Vision-Language μ¬ν
16. Vision-Language μ¬ν¶
κ°μ¶
Vision-Language Models (VLMs)λ μ΄λ―Έμ§μ ν μ€νΈλ₯Ό ν¨κ» μ΄ν΄νλ λͺ¨λΈμ λλ€. μ΄ λ μ¨μμλ LLaVA, Qwen-VL λ± μ΅μ VLM μν€ν μ²μ Visual Instruction Tuning κΈ°λ²μ λ€λ£Ήλλ€.
1. VLM ν¨λ¬λ€μ¶
1.1 λ°μ κ³Όμ ¶
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β VLM λ°μ κ³Όμ β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ€
β β
β 2021: CLIP β
β - Image-Text contrastive learning β
β - Zero-shot λΆλ₯ κ°λ₯ β
β β
β 2022: Flamingo β
β - LLMμ visual tokens μ£Όμ
β
β - Few-shot λΉμ -μΈμ΄ νμ΅ β
β β
β 2023: LLaVA β
β - Visual Instruction Tuning β
β - μ€νμμ€ GPT-4V λμ β
β β
β 2024: LLaVA-NeXT, Qwen-VL, Phi-3-Vision β
β - κ³ ν΄μλ, λ€μ€ μ΄λ―Έμ§, λΉλμ€ β
β - μμ© μμ€ μ±λ₯ β
β β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
1.2 μν€ν μ² λΉκ΅¶
| λͺ¨λΈ | Vision Encoder | LLM | μ°κ²° λ°©μ |
|---|---|---|---|
| LLaVA | CLIP ViT-L | Vicuna/LLaMA | Linear Projection |
| Qwen-VL | ViT-G | Qwen | Cross-Attention |
| InternVL | InternViT | InternLM | MLP |
| Phi-3-Vision | CLIP ViT | Phi-3 | Linear |
| GPT-4V | Unknown | GPT-4 | Unknown |
2. LLaVA (Large Language and Vision Assistant)¶
2.1 μν€ν μ²¶
LLaVA ꡬ쑰:
μ΄λ―Έμ§ β CLIP ViT-L/14 β Visual Features (576 tokens)
β
Linear Projection
β
Visual Tokens
β
[System] [Visual Tokens] [User Query] β LLaMA/Vicuna β Response
νμ΅ λ¨κ³:
1. Pre-training: Image-Text alignment (CC3M)
2. Fine-tuning: Visual Instruction Tuning (158K)
2.2 ꡬν¶
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, LlamaForCausalLM, LlamaTokenizer
class LLaVAModel(nn.Module):
"""LLaVA μ€νμΌ Vision-Language Model"""
def __init__(
self,
vision_encoder: str = "openai/clip-vit-large-patch14",
llm: str = "lmsys/vicuna-7b-v1.5",
freeze_vision: bool = True,
freeze_llm: bool = False
):
super().__init__()
# Vision Encoder
self.vision_encoder = CLIPVisionModel.from_pretrained(vision_encoder)
self.vision_hidden_size = self.vision_encoder.config.hidden_size
# Language Model
self.llm = LlamaForCausalLM.from_pretrained(llm)
self.llm_hidden_size = self.llm.config.hidden_size
# Vision-Language Projection
self.vision_projection = nn.Linear(
self.vision_hidden_size,
self.llm_hidden_size
)
# Freeze encoders
if freeze_vision:
for param in self.vision_encoder.parameters():
param.requires_grad = False
if freeze_llm:
for param in self.llm.parameters():
param.requires_grad = False
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
"""
μ΄λ―Έμ§ μΈμ½λ©
Args:
images: (B, C, H, W)
Returns:
visual_tokens: (B, num_patches, llm_hidden_size)
"""
# CLIP encoding
vision_outputs = self.vision_encoder(images)
image_features = vision_outputs.last_hidden_state # (B, 257, 1024)
# [CLS] ν ν° μ μΈ
image_features = image_features[:, 1:, :] # (B, 256, 1024)
# Project to LLM space
visual_tokens = self.vision_projection(image_features)
return visual_tokens
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
images: torch.Tensor = None,
image_positions: torch.Tensor = None,
labels: torch.Tensor = None
):
"""
Forward pass
Args:
input_ids: (B, seq_len) ν
μ€νΈ ν ν°
attention_mask: (B, seq_len)
images: (B, C, H, W) μ΄λ―Έμ§
image_positions: μ΄λ―Έμ§ ν ν°μ΄ λ€μ΄κ° μμΉ
labels: (B, seq_len) for training
"""
B, seq_len = input_ids.shape
# Text embeddings
text_embeds = self.llm.model.embed_tokens(input_ids)
# Image embeddings
if images is not None:
visual_tokens = self.encode_images(images) # (B, num_patches, hidden)
# Interleave visual tokens with text
# κ°μν: μ΄λ―Έμ§λ₯Ό ν
μ€νΈ μμ μΆκ°
combined_embeds = torch.cat([visual_tokens, text_embeds], dim=1)
# Attention mask μ‘°μ
visual_mask = torch.ones(B, visual_tokens.shape[1], device=attention_mask.device)
combined_mask = torch.cat([visual_mask, attention_mask], dim=1)
else:
combined_embeds = text_embeds
combined_mask = attention_mask
# LLM forward
outputs = self.llm(
inputs_embeds=combined_embeds,
attention_mask=combined_mask,
labels=labels,
return_dict=True
)
return outputs
class VisualInstructionDataset:
"""Visual Instruction Tuning λ°μ΄ν°μ
"""
INSTRUCTION_TEMPLATES = [
"Describe this image in detail.",
"What can you see in this image?",
"Explain what is happening in this picture.",
"<question>", # VQA
]
def __init__(self, data_path: str):
"""
λ°μ΄ν° νμ:
{
"image": "path/to/image.jpg",
"conversations": [
{"from": "human", "value": "<image>\nDescribe this image."},
{"from": "gpt", "value": "This image shows..."}
]
}
"""
import json
with open(data_path, 'r') as f:
self.data = json.load(f)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
# μ΄λ―Έμ§ λ‘λ
from PIL import Image
image = Image.open(item['image']).convert('RGB')
# λν ꡬμ±
conversations = item['conversations']
human_input = conversations[0]['value']
assistant_output = conversations[1]['value']
return {
'image': image,
'human': human_input,
'assistant': assistant_output
}
2.3 LLaVA-NeXT κ°μ μ ¶
class LLaVANeXTConfig:
"""
LLaVA-NeXT κ°μ μ¬ν
1. κ³ ν΄μλ μ§μ (AnyRes)
2. λ λμ Vision Encoder (SigLIP)
3. λ ν° LLM (Llama 3, Qwen 2)
"""
# AnyRes: λ€μν ν΄μλ μ²λ¦¬
SUPPORTED_RESOLUTIONS = [
(336, 336),
(672, 336),
(336, 672),
(672, 672),
(1008, 336),
(336, 1008),
]
@staticmethod
def select_best_resolution(image_size: tuple, resolutions: list):
"""μ΄λ―Έμ§μ κ°μ₯ μ ν©ν ν΄μλ μ ν"""
img_h, img_w = image_size
img_ratio = img_w / img_h
best_res = None
best_ratio_diff = float('inf')
for res in resolutions:
res_ratio = res[1] / res[0]
ratio_diff = abs(img_ratio - res_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_res = res
return best_res
def anyres_processing(image, base_resolution=336):
"""
AnyRes μ΄λ―Έμ§ μ²λ¦¬
κ³ ν΄μλ μ΄λ―Έμ§λ₯Ό κΈ°λ³Έ ν΄μλ νμΌλ‘ λΆν
+ μ 체 μ΄λ―Έμ§ μΆμλ³Έ
"""
from PIL import Image
import torch
# 1. μ 체 μ΄λ―Έμ§ 리μ¬μ΄μ¦ (μ μ 컨ν
μ€νΈ)
global_image = image.resize((base_resolution, base_resolution))
# 2. νμΌ λΆν (μ§μ λν
μΌ)
W, H = image.size
num_tiles_w = (W + base_resolution - 1) // base_resolution
num_tiles_h = (H + base_resolution - 1) // base_resolution
tiles = []
for i in range(num_tiles_h):
for j in range(num_tiles_w):
left = j * base_resolution
top = i * base_resolution
right = min(left + base_resolution, W)
bottom = min(top + base_resolution, H)
tile = image.crop((left, top, right, bottom))
# ν¨λ©
padded_tile = Image.new('RGB', (base_resolution, base_resolution))
padded_tile.paste(tile, (0, 0))
tiles.append(padded_tile)
# [global_image] + [tile1, tile2, ...]
all_images = [global_image] + tiles
return all_images
3. Qwen-VL¶
3.1 μν€ν μ²¶
Qwen-VL νΉμ§:
1. Vision Encoder: ViT-bigG (1.9B params)
2. κ³ ν΄μλ: 448Γ448 (κ°λ³)
3. Grounding μ§μ: λ°μ΄λ© λ°μ€ μΆλ ₯
4. OCR κ°μ : ν
μ€νΈ μΈμ μ°μ
μ
λ ₯ νμ:
<img>image_path</img> User question
<ref>object name</ref><box>(x1,y1),(x2,y2)</box>
3.2 μ¬μ© μμ¶
from transformers import AutoModelForCausalLM, AutoTokenizer
def use_qwen_vl():
"""Qwen-VL μ¬μ©"""
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen-VL-Chat",
trust_remote_code=True,
torch_dtype=torch.float16
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(
"Qwen/Qwen-VL-Chat",
trust_remote_code=True
)
# κΈ°λ³Έ VQA
query = tokenizer.from_list_format([
{'image': 'path/to/image.jpg'},
{'text': 'What is in this image?'},
])
response, history = model.chat(tokenizer, query=query, history=None)
print(response)
# Grounding (κ°μ²΄ μμΉ μ°ΎκΈ°)
query = tokenizer.from_list_format([
{'image': 'path/to/image.jpg'},
{'text': 'Find all the cats in this image and output their bounding boxes.'},
])
response, history = model.chat(tokenizer, query=query, history=None)
# μΆλ ₯: <ref>cat</ref><box>(100,200),(300,400)</box>
# λ€μ€ μ΄λ―Έμ§
query = tokenizer.from_list_format([
{'image': 'image1.jpg'},
{'image': 'image2.jpg'},
{'text': 'What is the difference between these two images?'},
])
response, history = model.chat(tokenizer, query=query, history=None)
return response
4. Visual Instruction Tuning¶
4.1 λ°μ΄ν° μμ±¶
class VisualInstructionGenerator:
"""Visual Instruction λ°μ΄ν° μμ±"""
def __init__(self, teacher_model="gpt-4-vision-preview"):
from openai import OpenAI
self.client = OpenAI()
self.teacher_model = teacher_model
def generate_conversation(
self,
image_path: str,
task_type: str = "detailed_description"
):
"""GPT-4Vλ‘ νμ΅ λ°μ΄ν° μμ±"""
import base64
# μ΄λ―Έμ§ μΈμ½λ©
with open(image_path, "rb") as f:
image_data = base64.b64encode(f.read()).decode()
task_prompts = {
"detailed_description": "Describe this image in detail.",
"reasoning": "What conclusions can you draw from this image? Explain your reasoning.",
"conversation": "Generate a multi-turn conversation about this image.",
"creative": "Write a creative story inspired by this image."
}
prompt = task_prompts.get(task_type, task_prompts["detailed_description"])
response = self.client.chat.completions.create(
model=self.teacher_model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}}
]
}
],
max_tokens=1024
)
return {
"image": image_path,
"task": task_type,
"question": prompt,
"answer": response.choices[0].message.content
}
def generate_dataset(
self,
image_paths: list,
output_path: str,
tasks: list = None
):
"""λκ·λͺ¨ λ°μ΄ν°μ
μμ±"""
import json
from tqdm import tqdm
if tasks is None:
tasks = ["detailed_description", "reasoning", "conversation"]
dataset = []
for image_path in tqdm(image_paths):
for task in tasks:
try:
data = self.generate_conversation(image_path, task)
dataset.append(data)
except Exception as e:
print(f"Error processing {image_path}: {e}")
with open(output_path, 'w') as f:
json.dump(dataset, f, indent=2)
return dataset
4.2 νμ΅ μ λ΅¶
from transformers import Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model
def finetune_vlm():
"""VLM Fine-tuning"""
# λͺ¨λΈ λ‘λ
model = LLaVAModel(
freeze_vision=True, # Vision encoder κ³ μ
freeze_llm=False # LLM fine-tune
)
# LoRA μ μ© (ν¨μ¨μ νμ΅)
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
lora_dropout=0.05,
)
model.llm = get_peft_model(model.llm, lora_config)
# νμ΅ μ€μ
training_args = TrainingArguments(
output_dir="./llava-finetuned",
per_device_train_batch_size=4,
gradient_accumulation_steps=8,
num_train_epochs=1,
learning_rate=2e-5,
warmup_ratio=0.03,
lr_scheduler_type="cosine",
bf16=True,
logging_steps=10,
save_steps=500,
dataloader_num_workers=4,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=vlm_data_collator,
)
trainer.train()
def vlm_data_collator(features):
"""VLM λ°μ΄ν° μ½λ μ΄ν°"""
batch = {
'input_ids': torch.stack([f['input_ids'] for f in features]),
'attention_mask': torch.stack([f['attention_mask'] for f in features]),
'images': torch.stack([f['image'] for f in features]),
'labels': torch.stack([f['labels'] for f in features]),
}
return batch
5. νκ° λ²€μΉλ§ν¬¶
5.1 μ£Όμ λ²€μΉλ§ν¬¶
VLM νκ° λ²€μΉλ§ν¬:
1. VQA-v2: μΌλ° Visual QA
2. GQA: ꡬ쑰μ μΆλ‘ QA
3. TextVQA: μ΄λ―Έμ§ λ΄ ν
μ€νΈ μ΄ν΄
4. POPE: νκ°(hallucination) νκ°
5. MME: 14κ° νμ νμ€ν¬ μ’
ν©
6. MMBench: 20κ° λ₯λ ₯ νκ°
7. SEED-Bench: 19K λ€μ§μ λ€ λ¬Έμ
5.2 νκ° μ½λ¶
def evaluate_vlm(model, dataset_name: str = "vqav2"):
"""VLM νκ°"""
if dataset_name == "vqav2":
return evaluate_vqa_v2(model)
elif dataset_name == "textvqa":
return evaluate_textvqa(model)
elif dataset_name == "pope":
return evaluate_pope(model)
def evaluate_pope(model):
"""
POPE: Polling-based Object Probing Evaluation
νκ° νκ°: "Is there a [object] in the image?"
"""
from datasets import load_dataset
dataset = load_dataset("lmms-lab/POPE")
correct = 0
total = 0
for item in dataset['test']:
image = item['image']
question = item['question'] # "Is there a dog in the image?"
answer = item['answer'] # "yes" or "no"
# λͺ¨λΈ μμΈ‘
prediction = model.generate(image, question)
pred_answer = "yes" if "yes" in prediction.lower() else "no"
if pred_answer == answer:
correct += 1
total += 1
accuracy = correct / total
print(f"POPE Accuracy: {accuracy:.4f}")
return accuracy
6. μ€μ μμ©¶
6.1 λ¬Έμ μ΄ν΄¶
def document_understanding():
"""λ¬Έμ μ΄ν΄ μμ©"""
model = load_qwen_vl() # OCR κ°μ
# PDF νμ΄μ§ λΆμ
def analyze_document_page(image_path: str, questions: list):
results = []
for question in questions:
query = f"<img>{image_path}</img>{question}"
answer = model.generate(query)
results.append({
'question': question,
'answer': answer
})
return results
# μμ μ§λ¬Έ
questions = [
"What is the title of this document?",
"Summarize the main points.",
"Extract all dates mentioned.",
"What tables are present? Describe their contents.",
]
results = analyze_document_page("document_page.png", questions)
def chart_understanding():
"""μ°¨νΈ/κ·Έλν μ΄ν΄"""
prompts = [
"What type of chart is this?",
"What is the trend shown in this chart?",
"What are the maximum and minimum values?",
"Describe the relationship between X and Y.",
]
# VLMμΌλ‘ μ°¨νΈ λΆμ
for prompt in prompts:
response = model.generate(chart_image, prompt)
print(f"Q: {prompt}")
print(f"A: {response}\n")
μ°Έκ³ μλ£¶
λ Όλ¬Έ¶
- Liu et al. (2023). "Visual Instruction Tuning" (LLaVA)
- Liu et al. (2024). "LLaVA-NeXT: Improved reasoning, OCR, and world knowledge"
- Bai et al. (2023). "Qwen-VL: A Versatile Vision-Language Model"