VISION-LLM-COT / train_vlm.py
Keeby-smilyai's picture
Create train_vlm.py
2626b5f verified
raw
history blame
5.39 kB
# train_vlm.py
import os
import torch
from transformers import (
AutoProcessor, LlavaForConditionalGeneration,
TrainingArguments, Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
from PIL import Image
import requests
from io import BytesIO
# --- Dataset ---
def load_vqa_dataset(stage, split="train[:5000]"): # limit for demo
# Use OCR-VQA or TextVQA for CoT-friendly VQA
dataset = load_dataset("HuggingFaceM4/TextVQA", split=split)
def format_stage1(example):
question = example['question']
answer = example['answers'][0] if example['answers'] else "unknown"
return {
"messages": [
{"role": "user", "content": f"<image>\nQ: {question}\nA: Let's think step by step."},
{"role": "assistant", "content": f"I see {answer} in the image. Therefore, the answer is {answer}."}
],
"image": example['image']
}
def format_stage2(example):
question = example['question']
answer = example['answers'][0] if example['answers'] else "unknown"
return {
"messages": [
{"role": "user", "content": f"<image>\nQ: {question}\nA: [INTERNAL THOUGHT HIDDEN]... Final Answer:"},
{"role": "assistant", "content": answer}
],
"labels_full": f"I analyzed regions and detected '{answer}'. Final Answer: {answer}",
"image": example['image']
}
def format_stage3(example):
question = example['question']
answer = example['answers'][0] if example['answers'] else "unknown"
return {
"messages": [
{"role": "user", "content": f"<image>\nQ: {question}\nA: Think deeply, reflect, and revise if needed."},
{"role": "assistant", "content": f"Initial thought: maybe '{answer}'. But checking object positions and text... I revise: '{answer}' is correct. Confidence: 89%."}
],
"image": example['image']
}
if stage == 1:
return dataset.map(format_stage1, remove_columns=dataset.column_names)
elif stage == 2:
return dataset.map(format_stage2, remove_columns=dataset.column_names)
elif stage == 3:
return dataset.map(format_stage3, remove_columns=dataset.column_names)
# --- Training ---
def train_vlm_stage(stage, model_name, output_dir, resume_from=None):
print(f"πŸš€ Starting VLM Stage {stage} Training...")
processor = AutoProcessor.from_pretrained(model_name)
model = LlavaForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
# LoRA for VLM β€” target vision & language projections
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "multi_modal_projector"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM"
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
dataset = load_vqa_dataset(stage)
def process_and_tokenize(example):
image = example["image"] if not isinstance(example["image"], str) else Image.open(requests.get(example["image"], stream=True).raw)
messages = example["messages"]
if stage == 2:
# Input: masked prompt, Labels: full reasoning
text_input = processor.apply_chat_template(messages, tokenize=False)
text_labels = example["labels_full"]
inputs = processor(text=text_input, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512)
labels = processor(text=text_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
inputs["labels"] = labels["input_ids"].squeeze(0)
else:
text = processor.apply_chat_template(messages, tokenize=False)
inputs = processor(text=text, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.squeeze(0) for k, v in inputs.items()}
inputs["labels"] = inputs["input_ids"].clone()
return inputs
tokenized_dataset = dataset.map(process_and_tokenize, remove_columns=dataset.column_names, batched=False)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=2, # VLMs are heavy
gradient_accumulation_steps=8,
num_train_epochs=1,
learning_rate=2e-4,
fp16=True,
save_steps=200,
save_total_limit=2,
logging_steps=10,
report_to="none",
optim="paged_adamw_8bit",
lr_scheduler_type="cosine",
warmup_steps=50,
remove_unused_columns=False, # critical for image inputs
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=lambda x: x, # Custom batching handled in map
)
trainer.train(resume_from_checkpoint=resume_from)
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
print(f"βœ… Stage {stage} saved to {output_dir}")