# train_vlm.py import os import torch from transformers import ( TrainingArguments, Trainer, DefaultDataCollator, AutoTokenizer, AutoImageProcessor ) from datasets import load_dataset from PIL import Image from custom_vlm import CustomScratchVLM, VLMConfig def get_processors_and_model(config): vision_model_name = config.vision_config._name_or_path language_model_name = config.language_config._name_or_path image_processor = AutoImageProcessor.from_pretrained(vision_model_name) tokenizer = AutoTokenizer.from_pretrained(language_model_name) IMAGE_TOKEN = "" tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]}) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token config.language_config.vocab_size = len(tokenizer) model = CustomScratchVLM(config) model.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) return image_processor, tokenizer, model def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:200]"): # --- USING THE DATASET YOU SPECIFIED --- print("Loading dataset 'zera09/lmarena-ai_VisionArena-Chat-en'...") dataset = load_dataset("zera09/lmarena-ai_VisionArena-Chat-en", split=split) print("Dataset loaded successfully.") IMAGE_TOKEN = "" TEXT_MAX_LENGTH = 256 NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) ** 2 FINAL_MAX_LENGTH = TEXT_MAX_LENGTH - 1 + NUM_IMAGE_PATCHES def preprocess_function(examples): image = examples['image'].convert("RGB") # --- USING THE CONVERSATION FORMAT YOU PROVIDED --- # We select 'conversation_a' and parse it as a list of lists of dicts. conversation = examples['conversation'] full_text = "" is_first_user_turn = True for turn_list in conversation: if not turn_list: continue turn = turn_list[0] role = turn['role'].upper() content = turn['content'] if role == "USER" and is_first_user_turn: full_text += f"USER: {IMAGE_TOKEN}\n{content}\n" is_first_user_turn = False else: full_text += f"{role}: {content}\n" full_text += tokenizer.eos_token tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True) input_ids = torch.tensor(tokenized.input_ids) try: image_token_idx = torch.where(input_ids == model.image_token_id)[0][0].item() except IndexError: return None labels = input_ids.clone() assistant_marker_ids = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids is_assistant_section = torch.zeros_like(labels, dtype=torch.bool) for i in range(len(labels) - len(assistant_marker_ids) + 1): if (labels[i:i+len(assistant_marker_ids)] == torch.tensor(assistant_marker_ids)).all(): end_idx = len(labels) user_marker_ids = tokenizer("USER:", add_special_tokens=False).input_ids for j in range(i + 1, len(labels) - len(user_marker_ids) + 1): if (labels[j:j+len(user_marker_ids)] == torch.tensor(user_marker_ids)).all(): end_idx = j break is_assistant_section[i:end_idx] = True labels[~is_assistant_section] = -100 pre_labels = labels[:image_token_idx] post_labels = labels[image_token_idx+1:] image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long) final_labels = torch.cat([pre_labels, image_labels_pad, post_labels], dim=0) final_labels = torch.nn.functional.pad(final_labels, (0, FINAL_MAX_LENGTH - len(final_labels)), value=-100)[:FINAL_MAX_LENGTH] attention_mask = torch.ones_like(input_ids) pre_mask = attention_mask[:image_token_idx] post_mask = attention_mask[image_token_idx+1:] image_mask = torch.ones(NUM_IMAGE_PATCHES, dtype=torch.long) final_attention_mask = torch.cat([pre_mask, image_mask, post_mask], dim=0) final_attention_mask = torch.nn.functional.pad(final_attention_mask, (0, FINAL_MAX_LENGTH - len(final_attention_mask)), value=0)[:FINAL_MAX_LENGTH] pixel_values = image_processor(image, return_tensors="pt").pixel_values return { "pixel_values": pixel_values.squeeze(0), "input_ids": input_ids, "attention_mask": final_attention_mask, "labels": final_labels } processed_dataset = dataset.map(preprocess_function, remove_columns=list(dataset.column_names)) return processed_dataset.filter(lambda x: x is not None) def train_vlm_stage(stage, output_dir, resume_from=None): print(f"🚀 Starting VLM Conversational Training Stage {stage} FROM SCRATCH...") device = "cuda" if torch.cuda.is_available() else "cpu" vlm_config = VLMConfig() image_processor, tokenizer, model = get_processors_and_model(vlm_config) model.to(device) split = f"train[{200*(stage-1)}:{200*stage}]" tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer, model, split=split) training_args = TrainingArguments( output_dir=output_dir, per_device_train_batch_size=1, gradient_accumulation_steps=8, num_train_epochs=3, learning_rate=5e-5, fp16=(device == "cuda"), bf16=(device == "cuda" and torch.cuda.is_bf16_supported()), save_strategy="epoch", logging_steps=5, report_to="none", optim="adamw_torch", remove_unused_columns=False, ) trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=DefaultDataCollator()) trainer.train(resume_from_checkpoint=resume_from) model.save_pretrained(output_dir) image_processor.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) print(f"✅ Stage {stage} model and processors saved to {output_dir}")