VISION-LLM-COT / train_vlm.py
Keeby-smilyai's picture
Update train_vlm.py
cb3d04d verified
# train_vlm.py
import os
import torch
from transformers import (
TrainingArguments, Trainer, DefaultDataCollator,
AutoTokenizer, AutoImageProcessor
)
from datasets import load_dataset
from PIL import Image
from custom_vlm import CustomScratchVLM, VLMConfig
def get_processors_and_model(config):
vision_model_name = config.vision_config._name_or_path
language_model_name = config.language_config._name_or_path
image_processor = AutoImageProcessor.from_pretrained(vision_model_name)
tokenizer = AutoTokenizer.from_pretrained(language_model_name)
IMAGE_TOKEN = "<IMAGE>"
tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]})
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
config.language_config.vocab_size = len(tokenizer)
model = CustomScratchVLM(config)
model.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
return image_processor, tokenizer, model
def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:200]"):
# --- USING THE DATASET YOU SPECIFIED ---
print("Loading dataset 'zera09/lmarena-ai_VisionArena-Chat-en'...")
dataset = load_dataset("zera09/lmarena-ai_VisionArena-Chat-en", split=split)
print("Dataset loaded successfully.")
IMAGE_TOKEN = "<IMAGE>"
TEXT_MAX_LENGTH = 256
NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) ** 2
FINAL_MAX_LENGTH = TEXT_MAX_LENGTH - 1 + NUM_IMAGE_PATCHES
def preprocess_function(examples):
image = examples['image'].convert("RGB")
# --- USING THE CONVERSATION FORMAT YOU PROVIDED ---
# We select 'conversation_a' and parse it as a list of lists of dicts.
conversation = examples['conversation']
full_text = ""
is_first_user_turn = True
for turn_list in conversation:
if not turn_list: continue
turn = turn_list[0]
role = turn['role'].upper()
content = turn['content']
if role == "USER" and is_first_user_turn:
full_text += f"USER: {IMAGE_TOKEN}\n{content}\n"
is_first_user_turn = False
else:
full_text += f"{role}: {content}\n"
full_text += tokenizer.eos_token
tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True)
input_ids = torch.tensor(tokenized.input_ids)
try:
image_token_idx = torch.where(input_ids == model.image_token_id)[0][0].item()
except IndexError:
return None
labels = input_ids.clone()
assistant_marker_ids = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids
is_assistant_section = torch.zeros_like(labels, dtype=torch.bool)
for i in range(len(labels) - len(assistant_marker_ids) + 1):
if (labels[i:i+len(assistant_marker_ids)] == torch.tensor(assistant_marker_ids)).all():
end_idx = len(labels)
user_marker_ids = tokenizer("USER:", add_special_tokens=False).input_ids
for j in range(i + 1, len(labels) - len(user_marker_ids) + 1):
if (labels[j:j+len(user_marker_ids)] == torch.tensor(user_marker_ids)).all():
end_idx = j
break
is_assistant_section[i:end_idx] = True
labels[~is_assistant_section] = -100
pre_labels = labels[:image_token_idx]
post_labels = labels[image_token_idx+1:]
image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long)
final_labels = torch.cat([pre_labels, image_labels_pad, post_labels], dim=0)
final_labels = torch.nn.functional.pad(final_labels, (0, FINAL_MAX_LENGTH - len(final_labels)), value=-100)[:FINAL_MAX_LENGTH]
attention_mask = torch.ones_like(input_ids)
pre_mask = attention_mask[:image_token_idx]
post_mask = attention_mask[image_token_idx+1:]
image_mask = torch.ones(NUM_IMAGE_PATCHES, dtype=torch.long)
final_attention_mask = torch.cat([pre_mask, image_mask, post_mask], dim=0)
final_attention_mask = torch.nn.functional.pad(final_attention_mask, (0, FINAL_MAX_LENGTH - len(final_attention_mask)), value=0)[:FINAL_MAX_LENGTH]
pixel_values = image_processor(image, return_tensors="pt").pixel_values
return {
"pixel_values": pixel_values.squeeze(0),
"input_ids": input_ids,
"attention_mask": final_attention_mask,
"labels": final_labels
}
processed_dataset = dataset.map(preprocess_function, remove_columns=list(dataset.column_names))
return processed_dataset.filter(lambda x: x is not None)
def train_vlm_stage(stage, output_dir, resume_from=None):
print(f"🚀 Starting VLM Conversational Training Stage {stage} FROM SCRATCH...")
device = "cuda" if torch.cuda.is_available() else "cpu"
vlm_config = VLMConfig()
image_processor, tokenizer, model = get_processors_and_model(vlm_config)
model.to(device)
split = f"train[{200*(stage-1)}:{200*stage}]"
tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer, model, split=split)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
num_train_epochs=3,
learning_rate=5e-5,
fp16=(device == "cuda"),
bf16=(device == "cuda" and torch.cuda.is_bf16_supported()),
save_strategy="epoch",
logging_steps=5, report_to="none", optim="adamw_torch",
remove_unused_columns=False,
)
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=DefaultDataCollator())
trainer.train(resume_from_checkpoint=resume_from)
model.save_pretrained(output_dir)
image_processor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"✅ Stage {stage} model and processors saved to {output_dir}")