File size: 6,188 Bytes
5878db1
2626b5f
 
 
5878db1
 
2626b5f
 
 
5878db1
2626b5f
5878db1
 
 
 
 
 
2626b5f
5878db1
 
 
 
 
 
 
 
e5aed01
5878db1
2626b5f
589d16b
 
 
 
5ab1e4b
5878db1
 
589d16b
be4d66f
 
5878db1
 
 
589d16b
 
cb3d04d
5878db1
589d16b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be4d66f
 
 
e5aed01
be4d66f
5ab1e4b
be4d66f
 
 
589d16b
 
be4d66f
589d16b
 
 
 
 
 
 
 
 
 
 
 
be4d66f
 
 
 
 
5ab1e4b
be4d66f
 
 
 
 
 
 
5ab1e4b
5878db1
 
 
 
be4d66f
5ab1e4b
be4d66f
 
5878db1
 
be4d66f
 
5878db1
 
589d16b
5878db1
 
 
 
 
2626b5f
589d16b
 
5878db1
2626b5f
5878db1
 
 
589d16b
5878db1
be4d66f
 
5878db1
be4d66f
5878db1
2626b5f
 
be4d66f
2626b5f
5878db1
2626b5f
5878db1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# train_vlm.py
import os
import torch
from transformers import (
    TrainingArguments, Trainer, DefaultDataCollator,
    AutoTokenizer, AutoImageProcessor
)
from datasets import load_dataset
from PIL import Image
from custom_vlm import CustomScratchVLM, VLMConfig

def get_processors_and_model(config):
    vision_model_name = config.vision_config._name_or_path
    language_model_name = config.language_config._name_or_path
    
    image_processor = AutoImageProcessor.from_pretrained(vision_model_name)
    tokenizer = AutoTokenizer.from_pretrained(language_model_name)

    IMAGE_TOKEN = "<IMAGE>"
    tokenizer.add_special_tokens({"additional_special_tokens": [IMAGE_TOKEN]})
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    config.language_config.vocab_size = len(tokenizer)
    model = CustomScratchVLM(config)
    model.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
    
    return image_processor, tokenizer, model

def load_and_prepare_dataset(stage, image_processor, tokenizer, model, split="train[:200]"):
    # --- USING THE DATASET YOU SPECIFIED ---
    print("Loading dataset 'zera09/lmarena-ai_VisionArena-Chat-en'...")
    dataset = load_dataset("zera09/lmarena-ai_VisionArena-Chat-en", split=split)
    print("Dataset loaded successfully.")
    
    IMAGE_TOKEN = "<IMAGE>"
    TEXT_MAX_LENGTH = 256
    NUM_IMAGE_PATCHES = (image_processor.size['height'] // image_processor.patch_size) ** 2
    FINAL_MAX_LENGTH = TEXT_MAX_LENGTH - 1 + NUM_IMAGE_PATCHES

    def preprocess_function(examples):
        image = examples['image'].convert("RGB")
        # --- USING THE CONVERSATION FORMAT YOU PROVIDED ---
        # We select 'conversation_a' and parse it as a list of lists of dicts.
        conversation = examples['conversation']
        
        full_text = ""
        is_first_user_turn = True
        for turn_list in conversation:
            if not turn_list: continue
            turn = turn_list[0]
            
            role = turn['role'].upper()
            content = turn['content']
            
            if role == "USER" and is_first_user_turn:
                full_text += f"USER: {IMAGE_TOKEN}\n{content}\n"
                is_first_user_turn = False
            else:
                full_text += f"{role}: {content}\n"
        
        full_text += tokenizer.eos_token
        tokenized = tokenizer(full_text, max_length=TEXT_MAX_LENGTH, truncation=True)
        input_ids = torch.tensor(tokenized.input_ids)

        try:
            image_token_idx = torch.where(input_ids == model.image_token_id)[0][0].item()
        except IndexError:
            return None

        labels = input_ids.clone()
        assistant_marker_ids = tokenizer("ASSISTANT:", add_special_tokens=False).input_ids
        is_assistant_section = torch.zeros_like(labels, dtype=torch.bool)
        
        for i in range(len(labels) - len(assistant_marker_ids) + 1):
            if (labels[i:i+len(assistant_marker_ids)] == torch.tensor(assistant_marker_ids)).all():
                end_idx = len(labels)
                user_marker_ids = tokenizer("USER:", add_special_tokens=False).input_ids
                for j in range(i + 1, len(labels) - len(user_marker_ids) + 1):
                    if (labels[j:j+len(user_marker_ids)] == torch.tensor(user_marker_ids)).all():
                        end_idx = j
                        break
                is_assistant_section[i:end_idx] = True
        
        labels[~is_assistant_section] = -100

        pre_labels = labels[:image_token_idx]
        post_labels = labels[image_token_idx+1:]
        image_labels_pad = torch.full((NUM_IMAGE_PATCHES,), -100, dtype=torch.long)
        
        final_labels = torch.cat([pre_labels, image_labels_pad, post_labels], dim=0)
        final_labels = torch.nn.functional.pad(final_labels, (0, FINAL_MAX_LENGTH - len(final_labels)), value=-100)[:FINAL_MAX_LENGTH]

        attention_mask = torch.ones_like(input_ids)
        pre_mask = attention_mask[:image_token_idx]
        post_mask = attention_mask[image_token_idx+1:]
        image_mask = torch.ones(NUM_IMAGE_PATCHES, dtype=torch.long)
        
        final_attention_mask = torch.cat([pre_mask, image_mask, post_mask], dim=0)
        final_attention_mask = torch.nn.functional.pad(final_attention_mask, (0, FINAL_MAX_LENGTH - len(final_attention_mask)), value=0)[:FINAL_MAX_LENGTH]

        pixel_values = image_processor(image, return_tensors="pt").pixel_values

        return {
            "pixel_values": pixel_values.squeeze(0),
            "input_ids": input_ids,
            "attention_mask": final_attention_mask,
            "labels": final_labels
        }

    processed_dataset = dataset.map(preprocess_function, remove_columns=list(dataset.column_names))
    return processed_dataset.filter(lambda x: x is not None)

def train_vlm_stage(stage, output_dir, resume_from=None):
    print(f"🚀 Starting VLM Conversational Training Stage {stage} FROM SCRATCH...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    vlm_config = VLMConfig()
    image_processor, tokenizer, model = get_processors_and_model(vlm_config)
    model.to(device)

    split = f"train[{200*(stage-1)}:{200*stage}]"
    tokenized_dataset = load_and_prepare_dataset(stage, image_processor, tokenizer, model, split=split)
    
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=8,
        num_train_epochs=3,
        learning_rate=5e-5,
        fp16=(device == "cuda"),
        bf16=(device == "cuda" and torch.cuda.is_bf16_supported()),
        save_strategy="epoch",
        logging_steps=5, report_to="none", optim="adamw_torch",
        remove_unused_columns=False,
    )

    trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=DefaultDataCollator())
    trainer.train(resume_from_checkpoint=resume_from)

    model.save_pretrained(output_dir)
    image_processor.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"✅ Stage {stage} model and processors saved to {output_dir}")