Keeby-smilyai commited on
Commit
5ab1e4b
·
verified ·
1 Parent(s): c91df11

Update train_vlm.py

Browse files
Files changed (1) hide show
  1. train_vlm.py +9 -17
train_vlm.py CHANGED
@@ -28,7 +28,11 @@ def get_processors_and_model(config):
28
  return image_processor, tokenizer, model
29
 
30
  def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:50]"):
31
- dataset = load_dataset("HuggingFaceM4/TextVQA", split=split)
 
 
 
 
32
 
33
  IMAGE_TOKEN = "<IMAGE>"
34
  TEXT_MAX_LENGTH = 128
@@ -48,53 +52,41 @@ def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="tr
48
  prompt = f"USER: {IMAGE_TOKEN}\n{question}\nASSISTANT: The final answer is: {answer}."
49
 
50
  full_text = prompt + tokenizer.eos_token
51
-
52
- # Tokenize text part first, up to a max text length
53
  tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True)
54
  input_ids = torch.tensor(tokenized.input_ids)
55
 
56
- # --- CRITICAL FIX: Build Labels and Attention Mask for the FINAL sequence length ---
57
-
58
- # 1. Find the location of the image token placeholder
59
  try:
60
  image_token_idx = torch.where(input_ids == model.image_token_id)[0][0].item()
61
- except IndexError: # If token was truncated out, skip this example
62
  return None
63
 
64
- # 2. Build the LABELS tensor
65
  labels = input_ids.clone()
66
- # Mask out the prompt part (everything before and including "ASSISTANT:")
67
  assistant_marker = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids
68
  for i in range(len(labels) - len(assistant_marker) + 1):
69
  if (labels[i:i+len(assistant_marker)] == torch.tensor(assistant_marker)).all():
70
  labels[:i+len(assistant_marker)] = -100
71
  break
72
 
73
- # Expand labels to the final length, inserting padding for the image patches
74
  pre_labels = labels[:image_token_idx]
75
  post_labels = labels[image_token_idx+1:]
76
- # The image part of the labels should be all -100 (we don't predict image patches)
77
  image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long)
78
 
79
- # Combine and pad/truncate to FINAL_MAX_LENGTH
80
  final_labels = torch.cat([pre_labels, image_labels_pad, post_labels], dim=0)
81
- final_labels = torch.nn.functional.pad(final_labels, (0, FINAL_MAX_LENGTH - len(final_labels)), value=-100)
82
 
83
- # 3. Build the ATTENTION MASK in the same way
84
  attention_mask = torch.ones_like(input_ids)
85
  pre_mask = attention_mask[:image_token_idx]
86
  post_mask = attention_mask[image_token_idx+1:]
87
  image_mask = torch.ones(NUM_IMAGE_PATCHES, dtype=torch.long)
88
 
89
  final_attention_mask = torch.cat([pre_mask, image_mask, post_mask], dim=0)
90
- final_attention_mask = torch.nn.functional.pad(final_attention_mask, (0, FINAL_MAX_LENGTH - len(final_attention_mask)), value=0)
91
 
92
- # 4. Process the image
93
  pixel_values = image_processor(image, return_tensors="pt").pixel_values
94
 
95
  return {
96
  "pixel_values": pixel_values.squeeze(0),
97
- "input_ids": input_ids, # Keep original input_ids for placeholder finding
98
  "attention_mask": final_attention_mask,
99
  "labels": final_labels
100
  }
 
28
  return image_processor, tokenizer, model
29
 
30
  def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:50]"):
31
+ # --- THIS IS THE FIX ---
32
+ # Using the official facebook/textvqa dataset with the required trust_remote_code flag.
33
+ print(f"Attempting to load dataset 'facebook/textvqa' with trust_remote_code=True...")
34
+ dataset = load_dataset("facebook/textvqa", split=split, trust_remote_code=True)
35
+ print("Dataset loaded successfully.")
36
 
37
  IMAGE_TOKEN = "<IMAGE>"
38
  TEXT_MAX_LENGTH = 128
 
52
  prompt = f"USER: {IMAGE_TOKEN}\n{question}\nASSISTANT: The final answer is: {answer}."
53
 
54
  full_text = prompt + tokenizer.eos_token
 
 
55
  tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True)
56
  input_ids = torch.tensor(tokenized.input_ids)
57
 
 
 
 
58
  try:
59
  image_token_idx = torch.where(input_ids == model.image_token_id)[0][0].item()
60
+ except IndexError:
61
  return None
62
 
 
63
  labels = input_ids.clone()
 
64
  assistant_marker = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids
65
  for i in range(len(labels) - len(assistant_marker) + 1):
66
  if (labels[i:i+len(assistant_marker)] == torch.tensor(assistant_marker)).all():
67
  labels[:i+len(assistant_marker)] = -100
68
  break
69
 
 
70
  pre_labels = labels[:image_token_idx]
71
  post_labels = labels[image_token_idx+1:]
 
72
  image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long)
73
 
 
74
  final_labels = torch.cat([pre_labels, image_labels_pad, post_labels], dim=0)
75
+ final_labels = torch.nn.functional.pad(final_labels, (0, FINAL_MAX_LENGTH - len(final_labels)), value=-100)[:FINAL_MAX_LENGTH]
76
 
 
77
  attention_mask = torch.ones_like(input_ids)
78
  pre_mask = attention_mask[:image_token_idx]
79
  post_mask = attention_mask[image_token_idx+1:]
80
  image_mask = torch.ones(NUM_IMAGE_PATCHES, dtype=torch.long)
81
 
82
  final_attention_mask = torch.cat([pre_mask, image_mask, post_mask], dim=0)
83
+ final_attention_mask = torch.nn.functional.pad(final_attention_mask, (0, FINAL_MAX_LENGTH - len(final_attention_mask)), value=0)[:FINAL_MAX_LENGTH]
84
 
 
85
  pixel_values = image_processor(image, return_tensors="pt").pixel_values
86
 
87
  return {
88
  "pixel_values": pixel_values.squeeze(0),
89
+ "input_ids": input_ids,
90
  "attention_mask": final_attention_mask,
91
  "labels": final_labels
92
  }