Spaces:
Sleeping
Sleeping
File size: 6,188 Bytes
5878db1 2626b5f 5878db1 2626b5f 5878db1 2626b5f 5878db1 2626b5f 5878db1 e5aed01 5878db1 2626b5f 589d16b 5ab1e4b 5878db1 589d16b be4d66f 5878db1 589d16b cb3d04d 5878db1 589d16b be4d66f e5aed01 be4d66f 5ab1e4b be4d66f 589d16b be4d66f 589d16b be4d66f 5ab1e4b be4d66f 5ab1e4b 5878db1 be4d66f 5ab1e4b be4d66f 5878db1 be4d66f 5878db1 589d16b 5878db1 2626b5f 589d16b 5878db1 2626b5f 5878db1 589d16b 5878db1 be4d66f 5878db1 be4d66f 5878db1 2626b5f be4d66f 2626b5f 5878db1 2626b5f 5878db1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# train_vlm.py
import os
import torch
from transformers import (
TrainingArguments, Trainer, DefaultDataCollator,
AutoTokenizer, AutoImageProcessor
)
from datasets import load_dataset
from PIL import Image
from custom_vlm import CustomScratchVLM, VLMConfig
def get_processors_and_model(config):
vision_model_name = config.vision_config._name_or_path
language_model_name = config.language_config._name_or_path
image_processor = AutoImageProcessor.from_pretrained(vision_model_name)
tokenizer = AutoTokenizer.from_pretrained(language_model_name)
IMAGE_TOKEN = "<IMAGE>"
tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]})
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
config.language_config.vocab_size = len(tokenizer)
model = CustomScratchVLM(config)
model.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
return image_processor, tokenizer, model
def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:200]"):
# --- USING THE DATASET YOU SPECIFIED ---
print("Loading dataset 'zera09/lmarena-ai_VisionArena-Chat-en'...")
dataset = load_dataset("zera09/lmarena-ai_VisionArena-Chat-en", split=split)
print("Dataset loaded successfully.")
IMAGE_TOKEN = "<IMAGE>"
TEXT_MAX_LENGTH = 256
NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) ** 2
FINAL_MAX_LENGTH = TEXT_MAX_LENGTH - 1 + NUM_IMAGE_PATCHES
def preprocess_function(examples):
image = examples['image'].convert("RGB")
# --- USING THE CONVERSATION FORMAT YOU PROVIDED ---
# We select 'conversation_a' and parse it as a list of lists of dicts.
conversation = examples['conversation']
full_text = ""
is_first_user_turn = True
for turn_list in conversation:
if not turn_list: continue
turn = turn_list[0]
role = turn['role'].upper()
content = turn['content']
if role == "USER" and is_first_user_turn:
full_text += f"USER: {IMAGE_TOKEN}\n{content}\n"
is_first_user_turn = False
else:
full_text += f"{role}: {content}\n"
full_text += tokenizer.eos_token
tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True)
input_ids = torch.tensor(tokenized.input_ids)
try:
image_token_idx = torch.where(input_ids == model.image_token_id)[0][0].item()
except IndexError:
return None
labels = input_ids.clone()
assistant_marker_ids = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids
is_assistant_section = torch.zeros_like(labels, dtype=torch.bool)
for i in range(len(labels) - len(assistant_marker_ids) + 1):
if (labels[i:i+len(assistant_marker_ids)] == torch.tensor(assistant_marker_ids)).all():
end_idx = len(labels)
user_marker_ids = tokenizer("USER:", add_special_tokens=False).input_ids
for j in range(i + 1, len(labels) - len(user_marker_ids) + 1):
if (labels[j:j+len(user_marker_ids)] == torch.tensor(user_marker_ids)).all():
end_idx = j
break
is_assistant_section[i:end_idx] = True
labels[~is_assistant_section] = -100
pre_labels = labels[:image_token_idx]
post_labels = labels[image_token_idx+1:]
image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long)
final_labels = torch.cat([pre_labels, image_labels_pad, post_labels], dim=0)
final_labels = torch.nn.functional.pad(final_labels, (0, FINAL_MAX_LENGTH - len(final_labels)), value=-100)[:FINAL_MAX_LENGTH]
attention_mask = torch.ones_like(input_ids)
pre_mask = attention_mask[:image_token_idx]
post_mask = attention_mask[image_token_idx+1:]
image_mask = torch.ones(NUM_IMAGE_PATCHES, dtype=torch.long)
final_attention_mask = torch.cat([pre_mask, image_mask, post_mask], dim=0)
final_attention_mask = torch.nn.functional.pad(final_attention_mask, (0, FINAL_MAX_LENGTH - len(final_attention_mask)), value=0)[:FINAL_MAX_LENGTH]
pixel_values = image_processor(image, return_tensors="pt").pixel_values
return {
"pixel_values": pixel_values.squeeze(0),
"input_ids": input_ids,
"attention_mask": final_attention_mask,
"labels": final_labels
}
processed_dataset = dataset.map(preprocess_function, remove_columns=list(dataset.column_names))
return processed_dataset.filter(lambda x: x is not None)
def train_vlm_stage(stage, output_dir, resume_from=None):
print(f"🚀 Starting VLM Conversational Training Stage {stage} FROM SCRATCH...")
device = "cuda" if torch.cuda.is_available() else "cpu"
vlm_config = VLMConfig()
image_processor, tokenizer, model = get_processors_and_model(vlm_config)
model.to(device)
split = f"train[{200*(stage-1)}:{200*stage}]"
tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer, model, split=split)
training_args = TrainingArguments(
output_dir=output_dir,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
num_train_epochs=3,
learning_rate=5e-5,
fp16=(device == "cuda"),
bf16=(device == "cuda" and torch.cuda.is_bf16_supported()),
save_strategy="epoch",
logging_steps=5, report_to="none", optim="adamw_torch",
remove_unused_columns=False,
)
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=DefaultDataCollator())
trainer.train(resume_from_checkpoint=resume_from)
model.save_pretrained(output_dir)
image_processor.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"✅ Stage {stage} model and processors saved to {output_dir}") |