Keeby-smilyai commited on
Commit
be4d66f
·
verified ·
1 Parent(s): c80f2b2

Update train_vlm.py

Browse files
Files changed (1) hide show
  1. train_vlm.py +58 -79
train_vlm.py CHANGED
@@ -7,147 +7,126 @@ from transformers import (
7
  )
8
  from datasets import load_dataset
9
  from PIL import Image
10
-
11
- # Import our custom VLM architecture
12
  from custom_vlm import CustomScratchVLM, VLMConfig
13
 
14
- # --- Tokenizer and Processor Setup ---
15
  def get_processors_and_model(config):
16
- """Initializes tokenizer, image processor, and the custom VLM."""
17
- # Using the sub-model names from our config
18
  vision_model_name = config.vision_config._name_or_path
19
  language_model_name = config.language_config._name_or_path
20
 
21
- # 1. Load standard processors for the chosen sub-models
22
  image_processor = AutoImageProcessor.from_pretrained(vision_model_name)
23
  tokenizer = AutoTokenizer.from_pretrained(language_model_name)
24
 
25
- # 2. Add a special token for the image placeholder
26
  IMAGE_TOKEN = "<IMAGE>"
27
  tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]})
28
  if tokenizer.pad_token is None:
29
  tokenizer.pad_token = tokenizer.eos_token
30
 
31
- # 3. Update the VLM config with the new vocab size
32
  config.language_config.vocab_size = len(tokenizer)
33
-
34
- # 4. Instantiate our from-scratch model
35
  model = CustomScratchVLM(config)
36
  model.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
37
 
38
  return image_processor, tokenizer, model
39
 
40
- # --- Dataset and Preprocessing ---
41
- def load_and_prepare_dataset(stage, image_processor, tokenizer, split="train[:50]"):
42
  dataset = load_dataset("HuggingFaceM4/TextVQA", split=split)
43
 
44
  IMAGE_TOKEN = "<IMAGE>"
45
- # This is the number of patch embeddings from the ViT
46
- NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) * \
47
- (image_processor.size['width'] // image_processor.patch_size)
48
 
49
  def preprocess_function(examples):
50
- # This function is now much more complex
51
  image = examples['image'].convert("RGB")
52
  question = examples.get('question', '')
53
  answer = examples['answers'][0] if examples.get('answers') else "unknown"
54
 
55
- # Stage-specific formatting
56
  if stage == 1:
57
- prompt = f"USER: {IMAGE_TOKEN}\nQ: {question}\nA: Let's think step by step.\nASSISTANT: I see {answer} in the image. Therefore, the answer is {answer}."
58
  elif stage == 2:
59
- prompt = f"USER: {IMAGE_TOKEN}\nQ: {question}\nA: [INTERNAL THOUGHT HIDDEN]... Final Answer:\nASSISTANT: {answer}"
60
- else: # Stage 3
61
- prompt = f"USER: {IMAGE_TOKEN}\nQ: {question}\nA: Think deeply.\nASSISTANT: I revise: '{answer}' is correct. Confidence: 89%."
62
 
63
- # Tokenize text
64
  full_text = prompt + tokenizer.eos_token
65
- tokenized = tokenizer(full_text, truncation=True, padding="max_length", max_length=256, return_tensors="pt")
66
 
67
- # Prepare labels for causal language modeling
68
- labels = tokenized.input_ids.clone()
 
 
 
69
 
70
- # Mask out the user's prompt part in the labels
71
- # Find where the assistant's response starts
72
  try:
73
- assistant_start_marker = "ASSISTANT:"
74
- # Find the token IDs for the marker
75
- marker_ids = tokenizer(assistant_start_marker, add_special_tokens=False).input_ids
76
- # Search for this sequence of IDs in the labels
77
- assistant_start_idx = -1
78
- for i in range(len(labels[0]) - len(marker_ids) + 1):
79
- if (labels[0, i:i+len(marker_ids)] == torch.tensor(marker_ids)).all():
80
- assistant_start_idx = i
81
- break
82
- if assistant_start_idx != -1:
83
- labels[0, :assistant_start_idx + len(marker_ids)] = -100 # Mask everything before and including the marker
84
- except Exception:
85
- # If something fails, just mask the first token
86
- labels[0, 0] = -100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- # Process image
89
  pixel_values = image_processor(image, return_tensors="pt").pixel_values
90
 
91
- # The model's forward pass expects the placeholder to be replaced by N patches.
92
- # But for input_ids, we only have one placeholder. The attention mask needs to
93
- # be expanded to account for the N patches that will replace the single token.
94
- image_token_idx = torch.where(tokenized.input_ids == model.image_token_id)[1]
95
-
96
- # Create a new attention mask
97
- new_attention_mask = torch.cat([
98
- tokenized.attention_mask[:, :image_token_idx],
99
- torch.ones(1, NUM_IMAGE_PATCHES, dtype=torch.long),
100
- tokenized.attention_mask[:, image_token_idx+1:]
101
- ], dim=1)
102
-
103
  return {
104
- "pixel_values": pixel_values.squeeze(),
105
- "input_ids": tokenized.input_ids.squeeze(),
106
- "attention_mask": new_attention_mask.squeeze(),
107
- "labels": labels.squeeze()
108
  }
109
 
110
- return dataset.map(preprocess_function, remove_columns=list(dataset.column_names))
 
111
 
112
- # --- Training ---
113
  def train_vlm_stage(stage, output_dir, resume_from=None):
114
  print(f"🚀 Starting VLM Stage {stage} Training FROM SCRATCH...")
115
  device = "cuda" if torch.cuda.is_available() else "cpu"
116
 
117
- # 1. Get our custom model and its processors
118
  vlm_config = VLMConfig()
119
  image_processor, tokenizer, model = get_processors_and_model(vlm_config)
120
  model.to(device)
121
 
122
- # 2. Load and prepare dataset using the new pipeline
123
- tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer)
124
 
