Keeby-smilyai commited on
Commit
589d16b
·
verified ·
1 Parent(s): 5ab1e4b

Update train_vlm.py

Browse files
Files changed (1) hide show
  1. train_vlm.py +42 -24
train_vlm.py CHANGED
@@ -27,31 +27,39 @@ def get_processors_and_model(config):
27
 
28
  return image_processor, tokenizer, model
29
 
30
- def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:50]"):
31
- # --- THIS IS THE FIX ---
32
- # Using the official facebook/textvqa dataset with the required trust_remote_code flag.
33
- print(f"Attempting to load dataset 'facebook/textvqa' with trust_remote_code=True...")
34
- dataset = load_dataset("facebook/textvqa", split=split, trust_remote_code=True)
35
  print("Dataset loaded successfully.")
36
 
37
  IMAGE_TOKEN = "<IMAGE>"
38
- TEXT_MAX_LENGTH = 128
39
  NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) ** 2
40
  FINAL_MAX_LENGTH = TEXT_MAX_LENGTH - 1 + NUM_IMAGE_PATCHES
41
 
42
  def preprocess_function(examples):
43
  image = examples['image'].convert("RGB")
44
- question = examples.get('question', '')
45
- answer = examples['answers'][0] if examples.get('answers') else "unknown"
 
46
 
47
- if stage == 1:
48
- prompt = f"USER: {IMAGE_TOKEN}\nQuestion: {question}\nASSISTANT: {answer}"
49
- elif stage == 2:
50
- prompt = f"USER: {IMAGE_TOKEN}\nQuestion: {question} Think step-by-step.\nASSISTANT: I think the answer is {answer}."
51
- else:
52
- prompt = f"USER: {IMAGE_TOKEN}\n{question}\nASSISTANT: The final answer is: {answer}."
53
-
54
- full_text = prompt + tokenizer.eos_token
 
 
 
 
 
 
 
 
55
  tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True)
56
  input_ids = torch.tensor(tokenized.input_ids)
57
 
@@ -61,12 +69,21 @@ def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="tr
61
  return None
62
 
63
  labels = input_ids.clone()
64
- assistant_marker = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids
65
- for i in range(len(labels) - len(assistant_marker) + 1):
66
- if (labels[i:i+len(assistant_marker)] == torch.tensor(assistant_marker)).all():
67
- labels[:i+len(assistant_marker)] = -100
68
- break
69
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  pre_labels = labels[:image_token_idx]
71
  post_labels = labels[image_token_idx+1:]
72
  image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long)
@@ -95,20 +112,21 @@ def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="tr
95
  return processed_dataset.filter(lambda x: x is not None)
96
 
97
  def train_vlm_stage(stage, output_dir, resume_from=None):
98
- print(f"🚀 Starting VLM Stage {stage} Training FROM SCRATCH...")
99
  device = "cuda" if torch.cuda.is_available() else "cpu"
100
 
101
  vlm_config = VLMConfig()
102
  image_processor, tokenizer, model = get_processors_and_model(vlm_config)
103
  model.to(device)
104
 
105
- tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer, model)
 
106
 
107
  training_args = TrainingArguments(
108
  output_dir=output_dir,
109
  per_device_train_batch_size=1,
110
  gradient_accumulation_steps=8,
111
- num_train_epochs=5,
112
  learning_rate=5e-5,
113
  fp16=(device == "cuda"),
114
  bf16=(device == "cuda" and torch.cuda.is_bf16_supported()),
 
27
 
28
  return image_processor, tokenizer, model
29
 
30
+ def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:200]"):
31
+ # --- USING THE DATASET YOU SPECIFIED ---
32
+ print("Loading dataset 'zera09/lmarena-ai_VisionArena-Chat-en'...")
33
+ dataset = load_dataset("zera09/lmarena-ai_VisionArena-Chat-en", split=split)
 
34
  print("Dataset loaded successfully.")
35
 
36
  IMAGE_TOKEN = "<IMAGE>"
37
+ TEXT_MAX_LENGTH = 256
38
  NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) ** 2
39
  FINAL_MAX_LENGTH = TEXT_MAX_LENGTH - 1 + NUM_IMAGE_PATCHES
40
 
41
  def preprocess_function(examples):
42
  image = examples['image'].convert("RGB")
43
+ # --- USING THE CONVERSATION FORMAT YOU PROVIDED ---
44
+ # We select 'conversation_a' and parse it as a list of lists of dicts.
45
+ conversation = examples['conversation_a']
46
 
47
+ full_text = ""
48
+ is_first_user_turn = True
49
+ for turn_list in conversation:
50
+ if not turn_list: continue
51
+ turn = turn_list[0]
52
+
53
+ role = turn['role'].upper()
54
+ content = turn['content']
55
+
56
+ if role == "USER" and is_first_user_turn:
57
+ full_text += f"USER: {IMAGE_TOKEN}\n{content}\n"
58
+ is_first_user_turn = False
59
+ else:
60
+ full_text += f"{role}: {content}\n"
61
+
62
+ full_text += tokenizer.eos_token
63
  tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True)
64
  input_ids = torch.tensor(tokenized.input_ids)
65
 
 
69
  return None
70
 
71
  labels = input_ids.clone()
72
+ assistant_marker_ids = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids
73
+ is_assistant_section = torch.zeros_like(labels, dtype=torch.bool)
 
 
 
74
 
75
+ for i in range(len(labels) - len(assistant_marker_ids) + 1):
76
+ if (labels[i:i+len(assistant_marker_ids)] == torch.tensor(assistant_marker_ids)).all():
77
+ end_idx = len(labels)
78
+ user_marker_ids = tokenizer("USER:", add_special_tokens=False).input_ids
79
+ for j in range(i + 1, len(labels) - len(user_marker_ids) + 1):
80
+ if (labels[j:j+len(user_marker_ids)] == torch.tensor(user_marker_ids)).all():
81
+ end_idx = j
82
+ break
83
+ is_assistant_section[i:end_idx] = True
84
+
85
+ labels[~is_assistant_section] = -100
86
+
87
  pre_labels = labels[:image_token_idx]
88
  post_labels = labels[image_token_idx+1:]
89
  image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long)
 
112
  return processed_dataset.filter(lambda x: x is not None)
113
 
114
  def train_vlm_stage(stage, output_dir, resume_from=None):
115
+ print(f"🚀 Starting VLM Conversational Training Stage {stage} FROM SCRATCH...")
116
  device = "cuda" if torch.cuda.is_available() else "cpu"
117
 
118
  vlm_config = VLMConfig()
119
  image_processor, tokenizer, model = get_processors_and_model(vlm_config)
120
  model.to(device)
121
 
122
+ split = f"train[{200*(stage-1)}:{200*stage}]"
123
+ tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer, model, split=split)
124
 
125
  training_args = TrainingArguments(
126
  output_dir=output_dir,
127
  per_device_train_batch_size=1,
128
  gradient_accumulation_steps=8,
129
+ num_train_epochs=3,
130
  learning_rate=5e-5,
131
  fp16=(device == "cuda"),
132
  bf16=(device == "cuda" and torch.cuda.is_bf16_supported()),