VISION-LLM-COT / custom_vlm.py
Keeby-smilyai's picture
Update custom_vlm.py
c80f2b2 verified
# custom_vlm.py
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers.models.auto.configuration_auto import AutoConfig
from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM
class VLMConfig(PretrainedConfig):
model_type = "custom_scratch_vlm"
def __init__(
self,
vision_model_name="google/vit-base-patch16-224-in21k",
language_model_name="gpt2",
projection_dim=512,
**kwargs
):
self.vision_config = AutoConfig.from_pretrained(vision_model_name)
self.language_config = AutoConfig.from_pretrained(language_model_name)
self.projection_dim = projection_dim
self.language_config.vocab_size = kwargs.get("vocab_size", self.language_config.vocab_size)
super().__init__(**kwargs)
class VLMProjector(nn.Module):
def __init__(self, config: VLMConfig):
super().__init__()
self.linear1 = nn.Linear(config.vision_config.hidden_size, config.projection_dim)
self.gelu = nn.GELU()
self.linear2 = nn.Linear(config.projection_dim, config.language_config.hidden_size)
def forward(self, x):
return self.linear2(self.gelu(self.linear1(x)))
class CustomScratchVLM(PreTrainedModel):
config_class = VLMConfig
def __init__(self, config: VLMConfig):
super().__init__(config)
print("Initializing model components from scratch using their configurations...")
self.vision_tower = AutoModel.from_config(config.vision_config)
self.language_model = AutoModelForCausalLM.from_config(config.language_config)
self.multi_modal_projector = VLMProjector(config)
self.image_token_id = -1
def forward(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
attention_mask: torch.LongTensor,
labels: torch.LongTensor = None,
**kwargs
):
image_features = self.vision_tower(pixel_values).last_hidden_state
image_embeds = self.multi_modal_projector(image_features)
text_embeds = self.language_model.get_input_embeddings()(input_ids)
final_embeds = text_embeds.clone()
# Replace each placeholder with the corresponding full sequence of image embeddings
for i in range(input_ids.shape[0]):
image_token_idx = torch.where(input_ids[i] == self.image_token_id)[0]
if image_token_idx.numel() == 0: continue # Skip if no image token found
image_token_idx = image_token_idx[0]
pre_img_embed = final_embeds[i, :image_token_idx]
post_img_embed = final_embeds[i, image_token_idx + 1:]
# Combine parts
combined = torch.cat([pre_img_embed, image_embeds[i], post_img_embed], dim=0)
# Since lengths can vary, we need to ensure it fits back.
# The preprocessor now handles creating correctly sized masks/labels.
final_embeds[i] = combined
outputs = self.language_model(
inputs_embeds=final_embeds,
attention_mask=attention_mask,
labels=labels,
**kwargs
)
return outputs
def generate(self, pixel_values, input_ids, attention_mask, **kwargs):
"""Custom generate function to handle multimodal input for inference."""
self.eval()
with torch.no_grad():
image_features = self.vision_tower(pixel_values).last_hidden_state
image_embeds = self.multi_modal_projector(image_features)
text_embeds = self.language_model.get_input_embeddings()(input_ids)
inputs_embeds = torch.cat([image_embeds, text_embeds], dim=1)
# Create a combined attention mask for generation
image_attention_mask = torch.ones(image_embeds.shape[:2], dtype=torch.long, device=self.device)
combined_attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1)
output_ids = self.language_model.generate(
inputs_embeds=inputs_embeds,
attention_mask=combined_attention_mask,
**kwargs
)
return output_ids