125
- is_cuda = (device == "cuda")
126
  training_args = TrainingArguments(
127
  output_dir=output_dir,
128
  per_device_train_batch_size=1,
129
  gradient_accumulation_steps=8,
130
- num_train_epochs=5, # More epochs needed for from-scratch
131
  learning_rate=5e-5,
132
- fp16=is_cuda,
133
- bf16=is_cuda and torch.cuda.is_bf16_supported(),
134
  save_strategy="epoch",
135
- logging_steps=5,
136
- report_to="none",
137
- optim="adamw_torch",
138
  remove_unused_columns=False,
139
  )
140
 
141
- trainer = Trainer(
142
- model=model,
143
- args=training_args,
144
- train_dataset=tokenized_dataset,
145
- data_collator=DefaultDataCollator()
146
- )
147
-
148
- print("--- Starting Trainer on From-Scratch Model ---")
149
  trainer.train(resume_from_checkpoint=resume_from)
150
- print("--- Training Finished ---")
151
 
152
  model.save_pretrained(output_dir)
153
  image_processor.save_pretrained(output_dir)
 
7
  )
8
  from datasets import load_dataset
9
  from PIL import Image
 
 
10
  from custom_vlm import CustomScratchVLM, VLMConfig
11
 
 
12
  def get_processors_and_model(config):
 
 
13
  vision_model_name = config.vision_config._name_or_path
14
  language_model_name = config.language_config._name_or_path
15
 
 
16
  image_processor = AutoImageProcessor.from_pretrained(vision_model_name)
17
  tokenizer = AutoTokenizer.from_pretrained(language_model_name)
18
 
 
19
  IMAGE_TOKEN = "<IMAGE>"
20
  tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]})
21
  if tokenizer.pad_token is None:
22
  tokenizer.pad_token = tokenizer.eos_token
23
 
 
24
  config.language_config.vocab_size = len(tokenizer)
 
 
25
  model = CustomScratchVLM(config)
26
  model.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
27
 
28
  return image_processor, tokenizer, model
29
 
30
+ def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:50]"):
 
31
  dataset = load_dataset("HuggingFaceM4/TextVQA", split=split)
32
 
33
  IMAGE_TOKEN = "<IMAGE>"
34
+ TEXT_MAX_LENGTH = 128
35
+ NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) ** 2
36
+ FINAL_MAX_LENGTH = TEXT_MAX_LENGTH - 1 + NUM_IMAGE_PATCHES
37
 
38
  def preprocess_function(examples):
 
39
  image = examples['image'].convert("RGB")
40
  question = examples.get('question', '')
41
  answer = examples['answers'][0] if examples.get('answers') else "unknown"
42
 
 
43
  if stage == 1:
44
+ prompt = f"USER: {IMAGE_TOKEN}\nQuestion: {question}\nASSISTANT: {answer}"
45
  elif stage == 2:
