File size: 4,309 Bytes
7b6828a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c80f2b2
7b6828a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c80f2b2
 
 
7b6828a
c80f2b2
7b6828a
c80f2b2
 
 
 
 
 
 
 
 
7b6828a
 
 
 
 
 
 
 
 
c80f2b2
 
7b6828a
 
 
 
c80f2b2
 
7b6828a
 
c80f2b2
 
 
 
 
 
 
 
 
7b6828a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# custom_vlm.py
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM

class VLMConfig(PretrainedConfig):
    model_type = "custom_scratch_vlm"

    def __init__(
        self,
        vision_model_name="google/vit-base-patch16-224-in21k",
        language_model_name="gpt2",
        projection_dim=512,
        **kwargs
    ):
        self.vision_config = AutoConfig.from_pretrained(vision_model_name)
        self.language_config = AutoConfig.from_pretrained(language_model_name)
        self.projection_dim = projection_dim
        self.language_config.vocab_size = kwargs.get("vocab_size", self.language_config.vocab_size)
        super().__init__(**kwargs)

class VLMProjector(nn.Module):
    def __init__(self, config: VLMConfig):
        super().__init__()
        self.linear1 = nn.Linear(config.vision_config.hidden_size, config.projection_dim)
        self.gelu = nn.GELU()
        self.linear2 = nn.Linear(config.projection_dim, config.language_config.hidden_size)

    def forward(self, x):
        return self.linear2(self.gelu(self.linear1(x)))

class CustomScratchVLM(PreTrainedModel):
    config_class = VLMConfig
    
    def __init__(self, config: VLMConfig):
        super().__init__(config)
        print("Initializing model components from scratch using their configurations...")
        self.vision_tower = AutoModel.from_config(config.vision_config)
        self.language_model = AutoModelForCausalLM.from_config(config.language_config)
        self.multi_modal_projector = VLMProjector(config)
        self.image_token_id = -1 

    def forward(
        self,
        input_ids: torch.LongTensor,
        pixel_values: torch.FloatTensor,
        attention_mask: torch.LongTensor,
        labels: torch.LongTensor = None,
        **kwargs
    ):
        image_features = self.vision_tower(pixel_values).last_hidden_state
        image_embeds = self.multi_modal_projector(image_features)
        text_embeds = self.language_model.get_input_embeddings()(input_ids)

        final_embeds = text_embeds.clone()
        
        # Replace each placeholder with the corresponding full sequence of image embeddings
        for i in range(input_ids.shape[0]):
            image_token_idx = torch.where(input_ids[i] == self.image_token_id)[0]
            if image_token_idx.numel() == 0: continue # Skip if no image token found
            
            image_token_idx = image_token_idx[0]
            
            pre_img_embed = final_embeds[i, :image_token_idx]
            post_img_embed = final_embeds[i, image_token_idx + 1:]
            
            # Combine parts
            combined = torch.cat([pre_img_embed, image_embeds[i], post_img_embed], dim=0)
            
            # Since lengths can vary, we need to ensure it fits back.
            # The preprocessor now handles creating correctly sized masks/labels.
            final_embeds[i] = combined

        outputs = self.language_model(
            inputs_embeds=final_embeds,
            attention_mask=attention_mask,
            labels=labels,
            **kwargs
        )
        return outputs

    def generate(self, pixel_values, input_ids, attention_mask, **kwargs):
        """Custom generate function to handle multimodal input for inference."""
        self.eval()
        with torch.no_grad():
            image_features = self.vision_tower(pixel_values).last_hidden_state
            image_embeds = self.multi_modal_projector(image_features)
            text_embeds = self.language_model.get_input_embeddings()(input_ids)
            
            inputs_embeds = torch.cat([image_embeds, text_embeds], dim=1)
            
            # Create a combined attention mask for generation
            image_attention_mask = torch.ones(image_embeds.shape[:2], dtype=torch.long, device=self.device)
            combined_attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
            
            output_ids = self.language_model.generate(
                inputs_embeds=inputs_embeds,
                attention_mask=combined_attention_mask,
                **kwargs
            )
        return output_ids