Spaces:
Sleeping
Sleeping
| # custom_vlm.py | |
| import torch | |
| import torch.nn as nn | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from transformers.models.auto.configuration_auto import AutoConfig | |
| from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM | |
| class VLMConfig(PretrainedConfig): | |
| model_type = "custom_scratch_vlm" | |
| def __init__( | |
| self, | |
| vision_model_name="google/vit-base-patch16-224-in21k", | |
| language_model_name="gpt2", | |
| projection_dim=512, | |
| **kwargs | |
| ): | |
| self.vision_config = AutoConfig.from_pretrained(vision_model_name) | |
| self.language_config = AutoConfig.from_pretrained(language_model_name) | |
| self.projection_dim = projection_dim | |
| self.language_config.vocab_size = kwargs.get("vocab_size", self.language_config.vocab_size) | |
| super().__init__(**kwargs) | |
| class VLMProjector(nn.Module): | |
| def __init__(self, config: VLMConfig): | |
| super().__init__() | |
| self.linear1 = nn.Linear(config.vision_config.hidden_size, config.projection_dim) | |
| self.gelu = nn.GELU() | |
| self.linear2 = nn.Linear(config.projection_dim, config.language_config.hidden_size) | |
| def forward(self, x): | |
| return self.linear2(self.gelu(self.linear1(x))) | |
| class CustomScratchVLM(PreTrainedModel): | |
| config_class = VLMConfig | |
| def __init__(self, config: VLMConfig): | |
| super().__init__(config) | |
| print("Initializing model components from scratch using their configurations...") | |
| self.vision_tower = AutoModel.from_config(config.vision_config) | |
| self.language_model = AutoModelForCausalLM.from_config(config.language_config) | |
| self.multi_modal_projector = VLMProjector(config) | |
| self.image_token_id = -1 | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor, | |
| pixel_values: torch.FloatTensor, | |
| attention_mask: torch.LongTensor, | |
| labels: torch.LongTensor = None, | |
| **kwargs | |
| ): | |
| image_features = self.vision_tower(pixel_values).last_hidden_state | |
| image_embeds = self.multi_modal_projector(image_features) | |
| text_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| final_embeds = text_embeds.clone() | |
| # Replace each placeholder with the corresponding full sequence of image embeddings | |
| for i in range(input_ids.shape[0]): | |
| image_token_idx = torch.where(input_ids[i] == self.image_token_id)[0] | |
| if image_token_idx.numel() == 0: continue # Skip if no image token found | |
| image_token_idx = image_token_idx[0] | |
| pre_img_embed = final_embeds[i, :image_token_idx] | |
| post_img_embed = final_embeds[i, image_token_idx + 1:] | |
| # Combine parts | |
| combined = torch.cat([pre_img_embed, image_embeds[i], post_img_embed], dim=0) | |
| # Since lengths can vary, we need to ensure it fits back. | |
| # The preprocessor now handles creating correctly sized masks/labels. | |
| final_embeds[i] = combined | |
| outputs = self.language_model( | |
| inputs_embeds=final_embeds, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| **kwargs | |
| ) | |
| return outputs | |
| def generate(self, pixel_values, input_ids, attention_mask, **kwargs): | |
| """Custom generate function to handle multimodal input for inference.""" | |
| self.eval() | |
| with torch.no_grad(): | |
| image_features = self.vision_tower(pixel_values).last_hidden_state | |
| image_embeds = self.multi_modal_projector(image_features) | |
| text_embeds = self.language_model.get_input_embeddings()(input_ids) | |
| inputs_embeds = torch.cat([image_embeds, text_embeds], dim=1) | |
| # Create a combined attention mask for generation | |
| image_attention_mask = torch.ones(image_embeds.shape[:2], dtype=torch.long, device=self.device) | |
| combined_attention_mask = torch.cat([image_attention_mask, attention_mask], dim=1) | |
| output_ids = self.language_model.generate( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=combined_attention_mask, | |
| **kwargs | |
| ) | |
| return output_ids |