Spaces:
Sleeping
Sleeping
File size: 5,392 Bytes
2626b5f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# train_vlm.py
import os
import torch
from transformers import (
AutoProcessor, LlavaForConditionalGeneration,
TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from PIL import Image
import requests
from io import BytesIO
# --- Dataset ---
def load_vqa_dataset(stage, split="train[:5000]"): # limit for demo
# Use OCR-VQA or TextVQA for CoT-friendly VQA
dataset = load_dataset("HuggingFaceM4/TextVQA", split=split)
def format_stage1(example):
question = example['question']
answer = example['answers'][0] if example['answers'] else "unknown"
return {
"messages": [
{"role": "user", "content": f"<image>\nQ: {question}\nA: Let's think step by step."},
{"role": "assistant", "content": f"I see {answer} in the image. Therefore, the answer is {answer}."}
],
"image": example['image']
}
def format_stage2(example):
question = example['question']
answer = example['answers'][0] if example['answers'] else "unknown"
return {
"messages": [
{"role": "user", "content": f"<image>\nQ: {question}\nA: [INTERNAL THOUGHT HIDDEN]... Final Answer:"},
{"role": "assistant", "content": answer}
],
"labels_full": f"I analyzed regions and detected '{answer}'. Final Answer: {answer}",
"image": example['image']
}
def format_stage3(example):
question = example['question']
answer = example['answers'][0] if example['answers'] else "unknown"
return {
"messages": [
{"role": "user", "content": f"<image>\nQ: {question}\nA: Think deeply, reflect, and revise if needed."},
{"role": "assistant", "content": f"Initial thought: maybe '{answer}'. But checking object positions and text... I revise: '{answer}' is correct. Confidence: 89%."}
],
"image": example['image']
}
if stage == 1:
return dataset.map(format_stage1, remove_columns=dataset.column_names)
elif stage == 2:
return dataset.map(format_stage2, remove_columns=dataset.column_names)
elif stage == 3:
return dataset.map(format_stage3, remove_columns=dataset.column_names)
# --- Training ---
def train_vlm_stage(stage, model_name, output_dir, resume_from=None):
print(f"π Starting VLM Stage {stage} Training...")
processor = AutoProcessor.from_pretrained(model_name)
model = LlavaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# LoRA for VLM β target vision & language projections
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "multi_modal_projector"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
dataset = load_vqa_dataset(stage)
def process_and_tokenize(example):
image = example["image"] if not isinstance(example["image"], str) else Image.open(requests.get(example["image"], stream=True).raw)
messages = example["messages"]
if stage == 2:
# Input: masked prompt, Labels: full reasoning
text_input = processor.apply_chat_template(messages, tokenize=False)
text_labels = example["labels_full"]
inputs = processor(text=text_input, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512)
labels = processor(text=text_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
inputs["labels"] = labels["input_ids"].squeeze(0)
else:
text = processor.apply_chat_template(messages, tokenize=False)
inputs = processor(text=text, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
inputs["labels"] = inputs["input_ids"].clone()
return inputs
tokenized_dataset = dataset.map(process_and_tokenize, remove_columns=dataset.column_names, batched=False)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=2, # VLMs are heavy
gradient_accumulation_steps=8,
num_train_epochs=1,
learning_rate=2e-4,
fp16=True,
save_steps=200,
save_total_limit=2,
logging_steps=10,
report_to="none",
optim="paged_adamw_8bit",
lr_scheduler_type="cosine",
warmup_steps=50,
remove_unused_columns=False, # critical for image inputs
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=lambda x: x, # Custom batching handled in map
)
trainer.train(resume_from_checkpoint=resume_from)
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
print(f"β
Stage {stage} saved to {output_dir}") |