Spaces:
Sleeping
Sleeping
| # train_vlm.py | |
| import os | |
| import torch | |
| from transformers import ( | |
| TrainingArguments, Trainer, DefaultDataCollator, | |
| AutoTokenizer, AutoImageProcessor | |
| ) | |
| from datasets import load_dataset | |
| from PIL import Image | |
| from custom_vlm import CustomScratchVLM, VLMConfig | |
| def get_processors_and_model(config): | |
| vision_model_name = config.vision_config._name_or_path | |
| language_model_name = config.language_config._name_or_path | |
| image_processor = AutoImageProcessor.from_pretrained(vision_model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(language_model_name) | |
| IMAGE_TOKEN = "<IMAGE>" | |
| tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]}) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| config.language_config.vocab_size = len(tokenizer) | |
| model = CustomScratchVLM(config) | |
| model.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) | |
| return image_processor, tokenizer, model | |
| def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:200]"): | |
| # --- USING THE DATASET YOU SPECIFIED --- | |
| print("Loading dataset 'zera09/lmarena-ai_VisionArena-Chat-en'...") | |
| dataset = load_dataset("zera09/lmarena-ai_VisionArena-Chat-en", split=split) | |
| print("Dataset loaded successfully.") | |
| IMAGE_TOKEN = "<IMAGE>" | |
| TEXT_MAX_LENGTH = 256 | |
| NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) ** 2 | |
| FINAL_MAX_LENGTH = TEXT_MAX_LENGTH - 1 + NUM_IMAGE_PATCHES | |
| def preprocess_function(examples): | |
| image = examples['image'].convert("RGB") | |
| # --- USING THE CONVERSATION FORMAT YOU PROVIDED --- | |
| # We select 'conversation_a' and parse it as a list of lists of dicts. | |
| conversation = examples['conversation'] | |
| full_text = "" | |
| is_first_user_turn = True | |
| for turn_list in conversation: | |
| if not turn_list: continue | |
| turn = turn_list[0] | |
| role = turn['role'].upper() | |
| content = turn['content'] | |
| if role == "USER" and is_first_user_turn: | |
| full_text += f"USER: {IMAGE_TOKEN}\n{content}\n" | |
| is_first_user_turn = False | |
| else: | |
| full_text += f"{role}: {content}\n" | |
| full_text += tokenizer.eos_token | |
| tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True) | |
| input_ids = torch.tensor(tokenized.input_ids) | |
| try: | |
| image_token_idx = torch.where(input_ids == model.image_token_id)[0][0].item() | |
| except IndexError: | |
| return None | |
| labels = input_ids.clone() | |
| assistant_marker_ids = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids | |
| is_assistant_section = torch.zeros_like(labels, dtype=torch.bool) | |
| for i in range(len(labels) - len(assistant_marker_ids) + 1): | |
| if (labels[i:i+len(assistant_marker_ids)] == torch.tensor(assistant_marker_ids)).all(): | |
| end_idx = len(labels) | |
| user_marker_ids = tokenizer("USER:", add_special_tokens=False).input_ids | |
| for j in range(i + 1, len(labels) - len(user_marker_ids) + 1): | |
| if (labels[j:j+len(user_marker_ids)] == torch.tensor(user_marker_ids)).all(): | |
| end_idx = j | |
| break | |
| is_assistant_section[i:end_idx] = True | |
| labels[~is_assistant_section] = -100 | |
| pre_labels = labels[:image_token_idx] | |
| post_labels = labels[image_token_idx+1:] | |
| image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long) | |
| final_labels = torch.cat([pre_labels, image_labels_pad, post_labels], dim=0) | |
| final_labels = torch.nn.functional.pad(final_labels, (0, FINAL_MAX_LENGTH - len(final_labels)), value=-100)[:FINAL_MAX_LENGTH] | |
| attention_mask = torch.ones_like(input_ids) | |
| pre_mask = attention_mask[:image_token_idx] | |
| post_mask = attention_mask[image_token_idx+1:] | |
| image_mask = torch.ones(NUM_IMAGE_PATCHES, dtype=torch.long) | |
| final_attention_mask = torch.cat([pre_mask, image_mask, post_mask], dim=0) | |
| final_attention_mask = torch.nn.functional.pad(final_attention_mask, (0, FINAL_MAX_LENGTH - len(final_attention_mask)), value=0)[:FINAL_MAX_LENGTH] | |
| pixel_values = image_processor(image, return_tensors="pt").pixel_values | |
| return { | |
| "pixel_values": pixel_values.squeeze(0), | |
| "input_ids": input_ids, | |
| "attention_mask": final_attention_mask, | |
| "labels": final_labels | |
| } | |
| processed_dataset = dataset.map(preprocess_function, remove_columns=list(dataset.column_names)) | |
| return processed_dataset.filter(lambda x: x is not None) | |
| def train_vlm_stage(stage, output_dir, resume_from=None): | |
| print(f"🚀 Starting VLM Conversational Training Stage {stage} FROM SCRATCH...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| vlm_config = VLMConfig() | |
| image_processor, tokenizer, model = get_processors_and_model(vlm_config) | |
| model.to(device) | |
| split = f"train[{200*(stage-1)}:{200*stage}]" | |
| tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer, model, split=split) | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=8, | |
| num_train_epochs=3, | |
| learning_rate=5e-5, | |
| fp16=(device == "cuda"), | |
| bf16=(device == "cuda" and torch.cuda.is_bf16_supported()), | |
| save_strategy="epoch", | |
| logging_steps=5, report_to="none", optim="adamw_torch", | |
| remove_unused_columns=False, | |
| ) | |
| trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=DefaultDataCollator()) | |
| trainer.train(resume_from_checkpoint=resume_from) | |
| model.save_pretrained(output_dir) | |
| image_processor.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| print(f"✅ Stage {stage} model and processors saved to {output_dir}") |