Spaces:
Sleeping
Sleeping
Update train_vlm.py
Browse files- train_vlm.py +58 -79
train_vlm.py
CHANGED
|
@@ -7,147 +7,126 @@ from transformers import (
|
|
| 7 |
)
|
| 8 |
from datasets import load_dataset
|
| 9 |
from PIL import Image
|
| 10 |
-
|
| 11 |
-
# Import our custom VLM architecture
|
| 12 |
from custom_vlm import CustomScratchVLM, VLMConfig
|
| 13 |
|
| 14 |
-
# --- Tokenizer and Processor Setup ---
|
| 15 |
def get_processors_and_model(config):
|
| 16 |
-
"""Initializes tokenizer, image processor, and the custom VLM."""
|
| 17 |
-
# Using the sub-model names from our config
|
| 18 |
vision_model_name = config.vision_config._name_or_path
|
| 19 |
language_model_name = config.language_config._name_or_path
|
| 20 |
|
| 21 |
-
# 1. Load standard processors for the chosen sub-models
|
| 22 |
image_processor = AutoImageProcessor.from_pretrained(vision_model_name)
|
| 23 |
tokenizer = AutoTokenizer.from_pretrained(language_model_name)
|
| 24 |
|
| 25 |
-
# 2. Add a special token for the image placeholder
|
| 26 |
IMAGE_TOKEN = "<IMAGE>"
|
| 27 |
tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]})
|
| 28 |
if tokenizer.pad_token is None:
|
| 29 |
tokenizer.pad_token = tokenizer.eos_token
|
| 30 |
|
| 31 |
-
# 3. Update the VLM config with the new vocab size
|
| 32 |
config.language_config.vocab_size = len(tokenizer)
|
| 33 |
-
|
| 34 |
-
# 4. Instantiate our from-scratch model
|
| 35 |
model = CustomScratchVLM(config)
|
| 36 |
model.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
| 37 |
|
| 38 |
return image_processor, tokenizer, model
|
| 39 |
|
| 40 |
-
|
| 41 |
-
def load_and_prepare_dataset(stage, image_processor, tokenizer, split="train[:50]"):
|
| 42 |
dataset = load_dataset("HuggingFaceM4/TextVQA", split=split)
|
| 43 |
|
| 44 |
IMAGE_TOKEN = "<IMAGE>"
|
| 45 |
-
|
| 46 |
-
NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size)
|
| 47 |
-
|
| 48 |
|
| 49 |
def preprocess_function(examples):
|
| 50 |
-
# This function is now much more complex
|
| 51 |
image = examples['image'].convert("RGB")
|
| 52 |
question = examples.get('question', '')
|
| 53 |
answer = examples['answers'][0] if examples.get('answers') else "unknown"
|
| 54 |
|
| 55 |
-
# Stage-specific formatting
|
| 56 |
if stage == 1:
|
| 57 |
-
prompt = f"USER: {IMAGE_TOKEN}\
|
| 58 |
elif stage == 2:
|
| 59 |
-
prompt = f"USER: {IMAGE_TOKEN}\
|
| 60 |
-
else:
|
| 61 |
-
prompt = f"USER: {IMAGE_TOKEN}\
|
| 62 |
|
| 63 |
-
# Tokenize text
|
| 64 |
full_text = prompt + tokenizer.eos_token
|
| 65 |
-
tokenized = tokenizer(full_text, truncation=True, padding="max_length", max_length=256, return_tensors="pt")
|
| 66 |
|
| 67 |
-
#
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
-
#
|
| 71 |
-
# Find where the assistant's response starts
|
| 72 |
try:
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
if
|
| 83 |
-
labels[
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
# Process image
|
| 89 |
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
| 90 |
|
| 91 |
-
# The model's forward pass expects the placeholder to be replaced by N patches.
|
| 92 |
-
# But for input_ids, we only have one placeholder. The attention mask needs to
|
| 93 |
-
# be expanded to account for the N patches that will replace the single token.
|
| 94 |
-
image_token_idx = torch.where(tokenized.input_ids == model.image_token_id)[1]
|
| 95 |
-
|
| 96 |
-
# Create a new attention mask
|
| 97 |
-
new_attention_mask = torch.cat([
|
| 98 |
-
tokenized.attention_mask[:, :image_token_idx],
|
| 99 |
-
torch.ones(1, NUM_IMAGE_PATCHES, dtype=torch.long),
|
| 100 |
-
tokenized.attention_mask[:, image_token_idx+1:]
|
| 101 |
-
], dim=1)
|
| 102 |
-
|
| 103 |
return {
|
| 104 |
-
"pixel_values": pixel_values.squeeze(),
|
| 105 |
-
"input_ids":
|
| 106 |
-
"attention_mask":
|
| 107 |
-
"labels":
|
| 108 |
}
|
| 109 |
|
| 110 |
-
|
|
|
|
| 111 |
|
| 112 |
-
# --- Training ---
|
| 113 |
def train_vlm_stage(stage, output_dir, resume_from=None):
|
| 114 |
print(f"🚀 Starting VLM Stage {stage} Training FROM SCRATCH...")
|
| 115 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 116 |
|
| 117 |
-
# 1. Get our custom model and its processors
|
| 118 |
vlm_config = VLMConfig()
|
| 119 |
image_processor, tokenizer, model = get_processors_and_model(vlm_config)
|
| 120 |
model.to(device)
|
| 121 |
|
| 122 |
-
|
| 123 |
-
tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer)
|
| 124 |
|
| 125 |
-
is_cuda = (device == "cuda")
|
| 126 |
training_args = TrainingArguments(
|
| 127 |
output_dir=output_dir,
|
| 128 |
per_device_train_batch_size=1,
|
| 129 |
gradient_accumulation_steps=8,
|
| 130 |
-
num_train_epochs=5,
|
| 131 |
learning_rate=5e-5,
|
| 132 |
-
fp16=
|
| 133 |
-
bf16=
|
| 134 |
save_strategy="epoch",
|
| 135 |
-
logging_steps=5,
|
| 136 |
-
report_to="none",
|
| 137 |
-
optim="adamw_torch",
|
| 138 |
remove_unused_columns=False,
|
| 139 |
)
|
| 140 |
|
| 141 |
-
trainer = Trainer(
|
| 142 |
-
model=model,
|
| 143 |
-
args=training_args,
|
| 144 |
-
train_dataset=tokenized_dataset,
|
| 145 |
-
data_collator=DefaultDataCollator()
|
| 146 |
-
)
|
| 147 |
-
|
| 148 |
-
print("--- Starting Trainer on From-Scratch Model ---")
|
| 149 |
trainer.train(resume_from_checkpoint=resume_from)
|
| 150 |
-
print("--- Training Finished ---")
|
| 151 |
|
| 152 |
model.save_pretrained(output_dir)
|
| 153 |
image_processor.save_pretrained(output_dir)
|
|
|
|
| 7 |
)
|
| 8 |
from datasets import load_dataset
|
| 9 |
from PIL import Image
|
|
|
|
|
|
|
| 10 |
from custom_vlm import CustomScratchVLM, VLMConfig
|
| 11 |
|
|
|
|
| 12 |
def get_processors_and_model(config):
|
|
|
|
|
|
|
| 13 |
vision_model_name = config.vision_config._name_or_path
|
| 14 |
language_model_name = config.language_config._name_or_path
|
| 15 |
|
|
|
|
| 16 |
image_processor = AutoImageProcessor.from_pretrained(vision_model_name)
|
| 17 |
tokenizer = AutoTokenizer.from_pretrained(language_model_name)
|
| 18 |
|
|
|
|
| 19 |
IMAGE_TOKEN = "<IMAGE>"
|
| 20 |
tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]})
|
| 21 |
if tokenizer.pad_token is None:
|
| 22 |
tokenizer.pad_token = tokenizer.eos_token
|
| 23 |
|
|
|
|
| 24 |
config.language_config.vocab_size = len(tokenizer)
|
|
|
|
|
|
|
| 25 |
model = CustomScratchVLM(config)
|
| 26 |
model.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
|
| 27 |
|
| 28 |
return image_processor, tokenizer, model
|
| 29 |
|
| 30 |
+
def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:50]"):
|
|
|
|
| 31 |
dataset = load_dataset("HuggingFaceM4/TextVQA", split=split)
|
| 32 |
|
| 33 |
IMAGE_TOKEN = "<IMAGE>"
|
| 34 |
+
TEXT_MAX_LENGTH = 128
|
| 35 |
+
NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) ** 2
|
| 36 |
+
FINAL_MAX_LENGTH = TEXT_MAX_LENGTH - 1 + NUM_IMAGE_PATCHES
|
| 37 |
|
| 38 |
def preprocess_function(examples):
|
|
|
|
| 39 |
image = examples['image'].convert("RGB")
|
| 40 |
question = examples.get('question', '')
|
| 41 |
answer = examples['answers'][0] if examples.get('answers') else "unknown"
|
| 42 |
|
|
|
|
| 43 |
if stage == 1:
|
| 44 |
+
prompt = f"USER: {IMAGE_TOKEN}\nQuestion: {question}\nASSISTANT: {answer}"
|
| 45 |
elif stage == 2:
|
| 46 |
+
prompt = f"USER: {IMAGE_TOKEN}\nQuestion: {question} Think step-by-step.\nASSISTANT: I think the answer is {answer}."
|
| 47 |
+
else:
|
| 48 |
+
prompt = f"USER: {IMAGE_TOKEN}\n{question}\nASSISTANT: The final answer is: {answer}."
|
| 49 |
|
|
|
|
| 50 |
full_text = prompt + tokenizer.eos_token
|
|
|
|
| 51 |
|
| 52 |
+
# Tokenize text part first, up to a max text length
|
| 53 |
+
tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True)
|
| 54 |
+
input_ids = torch.tensor(tokenized.input_ids)
|
| 55 |
+
|
| 56 |
+
# --- CRITICAL FIX: Build Labels and Attention Mask for the FINAL sequence length ---
|
| 57 |
|
| 58 |
+
# 1. Find the location of the image token placeholder
|
|
|
|
| 59 |
try:
|
| 60 |
+
image_token_idx = torch.where(input_ids == model.image_token_id)[0][0].item()
|
| 61 |
+
except IndexError: # If token was truncated out, skip this example
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
# 2. Build the LABELS tensor
|
| 65 |
+
labels = input_ids.clone()
|
| 66 |
+
# Mask out the prompt part (everything before and including "ASSISTANT:")
|
| 67 |
+
assistant_marker = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids
|
| 68 |
+
for i in range(len(labels) - len(assistant_marker) + 1):
|
| 69 |
+
if (labels[i:i+len(assistant_marker)] == torch.tensor(assistant_marker)).all():
|
| 70 |
+
labels[:i+len(assistant_marker)] = -100
|
| 71 |
+
break
|
| 72 |
+
|
| 73 |
+
# Expand labels to the final length, inserting padding for the image patches
|
| 74 |
+
pre_labels = labels[:image_token_idx]
|
| 75 |
+
post_labels = labels[image_token_idx+1:]
|
| 76 |
+
# The image part of the labels should be all -100 (we don't predict image patches)
|
| 77 |
+
image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long)
|
| 78 |
+
|
| 79 |
+
# Combine and pad/truncate to FINAL_MAX_LENGTH
|
| 80 |
+
final_labels = torch.cat([pre_labels, image_labels_pad, post_labels], dim=0)
|
| 81 |
+
final_labels = torch.nn.functional.pad(final_labels, (0, FINAL_MAX_LENGTH - len(final_labels)), value=-100)
|
| 82 |
+
|
| 83 |
+
# 3. Build the ATTENTION MASK in the same way
|
| 84 |
+
attention_mask = torch.ones_like(input_ids)
|
| 85 |
+
pre_mask = attention_mask[:image_token_idx]
|
| 86 |
+
post_mask = attention_mask[image_token_idx+1:]
|
| 87 |
+
image_mask = torch.ones(NUM_IMAGE_PATCHES, dtype=torch.long)
|
| 88 |
+
|
| 89 |
+
final_attention_mask = torch.cat([pre_mask, image_mask, post_mask], dim=0)
|
| 90 |
+
final_attention_mask = torch.nn.functional.pad(final_attention_mask, (0, FINAL_MAX_LENGTH - len(final_attention_mask)), value=0)
|
| 91 |
|
| 92 |
+
# 4. Process the image
|
| 93 |
pixel_values = image_processor(image, return_tensors="pt").pixel_values
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
return {
|
| 96 |
+
"pixel_values": pixel_values.squeeze(0),
|
| 97 |
+
"input_ids": input_ids, # Keep original input_ids for placeholder finding
|
| 98 |
+
"attention_mask": final_attention_mask,
|
| 99 |
+
"labels": final_labels
|
| 100 |
}
|
| 101 |
|
| 102 |
+
processed_dataset = dataset.map(preprocess_function, remove_columns=list(dataset.column_names))
|
| 103 |
+
return processed_dataset.filter(lambda x: x is not None)
|
| 104 |
|
|
|
|
| 105 |
def train_vlm_stage(stage, output_dir, resume_from=None):
|
| 106 |
print(f"🚀 Starting VLM Stage {stage} Training FROM SCRATCH...")
|
| 107 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 108 |
|
|
|
|
| 109 |
vlm_config = VLMConfig()
|
| 110 |
image_processor, tokenizer, model = get_processors_and_model(vlm_config)
|
| 111 |
model.to(device)
|
| 112 |
|
| 113 |
+
tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer, model)
|
|
|
|
| 114 |
|
|
|
|
| 115 |
training_args = TrainingArguments(
|
| 116 |
output_dir=output_dir,
|
| 117 |
per_device_train_batch_size=1,
|
| 118 |
gradient_accumulation_steps=8,
|
| 119 |
+
num_train_epochs=5,
|
| 120 |
learning_rate=5e-5,
|
| 121 |
+
fp16=(device == "cuda"),
|
| 122 |
+
bf16=(device == "cuda" and torch.cuda.is_bf16_supported()),
|
| 123 |
save_strategy="epoch",
|
| 124 |
+
logging_steps=5, report_to="none", optim="adamw_torch",
|
|
|
|
|
|
|
| 125 |
remove_unused_columns=False,
|
| 126 |
)
|
| 127 |
|
| 128 |
+
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=DefaultDataCollator())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
trainer.train(resume_from_checkpoint=resume_from)
|
|
|
|
| 130 |
|
| 131 |
model.save_pretrained(output_dir)
|
| 132 |
image_processor.save_pretrained(output_dir)
|