Spaces:
Sleeping
Sleeping
| # train_vlm.py | |
| import os | |
| import torch | |
| from transformers import ( | |
| AutoProcessor, LlavaForConditionalGeneration, | |
| TrainingArguments, Trainer | |
| ) | |
| from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training | |
| from datasets import load_dataset | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| # --- Dataset --- | |
| def load_vqa_dataset(stage, split="train[:5000]"): # limit for demo | |
| # Use OCR-VQA or TextVQA for CoT-friendly VQA | |
| dataset = load_dataset("HuggingFaceM4/TextVQA", split=split) | |
| def format_stage1(example): | |
| question = example['question'] | |
| answer = example['answers'][0] if example['answers'] else "unknown" | |
| return { | |
| "messages": [ | |
| {"role": "user", "content": f"<image>\nQ: {question}\nA: Let's think step by step."}, | |
| {"role": "assistant", "content": f"I see {answer} in the image. Therefore, the answer is {answer}."} | |
| ], | |
| "image": example['image'] | |
| } | |
| def format_stage2(example): | |
| question = example['question'] | |
| answer = example['answers'][0] if example['answers'] else "unknown" | |
| return { | |
| "messages": [ | |
| {"role": "user", "content": f"<image>\nQ: {question}\nA: [INTERNAL THOUGHT HIDDEN]... Final Answer:"}, | |
| {"role": "assistant", "content": answer} | |
| ], | |
| "labels_full": f"I analyzed regions and detected '{answer}'. Final Answer: {answer}", | |
| "image": example['image'] | |
| } | |
| def format_stage3(example): | |
| question = example['question'] | |
| answer = example['answers'][0] if example['answers'] else "unknown" | |
| return { | |
| "messages": [ | |
| {"role": "user", "content": f"<image>\nQ: {question}\nA: Think deeply, reflect, and revise if needed."}, | |
| {"role": "assistant", "content": f"Initial thought: maybe '{answer}'. But checking object positions and text... I revise: '{answer}' is correct. Confidence: 89%."} | |
| ], | |
| "image": example['image'] | |
| } | |
| if stage == 1: | |
| return dataset.map(format_stage1, remove_columns=dataset.column_names) | |
| elif stage == 2: | |
| return dataset.map(format_stage2, remove_columns=dataset.column_names) | |
| elif stage == 3: | |
| return dataset.map(format_stage3, remove_columns=dataset.column_names) | |
| # --- Training --- | |
| def train_vlm_stage(stage, model_name, output_dir, resume_from=None): | |
| print(f"π Starting VLM Stage {stage} Training...") | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| model = LlavaForConditionalGeneration.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto" | |
| ) | |
| # LoRA for VLM β target vision & language projections | |
| lora_config = LoraConfig( | |
| r=8, | |
| lora_alpha=32, | |
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "multi_modal_projector"], | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM" | |
| ) | |
| model = prepare_model_for_kbit_training(model) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| dataset = load_vqa_dataset(stage) | |
| def process_and_tokenize(example): | |
| image = example["image"] if not isinstance(example["image"], str) else Image.open(requests.get(example["image"], stream=True).raw) | |
| messages = example["messages"] | |
| if stage == 2: | |
| # Input: masked prompt, Labels: full reasoning | |
| text_input = processor.apply_chat_template(messages, tokenize=False) | |
| text_labels = example["labels_full"] | |
| inputs = processor(text=text_input, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| labels = processor(text=text_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {k: v.squeeze(0) for k, v in inputs.items()} | |
| inputs["labels"] = labels["input_ids"].squeeze(0) | |
| else: | |
| text = processor.apply_chat_template(messages, tokenize=False) | |
| inputs = processor(text=text, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| inputs = {k: v.squeeze(0) for k, v in inputs.items()} | |
| inputs["labels"] = inputs["input_ids"].clone() | |
| return inputs | |
| tokenized_dataset = dataset.map(process_and_tokenize, remove_columns=dataset.column_names, batched=False) | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| per_device_train_batch_size=2, # VLMs are heavy | |
| gradient_accumulation_steps=8, | |
| num_train_epochs=1, | |
| learning_rate=2e-4, | |
| fp16=True, | |
| save_steps=200, | |
| save_total_limit=2, | |
| logging_steps=10, | |
| report_to="none", | |
| optim="paged_adamw_8bit", | |
| lr_scheduler_type="cosine", | |
| warmup_steps=50, | |
| remove_unused_columns=False, # critical for image inputs | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_dataset, | |
| data_collator=lambda x: x, # Custom batching handled in map | |
| ) | |
| trainer.train(resume_from_checkpoint=resume_from) | |
| model.save_pretrained(output_dir) | |
| processor.save_pretrained(output_dir) | |
| print(f"β Stage {stage} saved to {output_dir}") |