46
+ prompt = f"USER: {IMAGE_TOKEN}\nQuestion: {question} Think step-by-step.\nASSISTANT: I think the answer is {answer}."
47
+ else:
48
+ prompt = f"USER: {IMAGE_TOKEN}\n{question}\nASSISTANT: The final answer is: {answer}."
49
 
 
50
  full_text = prompt + tokenizer.eos_token
 
51
 
52
+ # Tokenize text part first, up to a max text length
53
+ tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True)
54
+ input_ids = torch.tensor(tokenized.input_ids)
55
+
56
+ # --- CRITICAL FIX: Build Labels and Attention Mask for the FINAL sequence length ---
57
 
58
+ # 1. Find the location of the image token placeholder
 
59
  try:
60
+ image_token_idx = torch.where(input_ids == model.image_token_id)[0][0].item()
61
+ except IndexError: # If token was truncated out, skip this example
62
+ return None
63
+
64
+ # 2. Build the LABELS tensor
65
+ labels = input_ids.clone()
66
+ # Mask out the prompt part (everything before and including "ASSISTANT:")
67
+ assistant_marker = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids
68
+ for i in range(len(labels) - len(assistant_marker) + 1):
69
+ if (labels[i:i+len(assistant_marker)] == torch.tensor(assistant_marker)).all():
70
+ labels[:i+len(assistant_marker)] = -100
71
+ break
72
+
73
+ # Expand labels to the final length, inserting padding for the image patches
74
+ pre_labels = labels[:image_token_idx]
75
+ post_labels = labels[image_token_idx+1:]
76
+ # The image part of the labels should be all -100 (we don't predict image patches)
77
+ image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long)
78
+
79
+ # Combine and pad/truncate to FINAL_MAX_LENGTH
80
+ final_labels = torch.cat([pre_labels, image_labels_pad, post_labels], dim=0)
81
+ final_labels = torch.nn.functional.pad(final_labels, (0, FINAL_MAX_LENGTH - len(final_labels)), value=-100)
82
+
83
+ # 3. Build the ATTENTION MASK in the same way
84
+ attention_mask = torch.ones_like(input_ids)
85
+ pre_mask = attention_mask[:image_token_idx]
86
+ post_mask = attention_mask[image_token_idx+1:]
87
+ image_mask = torch.ones(NUM_IMAGE_PATCHES, dtype=torch.long)
88
+
89
+ final_attention_mask = torch.cat([pre_mask, image_mask, post_mask], dim=0)
90
+ final_attention_mask = torch.nn.functional.pad(final_attention_mask, (0, FINAL_MAX_LENGTH - len(final_attention_mask)), value=0)
91
 
92
+ # 4. Process the image
93
  pixel_values = image_processor(image, return_tensors="pt").pixel_values
94
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  return {
96
+ "pixel_values": pixel_values.squeeze(0),
97
+ "input_ids": input_ids, # Keep original input_ids for placeholder finding
98
+ "attention_mask": final_attention_mask,
99
+ "labels": final_labels
100
  }
101
 
102
+ processed_dataset = dataset.map(preprocess_function, remove_columns=list(dataset.column_names))
103
+ return processed_dataset.filter(lambda x: x is not None)
104
 
 
105
  def train_vlm_stage(stage, output_dir, resume_from=None):
106
  print(f"🚀 Starting VLM Stage {stage} Training FROM SCRATCH...")
107
  device = "cuda" if torch.cuda.is_available() else "cpu"
108
 
 
109
  vlm_config = VLMConfig()
110
  image_processor, tokenizer, model = get_processors_and_model(vlm_config)
111
  model.to(device)
112
 
113
+ tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer, model)
 
114
 
 
115
  training_args = TrainingArguments(
116
  output_dir=output_dir,
117
  per_device_train_batch_size=1,
118
  gradient_accumulation_steps=8,
119
+ num_train_epochs=5,
120
  learning_rate=5e-5,
121
+ fp16=(device == "cuda"),
122
+ bf16=(device == "cuda" and torch.cuda.is_bf16_supported()),
123
  save_strategy="epoch",
124
+ logging_steps=5, report_to="none", optim="adamw_torch",
 
 
125
  remove_unused_columns=False,
126
  )
127
 
128
+ trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=DefaultDataCollator())
 
 
 
 
 
 
 
129
  trainer.train(resume_from_checkpoint=resume_from)
 
130
 
131
  model.save_pretrained(output_dir)
132
  image_processor.save_pretrained(output_dir)