Keeby-smilyai commited on
Commit
2626b5f
·
verified ·
1 Parent(s): f746dfb

Create train_vlm.py

Browse files
Files changed (1) hide show
  1. train_vlm.py +136 -0
train_vlm.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_vlm.py
2
+ import os
3
+ import torch
4
+ from transformers import (
5
+ AutoProcessor, LlavaForConditionalGeneration,
6
+ TrainingArguments, Trainer
7
+ )
8
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
9
+ from datasets import load_dataset
10
+ from PIL import Image
11
+ import requests
12
+ from io import BytesIO
13
+
14
+ # --- Dataset ---
15
+ def load_vqa_dataset(stage, split="train[:5000]"): # limit for demo
16
+ # Use OCR-VQA or TextVQA for CoT-friendly VQA
17
+ dataset = load_dataset("HuggingFaceM4/TextVQA", split=split)
18
+
19
+ def format_stage1(example):
20
+ question = example['question']
21
+ answer = example['answers'][0] if example['answers'] else "unknown"
22
+ return {
23
+ "messages": [
24
+ {"role": "user", "content": f"<image>\nQ: {question}\nA: Let's think step by step."},
25
+ {"role": "assistant", "content": f"I see {answer} in the image. Therefore, the answer is {answer}."}
26
+ ],
27
+ "image": example['image']
28
+ }
29
+
30
+ def format_stage2(example):
31
+ question = example['question']
32
+ answer = example['answers'][0] if example['answers'] else "unknown"
33
+ return {
34
+ "messages": [
35
+ {"role": "user", "content": f"<image>\nQ: {question}\nA: [INTERNAL THOUGHT HIDDEN]... Final Answer:"},
36
+ {"role": "assistant", "content": answer}
37
+ ],
38
+ "labels_full": f"I analyzed regions and detected '{answer}'. Final Answer: {answer}",
39
+ "image": example['image']
40
+ }
41
+
42
+ def format_stage3(example):
43
+ question = example['question']
44
+ answer = example['answers'][0] if example['answers'] else "unknown"
45
+ return {
46
+ "messages": [
47
+ {"role": "user", "content": f"<image>\nQ: {question}\nA: Think deeply, reflect, and revise if needed."},
48
+ {"role": "assistant", "content": f"Initial thought: maybe '{answer}'. But checking object positions and text... I revise: '{answer}' is correct. Confidence: 89%."}
49
+ ],
50
+ "image": example['image']
51
+ }
52
+
53
+ if stage == 1:
54
+ return dataset.map(format_stage1, remove_columns=dataset.column_names)
55
+ elif stage == 2:
56
+ return dataset.map(format_stage2, remove_columns=dataset.column_names)
57
+ elif stage == 3:
58
+ return dataset.map(format_stage3, remove_columns=dataset.column_names)
59
+
60
+ # --- Training ---
61
+ def train_vlm_stage(stage, model_name, output_dir, resume_from=None):
62
+ print(f"🚀 Starting VLM Stage {stage} Training...")
63
+
64
+ processor = AutoProcessor.from_pretrained(model_name)
65
+ model = LlavaForConditionalGeneration.from_pretrained(
66
+ model_name,
67
+ torch_dtype=torch.bfloat16,
68
+ device_map="auto"
69
+ )
70
+
71
+ # LoRA for VLM — target vision & language projections
72
+ lora_config = LoraConfig(
73
+ r=8,
74
+ lora_alpha=32,
75
+ target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "multi_modal_projector"],
76
+ lora_dropout=0.05,
77
+ bias="none",
78
+ task_type="CAUSAL_LM"
79
+ )
80
+
81
+ model = prepare_model_for_kbit_training(model)
82
+ model = get_peft_model(model, lora_config)
83
+ model.print_trainable_parameters()
84
+
85
+ dataset = load_vqa_dataset(stage)
86
+
87
+ def process_and_tokenize(example):
88
+ image = example["image"] if not isinstance(example["image"], str) else Image.open(requests.get(example["image"], stream=True).raw)
89
+ messages = example["messages"]
90
+
91
+ if stage == 2:
92
+ # Input: masked prompt, Labels: full reasoning
93
+ text_input = processor.apply_chat_template(messages, tokenize=False)
94
+ text_labels = example["labels_full"]
95
+ inputs = processor(text=text_input, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512)
96
+ labels = processor(text=text_labels, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512)
97
+ inputs = {k: v.squeeze(0) for k, v in inputs.items()}
98
+ inputs["labels"] = labels["input_ids"].squeeze(0)
99
+ else:
100
+ text = processor.apply_chat_template(messages, tokenize=False)
101
+ inputs = processor(text=text, images=image, return_tensors="pt", padding=True, truncation=True, max_length=512)
102
+ inputs = {k: v.squeeze(0) for k, v in inputs.items()}
103
+ inputs["labels"] = inputs["input_ids"].clone()
104
+
105
+ return inputs
106
+
107
+ tokenized_dataset = dataset.map(process_and_tokenize, remove_columns=dataset.column_names, batched=False)
108
+
109
+ training_args = TrainingArguments(
110
+ output_dir=output_dir,
111
+ per_device_train_batch_size=2, # VLMs are heavy
112
+ gradient_accumulation_steps=8,
113
+ num_train_epochs=1,
114
+ learning_rate=2e-4,
115
+ fp16=True,
116
+ save_steps=200,
117
+ save_total_limit=2,
118
+ logging_steps=10,
119
+ report_to="none",
120
+ optim="paged_adamw_8bit",
121
+ lr_scheduler_type="cosine",
122
+ warmup_steps=50,
123
+ remove_unused_columns=False, # critical for image inputs
124
+ )
125
+
126
+ trainer = Trainer(
127
+ model=model,
128
+ args=training_args,
129
+ train_dataset=tokenized_dataset,
130
+ data_collator=lambda x: x, # Custom batching handled in map
131
+ )
132
+
133
+ trainer.train(resume_from_checkpoint=resume_from)
134
+ model.save_pretrained(output_dir)
135
+ processor.save_pretrained(output_dir)
136
+ print(f"✅ Stage {stage} saved to {output_dir}")