Keeby-smilyai commited on
Commit
c80f2b2
·
verified ·
1 Parent(s): 6fedc6b

Update custom_vlm.py

Browse files
Files changed (1) hide show
  1. custom_vlm.py +27 -48
custom_vlm.py CHANGED
@@ -6,10 +6,6 @@ from transformers.models.auto.configuration_auto import AutoConfig
6
  from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
7
 
8
  class VLMConfig(PretrainedConfig):
9
- """
10
- Configuration class for our custom from-scratch Vision Language Model.
11
- This holds the configurations for the sub-modules.
12
- """
13
  model_type = "custom_scratch_vlm"
14
 
15
  def __init__(
@@ -22,12 +18,10 @@ class VLMConfig(PretrainedConfig):
22
  self.vision_config = AutoConfig.from_pretrained(vision_model_name)
23
  self.language_config = AutoConfig.from_pretrained(language_model_name)
24
  self.projection_dim = projection_dim
25
- # Make language model config aware of vocab size change if tokenizer is updated
26
  self.language_config.vocab_size = kwargs.get("vocab_size", self.language_config.vocab_size)
27
  super().__init__(**kwargs)
28
 
29
  class VLMProjector(nn.Module):
30
- """Simple MLP to project vision features into the language model's embedding space."""
31
  def __init__(self, config: VLMConfig):
32
  super().__init__()
33
  self.linear1 = nn.Linear(config.vision_config.hidden_size, config.projection_dim)
@@ -38,23 +32,15 @@ class VLMProjector(nn.Module):
38
  return self.linear2(self.gelu(self.linear1(x)))
39
 
40
  class CustomScratchVLM(PreTrainedModel):
41
- """
42
- A VLM built from randomly initialized components.
43
- """
44
  config_class = VLMConfig
45
 
46
  def __init__(self, config: VLMConfig):
47
  super().__init__(config)
48
  print("Initializing model components from scratch using their configurations...")
49
- # 1. Initialize models from their CONFIGURATIONS ONLY (random weights)
50
  self.vision_tower = AutoModel.from_config(config.vision_config)
51
  self.language_model = AutoModelForCausalLM.from_config(config.language_config)
52
-
53
- # 2. Initialize our custom projector
54
  self.multi_modal_projector = VLMProjector(config)
55
-
56
- # This will be used to find where image features should be inserted
57
- self.image_token_id = -1 # Placeholder, will be set after tokenizer is prepared
58
 
59
  def forward(
60
  self,
@@ -64,42 +50,29 @@ class CustomScratchVLM(PreTrainedModel):
64
  labels: torch.LongTensor = None,
65
  **kwargs
66
  ):
67
- # Step 1: Get image embeddings from the vision tower
68
  image_features = self.vision_tower(pixel_values).last_hidden_state
69
-
70
- # Step 2: Project image patch embeddings to the language model's input space
71
  image_embeds = self.multi_modal_projector(image_features)
72
-
73
- # Step 3: Get text embeddings
74
  text_embeds = self.language_model.get_input_embeddings()(input_ids)
75
 
76
- # Step 4: Find placeholder token indices and replace with image embeddings
77
- batch_size = input_ids.shape[0]
78
- # Find where the image token placeholder is in the input_ids
79
- # It's assumed there is one image token per sequence
80
- image_token_indices = torch.where(input_ids == self.image_token_id)
81
-
82
  final_embeds = text_embeds.clone()
83
 
84
  # Replace each placeholder with the corresponding full sequence of image embeddings
85
- for i in range(batch_size):
86
- # The start index for replacement in the text embeddings
87
- start_idx = image_token_indices[1][i]
88
- # The corresponding image embeddings for this item in the batch
89
- img_embed_item = image_embeds[i]
90
 
91
- # Construct the new embedding sequence
92
- # 1. Part of text before the image
93
- pre_img_embed = final_embeds[i, :start_idx]
94
- # 2. Part of text after the image
95
- post_img_embed = final_embeds[i, start_idx + 1:] # +1 to skip the placeholder
96
 
97
- # Concatenate them all
98
- final_embeds[i] = torch.cat(
99
- [pre_img_embed, img_embed_item, post_img_embed], dim=0
100
- )
 
 
 
 
 
101
 
102
- # Step 5: Pass combined embeddings to the language model
103
  outputs = self.language_model(
104
  inputs_embeds=final_embeds,
105
  attention_mask=attention_mask,
@@ -108,17 +81,23 @@ class CustomScratchVLM(PreTrainedModel):
108
  )
109
  return outputs
110
 
111
- def generate(self, pixel_values, prompt_ids, **kwargs):
112
- """Custom generate function to handle multimodal input."""
113
- # This is a simplified generate function. More robust implementations are complex.
114
  self.eval()
115
  with torch.no_grad():
116
  image_features = self.vision_tower(pixel_values).last_hidden_state
117
  image_embeds = self.multi_modal_projector(image_features)
118
- text_embeds = self.language_model.get_input_embeddings()(prompt_ids)
119
-
120
- # Combine embeddings (simple concatenation for generation)
121
  inputs_embeds = torch.cat([image_embeds, text_embeds], dim=1)
122
 
123
- output_ids = self.language_model.generate(inputs_embeds=inputs_embeds, **kwargs)
 
 
 
 
 
 
 
 
124
  return output_ids
 
6
  from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
7
 
8
  class VLMConfig(PretrainedConfig):
 
 
 
 
9
  model_type = "custom_scratch_vlm"
10
 
11
  def __init__(
 
18
  self.vision_config = AutoConfig.from_pretrained(vision_model_name)
19
  self.language_config = AutoConfig.from_pretrained(language_model_name)
20
  self.projection_dim = projection_dim
 
21
  self.language_config.vocab_size = kwargs.get("vocab_size", self.language_config.vocab_size)
22
  super().__init__(**kwargs)
23
 
24
  class VLMProjector(nn.Module):
 
25
  def __init__(self, config: VLMConfig):
26
  super().__init__()
27
  self.linear1 = nn.Linear(config.vision_config.hidden_size, config.projection_dim)
 
32
  return self.linear2(self.gelu(self.linear1(x)))
33
 
34
  class CustomScratchVLM(PreTrainedModel):
 
 
 
35
  config_class = VLMConfig
36
 
37
  def __init__(self, config: VLMConfig):
38
  super().__init__(config)
39
  print("Initializing model components from scratch using their configurations...")
 
40
  self.vision_tower = AutoModel.from_config(config.vision_config)
41
  self.language_model = AutoModelForCausalLM.from_config(config.language_config)
 
 
42
  self.multi_modal_projector = VLMProjector(config)
43
+ self.image_token_id = -1
 
 
44
 
45
  def forward(
46
  self,
 
50
  labels: torch.LongTensor = None,
51
  **kwargs
52
  ):
 
53
  image_features = self.vision_tower(pixel_values).last_hidden_state
 
 
54
  image_embeds = self.multi_modal_projector(image_features)
 
 
55
  text_embeds = self.language_model.get_input_embeddings()(input_ids)
56
 
 
 
 
 
 
 
57
  final_embeds = text_embeds.clone()
58
 
59
  # Replace each placeholder with the corresponding full sequence of image embeddings
60
+ for i in range(input_ids.shape[0]):
61
+ image_token_idx = torch.where(input_ids[i] == self.image_token_id)[0]
62
+ if image_token_idx.numel() == 0: continue # Skip if no image token found
 
 
63
 
64
+ image_token_idx = image_token_idx[0]
 
 
 
 
65
 
66
+ pre_img_embed = final_embeds[i, :image_token_idx]
67
+ post_img_embed = final_embeds[i, image_token_idx + 1:]
68
+
69
+ # Combine parts
70
+ combined = torch.cat([pre_img_embed, image_embeds[i], post_img_embed], dim=0)
71
+
72
+ # Since lengths can vary, we need to ensure it fits back.
73
+ # The preprocessor now handles creating correctly sized masks/labels.
74
+ final_embeds[i] = combined
75
 
 
76
  outputs = self.language_model(
77
  inputs_embeds=final_embeds,
78
  attention_mask=attention_mask,
 
81
  )
82
  return outputs
83
 
84
+ def generate(self, pixel_values, input_ids, attention_mask, **kwargs):
85
+ """Custom generate function to handle multimodal input for inference."""
 
86
  self.eval()
87
  with torch.no_grad():
88
  image_features = self.vision_tower(pixel_values).last_hidden_state
89
  image_embeds = self.multi_modal_projector(image_features)
90
+ text_embeds = self.language_model.get_input_embeddings()(input_ids)
91
+
 
92
  inputs_embeds = torch.cat([image_embeds, text_embeds], dim=1)
93
 
94
+ # Create a combined attention mask for generation
95
+ image_attention_mask = torch.ones(image_embeds.shape[:2], dtype=torch.long, device=self.device)
96
+ combined_attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
97
+
98
+ output_ids = self.language_model.generate(
99
+ inputs_embeds=inputs_embeds,
100
+ attention_mask=combined_attention_mask,
101
+ **kwargs
102
+ )
103
  return output_ids