# custom_vlm.py import torch import torch.nn as nn from transformers import PretrainedConfig, PreTrainedModel from transformers.models.auto.configuration_auto import AutoConfig from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM class VLMConfig(PretrainedConfig): model_type = "custom_scratch_vlm" def __init__( self, vision_model_name="google/vit-base-patch16-224-in21k", language_model_name="gpt2", projection_dim=512, **kwargs ): self.vision_config = AutoConfig.from_pretrained(vision_model_name) self.language_config = AutoConfig.from_pretrained(language_model_name) self.projection_dim = projection_dim self.language_config.vocab_size = kwargs.get("vocab_size", self.language_config.vocab_size) super().__init__(**kwargs) class VLMProjector(nn.Module): def __init__(self, config: VLMConfig): super().__init__() self.linear1 = nn.Linear(config.vision_config.hidden_size, config.projection_dim) self.gelu = nn.GELU() self.linear2 = nn.Linear(config.projection_dim, config.language_config.hidden_size) def forward(self, x): return self.linear2(self.gelu(self.linear1(x))) class CustomScratchVLM(PreTrainedModel): config_class = VLMConfig def __init__(self, config: VLMConfig): super().__init__(config) print("Initializing model components from scratch using their configurations...") self.vision_tower = AutoModel.from_config(config.vision_config) self.language_model = AutoModelForCausalLM.from_config(config.language_config) self.multi_modal_projector = VLMProjector(config) self.image_token_id = -1 def forward( self, input_ids: torch.LongTensor, pixel_values: torch.FloatTensor, attention_mask: torch.LongTensor, labels: torch.LongTensor = None, **kwargs ): image_features = self.vision_tower(pixel_values).last_hidden_state image_embeds = self.multi_modal_projector(image_features) text_embeds = self.language_model.get_input_embeddings()(input_ids) final_embeds = text_embeds.clone() # Replace each placeholder with the corresponding full sequence of image embeddings for i in range(input_ids.shape[0]): image_token_idx = torch.where(input_ids[i] == self.image_token_id)[0] if image_token_idx.numel() == 0: continue # Skip if no image token found image_token_idx = image_token_idx[0] pre_img_embed = final_embeds[i, :image_token_idx] post_img_embed = final_embeds[i, image_token_idx + 1:] # Combine parts combined = torch.cat([pre_img_embed, image_embeds[i], post_img_embed], dim=0) # Since lengths can vary, we need to ensure it fits back. # The preprocessor now handles creating correctly sized masks/labels. final_embeds[i] = combined outputs = self.language_model( inputs_embeds=final_embeds, attention_mask=attention_mask, labels=labels, **kwargs ) return outputs def generate(self, pixel_values, input_ids, attention_mask, **kwargs): """Custom generate function to handle multimodal input for inference.""" self.eval() with torch.no_grad(): image_features = self.vision_tower(pixel_values).last_hidden_state image_embeds = self.multi_modal_projector(image_features) text_embeds = self.language_model.get_input_embeddings()(input_ids) inputs_embeds = torch.cat([image_embeds, text_embeds], dim=1) # Create a combined attention mask for generation image_attention_mask = torch.ones(image_embeds.shape[:2], dtype=torch.long, device=self.device) combined_attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) output_ids = self.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=combined_attention_mask, **kwargs ) return output_ids