from typing import List, Optional, Tuple, Union from addict import Dict from dataclasses import dataclass import torch.nn.functional as F import numpy as np import pickle import torch import math import torch.nn as nn from torch.nn import CrossEntropyLoss from objectrelator.model.visual_prompt_module.context_cluster import region_pooling from transformers import AutoConfig, AutoModelForCausalLM from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM from objectrelator.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, \ DEFAULT_IM_END_TOKEN, SEG_TOKEN_INDEX, CLS_TOKEN_INDEX, REGION_TOKEN_INDEX, REFER_TOKEN_INDEX from detectron2.structures import Boxes, ImageList, Instances, BitMasks from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast from detectron2.modeling.postprocessing import sem_seg_postprocess from detectron2.utils.memory import retry_if_cuda_oom from ..mask_decoder.Mask2Former_Simplify.modeling.transformer_decoder.ObjectRelator_decoder import MultiScaleMaskedTransformerDecoder from ..mask_decoder.Mask2Former_Simplify.modeling.pixel_decoder.msdeformattn import MSDeformAttnPixelDecoder from ..multimodal_projector.builder import build_vision_projector from ..multimodal_encoder.swin_trans import build_swin_b, build_swin_l from ..datasets_mapper.coco_instance_mapper import COCOInstanceNewBaselineDatasetMapper from ..datasets_mapper.coco_panoptic_mapper import COCOPanopticNewBaselineDatasetMapper from ..datasets_mapper.coco_semantic_mapper import COCOSemanticNewBaselineDatasetMapper from objectrelator.model.mask_decoder.mask_criterion.pretrain_criterion import PSALM_criterion, hungarian_matcher_PSALM from transformers import PhiModel, PhiForCausalLM, PhiConfig import torch import torch.nn.functional as F def fuse_multicondition(fuse_method, weight, SEG_embedding, region_embedding_list): if SEG_embedding == None: return region_embedding_list new_region_embedding_list = [] for seg_embed, region_embedding in zip(SEG_embedding, region_embedding_list): if fuse_method == "add": if weight is not None: fused_region = weight * seg_embed + (1 - weight) * region_embedding else: fused_region = seg_embed + region_embedding elif fuse_method == "concat": fused_region = torch.cat([seg_embed.expand(region_embedding.shape[0], -1), region_embedding], dim=-1) elif fuse_method == "cross_attention_withoutpara": attention_scores = F.softmax(seg_embed @ region_embedding.transpose(-2, -1), dim=-1) fused_region = (attention_scores.transpose(-2, -1) * region_embedding) new_region_embedding_list.append(fused_region) return new_region_embedding_list def deep_copy_input(input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices, class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices,output_attentions): def copy_if_tensor(value): if value is None or isinstance(value, bool): return value # Directly return None or bool without modification elif isinstance(value, torch.Tensor): # Copy the tensor, and set requires_grad only if the original tensor requires it and is floating-point new_value = value.detach().clone() if value.requires_grad: new_value.requires_grad_(True) return new_value elif isinstance(value, list): # Deep copy each tensor in the list, if applicable return [copy_if_tensor(v) for v in value] else: return value # If it's not tensor or list, return it directly # Apply the helper function to each input and create the corresponding _exo version input_ids_exo = copy_if_tensor(input_ids) attention_mask_exo = copy_if_tensor(attention_mask) past_key_values_exo = copy_if_tensor(past_key_values) labels_exo = copy_if_tensor(labels) images_exo = copy_if_tensor(images) vp_images_exo = copy_if_tensor(vp_images) class_name_embedding_indices_exo = copy_if_tensor(class_name_embedding_indices) class_name_ids_exo = copy_if_tensor(class_name_ids) cls_indices_exo = copy_if_tensor(cls_indices) instances_exo = copy_if_tensor(instances) token_refer_id_exo = copy_if_tensor(token_refer_id) refer_embedding_indices_exo = copy_if_tensor(refer_embedding_indices) output_attentions_exo = copy_if_tensor(output_attentions) # Return all the copied values as a dictionary or as individual values return input_ids_exo, attention_mask_exo, past_key_values_exo, labels_exo, images_exo, vp_images_exo, class_name_embedding_indices_exo, class_name_ids_exo, cls_indices_exo, instances_exo, token_refer_id_exo, refer_embedding_indices_exo, output_attentions_exo def XObjAlign(embedding_list1, embedding_list2, sim_type="cos", L2_norm=False): # Ensure the lists have the same length assert len(embedding_list1) == len(embedding_list2), "Embedding lists must have the same length." for i in range(len(embedding_list1)): assert embedding_list1[i].shape == embedding_list2[i].shape # Initialize a list to store similarity or distance scores for each pair similarity_scores = [] for emb1, emb2 in zip(embedding_list1, embedding_list2): if L2_norm: emb1 = F.normalize(emb1, p=2, dim=-1) emb2 = F.normalize(emb2, p=2, dim=-1) if sim_type == "cos": sim = F.cosine_similarity(emb1, emb2, dim=-1) similarity_scores.append(sim.mean()) elif sim_type == "ecu": dist = torch.norm(emb1 - emb2, p=2, dim=-1) similarity_scores.append(dist.mean()) if sim_type == "cos": avg_similarity = torch.mean(torch.stack(similarity_scores)) loss = 1 - avg_similarity elif sim_type == "ecu": avg_distance = torch.mean(torch.stack(similarity_scores)) loss = avg_distance return loss class LlavaConfig(PhiConfig): model_type = "llava_phi" @dataclass class CausalOutputWithMask(CausalLMOutputWithPast): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None loss_mask: Optional[torch.FloatTensor] = None loss_dice: Optional[torch.FloatTensor] = None loss_SEG_class: Optional[torch.FloatTensor] = None loss_class_name_class: Optional[torch.FloatTensor] = None loss_region_class: Optional[torch.FloatTensor] = None loss_llm: Optional[torch.FloatTensor] = None @dataclass class CausalOutputWithMaskSSL(CausalLMOutputWithPast): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None loss_mask: Optional[torch.FloatTensor] = None loss_dice: Optional[torch.FloatTensor] = None loss_SEG_class: Optional[torch.FloatTensor] = None loss_class_name_class: Optional[torch.FloatTensor] = None loss_region_class: Optional[torch.FloatTensor] = None loss_llm: Optional[torch.FloatTensor] = None loss_region_emb_SSL: Optional[torch.FloatTensor] = None class PSALMModel(LlavaMetaModel, PhiModel): config_class = LlavaConfig def __init__(self, config: PhiConfig, mask_decoder_cfg=None): super(PSALMModel, self).__init__(config) self.cfg = mask_decoder_cfg self.projector_outdim = config.hidden_size if hasattr(config, "mm_vision_tower"): swin_type = getattr(config,'swin_type','base') if swin_type == 'base': self.vision_tower = build_swin_b(None) else: self.vision_tower = build_swin_l(None) self.mm_projector = build_vision_projector(config) self.vision_tower.image_processor = {} self.vision_tower.image_processor['panoptic'] = COCOPanopticNewBaselineDatasetMapper(self.cfg) self.vision_tower.image_processor['instance'] = COCOInstanceNewBaselineDatasetMapper(self.cfg) self.vision_tower.image_processor['semantic'] = COCOSemanticNewBaselineDatasetMapper(self.cfg) def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def initialize_vision_modules(self, model_args, fsdp=None): vision_tower = model_args.vision_tower if hasattr(model_args, 'vision_tower') else model_args.mm_vision_tower with_norm = model_args.with_norm with_layernorm = model_args.with_layernorm pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter if hasattr(model_args, 'pretrain_mm_mlp_adapter') else None projector_outdim = self.projector_outdim self.config.mm_vision_tower = vision_tower swin_type = getattr(model_args,'swin_type','base') self.config.swin_type = swin_type if swin_type == 'base': vision_tower = build_swin_b(vision_tower) else: print('current visual encoder is swin large') vision_tower = build_swin_l(vision_tower) self.config.mm_input_embeds = 1536 if fsdp is not None and len(fsdp) > 0: self.vision_tower = [vision_tower] else: self.vision_tower = vision_tower self.config.use_mm_proj = True vision_tower.hidden_size = 256 vision_tower.image_processor = {} vision_tower.image_processor['panoptic'] = COCOPanopticNewBaselineDatasetMapper(self.cfg) vision_tower.image_processor['instance'] = COCOInstanceNewBaselineDatasetMapper(self.cfg) vision_tower.image_processor['semantic'] = COCOSemanticNewBaselineDatasetMapper(self.cfg) # if model_args.seg_task == 'instance': # vision_tower.image_processor = COCOInstanceNewBaselineDatasetMapper(self.cfg) # else: # vision_tower.image_processor = COCOPanopticNewBaselineDatasetMapper(self.cfg) self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'conv') # print(f'current mm_project_type is {self.config.mm_projector_type}, the output dim is {projector_outdim}') self.config.mm_hidden_size = vision_tower.hidden_size self.config.with_norm = with_norm self.config.with_layernorm = with_layernorm self.config.projector_outdim = projector_outdim if not hasattr(self, "mm_projector"): self.mm_projector = build_vision_projector(self.config) else: print('exist mm_projector, skip init') if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} # import ipdb;ipdb.set_trace() self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'), strict=False) print('load mm_projector pth successfully') class ObjectRelator(PhiForCausalLM, LlavaMetaForCausalLM): print('hi, here, class ObjectRelator') config_class = LlavaConfig def __init__(self, config, mask_decoder_cfg=None, add_cross_attn=True, cross_attn_index=None): super(ObjectRelator, self).__init__(config) self.model = PSALMModel(config, mask_decoder_cfg) self.init_config = config self.mask_decoder_cfg = mask_decoder_cfg self.cross_attn_index = cross_attn_index self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing is_train_mask_decode = getattr(config, 'mask_decode_train', False) self.is_train_mask_decode = is_train_mask_decode self.refer_pooling = nn.AdaptiveAvgPool1d(output_size=1) self.class_name_pooling = nn.AdaptiveAvgPool1d(output_size=1) self.region_sampler = region_pooling(num_sample_point=256) self.region_projector = nn.Linear(config.hidden_size, mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM) if is_train_mask_decode: print('Mask Decoder has been trained, init directly') self.initial_mask_module() self.post_init() def initial_mask_module(self, pretrained_path=None, model_args=None): if not self.is_train_mask_decode: print('Initialize mask modules...') self.config.mask_decode_train = True self.seg_query = nn.Parameter( torch.zeros([self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, self.config.hidden_size])) self.num_queries = self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES self.num_classes = self.mask_decoder_cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES self.test_topk_per_image = self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES input_shape = self.output_shape() self.pixel_decoder = self.pixel_decoder_init(cfg=self.mask_decoder_cfg, input_shape=input_shape) self.predictor = self.predictor_init(cfg=self.mask_decoder_cfg) self.seg_query_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM) self.SEG_token_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM) self.class_name_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM) self.mask_decoder_training_init(self.mask_decoder_cfg) if pretrained_path is not None: def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} def change_w(weights, old_name, new_name): weights[new_name] = weights[old_name] weights.pop(old_name) if pretrained_path.endswith('.pkl'): with open(pretrained_path, 'rb') as f: ckpt = pickle.load(f) else: ckpt = torch.load(pretrained_path) pixel_decoder_weights = get_w(ckpt['model'],'sem_seg_head.pixel_decoder') predictor_weights = get_w(ckpt['model'],'sem_seg_head.predictor') pixel_decoder_weights = {k: torch.tensor(v) for k, v in pixel_decoder_weights.items()} predictor_weights = {k: torch.tensor(v) for k, v in predictor_weights.items()} #deal some diff keys change_w(pixel_decoder_weights,'adapter_1.weight','adapter_1.0.weight') change_w(pixel_decoder_weights,'adapter_1.norm.weight','adapter_1.1.weight') change_w(pixel_decoder_weights,'adapter_1.norm.bias','adapter_1.1.bias') change_w(pixel_decoder_weights,'layer_1.weight','layer_1.0.weight') change_w(pixel_decoder_weights,'layer_1.norm.weight','layer_1.1.weight') change_w(pixel_decoder_weights,'layer_1.norm.bias','layer_1.1.bias') if 'static_query.weight' in predictor_weights: change_w(predictor_weights,'static_query.weight','query_feat.weight') if predictor_weights['query_embed.weight'].shape[0] == 200: predictor_weights['query_embed.weight'] = predictor_weights['query_embed.weight'][:100,:] diff_pixel_msg = self.pixel_decoder.load_state_dict(pixel_decoder_weights,strict=False) diff_predictor_msg = self.predictor.load_state_dict(predictor_weights,strict=False) def get_vision_tower_feature(self, images): features = self.get_model().get_vision_tower()(images) features_dict = { 'res2': features[0], 'res3': features[1], 'res4': features[2], 'res5': features[3], } return features_dict def mask_decoder_training_init(self, cfg): # Loss parameters: deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT # loss weights class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT # boundary_weight = cfg.MODEL.MASK_FORMER.BOUNDARY_WEIGHT matcher = hungarian_matcher_PSALM( cost_class=class_weight, cost_mask=mask_weight, cost_dice=dice_weight, num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, ) weight_dict = {"loss_SEG_class": class_weight, "loss_class_name_class": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight, "loss_region_class": class_weight} self.weight_dict = weight_dict if deep_supervision: dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS aux_weight_dict = {} for i in range(dec_layers - 1): aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) losses = ["SEG_labels", "class_name_labels", "masks", "region_labels"] self.criterion = PSALM_criterion( matcher=matcher, losses=losses, num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO, importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO, device=self.device ) self.size_divisibility = 32 if cfg.MODEL.MASK_FORMER.SEG_TASK == 'semantic': self.semantic_on = True self.instance_on = False self.panoptic_on = False self.referring_on = False self.region_on = False elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'instance': self.semantic_on = False self.instance_on = True self.panoptic_on = False self.referring_on = False self.region_on = False elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'panoptic': self.semantic_on = True self.instance_on = True self.panoptic_on = True self.referring_on = False self.region_on = False elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'referring': self.semantic_on = False self.instance_on = False self.panoptic_on = False self.referring_on = True self.region_on = False elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'region': self.semantic_on = False self.instance_on = False self.panoptic_on = False self.referring_on = False self.region_on = True else: raise NotImplementedError self.sem_seg_postprocess_before_inference = self.instance_on or self.panoptic_on or self.referring_on or self.region_on def get_region_embedding(self, hidden_states, region_embedding_masks): region_embedding_list = [] for sample_hidden_satates, sample_region_embedding_masks in zip(hidden_states, region_embedding_masks): sample_region_embedding = sample_hidden_satates[sample_region_embedding_masks.bool()] region_embedding_list.append(sample_region_embedding) return region_embedding_list def SEG_instance_inference(self, SEG_cls, mask_pred): # mask_pred is already processed to have the same shape as original input image_size = mask_pred.shape[-2:] scores = F.sigmoid(SEG_cls) scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) mask_pred = mask_pred[topk_indices] result = Instances(image_size) result.pred_masks = (mask_pred > 0).float() result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / ( result.pred_masks.flatten(1).sum(1) + 1e-6) result.scores = scores_per_image * mask_scores_per_image return result def class_name_panoptic_inference(self, SEG_cls, class_name_cls, mask_pred): scores, labels = F.softmax(class_name_cls, dim=-1).max(-1) num_classes = class_name_cls.shape[-1] - 1 mask_pred = mask_pred.sigmoid() object_mask_threshold = 0.8 overlap_threshold = 0.8 keep = labels.ne(num_classes) & (scores > object_mask_threshold) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_mask_cls = class_name_cls[keep] cur_mask_cls = cur_mask_cls[:, :-1] cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) segments_info = [] current_segment_id = 0 if cur_masks.shape[0] == 0: return panoptic_seg, segments_info else: # take argmax cur_mask_ids = cur_prob_masks.argmax(0) stuff_memory_list = {} for k in range(cur_classes.shape[0]): pred_class = cur_classes[k].item() isthing = self.is_thing_list[pred_class] mask_area = (cur_mask_ids == k).sum().item() original_area = (cur_masks[k] >= 0.5).sum().item() mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: if mask_area / original_area < overlap_threshold: continue # merge stuff regions if not isthing: if int(pred_class) in stuff_memory_list.keys(): panoptic_seg[mask] = stuff_memory_list[int(pred_class)] continue else: stuff_memory_list[int(pred_class)] = current_segment_id + 1 current_segment_id += 1 panoptic_seg[mask] = current_segment_id segments_info.append( { "id": current_segment_id, "isthing": bool(isthing), "category_id": int(pred_class), } ) return panoptic_seg, segments_info def region_inference(self, region_cls, mask_pred): image_size = mask_pred.shape[-2:] scores = F.sigmoid(region_cls) result = Instances(image_size) result.pred_masks = (mask_pred > 0).float() result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / ( result.pred_masks.flatten(1).sum(1) + 1e-6) result.scores = (scores * mask_scores_per_image[None,...].repeat(scores.shape[0],1)).transpose(1,0) return result def class_name_semantic_inference(self, SEG_cls, class_name_cls, mask_pred): mask_cls = F.softmax(class_name_cls, dim=-1)[:, :-1] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg def class_name_instance_inference(self, SEG_cls, class_name_cls, mask_pred): image_size = mask_pred.shape[-2:] cls_scores = F.softmax(class_name_cls, dim=-1)[:, :-1] scores = cls_scores num_classes = scores.shape[-1] labels = torch.arange(num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) # scores_per_image, topk_indices = scores.flatten(0, 1).topk(5000, sorted=False) labels_per_image = labels[topk_indices] topk_indices = topk_indices // num_classes mask_pred = mask_pred[topk_indices] # if this is panoptic segmentation, we only keep the "thing" classes if self.panoptic_on: keep = torch.zeros_like(scores_per_image).bool() for i, lab in enumerate(labels_per_image): keep[i] = self.is_thing_list[lab] scores_per_image = scores_per_image[keep] labels_per_image = labels_per_image[keep] mask_pred = mask_pred[keep] result = Instances(image_size) # mask (before sigmoid) result.pred_masks = (mask_pred > 0).float() result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) # Uncomment the following to get boxes from masks (this is slow) # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() # calculate average mask prob mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / ( result.pred_masks.flatten(1).sum(1) + 1e-6) result.scores = scores_per_image * mask_scores_per_image result.pred_classes = labels_per_image return result def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) # [2,256,64,64] image_features = self.get_model().mm_projector(image_features[-1]) return image_features def predictor_init(self, cfg): in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM hidden_dim = cfg.MODEL.MASK_FORMER.HIDDEN_DIM num_queries = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES nheads = cfg.MODEL.MASK_FORMER.NHEADS dim_feedforward = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1 pre_norm = cfg.MODEL.MASK_FORMER.PRE_NORM mask_dim = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM enforce_input_project = False seg_norm = cfg.MODEL.MASK_FORMER.SEG_NORM seg_proj = cfg.MODEL.MASK_FORMER.SEG_PROJ seg_fuse_score = cfg.MODEL.MASK_FORMER.FUSE_SCORE seg_concat = False predictor = MultiScaleMaskedTransformerDecoder(in_channels, hidden_dim, num_queries, nheads, dim_feedforward, dec_layers, pre_norm, mask_dim, enforce_input_project, seg_norm, seg_concat, seg_proj, seg_fuse_score) return predictor def get_model(self): return self.model def output_shape(self): out_features = self.mask_decoder_cfg.MODEL.SWIN.OUT_FEATURES out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } num_features = [int(self.mask_decoder_cfg.MODEL.SWIN.EMBED_DIM * 2 ** i) for i in range(len(self.mask_decoder_cfg.MODEL.SWIN.DEPTHS))] out_feature_channels = { "res2": num_features[0], "res3": num_features[1], "res4": num_features[2], "res5": num_features[3], } backbone_feature_shape = dict() for name in out_features: backbone_feature_shape[name] = Dict( {'channel': out_feature_channels[name], 'stride': out_feature_strides[name]}) return backbone_feature_shape def get_encoder_image(self, images): encode_image_features = self.get_model().get_vision_tower()(images) return encode_image_features def pixel_decoder_init(self, cfg, input_shape): common_stride = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE transformer_dropout = cfg.MODEL.MASK_FORMER.DROPOUT transformer_nheads = cfg.MODEL.MASK_FORMER.NHEADS transformer_dim_feedforward = 1024 transformer_enc_layers = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS conv_dim = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM mask_dim = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM transformer_in_features = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES # ["res3", "res4", "res5"] pixel_decoder = MSDeformAttnPixelDecoder(input_shape, transformer_dropout, transformer_nheads, transformer_dim_feedforward, transformer_enc_layers, conv_dim, mask_dim, transformer_in_features, common_stride) return pixel_decoder def prepare_targets(self, targets, images): h_pad, w_pad = images.shape[-2:] new_targets = [] for targets_per_image in targets: # pad gt gt_masks = targets_per_image.gt_masks padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks new_targets.append( { "labels": targets_per_image.gt_classes, "masks": padded_masks, } ) return new_targets def get_special_token(self, SEG, EOS): self.SEG_id = SEG self.EOS_id = EOS def get_class_name_embedding(self, hidden_states, cls_token_indices): class_name_embedding_list = [] for current_hidden_state, current_token_indice in zip(hidden_states, cls_token_indices): class_id = torch.unique(current_token_indice) class_id = class_id[class_id != 0] current_class_name_embedding_list = [] for id in class_id: current_class_mask = (current_token_indice == id) current_class_state = current_hidden_state[current_class_mask] current_class_name_embedding_list.append(current_class_state) current_pool_class_name_embedding = [self.class_name_pooling(class_name.transpose(-2, -1)).transpose(-2, -1) for class_name in current_class_name_embedding_list] class_name_embedding_list.append(torch.cat(current_pool_class_name_embedding, dim=0)) return torch.stack(class_name_embedding_list, dim=0) def embed_class_ids(self, class_name_ids, cls_indices): if class_name_ids is None: return None num_class = cls_indices.unique_consecutive() num_class = num_class[num_class >= 0] class_name_ids = [class_name_ids[cls_indices == idx] for idx in num_class] embedded_class_name = [self.get_model().embed_tokens(id) for id in class_name_ids] return embedded_class_name def embed_refer_ids(self, refer_ids): if refer_ids is None: return None embedded_refer = self.get_model().embed_tokens(refer_ids) return embedded_refer def concat_image_seg_cls_embeds(self, input_id, img_feature, label, seg_query, seg_query_mask, class_embed, class_name_embedding_indices,region_embedding_mask=None, region_feature_list=None, refer_embedding_indices=None, refer_embedding=None): image_token_indices = torch.where(input_id == IMAGE_TOKEN_INDEX)[0] seg_query_indices = torch.where(input_id == SEG_TOKEN_INDEX)[0] cls_token_indices = torch.where(input_id == CLS_TOKEN_INDEX)[0] region_token_indices = torch.where(input_id == REGION_TOKEN_INDEX)[0] assert len(image_token_indices) == 1, 'not supporting multi image index' assert len(seg_query_indices) == 1, 'not supporting multi seg index' if class_name_embedding_indices is not None: assert len(cls_token_indices) == len(class_embed), 'the number of tokens and class_embed needs to be same' if region_feature_list is not None: assert len(region_feature_list) == len( region_token_indices), 'the munber of tokens and regions needs to be same' cur_new_input_embeds = [] cur_new_seg_query_mask = [] if label is not None: cur_new_label = [] assert label.shape == input_id.shape else: cur_new_label = None cur_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None cur_refer_embedding_indices = [] if refer_embedding_indices is not None else None if region_embedding_mask is not None: enable_region_mask = True cur_new_region_embedding_mask = [] else: enable_region_mask = False cur_new_region_embedding_mask = None chunks = [] current_chunk = [] for id in input_id: if id >= 0: current_chunk.append(id.item()) else: if current_chunk: chunks.append(torch.tensor(current_chunk, device=input_id.device)) current_chunk = [] chunks.append([id]) if current_chunk: chunks.append(torch.tensor(current_chunk, device=input_id.device)) cls_idx = 0 region_idx = 0 for chunk in chunks: chunk_len = len(chunk) if chunk_len == 1 and chunk[0] == IMAGE_TOKEN_INDEX: cur_new_input_embeds.append(img_feature) cur_new_seg_query_mask.append(torch.zeros(img_feature.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((img_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((img_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((img_feature.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(img_feature.shape[0])) elif chunk_len == 1 and chunk[0] == SEG_TOKEN_INDEX: cur_new_input_embeds.append(seg_query) cur_new_seg_query_mask.append(torch.ones(seg_query.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device, dtype=label.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device, dtype=label.dtype)) if label is not None: cur_new_label.append( torch.full((seg_query.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype)) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(seg_query.shape[0])) elif chunk_len == 1 and chunk[0] == CLS_TOKEN_INDEX: cls_embed = class_embed[cls_idx] if len(cls_embed.shape) == 1: cls_embed = cls_embed.unsqueeze(0) cls_idx += 1 cur_new_input_embeds.append(cls_embed) cur_new_seg_query_mask.append(torch.zeros(cls_embed.shape[0])) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(cls_embed.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((cls_embed.shape[0],), cls_idx, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((cls_embed.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((cls_embed.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) elif chunk_len == 1 and chunk[0] == REGION_TOKEN_INDEX: region_feature = region_feature_list[region_idx] region_idx += 1 cur_new_input_embeds.append(region_feature) cur_new_seg_query_mask.append(torch.zeros(region_feature.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((region_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((region_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((region_feature.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) if enable_region_mask: cur_new_region_embedding_mask.append(torch.ones(region_feature.shape[0])) elif chunk_len == 1 and chunk[0] == REFER_TOKEN_INDEX: refer_embed = refer_embedding if len(refer_embed.shape) == 1: refer_embed = refer_embed.unsqueeze(0) cur_new_input_embeds.append(refer_embed) cur_new_seg_query_mask.append(torch.zeros(refer_embed.shape[0])) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(refer_embed.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((refer_embed.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((refer_embed.shape[0],), 1, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((refer_embed.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) else: cur_new_input_embeds.append(self.get_model().embed_tokens(input_id[:chunk_len])) cur_new_seg_query_mask.append(seg_query_mask[:chunk_len]) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append(class_name_embedding_indices[:chunk_len]) if refer_embedding_indices is not None: cur_refer_embedding_indices.append(refer_embedding_indices[:chunk_len]) if label is not None: cur_new_label.append(label[:chunk_len]) if enable_region_mask: cur_new_region_embedding_mask.append(region_embedding_mask[:chunk_len]) input_id = input_id[chunk_len:] seg_query_mask = seg_query_mask[chunk_len:] if class_name_embedding_indices is not None: class_name_embedding_indices = class_name_embedding_indices[chunk_len:] if refer_embedding_indices is not None: refer_embedding_indices = refer_embedding_indices[chunk_len:] if label is not None: label = label[chunk_len:] if enable_region_mask: region_embedding_mask = region_embedding_mask[chunk_len:] cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) if label is not None: cur_new_label = [x.to(device=self.device) for x in cur_new_label] cur_new_label = torch.cat(cur_new_label, dim=0) cur_new_seg_query_mask = [x.to(device=self.device) for x in cur_new_seg_query_mask] cur_new_seg_query_mask = torch.cat(cur_new_seg_query_mask, dim=0) if class_name_embedding_indices is not None: cur_class_name_embedding_indices = [x.to(device=self.device) for x in cur_class_name_embedding_indices] cur_class_name_embedding_indices = torch.cat(cur_class_name_embedding_indices, dim=0) if refer_embedding_indices is not None: cur_refer_embedding_indices = [x.to(device=self.device) for x in cur_refer_embedding_indices] cur_refer_embedding_indices = torch.cat(cur_refer_embedding_indices, dim=0) if enable_region_mask: cur_new_region_embedding_mask = [x.to(device=self.device) for x in cur_new_region_embedding_mask] cur_new_region_embedding_mask = torch.cat(cur_new_region_embedding_mask, dim=0) return cur_new_input_embeds, cur_new_label, cur_new_seg_query_mask, cur_class_name_embedding_indices, cur_new_region_embedding_mask, cur_refer_embedding_indices def concat_image_seg_cls_embeds_SSL(self, input_id, img_feature, label, seg_query, seg_query_mask, class_embed, class_name_embedding_indices,region_embedding_mask=None, region_embedding_mask_exo=None, region_feature_list=None, region_feature_list_exo=None, refer_embedding_indices=None, refer_embedding=None): image_token_indices = torch.where(input_id == IMAGE_TOKEN_INDEX)[0] seg_query_indices = torch.where(input_id == SEG_TOKEN_INDEX)[0] cls_token_indices = torch.where(input_id == CLS_TOKEN_INDEX)[0] region_token_indices = torch.where(input_id == REGION_TOKEN_INDEX)[0] assert len(image_token_indices) == 1, 'not supporting multi image index' assert len(seg_query_indices) == 1, 'not supporting multi seg index' if class_name_embedding_indices is not None: assert len(cls_token_indices) == len(class_embed), 'the number of tokens and class_embed needs to be same' if region_feature_list is not None: assert len(region_feature_list) == len( region_token_indices), 'the munber of tokens and regions needs to be same' cur_new_input_embeds = [] cur_new_seg_query_mask = [] if label is not None: cur_new_label = [] assert label.shape == input_id.shape else: cur_new_label = None cur_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None cur_refer_embedding_indices = [] if refer_embedding_indices is not None else None if region_embedding_mask is not None: enable_region_mask = True cur_new_region_embedding_mask = [] cur_new_region_embedding_mask_exo = [] else: enable_region_mask = False cur_new_region_embedding_mask = None cur_new_region_embedding_mask_exo = None chunks = [] current_chunk = [] for id in input_id: if id >= 0: current_chunk.append(id.item()) else: if current_chunk: chunks.append(torch.tensor(current_chunk, device=input_id.device)) current_chunk = [] chunks.append([id]) if current_chunk: chunks.append(torch.tensor(current_chunk, device=input_id.device)) cls_idx = 0 region_idx = 0 for chunk in chunks: chunk_len = len(chunk) if chunk_len == 1 and chunk[0] == IMAGE_TOKEN_INDEX: cur_new_input_embeds.append(img_feature) cur_new_seg_query_mask.append(torch.zeros(img_feature.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((img_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((img_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((img_feature.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(img_feature.shape[0])) cur_new_region_embedding_mask_exo.append(torch.zeros(img_feature.shape[0])) elif chunk_len == 1 and chunk[0] == SEG_TOKEN_INDEX: cur_new_input_embeds.append(seg_query) cur_new_seg_query_mask.append(torch.ones(seg_query.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device, dtype=label.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device, dtype=label.dtype)) if label is not None: cur_new_label.append( torch.full((seg_query.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype)) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(seg_query.shape[0])) cur_new_region_embedding_mask_exo.append(torch.zeros(seg_query.shape[0])) elif chunk_len == 1 and chunk[0] == CLS_TOKEN_INDEX: cls_embed = class_embed[cls_idx] if len(cls_embed.shape) == 1: cls_embed = cls_embed.unsqueeze(0) cls_idx += 1 cur_new_input_embeds.append(cls_embed) cur_new_seg_query_mask.append(torch.zeros(cls_embed.shape[0])) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(cls_embed.shape[0])) cur_new_region_embedding_mask_exo.append(torch.zeros(cls_embed.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((cls_embed.shape[0],), cls_idx, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((cls_embed.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((cls_embed.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) elif chunk_len == 1 and chunk[0] == REGION_TOKEN_INDEX: region_feature = region_feature_list[region_idx] region_feature_exo = region_feature_list_exo[region_idx] region_idx += 1 cur_new_input_embeds.append(region_feature) cur_new_seg_query_mask.append(torch.zeros(region_feature.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((region_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((region_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((region_feature.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) if enable_region_mask: cur_new_region_embedding_mask.append(torch.ones(region_feature.shape[0])) #cur_new_region_embedding_mask_exo.append(torch.ones(region_feature.shape[0])) cur_new_region_embedding_mask_exo.append(torch.ones(region_feature_exo.shape[0])) elif chunk_len == 1 and chunk[0] == REFER_TOKEN_INDEX: refer_embed = refer_embedding if len(refer_embed.shape) == 1: refer_embed = refer_embed.unsqueeze(0) cur_new_input_embeds.append(refer_embed) cur_new_seg_query_mask.append(torch.zeros(refer_embed.shape[0])) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(refer_embed.shape[0])) cur_new_region_embedding_mask_exo.append(torch.zeros(refer_embed.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((refer_embed.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((refer_embed.shape[0],), 1, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((refer_embed.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) else: cur_new_input_embeds.append(self.get_model().embed_tokens(input_id[:chunk_len])) cur_new_seg_query_mask.append(seg_query_mask[:chunk_len]) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append(class_name_embedding_indices[:chunk_len]) if refer_embedding_indices is not None: cur_refer_embedding_indices.append(refer_embedding_indices[:chunk_len]) if label is not None: cur_new_label.append(label[:chunk_len]) if enable_region_mask: cur_new_region_embedding_mask.append(region_embedding_mask[:chunk_len]) cur_new_region_embedding_mask_exo.append(region_embedding_mask_exo[:chunk_len]) input_id = input_id[chunk_len:] seg_query_mask = seg_query_mask[chunk_len:] if class_name_embedding_indices is not None: class_name_embedding_indices = class_name_embedding_indices[chunk_len:] if refer_embedding_indices is not None: refer_embedding_indices = refer_embedding_indices[chunk_len:] if label is not None: label = label[chunk_len:] if enable_region_mask: region_embedding_mask = region_embedding_mask[chunk_len:] region_embedding_mask_exo = region_embedding_mask_exo[chunk_len:] cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) if label is not None: cur_new_label = [x.to(device=self.device) for x in cur_new_label] cur_new_label = torch.cat(cur_new_label, dim=0) cur_new_seg_query_mask = [x.to(device=self.device) for x in cur_new_seg_query_mask] cur_new_seg_query_mask = torch.cat(cur_new_seg_query_mask, dim=0) if class_name_embedding_indices is not None: cur_class_name_embedding_indices = [x.to(device=self.device) for x in cur_class_name_embedding_indices] cur_class_name_embedding_indices = torch.cat(cur_class_name_embedding_indices, dim=0) if refer_embedding_indices is not None: cur_refer_embedding_indices = [x.to(device=self.device) for x in cur_refer_embedding_indices] cur_refer_embedding_indices = torch.cat(cur_refer_embedding_indices, dim=0) if enable_region_mask: cur_new_region_embedding_mask = [x.to(device=self.device) for x in cur_new_region_embedding_mask] cur_new_region_embedding_mask = torch.cat(cur_new_region_embedding_mask, dim=0) cur_new_region_embedding_mask_exo = [x.to(device=self.device) for x in cur_new_region_embedding_mask_exo] cur_new_region_embedding_mask_exo = torch.cat(cur_new_region_embedding_mask_exo, dim=0) return cur_new_input_embeds, cur_new_label, cur_new_seg_query_mask, cur_class_name_embedding_indices, cur_new_region_embedding_mask, cur_new_region_embedding_mask_exo, cur_refer_embedding_indices def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images, vp_images=None, class_name_embedding_indices=None, class_name_ids=None, cls_indices=None, instances=None, token_refer_id=None, refer_embedding_indices=None ): vision_tower = self.get_vision_tower() seg_query_mask = torch.zeros_like(input_ids) if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[ 1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels, seg_query_mask if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) expanded_seg_query = self.seg_query.unsqueeze(0).expand(input_ids.shape[0], -1, -1) if (input_ids == REGION_TOKEN_INDEX).sum() != 0 and instances is not None: region_masks_list = [instance.vp_region_masks.tensor for instance in instances] vp_image_features = self.encode_images(vp_images) # [region_features_per_batch: [num_region, 1, dims]], len(region_features) = batch_size # print("region_mask_list:", region_masks_list) # debug # print("type of region_mask_list:", type(region_masks_list)) # print("len of region_mask_list:", len(region_masks_list)) # print("region_masks_list[0] shape:", region_masks_list[0].shape) region_features = self.region_sampler(vp_image_features, region_masks_list, original_dtype=vp_image_features.dtype, return_dtype=vp_image_features.dtype) # print("type of region_features:", len(region_features)) # print("shape of region_features[0]:", region_features[0].shape) region_embedding_masks = torch.zeros_like(input_ids) else: region_features = None region_embedding_masks = None new_input_embeds = [] new_labels = [] if labels is not None else None new_seg_query_masks = [] new_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None new_refer_embedding_indices = [] if refer_embedding_indices is not None else None new_region_embedding_masks = [] if region_features is not None else None for batch_idx, cur_input_ids in enumerate(input_ids): cur_seg_query_mask = seg_query_mask[batch_idx] cur_seg_query = expanded_seg_query[batch_idx] cur_image_feature = image_features[batch_idx] cur_class_name_embedding_indices = class_name_embedding_indices[ batch_idx] if class_name_embedding_indices is not None else None cur_refer_embedding_indices = refer_embedding_indices[ batch_idx] if refer_embedding_indices is not None else None cur_region_feature_list = region_features[batch_idx] if region_features is not None else None cur_region_embedding_mask = region_embedding_masks[batch_idx] if region_features is not None else None if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) # ensure gradients back propagation, not changing cur_input_embeds cur_input_embeds = cur_input_embeds + ( 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) new_seg_query_masks.append(cur_seg_query_mask) # cur_image_idx += 1 continue if labels is not None: cur_label = labels[batch_idx] else: cur_label = None if class_name_ids is not None: cur_class_name_ids = class_name_ids[batch_idx] cur_cls_indices = cls_indices[batch_idx] else: cur_class_name_ids = None cur_cls_indices = None if token_refer_id is not None: cur_token_refer_id = token_refer_id[batch_idx] else: cur_token_refer_id = None cur_class_name_embedding = self.embed_class_ids(cur_class_name_ids, cur_cls_indices) cur_refer_embedding = self.embed_refer_ids(cur_token_refer_id) cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds( input_id=cur_input_ids, img_feature=cur_image_feature, label=cur_label, seg_query=cur_seg_query, seg_query_mask=cur_seg_query_mask, class_embed=cur_class_name_embedding, class_name_embedding_indices=cur_class_name_embedding_indices, region_embedding_mask=cur_region_embedding_mask, region_feature_list=cur_region_feature_list, refer_embedding_indices=cur_refer_embedding_indices, refer_embedding=cur_refer_embedding ) assert cur_input_embeds.shape[0] == cur_seg_query_mask.shape[0] new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(cur_label) new_seg_query_masks.append(cur_seg_query_mask) if class_name_embedding_indices is not None: new_class_name_embedding_indices.append(cur_class_name_embedding_indices) if refer_embedding_indices is not None: new_refer_embedding_indices.append(cur_refer_embedding_indices) if new_region_embedding_masks is not None: new_region_embedding_masks.append(cur_region_embedding_mask) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) new_seg_query_masks_align = [] for new_seg_query_mask in new_seg_query_masks: new_seg_query_mask = torch.cat( (new_seg_query_mask, torch.zeros((max_len - new_seg_query_mask.shape[0]), dtype=new_seg_query_mask.dtype, device=new_seg_query_mask.device)), dim=0) new_seg_query_masks_align.append(new_seg_query_mask) new_seg_query_masks = torch.stack(new_seg_query_masks_align, dim=0) new_class_name_embedding_indices_align = [] if class_name_embedding_indices is not None: for new_class_name_embedding_indice in new_class_name_embedding_indices: new_class_name_embedding_indice = torch.cat( (new_class_name_embedding_indice, torch.zeros((max_len - new_class_name_embedding_indice.shape[0]), dtype=new_class_name_embedding_indice.dtype, device=new_class_name_embedding_indice.device)), dim=0) new_class_name_embedding_indices_align.append(new_class_name_embedding_indice) new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices_align, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices_align = [] for new_refer_embedding_indice in new_refer_embedding_indices: new_refer_embedding_indice = torch.cat( (new_refer_embedding_indice, torch.zeros((max_len - new_refer_embedding_indice.shape[0]), dtype=new_refer_embedding_indice.dtype, device=new_refer_embedding_indice.device)), dim=0) new_refer_embedding_indices_align.append(new_refer_embedding_indice) new_refer_embedding_indices = torch.stack(new_refer_embedding_indices_align, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks_align = [] for new_region_embedding_mask in new_region_embedding_masks: new_region_embedding_mask = torch.cat( (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]), dtype=new_region_embedding_mask.dtype, device=new_region_embedding_mask.device)), dim=0) new_region_embedding_masks_align.append(new_region_embedding_mask) new_region_embedding_masks = torch.stack(new_region_embedding_masks_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat( (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) new_seg_query_masks = torch.stack(new_seg_query_masks, dim=0) if class_name_embedding_indices is not None: new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices = torch.stack(new_refer_embedding_indices, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks = torch.stack(new_region_embedding_masks, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full( (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels, new_seg_query_masks, new_class_name_embedding_indices, new_region_embedding_masks, new_refer_embedding_indices def prepare_inputs_labels_for_multimodal_EXO( self, input_ids, attention_mask, past_key_values, labels, images, vp_images=None, class_name_embedding_indices=None, class_name_ids=None, cls_indices=None, instances=None, token_refer_id=None, refer_embedding_indices=None ): vision_tower = self.get_vision_tower() seg_query_mask = torch.zeros_like(input_ids) if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[ 1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels, seg_query_mask if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) expanded_seg_query = self.seg_query.unsqueeze(0).expand(input_ids.shape[0], -1, -1) if (input_ids == REGION_TOKEN_INDEX).sum() != 0 and instances is not None: region_masks_list = [instance.gt_masks.detach().clone() for instance in instances] #image_features_2 = image_features.detach().clone().requires_grad_(True) region_features = self.region_sampler(image_features, region_masks_list, original_dtype=image_features.dtype, return_dtype=image_features.dtype) region_embedding_masks = torch.zeros_like(input_ids) else: region_features = None region_embedding_masks = None new_input_embeds = [] new_labels = [] if labels is not None else None new_seg_query_masks = [] new_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None new_refer_embedding_indices = [] if refer_embedding_indices is not None else None new_region_embedding_masks = [] if region_features is not None else None for batch_idx, cur_input_ids in enumerate(input_ids): cur_seg_query_mask = seg_query_mask[batch_idx] cur_seg_query = expanded_seg_query[batch_idx] cur_image_feature = image_features[batch_idx] cur_class_name_embedding_indices = class_name_embedding_indices[ batch_idx] if class_name_embedding_indices is not None else None cur_refer_embedding_indices = refer_embedding_indices[ batch_idx] if refer_embedding_indices is not None else None cur_region_feature_list = region_features[batch_idx] if region_features is not None else None cur_region_embedding_mask = region_embedding_masks[batch_idx] if region_features is not None else None if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) # ensure gradients back propagation, not changing cur_input_embeds cur_input_embeds = cur_input_embeds + ( 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) new_seg_query_masks.append(cur_seg_query_mask) # cur_image_idx += 1 continue if labels is not None: cur_label = labels[batch_idx] else: cur_label = None if class_name_ids is not None: cur_class_name_ids = class_name_ids[batch_idx] cur_cls_indices = cls_indices[batch_idx] else: cur_class_name_ids = None cur_cls_indices = None if token_refer_id is not None: cur_token_refer_id = token_refer_id[batch_idx] else: cur_token_refer_id = None cur_class_name_embedding = self.embed_class_ids(cur_class_name_ids, cur_cls_indices) cur_refer_embedding = self.embed_refer_ids(cur_token_refer_id) cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds( input_id=cur_input_ids, img_feature=cur_image_feature, label=cur_label, seg_query=cur_seg_query, seg_query_mask=cur_seg_query_mask, class_embed=cur_class_name_embedding, class_name_embedding_indices=cur_class_name_embedding_indices, region_embedding_mask=cur_region_embedding_mask, region_feature_list=cur_region_feature_list, refer_embedding_indices=cur_refer_embedding_indices, refer_embedding=cur_refer_embedding ) assert cur_input_embeds.shape[0] == cur_seg_query_mask.shape[0] new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(cur_label) new_seg_query_masks.append(cur_seg_query_mask) if class_name_embedding_indices is not None: new_class_name_embedding_indices.append(cur_class_name_embedding_indices) if refer_embedding_indices is not None: new_refer_embedding_indices.append(cur_refer_embedding_indices) if new_region_embedding_masks is not None: new_region_embedding_masks.append(cur_region_embedding_mask) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) new_seg_query_masks_align = [] for new_seg_query_mask in new_seg_query_masks: new_seg_query_mask = torch.cat( (new_seg_query_mask, torch.zeros((max_len - new_seg_query_mask.shape[0]), dtype=new_seg_query_mask.dtype, device=new_seg_query_mask.device)), dim=0) new_seg_query_masks_align.append(new_seg_query_mask) new_seg_query_masks = torch.stack(new_seg_query_masks_align, dim=0) new_class_name_embedding_indices_align = [] if class_name_embedding_indices is not None: for new_class_name_embedding_indice in new_class_name_embedding_indices: new_class_name_embedding_indice = torch.cat( (new_class_name_embedding_indice, torch.zeros((max_len - new_class_name_embedding_indice.shape[0]), dtype=new_class_name_embedding_indice.dtype, device=new_class_name_embedding_indice.device)), dim=0) new_class_name_embedding_indices_align.append(new_class_name_embedding_indice) new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices_align, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices_align = [] for new_refer_embedding_indice in new_refer_embedding_indices: new_refer_embedding_indice = torch.cat( (new_refer_embedding_indice, torch.zeros((max_len - new_refer_embedding_indice.shape[0]), dtype=new_refer_embedding_indice.dtype, device=new_refer_embedding_indice.device)), dim=0) new_refer_embedding_indices_align.append(new_refer_embedding_indice) new_refer_embedding_indices = torch.stack(new_refer_embedding_indices_align, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks_align = [] for new_region_embedding_mask in new_region_embedding_masks: new_region_embedding_mask = torch.cat( (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]), dtype=new_region_embedding_mask.dtype, device=new_region_embedding_mask.device)), dim=0) new_region_embedding_masks_align.append(new_region_embedding_mask) new_region_embedding_masks = torch.stack(new_region_embedding_masks_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat( (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) new_seg_query_masks = torch.stack(new_seg_query_masks, dim=0) if class_name_embedding_indices is not None: new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices = torch.stack(new_refer_embedding_indices, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks = torch.stack(new_region_embedding_masks, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full( (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels, new_seg_query_masks, new_class_name_embedding_indices, new_region_embedding_masks, new_refer_embedding_indices def prepare_inputs_labels_for_multimodal_SSL( self, input_ids, attention_mask, past_key_values, labels, images, vp_images=None, class_name_embedding_indices=None, class_name_ids=None, cls_indices=None, instances=None, token_refer_id=None, refer_embedding_indices=None ): vision_tower = self.get_vision_tower() seg_query_mask = torch.zeros_like(input_ids) if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[ 1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels, seg_query_mask if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) expanded_seg_query = self.seg_query.unsqueeze(0).expand(input_ids.shape[0], -1, -1) if (input_ids == REGION_TOKEN_INDEX).sum() != 0 and instances is not None: region_masks_list = [instance.vp_region_masks.tensor for instance in instances] vp_image_features = self.encode_images(vp_images) # [region_features_per_batch: [num_region, 1, dims]], len(region_features) = batch_size region_features = self.region_sampler(vp_image_features, region_masks_list, original_dtype=vp_image_features.dtype, return_dtype=vp_image_features.dtype) region_embedding_masks = torch.zeros_like(input_ids) region_masks_list_exo = [instance.gt_masks for instance in instances] image_features_exo = image_features.detach().clone().requires_grad_(True) region_features_exo = self.region_sampler(image_features_exo, region_masks_list_exo, original_dtype=image_features.dtype, return_dtype=image_features.dtype) region_embedding_masks_exo = torch.zeros_like(input_ids) else: region_features = None region_embedding_masks = None region_features_exo = None region_embedding_masks_exo = None new_input_embeds = [] new_input_embeds_exo = [] new_labels = [] if labels is not None else None new_seg_query_masks = [] new_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None new_refer_embedding_indices = [] if refer_embedding_indices is not None else None new_region_embedding_masks = [] if region_features is not None else None # for exo new_region_embedding_masks_exo = [] if region_features_exo is not None else None for batch_idx, cur_input_ids in enumerate(input_ids): cur_seg_query_mask = seg_query_mask[batch_idx] cur_seg_query = expanded_seg_query[batch_idx] cur_image_feature = image_features[batch_idx] cur_class_name_embedding_indices = class_name_embedding_indices[ batch_idx] if class_name_embedding_indices is not None else None cur_refer_embedding_indices = refer_embedding_indices[ batch_idx] if refer_embedding_indices is not None else None cur_region_feature_list = region_features[batch_idx] if region_features is not None else None cur_region_embedding_mask = region_embedding_masks[batch_idx] if region_features is not None else None # for exo cur_region_feature_list_exo = region_features_exo[batch_idx] if region_features_exo is not None else None cur_region_embedding_mask_exo = region_embedding_masks_exo[batch_idx] if region_features_exo is not None else None if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) # ensure gradients back propagation, not changing cur_input_embeds cur_input_embeds = cur_input_embeds + ( 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) new_input_embeds_exo.append(cur_input_embeds.detach().clone()) if labels is not None: new_labels.append(labels[batch_idx]) new_seg_query_masks.append(cur_seg_query_mask) # cur_image_idx += 1 continue if labels is not None: cur_label = labels[batch_idx] else: cur_label = None if class_name_ids is not None: cur_class_name_ids = class_name_ids[batch_idx] cur_cls_indices = cls_indices[batch_idx] else: cur_class_name_ids = None cur_cls_indices = None if token_refer_id is not None: cur_token_refer_id = token_refer_id[batch_idx] else: cur_token_refer_id = None cur_class_name_embedding = self.embed_class_ids(cur_class_name_ids, cur_cls_indices) cur_refer_embedding = self.embed_refer_ids(cur_token_refer_id) # savae initial inputs for the exo input Init_cur_input_ids = cur_input_ids.clone() Init_cur_image_feature = cur_image_feature.clone() Init_cur_label = cur_label.clone() Init_cur_seg_query = cur_seg_query.clone() Init_cur_seg_query_mask = cur_seg_query_mask.clone() if cur_class_name_embedding is not None: Init_cur_class_name_embedding = cur_class_name_embedding.clone() #None else: Init_cur_class_name_embedding = cur_class_name_embedding if cur_class_name_embedding_indices is not None: Init_cur_class_name_embedding_indices = cur_class_name_embedding_indices.clone() #None else: Init_cur_class_name_embedding_indices = cur_class_name_embedding_indices if cur_refer_embedding_indices is not None: Init_cur_refer_embedding_indices = cur_refer_embedding_indices.clone() else: Init_cur_refer_embedding_indices = cur_refer_embedding_indices if cur_refer_embedding is not None: Init_cur_refer_embedding = cur_refer_embedding.clone() else: Init_cur_refer_embedding = cur_refer_embedding cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds( input_id=cur_input_ids, img_feature=cur_image_feature, label=cur_label, seg_query=cur_seg_query, seg_query_mask=cur_seg_query_mask, class_embed=cur_class_name_embedding, class_name_embedding_indices=cur_class_name_embedding_indices, region_embedding_mask=cur_region_embedding_mask, region_feature_list=cur_region_feature_list, refer_embedding_indices=cur_refer_embedding_indices, refer_embedding=cur_refer_embedding ) cur_input_embeds_exo, cur_label_exo, cur_seg_query_mask_exo, cur_class_name_embedding_indices_exo, cur_region_embedding_mask_exo, cur_refer_embedding_indices_exo = self.concat_image_seg_cls_embeds( input_id=Init_cur_input_ids, img_feature=Init_cur_image_feature, label=Init_cur_label, seg_query=Init_cur_seg_query, seg_query_mask=Init_cur_seg_query_mask, class_embed=Init_cur_class_name_embedding, class_name_embedding_indices=Init_cur_class_name_embedding_indices, region_embedding_mask=cur_region_embedding_mask_exo, region_feature_list=cur_region_feature_list_exo, refer_embedding_indices=Init_cur_refer_embedding_indices, refer_embedding=Init_cur_refer_embedding ) assert cur_input_embeds.shape[0] == cur_seg_query_mask.shape[0] # for exo new_input_embeds.append(cur_input_embeds) new_input_embeds_exo.append(cur_input_embeds_exo) if labels is not None: new_labels.append(cur_label) new_seg_query_masks.append(cur_seg_query_mask) if class_name_embedding_indices is not None: new_class_name_embedding_indices.append(cur_class_name_embedding_indices) if refer_embedding_indices is not None: new_refer_embedding_indices.append(cur_refer_embedding_indices) if new_region_embedding_masks is not None: new_region_embedding_masks.append(cur_region_embedding_mask) new_region_embedding_masks_exo.append(cur_region_embedding_mask_exo) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) #for exo new_input_embeds_align_exo = [] for cur_new_embed in new_input_embeds_exo: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align_exo.append(cur_new_embed) new_input_embeds_exo = torch.stack(new_input_embeds_align_exo, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) new_seg_query_masks_align = [] for new_seg_query_mask in new_seg_query_masks: new_seg_query_mask = torch.cat( (new_seg_query_mask, torch.zeros((max_len - new_seg_query_mask.shape[0]), dtype=new_seg_query_mask.dtype, device=new_seg_query_mask.device)), dim=0) new_seg_query_masks_align.append(new_seg_query_mask) new_seg_query_masks = torch.stack(new_seg_query_masks_align, dim=0) new_class_name_embedding_indices_align = [] if class_name_embedding_indices is not None: for new_class_name_embedding_indice in new_class_name_embedding_indices: new_class_name_embedding_indice = torch.cat( (new_class_name_embedding_indice, torch.zeros((max_len - new_class_name_embedding_indice.shape[0]), dtype=new_class_name_embedding_indice.dtype, device=new_class_name_embedding_indice.device)), dim=0) new_class_name_embedding_indices_align.append(new_class_name_embedding_indice) new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices_align, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices_align = [] for new_refer_embedding_indice in new_refer_embedding_indices: new_refer_embedding_indice = torch.cat( (new_refer_embedding_indice, torch.zeros((max_len - new_refer_embedding_indice.shape[0]), dtype=new_refer_embedding_indice.dtype, device=new_refer_embedding_indice.device)), dim=0) new_refer_embedding_indices_align.append(new_refer_embedding_indice) new_refer_embedding_indices = torch.stack(new_refer_embedding_indices_align, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks_align = [] for new_region_embedding_mask in new_region_embedding_masks: new_region_embedding_mask = torch.cat( (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]), dtype=new_region_embedding_mask.dtype, device=new_region_embedding_mask.device)), dim=0) new_region_embedding_masks_align.append(new_region_embedding_mask) new_region_embedding_masks = torch.stack(new_region_embedding_masks_align, dim=0) # for exo new_region_embedding_masks_align_exo = [] for new_region_embedding_mask in new_region_embedding_masks_exo: new_region_embedding_mask = torch.cat( (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]), dtype=new_region_embedding_mask.dtype, device=new_region_embedding_mask.device)), dim=0) new_region_embedding_masks_align_exo.append(new_region_embedding_mask) new_region_embedding_masks_exo = torch.stack(new_region_embedding_masks_align_exo, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat( (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) new_input_embeds_exo = torch.stack(new_input_embeds_exo, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) new_seg_query_masks = torch.stack(new_seg_query_masks, dim=0) if class_name_embedding_indices is not None: new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices = torch.stack(new_refer_embedding_indices, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks = torch.stack(new_region_embedding_masks, dim=0) new_region_embedding_masks_exo = torch.stack(new_region_embedding_masks_exo, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full( (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_input_embeds_exo, new_labels, new_seg_query_masks, new_class_name_embedding_indices, new_region_embedding_masks, new_region_embedding_masks_exo, new_refer_embedding_indices def get_SEG_embedding(self,hidden_states, refer_embedding_indices): refer_embedding_list = [] for current_hidden_state, current_token_indice in zip(hidden_states, refer_embedding_indices): current_refer_state = current_hidden_state[current_token_indice.bool()] current_pool_refer_state = self.refer_pooling(current_refer_state.transpose(-2, -1)).transpose(-2, -1) refer_embedding_list.append(current_pool_refer_state) return torch.stack(refer_embedding_list, dim=0) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, vp_images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, seg_info=None, class_name_ids=None, class_name_embedding_indices=None, cls_indices=None, random_idx=None, token_refer_id=None, refer_embedding_indices=None, dataset_type=None, ) -> Union[Tuple, CausalLMOutputWithPast]: if dataset_type is not None: assert all(item == dataset_type[0] for item in dataset_type), f'this batch contain different dataset_type: {dataset_type}' batch_dataset_type = dataset_type[0] else: batch_dataset_type = [] output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids == SEG_TOKEN_INDEX).sum() != 0: if (input_ids == REGION_TOKEN_INDEX).sum() != 0: instances = [i['instances'] for i in seg_info] else: instances = None input_ids_exo, attention_mask_exo, past_key_values_exo, labels_exo, images_exo, vp_images_exo, class_name_embedding_indices_exo, class_name_ids_exo, cls_indices_exo, instances_exo, token_refer_id_exo, refer_embedding_indices_exo, output_attentions_exo = deep_copy_input( input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices,class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices, output_attentions ) # ego prepare data input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices,class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices) input_ids_exo, attention_mask_exo, past_key_values_exo, inputs_embeds_exo, labels_exo, seg_query_mask_exo, class_name_embedding_indices_exo, region_embedding_masks_exo, refer_embedding_indices_exo = self.prepare_inputs_labels_for_multimodal_EXO(input_ids_exo, attention_mask_exo, past_key_values_exo, labels_exo, images_exo, vp_images_exo, class_name_embedding_indices_exo, class_name_ids_exo, cls_indices_exo, instances_exo, token_refer_id_exo, refer_embedding_indices_exo) else: seg_query_mask = None class_name_embedding_indices = None region_embedding_masks = None SEG_token_indices = None input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.mm_conv_prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) # for exo outputs_exo = self.model( input_ids=input_ids_exo, attention_mask=attention_mask_exo, past_key_values=past_key_values_exo, inputs_embeds=inputs_embeds_exo, use_cache=use_cache, output_attentions=output_attentions_exo, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) hidden_states_exo = outputs_exo.last_hidden_state if class_name_embedding_indices is not None: class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices) class_name_embedding = self.class_name_projector(class_name_embedding) else: class_name_embedding = None if class_name_embedding is not None: class_name_embedding = torch.gather(class_name_embedding,dim=1,index=random_idx.unsqueeze(-1).repeat(1, 1, class_name_embedding.shape[-1])) if region_embedding_masks is not None: region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks) region_embedding_list = [self.region_projector(region_embedding) for region_embedding in region_embedding_list] # for exo region_embedding_list_exo = self.get_region_embedding(hidden_states_exo, region_embedding_masks_exo) region_embedding_list_exo = [self.region_projector(region_embedding) for region_embedding in region_embedding_list_exo] else: region_embedding_list = None if 'referring' in batch_dataset_type or 'region' in batch_dataset_type: class_name_embedding = None loss_region_emb_SSL = XObjAlign(region_embedding_list, region_embedding_list_exo, sim_type="ecu") #ours loss = None if labels is not None and seg_query_mask is None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) llm_loss = loss_fct(shift_logits, shift_labels) if seg_query_mask is not None: seg_query = self.get_seg_query(hidden_states, seg_query_mask) seg_query = self.seg_query_projector(seg_query) image_features = self.get_vision_tower_feature(images) mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features( image_features) if refer_embedding_indices is not None: SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices) SEG_embedding = self.SEG_token_projector(SEG_embedding) else: SEG_embedding = None if 'panoptic' in batch_dataset_type or 'region' in batch_dataset_type: SEG_embedding = None mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding, class_name_embedding, region_embedding_list) if seg_info is not None: if "instances" in seg_info[0]: gt_instances = [x["instances"].to(self.device) for x in seg_info] # images = ImageList.from_tensors(images, self.size_divisibility) targets = self.prepare_targets(gt_instances, images) else: targets = None mask_losses = self.criterion(mask_outputs, targets) weight_dict = self.weight_dict loss_mask = 0.0 loss_dice = 0.0 loss_SEG_class = 0.0 loss_class_name_class = 0.0 loss_region_class = 0.0 for k in list(mask_losses.keys()): if k in weight_dict: if mask_losses[k] is not None: mask_losses[k] *= weight_dict[k] if '_SEG' in k and mask_losses[k] is not None: loss_SEG_class += mask_losses[k] elif '_name' in k and mask_losses[k] is not None: loss_class_name_class += mask_losses[k] elif '_mask' in k: loss_mask += mask_losses[k] elif '_dice' in k: loss_dice += mask_losses[k] elif '_region' in k and mask_losses[k] is not None: loss_region_class += mask_losses[k] else: mask_losses.pop(k) # adjust the SSL loss weights k_SSL = 1 mask_loss = loss_mask + loss_dice + loss_SEG_class + loss_class_name_class + loss_region_class + k_SSL*loss_region_emb_SSL if isinstance(loss_class_name_class, float): loss_class_name_class = torch.tensor(loss_class_name_class, device=mask_loss.device) if isinstance(loss_SEG_class, float): loss_SEG_class = torch.tensor(loss_SEG_class, device=mask_loss.device) if isinstance(loss_region_class, float): loss_region_class = torch.tensor(loss_region_class, device=mask_loss.device) llm = torch.tensor(0.0, device=mask_loss.device) if labels is not None: # loss = llm_loss + mask_loss loss = mask_loss return CausalOutputWithMaskSSL( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss_mask=loss_mask.detach(), loss_dice=loss_dice.detach(), loss_SEG_class=loss_SEG_class.detach(), loss_class_name_class=loss_class_name_class.detach(), loss_region_class=loss_region_class.detach(), loss_llm=llm.detach(), loss_region_emb_SSL = loss_region_emb_SSL.detach(), ) if labels is not None and seg_query_mask is None: loss_mask = torch.tensor(0.0, device=llm_loss.device) loss_dice = torch.tensor(0.0, device=llm_loss.device) loss_SEG_class = torch.tensor(0.0, device=llm_loss.device) loss_class_name_class = torch.tensor(0.0, device=llm_loss.device) loss_region_class = torch.tensor(0.0, device=llm_loss.device) loss = llm_loss else: return CausalOutputWithMask( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) return CausalOutputWithMask( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss_mask=loss_mask.detach(), loss_dice=loss_dice.detach(), loss_SEG_class=loss_SEG_class.detach(), loss_class_name_class=loss_class_name_class.detach(), loss_region_class=loss_region_class.detach(), loss_llm=llm_loss.detach(), ) def mm_conv_prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) # ensure gradients back propagation, not changing cur_input_embeds cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape # concat text and image embedding. prepare labels, IGNORE_INDEX for image tokens while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach()) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def get_seg_query(self, hidden_states, seg_query_masks): seg_query_list = [] for sample_hidden_state, sample_query_mask in zip(hidden_states, seg_query_masks): if torch.sum(sample_query_mask) == 0: continue unique_query_value = torch.unique(sample_query_mask) unique_query_value = unique_query_value[unique_query_value != 0] for value in unique_query_value: current_query_mask = (sample_query_mask == value) current_query = sample_hidden_state[current_query_mask] seg_query_list.append(current_query) seg_query = torch.stack(seg_query_list, dim=0) return seg_query def eval_seg( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, vp_images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, seg_info=None, class_name_ids=None, class_name_embedding_indices=None, cls_indices=None, token_refer_id=None, refer_embedding_indices=None, is_thing_list=None ): if self.panoptic_on: assert is_thing_list is not None, 'is_thing_list need to be given' self.is_thing_list = is_thing_list output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids == REGION_TOKEN_INDEX).sum() != 0: instances = [i['instances'] for i in seg_info] else: instances = None input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices, class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs.last_hidden_state seg_query = self.get_seg_query(hidden_states, seg_query_mask) seg_query = self.seg_query_projector(seg_query) image_features = self.get_vision_tower_feature(images) mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features( image_features) if refer_embedding_indices is not None: SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices) SEG_embedding = self.SEG_token_projector(SEG_embedding) else: SEG_embedding = None if class_name_embedding_indices is not None: class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices) class_name_embedding = self.class_name_projector(class_name_embedding) else: class_name_embedding = None if region_embedding_masks is not None: region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks) region_embedding_list = [self.region_projector(region_embedding) for region_embedding in region_embedding_list] else: region_embedding_list = None mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding, class_name_embedding, region_embedding_list) SEG_cls_results = mask_outputs['pred_SEG_logits'] class_name_cls_results = mask_outputs['pred_class_name_logits'] mask_pred_results = mask_outputs["pred_masks"] region_cls_results = mask_outputs['pred_region_logits'] images = [x for x in images] images = ImageList.from_tensors(images, self.size_divisibility) mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) del mask_outputs processed_results = [] if SEG_cls_results is None: SEG_cls_results = [None] if class_name_cls_results is None: class_name_cls_results = [None] for _seg_info, SEG_cls_result, class_name_cls_result, mask_pred_result, input_per_image, image_size in zip( seg_info, SEG_cls_results, class_name_cls_results, mask_pred_results, seg_info, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) padding_mask = input_per_image.get("padding_mask") non_padding_indices = np.where(~ np.array(padding_mask)) min_y, max_y = np.min(non_padding_indices[0]), np.max(non_padding_indices[0]) min_x, max_x = np.min(non_padding_indices[1]), np.max(non_padding_indices[1]) original_height = max_y - min_y + 1 original_width = max_x - min_x + 1 processed_results.append({}) # gt = _seg_info['instances'].gt_masks if self.sem_seg_postprocess_before_inference: mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, [original_height, original_width], height, width ) if SEG_cls_result is not None: SEG_cls_result = SEG_cls_result.to(mask_pred_result) if self.semantic_on: semantic_r = retry_if_cuda_oom(self.class_name_semantic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) if not self.sem_seg_postprocess_before_inference: semantic_r = retry_if_cuda_oom(sem_seg_postprocess)( semantic_r, [original_height, original_width], height, width ) processed_results[-1]["sem_seg"] = semantic_r if self.instance_on: instance_r = retry_if_cuda_oom(self.class_name_instance_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.panoptic_on: panoptic_r = retry_if_cuda_oom(self.class_name_panoptic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["panoptic_seg"] = panoptic_r if self.referring_on: instance_r = retry_if_cuda_oom(self.SEG_instance_inference)(SEG_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.region_on: gt = _seg_info['instances'].gt_masks gt_result = retry_if_cuda_oom(sem_seg_postprocess)( gt, [original_height, original_width], height, width ) region_cls_results = region_cls_results[0].to(mask_pred_result) instance_r = retry_if_cuda_oom(self.region_inference)(region_cls_results.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r processed_results[-1]["gt"] = gt_result return processed_results def eval_video( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, vp_images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, seg_info=None, class_name_ids=None, class_name_embedding_indices=None, cls_indices=None, token_refer_id=None, refer_embedding_indices=None, is_thing_list=None ): if self.panoptic_on: assert is_thing_list is not None, 'is_thing_list need to be given' self.is_thing_list = is_thing_list output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids == REGION_TOKEN_INDEX).sum() != 0: instances = [i['instances'] for i in seg_info] else: instances = None input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images,vp_images, class_name_embedding_indices, class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs.last_hidden_state seg_query = self.get_seg_query(hidden_states, seg_query_mask) seg_query = self.seg_query_projector(seg_query) image_features = self.get_vision_tower_feature(images) mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features( image_features) if refer_embedding_indices is not None: SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices) SEG_embedding = self.SEG_token_projector(SEG_embedding) else: SEG_embedding = None if class_name_embedding_indices is not None: class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices) class_name_embedding = self.class_name_projector(class_name_embedding) else: class_name_embedding = None if region_embedding_masks is not None: region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks) region_embedding_list = [self.region_projector(region_embedding) for region_embedding in region_embedding_list] else: region_embedding_list = None mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding, class_name_embedding, region_embedding_list) SEG_cls_results = mask_outputs['pred_SEG_logits'] class_name_cls_results = mask_outputs['pred_class_name_logits'] mask_pred_results = mask_outputs["pred_masks"] region_cls_results = mask_outputs['pred_region_logits'] images = [x for x in images] images = ImageList.from_tensors(images, self.size_divisibility) mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) del mask_outputs processed_results = [] if SEG_cls_results is None: SEG_cls_results = [None] if class_name_cls_results is None: class_name_cls_results = [None] for _seg_info, SEG_cls_result, class_name_cls_result, mask_pred_result, input_per_image, image_size in zip( seg_info, SEG_cls_results, class_name_cls_results, mask_pred_results, seg_info, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) padding_mask = input_per_image.get("padding_mask") non_padding_indices = np.where(~ np.array(padding_mask)) min_y, max_y = np.min(non_padding_indices[0]), np.max(non_padding_indices[0]) min_x, max_x = np.min(non_padding_indices[1]), np.max(non_padding_indices[1]) original_height = max_y - min_y + 1 original_width = max_x - min_x + 1 processed_results.append({}) if self.sem_seg_postprocess_before_inference: mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, [original_height, original_width], height, width ) if SEG_cls_result is not None: SEG_cls_result = SEG_cls_result.to(mask_pred_result) if self.semantic_on: semantic_r = retry_if_cuda_oom(self.class_name_semantic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) if not self.sem_seg_postprocess_before_inference: semantic_r = retry_if_cuda_oom(sem_seg_postprocess)( semantic_r, [original_height, original_width], height, width ) processed_results[-1]["sem_seg"] = semantic_r if self.instance_on: instance_r = retry_if_cuda_oom(self.class_name_instance_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.panoptic_on: panoptic_r = retry_if_cuda_oom(self.class_name_panoptic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["panoptic_seg"] = panoptic_r # print("self.referring_on",self.referring_on) #debug if self.referring_on: instance_r = retry_if_cuda_oom(self.SEG_instance_inference)(SEG_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.region_on: gt = _seg_info['instances'].gt_masks gt_result = retry_if_cuda_oom(sem_seg_postprocess)( gt, [original_height, original_width], height, width ) region_cls_results = region_cls_results[0].to(mask_pred_result) instance_r = retry_if_cuda_oom(self.region_inference)(region_cls_results.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r processed_results[-1]["gt"] = gt_result return processed_results class PSALM(PhiForCausalLM, LlavaMetaForCausalLM): config_class = LlavaConfig def __init__(self, config, mask_decoder_cfg=None, add_cross_attn=True, cross_attn_index=None): super(PSALM, self).__init__(config) self.model = PSALMModel(config, mask_decoder_cfg) self.init_config = config self.mask_decoder_cfg = mask_decoder_cfg self.cross_attn_index = cross_attn_index self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Initialize weights and apply final processing is_train_mask_decode = getattr(config, 'mask_decode_train', False) self.is_train_mask_decode = is_train_mask_decode self.refer_pooling = nn.AdaptiveAvgPool1d(output_size=1) self.class_name_pooling = nn.AdaptiveAvgPool1d(output_size=1) self.region_sampler = region_pooling(num_sample_point=256) self.region_projector = nn.Linear(config.hidden_size, mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM) if is_train_mask_decode: print('Mask Decoder has been trained, init directly') self.initial_mask_module() self.post_init() def initial_mask_module(self, pretrained_path=None, model_args=None): if not self.is_train_mask_decode: print('Initialize mask modules...') self.config.mask_decode_train = True self.seg_query = nn.Parameter( torch.zeros([self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES, self.config.hidden_size])) self.num_queries = self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES self.num_classes = self.mask_decoder_cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES self.test_topk_per_image = self.mask_decoder_cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES input_shape = self.output_shape() self.pixel_decoder = self.pixel_decoder_init(cfg=self.mask_decoder_cfg, input_shape=input_shape) self.predictor = self.predictor_init(cfg=self.mask_decoder_cfg) self.seg_query_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM) self.SEG_token_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM) self.class_name_projector = nn.Linear(self.config.hidden_size, self.mask_decoder_cfg.MODEL.MASK_FORMER.HIDDEN_DIM) self.mask_decoder_training_init(self.mask_decoder_cfg) if pretrained_path is not None: def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} def change_w(weights, old_name, new_name): weights[new_name] = weights[old_name] weights.pop(old_name) if pretrained_path.endswith('.pkl'): with open(pretrained_path, 'rb') as f: ckpt = pickle.load(f) else: ckpt = torch.load(pretrained_path) pixel_decoder_weights = get_w(ckpt['model'],'sem_seg_head.pixel_decoder') predictor_weights = get_w(ckpt['model'],'sem_seg_head.predictor') pixel_decoder_weights = {k: torch.tensor(v) for k, v in pixel_decoder_weights.items()} predictor_weights = {k: torch.tensor(v) for k, v in predictor_weights.items()} #deal some diff keys change_w(pixel_decoder_weights,'adapter_1.weight','adapter_1.0.weight') change_w(pixel_decoder_weights,'adapter_1.norm.weight','adapter_1.1.weight') change_w(pixel_decoder_weights,'adapter_1.norm.bias','adapter_1.1.bias') change_w(pixel_decoder_weights,'layer_1.weight','layer_1.0.weight') change_w(pixel_decoder_weights,'layer_1.norm.weight','layer_1.1.weight') change_w(pixel_decoder_weights,'layer_1.norm.bias','layer_1.1.bias') if 'static_query.weight' in predictor_weights: change_w(predictor_weights,'static_query.weight','query_feat.weight') if predictor_weights['query_embed.weight'].shape[0] == 200: predictor_weights['query_embed.weight'] = predictor_weights['query_embed.weight'][:100,:] diff_pixel_msg = self.pixel_decoder.load_state_dict(pixel_decoder_weights,strict=False) diff_predictor_msg = self.predictor.load_state_dict(predictor_weights,strict=False) # print(diff_predictor_msg) # print(diff_pixel_msg) def get_vision_tower_feature(self, images): features = self.get_model().get_vision_tower()(images) features_dict = { 'res2': features[0], 'res3': features[1], 'res4': features[2], 'res5': features[3], } return features_dict def mask_decoder_training_init(self, cfg): # Loss parameters: deep_supervision = cfg.MODEL.MASK_FORMER.DEEP_SUPERVISION no_object_weight = cfg.MODEL.MASK_FORMER.NO_OBJECT_WEIGHT # loss weights class_weight = cfg.MODEL.MASK_FORMER.CLASS_WEIGHT dice_weight = cfg.MODEL.MASK_FORMER.DICE_WEIGHT mask_weight = cfg.MODEL.MASK_FORMER.MASK_WEIGHT # boundary_weight = cfg.MODEL.MASK_FORMER.BOUNDARY_WEIGHT matcher = hungarian_matcher_PSALM( cost_class=class_weight, cost_mask=mask_weight, cost_dice=dice_weight, num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, ) weight_dict = {"loss_SEG_class": class_weight, "loss_class_name_class": class_weight, "loss_mask": mask_weight, "loss_dice": dice_weight, "loss_region_class": class_weight} self.weight_dict = weight_dict if deep_supervision: dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS aux_weight_dict = {} for i in range(dec_layers - 1): aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) weight_dict.update(aux_weight_dict) losses = ["SEG_labels", "class_name_labels", "masks", "region_labels"] self.criterion = PSALM_criterion( matcher=matcher, losses=losses, num_points=cfg.MODEL.MASK_FORMER.TRAIN_NUM_POINTS, oversample_ratio=cfg.MODEL.MASK_FORMER.OVERSAMPLE_RATIO, importance_sample_ratio=cfg.MODEL.MASK_FORMER.IMPORTANCE_SAMPLE_RATIO, device=self.device ) self.size_divisibility = 32 if cfg.MODEL.MASK_FORMER.SEG_TASK == 'semantic': self.semantic_on = True self.instance_on = False self.panoptic_on = False self.referring_on = False self.region_on = False elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'instance': self.semantic_on = False self.instance_on = True self.panoptic_on = False self.referring_on = False self.region_on = False elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'panoptic': self.semantic_on = True self.instance_on = True self.panoptic_on = True self.referring_on = False self.region_on = False elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'referring': self.semantic_on = False self.instance_on = False self.panoptic_on = False self.referring_on = True self.region_on = False elif cfg.MODEL.MASK_FORMER.SEG_TASK == 'region': self.semantic_on = False self.instance_on = False self.panoptic_on = False self.referring_on = False self.region_on = True else: raise NotImplementedError self.sem_seg_postprocess_before_inference = self.instance_on or self.panoptic_on or self.referring_on or self.region_on def get_region_embedding(self, hidden_states, region_embedding_masks): region_embedding_list = [] for sample_hidden_satates, sample_region_embedding_masks in zip(hidden_states, region_embedding_masks): sample_region_embedding = sample_hidden_satates[sample_region_embedding_masks.bool()] region_embedding_list.append(sample_region_embedding) return region_embedding_list def SEG_instance_inference(self, SEG_cls, mask_pred): # mask_pred is already processed to have the same shape as original input image_size = mask_pred.shape[-2:] scores = F.sigmoid(SEG_cls) scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) mask_pred = mask_pred[topk_indices] result = Instances(image_size) result.pred_masks = (mask_pred > 0).float() result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / ( result.pred_masks.flatten(1).sum(1) + 1e-6) result.scores = scores_per_image * mask_scores_per_image return result def class_name_panoptic_inference(self, SEG_cls, class_name_cls, mask_pred): scores, labels = F.softmax(class_name_cls, dim=-1).max(-1) num_classes = class_name_cls.shape[-1] - 1 mask_pred = mask_pred.sigmoid() object_mask_threshold = 0.8 overlap_threshold = 0.8 keep = labels.ne(num_classes) & (scores > object_mask_threshold) cur_scores = scores[keep] cur_classes = labels[keep] cur_masks = mask_pred[keep] cur_mask_cls = class_name_cls[keep] cur_mask_cls = cur_mask_cls[:, :-1] cur_prob_masks = cur_scores.view(-1, 1, 1) * cur_masks h, w = cur_masks.shape[-2:] panoptic_seg = torch.zeros((h, w), dtype=torch.int32, device=cur_masks.device) segments_info = [] current_segment_id = 0 if cur_masks.shape[0] == 0: return panoptic_seg, segments_info else: # take argmax cur_mask_ids = cur_prob_masks.argmax(0) stuff_memory_list = {} for k in range(cur_classes.shape[0]): pred_class = cur_classes[k].item() isthing = self.is_thing_list[pred_class] mask_area = (cur_mask_ids == k).sum().item() original_area = (cur_masks[k] >= 0.5).sum().item() mask = (cur_mask_ids == k) & (cur_masks[k] >= 0.5) if mask_area > 0 and original_area > 0 and mask.sum().item() > 0: if mask_area / original_area < overlap_threshold: continue # merge stuff regions if not isthing: if int(pred_class) in stuff_memory_list.keys(): panoptic_seg[mask] = stuff_memory_list[int(pred_class)] continue else: stuff_memory_list[int(pred_class)] = current_segment_id + 1 current_segment_id += 1 panoptic_seg[mask] = current_segment_id segments_info.append( { "id": current_segment_id, "isthing": bool(isthing), "category_id": int(pred_class), } ) return panoptic_seg, segments_info def region_inference(self, region_cls, mask_pred): image_size = mask_pred.shape[-2:] scores = F.sigmoid(region_cls) result = Instances(image_size) result.pred_masks = (mask_pred > 0).float() result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / ( result.pred_masks.flatten(1).sum(1) + 1e-6) result.scores = (scores * mask_scores_per_image[None,...].repeat(scores.shape[0],1)).transpose(1,0) return result def class_name_semantic_inference(self, SEG_cls, class_name_cls, mask_pred): mask_cls = F.softmax(class_name_cls, dim=-1)[:, :-1] mask_pred = mask_pred.sigmoid() semseg = torch.einsum("qc,qhw->chw", mask_cls, mask_pred) return semseg def class_name_instance_inference(self, SEG_cls, class_name_cls, mask_pred): image_size = mask_pred.shape[-2:] cls_scores = F.softmax(class_name_cls, dim=-1)[:, :-1] scores = cls_scores num_classes = scores.shape[-1] labels = torch.arange(num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) # scores_per_image, topk_indices = scores.flatten(0, 1).topk(5000, sorted=False) labels_per_image = labels[topk_indices] topk_indices = topk_indices // num_classes mask_pred = mask_pred[topk_indices] # if this is panoptic segmentation, we only keep the "thing" classes if self.panoptic_on: keep = torch.zeros_like(scores_per_image).bool() for i, lab in enumerate(labels_per_image): keep[i] = self.is_thing_list[lab] scores_per_image = scores_per_image[keep] labels_per_image = labels_per_image[keep] mask_pred = mask_pred[keep] result = Instances(image_size) # mask (before sigmoid) result.pred_masks = (mask_pred > 0).float() result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) # Uncomment the following to get boxes from masks (this is slow) # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() # calculate average mask prob mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / ( result.pred_masks.flatten(1).sum(1) + 1e-6) result.scores = scores_per_image * mask_scores_per_image result.pred_classes = labels_per_image return result def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images) image_features = self.get_model().mm_projector(image_features[-1]) return image_features def predictor_init(self, cfg): in_channels = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM hidden_dim = cfg.MODEL.MASK_FORMER.HIDDEN_DIM num_queries = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES nheads = cfg.MODEL.MASK_FORMER.NHEADS dim_feedforward = cfg.MODEL.MASK_FORMER.DIM_FEEDFORWARD dec_layers = cfg.MODEL.MASK_FORMER.DEC_LAYERS - 1 pre_norm = cfg.MODEL.MASK_FORMER.PRE_NORM mask_dim = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM enforce_input_project = False seg_norm = cfg.MODEL.MASK_FORMER.SEG_NORM seg_proj = cfg.MODEL.MASK_FORMER.SEG_PROJ seg_fuse_score = cfg.MODEL.MASK_FORMER.FUSE_SCORE seg_concat = False print(f'current seg concat mode: {seg_concat}, seg_norm: {seg_norm}, seg_proj: {seg_proj}, seg_fuse_score: {seg_fuse_score}') predictor = MultiScaleMaskedTransformerDecoder(in_channels, hidden_dim, num_queries, nheads, dim_feedforward, dec_layers, pre_norm, mask_dim, enforce_input_project, seg_norm, seg_concat, seg_proj, seg_fuse_score) return predictor def get_model(self): return self.model def output_shape(self): out_features = self.mask_decoder_cfg.MODEL.SWIN.OUT_FEATURES out_feature_strides = { "res2": 4, "res3": 8, "res4": 16, "res5": 32, } num_features = [int(self.mask_decoder_cfg.MODEL.SWIN.EMBED_DIM * 2 ** i) for i in range(len(self.mask_decoder_cfg.MODEL.SWIN.DEPTHS))] out_feature_channels = { "res2": num_features[0], "res3": num_features[1], "res4": num_features[2], "res5": num_features[3], } backbone_feature_shape = dict() for name in out_features: backbone_feature_shape[name] = Dict( {'channel': out_feature_channels[name], 'stride': out_feature_strides[name]}) return backbone_feature_shape def get_encoder_image(self, images): encode_image_features = self.get_model().get_vision_tower()(images) return encode_image_features def pixel_decoder_init(self, cfg, input_shape): common_stride = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE transformer_dropout = cfg.MODEL.MASK_FORMER.DROPOUT transformer_nheads = cfg.MODEL.MASK_FORMER.NHEADS transformer_dim_feedforward = 1024 transformer_enc_layers = cfg.MODEL.SEM_SEG_HEAD.TRANSFORMER_ENC_LAYERS conv_dim = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM mask_dim = cfg.MODEL.SEM_SEG_HEAD.MASK_DIM transformer_in_features = cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES # ["res3", "res4", "res5"] pixel_decoder = MSDeformAttnPixelDecoder(input_shape, transformer_dropout, transformer_nheads, transformer_dim_feedforward, transformer_enc_layers, conv_dim, mask_dim, transformer_in_features, common_stride) return pixel_decoder def prepare_targets(self, targets, images): h_pad, w_pad = images.shape[-2:] new_targets = [] for targets_per_image in targets: # pad gt gt_masks = targets_per_image.gt_masks padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks new_targets.append( { "labels": targets_per_image.gt_classes, "masks": padded_masks, } ) return new_targets def get_special_token(self, SEG, EOS): self.SEG_id = SEG self.EOS_id = EOS def get_class_name_embedding(self, hidden_states, cls_token_indices): class_name_embedding_list = [] for current_hidden_state, current_token_indice in zip(hidden_states, cls_token_indices): class_id = torch.unique(current_token_indice) class_id = class_id[class_id != 0] current_class_name_embedding_list = [] for id in class_id: current_class_mask = (current_token_indice == id) current_class_state = current_hidden_state[current_class_mask] current_class_name_embedding_list.append(current_class_state) current_pool_class_name_embedding = [self.class_name_pooling(class_name.transpose(-2, -1)).transpose(-2, -1) for class_name in current_class_name_embedding_list] class_name_embedding_list.append(torch.cat(current_pool_class_name_embedding, dim=0)) return torch.stack(class_name_embedding_list, dim=0) def embed_class_ids(self, class_name_ids, cls_indices): if class_name_ids is None: return None num_class = cls_indices.unique_consecutive() num_class = num_class[num_class >= 0] class_name_ids = [class_name_ids[cls_indices == idx] for idx in num_class] embedded_class_name = [self.get_model().embed_tokens(id) for id in class_name_ids] return embedded_class_name def embed_refer_ids(self, refer_ids): if refer_ids is None: return None embedded_refer = self.get_model().embed_tokens(refer_ids) return embedded_refer def concat_image_seg_cls_embeds(self, input_id, img_feature, label, seg_query, seg_query_mask, class_embed, class_name_embedding_indices,region_embedding_mask=None, region_feature_list=None, refer_embedding_indices=None, refer_embedding=None): image_token_indices = torch.where(input_id == IMAGE_TOKEN_INDEX)[0] seg_query_indices = torch.where(input_id == SEG_TOKEN_INDEX)[0] cls_token_indices = torch.where(input_id == CLS_TOKEN_INDEX)[0] region_token_indices = torch.where(input_id == REGION_TOKEN_INDEX)[0] assert len(image_token_indices) == 1, 'not supporting multi image index' assert len(seg_query_indices) == 1, 'not supporting multi seg index' if class_name_embedding_indices is not None: assert len(cls_token_indices) == len(class_embed), 'the number of tokens and class_embed needs to be same' if region_feature_list is not None: assert len(region_feature_list) == len( region_token_indices), 'the munber of tokens and regions needs to be same' cur_new_input_embeds = [] cur_new_seg_query_mask = [] if label is not None: cur_new_label = [] assert label.shape == input_id.shape else: cur_new_label = None cur_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None cur_refer_embedding_indices = [] if refer_embedding_indices is not None else None if region_embedding_mask is not None: enable_region_mask = True cur_new_region_embedding_mask = [] else: enable_region_mask = False cur_new_region_embedding_mask = None chunks = [] current_chunk = [] for id in input_id: if id >= 0: current_chunk.append(id.item()) else: if current_chunk: chunks.append(torch.tensor(current_chunk, device=input_id.device)) current_chunk = [] chunks.append([id]) if current_chunk: chunks.append(torch.tensor(current_chunk, device=input_id.device)) cls_idx = 0 region_idx = 0 for chunk in chunks: chunk_len = len(chunk) if chunk_len == 1 and chunk[0] == IMAGE_TOKEN_INDEX: cur_new_input_embeds.append(img_feature) cur_new_seg_query_mask.append(torch.zeros(img_feature.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((img_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((img_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((img_feature.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(img_feature.shape[0])) elif chunk_len == 1 and chunk[0] == SEG_TOKEN_INDEX: cur_new_input_embeds.append(seg_query) cur_new_seg_query_mask.append(torch.ones(seg_query.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device, dtype=label.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append(torch.full((seg_query.shape[0],), 0, device=label.device, dtype=label.dtype)) if label is not None: cur_new_label.append( torch.full((seg_query.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype)) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(seg_query.shape[0])) elif chunk_len == 1 and chunk[0] == CLS_TOKEN_INDEX: cls_embed = class_embed[cls_idx] if len(cls_embed.shape) == 1: cls_embed = cls_embed.unsqueeze(0) cls_idx += 1 cur_new_input_embeds.append(cls_embed) cur_new_seg_query_mask.append(torch.zeros(cls_embed.shape[0])) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(cls_embed.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((cls_embed.shape[0],), cls_idx, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((cls_embed.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((cls_embed.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) elif chunk_len == 1 and chunk[0] == REGION_TOKEN_INDEX: region_feature = region_feature_list[region_idx] region_idx += 1 cur_new_input_embeds.append(region_feature) cur_new_seg_query_mask.append(torch.zeros(region_feature.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((region_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((region_feature.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((region_feature.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) if enable_region_mask: cur_new_region_embedding_mask.append(torch.ones(region_feature.shape[0])) elif chunk_len == 1 and chunk[0] == REFER_TOKEN_INDEX: refer_embed = refer_embedding if len(refer_embed.shape) == 1: refer_embed = refer_embed.unsqueeze(0) cur_new_input_embeds.append(refer_embed) cur_new_seg_query_mask.append(torch.zeros(refer_embed.shape[0])) if enable_region_mask: cur_new_region_embedding_mask.append(torch.zeros(refer_embed.shape[0])) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append( torch.full((refer_embed.shape[0],), 0, device=input_id.device, dtype=input_id.dtype)) if refer_embedding_indices is not None: cur_refer_embedding_indices.append( torch.full((refer_embed.shape[0],), 1, device=input_id.device, dtype=input_id.dtype)) if label is not None: cur_new_label.append( torch.full((refer_embed.shape[0],), IGNORE_INDEX, device=label.device, dtype=label.dtype) ) else: cur_new_input_embeds.append(self.get_model().embed_tokens(input_id[:chunk_len])) cur_new_seg_query_mask.append(seg_query_mask[:chunk_len]) if class_name_embedding_indices is not None: cur_class_name_embedding_indices.append(class_name_embedding_indices[:chunk_len]) if refer_embedding_indices is not None: cur_refer_embedding_indices.append(refer_embedding_indices[:chunk_len]) if label is not None: cur_new_label.append(label[:chunk_len]) if enable_region_mask: cur_new_region_embedding_mask.append(region_embedding_mask[:chunk_len]) input_id = input_id[chunk_len:] seg_query_mask = seg_query_mask[chunk_len:] if class_name_embedding_indices is not None: class_name_embedding_indices = class_name_embedding_indices[chunk_len:] if refer_embedding_indices is not None: refer_embedding_indices = refer_embedding_indices[chunk_len:] if label is not None: label = label[chunk_len:] if enable_region_mask: region_embedding_mask = region_embedding_mask[chunk_len:] cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) if label is not None: cur_new_label = [x.to(device=self.device) for x in cur_new_label] cur_new_label = torch.cat(cur_new_label, dim=0) cur_new_seg_query_mask = [x.to(device=self.device) for x in cur_new_seg_query_mask] cur_new_seg_query_mask = torch.cat(cur_new_seg_query_mask, dim=0) if class_name_embedding_indices is not None: cur_class_name_embedding_indices = [x.to(device=self.device) for x in cur_class_name_embedding_indices] cur_class_name_embedding_indices = torch.cat(cur_class_name_embedding_indices, dim=0) if refer_embedding_indices is not None: cur_refer_embedding_indices = [x.to(device=self.device) for x in cur_refer_embedding_indices] cur_refer_embedding_indices = torch.cat(cur_refer_embedding_indices, dim=0) if enable_region_mask: cur_new_region_embedding_mask = [x.to(device=self.device) for x in cur_new_region_embedding_mask] cur_new_region_embedding_mask = torch.cat(cur_new_region_embedding_mask, dim=0) return cur_new_input_embeds, cur_new_label, cur_new_seg_query_mask, cur_class_name_embedding_indices, cur_new_region_embedding_mask, cur_refer_embedding_indices def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images, vp_images=None, class_name_embedding_indices=None, class_name_ids=None, cls_indices=None, instances=None, token_refer_id=None, refer_embedding_indices=None ): vision_tower = self.get_vision_tower() seg_query_mask = torch.zeros_like(input_ids) if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[ 1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels, seg_query_mask if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) expanded_seg_query = self.seg_query.unsqueeze(0).expand(input_ids.shape[0], -1, -1) if (input_ids == REGION_TOKEN_INDEX).sum() != 0 and instances is not None: region_masks_list = [instance.vp_region_masks.tensor for instance in instances] vp_image_features = self.encode_images(vp_images) # [region_features_per_batch: [num_region, 1, dims]], len(region_features) = batch_size region_features = self.region_sampler(vp_image_features, region_masks_list, original_dtype=vp_image_features.dtype, return_dtype=vp_image_features.dtype) region_embedding_masks = torch.zeros_like(input_ids) else: region_features = None region_embedding_masks = None new_input_embeds = [] new_labels = [] if labels is not None else None new_seg_query_masks = [] new_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None new_refer_embedding_indices = [] if refer_embedding_indices is not None else None new_region_embedding_masks = [] if region_features is not None else None for batch_idx, cur_input_ids in enumerate(input_ids): cur_seg_query_mask = seg_query_mask[batch_idx] cur_seg_query = expanded_seg_query[batch_idx] cur_image_feature = image_features[batch_idx] cur_class_name_embedding_indices = class_name_embedding_indices[ batch_idx] if class_name_embedding_indices is not None else None cur_refer_embedding_indices = refer_embedding_indices[ batch_idx] if refer_embedding_indices is not None else None cur_region_feature_list = region_features[batch_idx] if region_features is not None else None cur_region_embedding_mask = region_embedding_masks[batch_idx] if region_features is not None else None if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) # ensure gradients back propagation, not changing cur_input_embeds cur_input_embeds = cur_input_embeds + ( 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) new_seg_query_masks.append(cur_seg_query_mask) # cur_image_idx += 1 continue if labels is not None: cur_label = labels[batch_idx] else: cur_label = None if class_name_ids is not None: cur_class_name_ids = class_name_ids[batch_idx] cur_cls_indices = cls_indices[batch_idx] else: cur_class_name_ids = None cur_cls_indices = None if token_refer_id is not None: cur_token_refer_id = token_refer_id[batch_idx] else: cur_token_refer_id = None cur_class_name_embedding = self.embed_class_ids(cur_class_name_ids, cur_cls_indices) cur_refer_embedding = self.embed_refer_ids(cur_token_refer_id) cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds( input_id=cur_input_ids, img_feature=cur_image_feature, label=cur_label, seg_query=cur_seg_query, seg_query_mask=cur_seg_query_mask, class_embed=cur_class_name_embedding, class_name_embedding_indices=cur_class_name_embedding_indices, region_embedding_mask=cur_region_embedding_mask, region_feature_list=cur_region_feature_list, refer_embedding_indices=cur_refer_embedding_indices, refer_embedding=cur_refer_embedding ) assert cur_input_embeds.shape[0] == cur_seg_query_mask.shape[0] new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(cur_label) new_seg_query_masks.append(cur_seg_query_mask) if class_name_embedding_indices is not None: new_class_name_embedding_indices.append(cur_class_name_embedding_indices) if refer_embedding_indices is not None: new_refer_embedding_indices.append(cur_refer_embedding_indices) if new_region_embedding_masks is not None: new_region_embedding_masks.append(cur_region_embedding_mask) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) new_seg_query_masks_align = [] for new_seg_query_mask in new_seg_query_masks: new_seg_query_mask = torch.cat( (new_seg_query_mask, torch.zeros((max_len - new_seg_query_mask.shape[0]), dtype=new_seg_query_mask.dtype, device=new_seg_query_mask.device)), dim=0) new_seg_query_masks_align.append(new_seg_query_mask) new_seg_query_masks = torch.stack(new_seg_query_masks_align, dim=0) new_class_name_embedding_indices_align = [] if class_name_embedding_indices is not None: for new_class_name_embedding_indice in new_class_name_embedding_indices: new_class_name_embedding_indice = torch.cat( (new_class_name_embedding_indice, torch.zeros((max_len - new_class_name_embedding_indice.shape[0]), dtype=new_class_name_embedding_indice.dtype, device=new_class_name_embedding_indice.device)), dim=0) new_class_name_embedding_indices_align.append(new_class_name_embedding_indice) new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices_align, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices_align = [] for new_refer_embedding_indice in new_refer_embedding_indices: new_refer_embedding_indice = torch.cat( (new_refer_embedding_indice, torch.zeros((max_len - new_refer_embedding_indice.shape[0]), dtype=new_refer_embedding_indice.dtype, device=new_refer_embedding_indice.device)), dim=0) new_refer_embedding_indices_align.append(new_refer_embedding_indice) new_refer_embedding_indices = torch.stack(new_refer_embedding_indices_align, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks_align = [] for new_region_embedding_mask in new_region_embedding_masks: new_region_embedding_mask = torch.cat( (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]), dtype=new_region_embedding_mask.dtype, device=new_region_embedding_mask.device)), dim=0) new_region_embedding_masks_align.append(new_region_embedding_mask) new_region_embedding_masks = torch.stack(new_region_embedding_masks_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat( (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) new_seg_query_masks = torch.stack(new_seg_query_masks, dim=0) if class_name_embedding_indices is not None: new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices = torch.stack(new_refer_embedding_indices, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks = torch.stack(new_region_embedding_masks, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full( (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels, new_seg_query_masks, new_class_name_embedding_indices, new_region_embedding_masks, new_refer_embedding_indices def get_SEG_embedding(self,hidden_states, refer_embedding_indices): refer_embedding_list = [] for current_hidden_state, current_token_indice in zip(hidden_states, refer_embedding_indices): current_refer_state = current_hidden_state[current_token_indice.bool()] current_pool_refer_state = self.refer_pooling(current_refer_state.transpose(-2, -1)).transpose(-2, -1) refer_embedding_list.append(current_pool_refer_state) return torch.stack(refer_embedding_list, dim=0) def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, vp_images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, seg_info=None, class_name_ids=None, class_name_embedding_indices=None, cls_indices=None, random_idx=None, token_refer_id=None, refer_embedding_indices=None, dataset_type=None, ) -> Union[Tuple, CausalLMOutputWithPast]: if dataset_type is not None: assert all(item == dataset_type[0] for item in dataset_type), f'this batch contain different dataset_type: {dataset_type}' batch_dataset_type = dataset_type[0] #print(batch_dataset_type) else: batch_dataset_type = [] output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids == SEG_TOKEN_INDEX).sum() != 0: if (input_ids == REGION_TOKEN_INDEX).sum() != 0: instances = [i['instances'] for i in seg_info] else: instances = None input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices, class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices) else: seg_query_mask = None class_name_embedding_indices = None region_embedding_masks = None SEG_token_indices = None input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.mm_conv_prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs.last_hidden_state logits = self.lm_head(hidden_states) if class_name_embedding_indices is not None: class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices) class_name_embedding = self.class_name_projector(class_name_embedding) else: class_name_embedding = None if class_name_embedding is not None: class_name_embedding = torch.gather(class_name_embedding,dim=1,index=random_idx.unsqueeze(-1).repeat(1, 1, class_name_embedding.shape[-1])) if region_embedding_masks is not None: region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks) region_embedding_list = [self.region_projector(region_embedding) for region_embedding in region_embedding_list] else: region_embedding_list = None if 'referring' in batch_dataset_type or 'region' in batch_dataset_type: class_name_embedding = None loss = None if labels is not None and seg_query_mask is None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() shift_logits = shift_logits.view(-1, self.config.vocab_size) shift_labels = shift_labels.view(-1) # Enable model/pipeline parallelism shift_labels = shift_labels.to(shift_logits.device) llm_loss = loss_fct(shift_logits, shift_labels) if seg_query_mask is not None: seg_query = self.get_seg_query(hidden_states, seg_query_mask) seg_query = self.seg_query_projector(seg_query) image_features = self.get_vision_tower_feature(images) mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features( image_features) if refer_embedding_indices is not None: SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices) SEG_embedding = self.SEG_token_projector(SEG_embedding) else: SEG_embedding = None if 'panoptic' in batch_dataset_type or 'region' in batch_dataset_type: SEG_embedding = None mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding, class_name_embedding, region_embedding_list) if seg_info is not None: if "instances" in seg_info[0]: gt_instances = [x["instances"].to(self.device) for x in seg_info] # images = ImageList.from_tensors(images, self.size_divisibility) targets = self.prepare_targets(gt_instances, images) else: targets = None mask_losses = self.criterion(mask_outputs, targets) weight_dict = self.weight_dict loss_mask = 0.0 loss_dice = 0.0 loss_SEG_class = 0.0 loss_class_name_class = 0.0 loss_region_class = 0.0 for k in list(mask_losses.keys()): if k in weight_dict: if mask_losses[k] is not None: mask_losses[k] *= weight_dict[k] if '_SEG' in k and mask_losses[k] is not None: loss_SEG_class += mask_losses[k] elif '_name' in k and mask_losses[k] is not None: loss_class_name_class += mask_losses[k] elif '_mask' in k: loss_mask += mask_losses[k] elif '_dice' in k: loss_dice += mask_losses[k] elif '_region' in k and mask_losses[k] is not None: loss_region_class += mask_losses[k] else: mask_losses.pop(k) mask_loss = loss_mask + loss_dice + loss_SEG_class + loss_class_name_class + loss_region_class if isinstance(loss_class_name_class, float): loss_class_name_class = torch.tensor(loss_class_name_class, device=mask_loss.device) if isinstance(loss_SEG_class, float): loss_SEG_class = torch.tensor(loss_SEG_class, device=mask_loss.device) if isinstance(loss_region_class, float): loss_region_class = torch.tensor(loss_region_class, device=mask_loss.device) llm = torch.tensor(0.0, device=mask_loss.device) if labels is not None: # loss = llm_loss + mask_loss loss = mask_loss return CausalOutputWithMask( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss_mask=loss_mask.detach(), loss_dice=loss_dice.detach(), loss_SEG_class=loss_SEG_class.detach(), loss_class_name_class=loss_class_name_class.detach(), loss_region_class=loss_region_class.detach(), loss_llm=llm.detach(), ) if labels is not None and seg_query_mask is None: loss_mask = torch.tensor(0.0, device=llm_loss.device) loss_dice = torch.tensor(0.0, device=llm_loss.device) loss_SEG_class = torch.tensor(0.0, device=llm_loss.device) loss_class_name_class = torch.tensor(0.0, device=llm_loss.device) loss_region_class = torch.tensor(0.0, device=llm_loss.device) loss = llm_loss else: return CausalOutputWithMask( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) return CausalOutputWithMask( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, loss_mask=loss_mask.detach(), loss_dice=loss_dice.detach(), loss_SEG_class=loss_SEG_class.detach(), loss_class_name_class=loss_class_name_class.detach(), loss_region_class=loss_region_class.detach(), loss_llm=llm_loss.detach(), ) def mm_conv_prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images ): vision_tower = self.get_vision_tower() if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) # ensure gradients back propagation, not changing cur_input_embeds cur_input_embeds = cur_input_embeds + (0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape # concat text and image embedding. prepare labels, IGNORE_INDEX for image tokens while image_token_indices.numel() > 0: cur_image_features = image_features[cur_image_idx] image_token_start = image_token_indices[0] if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach()) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) cur_new_input_embeds.append(cur_image_features) cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_new_labels.append(cur_labels[image_token_start:image_token_start+1]) cur_labels = cur_labels[image_token_start+2:] else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) cur_new_input_embeds.append(cur_image_features) if labels is not None: cur_new_labels.append(cur_labels[:image_token_start]) cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) cur_labels = cur_labels[image_token_start+1:] cur_image_idx += 1 if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_input_ids = cur_input_ids[image_token_start+2:] else: cur_input_ids = cur_input_ids[image_token_start+1:] image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] if cur_input_ids.numel() > 0: if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) else: cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) if labels is not None: cur_new_labels.append(cur_labels) cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) new_input_embeds.append(cur_new_input_embeds) if labels is not None: cur_new_labels = torch.cat(cur_new_labels, dim=0) new_labels.append(cur_new_labels) # Align embedddings, labels, attn_mask from different sample into a batch if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels def get_seg_query(self, hidden_states, seg_query_masks): seg_query_list = [] for sample_hidden_state, sample_query_mask in zip(hidden_states, seg_query_masks): if torch.sum(sample_query_mask) == 0: continue unique_query_value = torch.unique(sample_query_mask) unique_query_value = unique_query_value[unique_query_value != 0] for value in unique_query_value: current_query_mask = (sample_query_mask == value) current_query = sample_hidden_state[current_query_mask] seg_query_list.append(current_query) seg_query = torch.stack(seg_query_list, dim=0) return seg_query def eval_seg( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, vp_images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, seg_info=None, class_name_ids=None, class_name_embedding_indices=None, cls_indices=None, token_refer_id=None, refer_embedding_indices=None, is_thing_list=None ): if self.panoptic_on: assert is_thing_list is not None, 'is_thing_list need to be given' self.is_thing_list = is_thing_list output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids == REGION_TOKEN_INDEX).sum() != 0: instances = [i['instances'] for i in seg_info] else: instances = None input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images, vp_images, class_name_embedding_indices, class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs.last_hidden_state seg_query = self.get_seg_query(hidden_states, seg_query_mask) seg_query = self.seg_query_projector(seg_query) image_features = self.get_vision_tower_feature(images) mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features( image_features) if refer_embedding_indices is not None: SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices) SEG_embedding = self.SEG_token_projector(SEG_embedding) else: SEG_embedding = None if class_name_embedding_indices is not None: class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices) class_name_embedding = self.class_name_projector(class_name_embedding) else: class_name_embedding = None if region_embedding_masks is not None: region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks) region_embedding_list = [self.region_projector(region_embedding) for region_embedding in region_embedding_list] else: region_embedding_list = None mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding, class_name_embedding, region_embedding_list) SEG_cls_results = mask_outputs['pred_SEG_logits'] class_name_cls_results = mask_outputs['pred_class_name_logits'] mask_pred_results = mask_outputs["pred_masks"] region_cls_results = mask_outputs['pred_region_logits'] images = [x for x in images] images = ImageList.from_tensors(images, self.size_divisibility) mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) del mask_outputs processed_results = [] if SEG_cls_results is None: SEG_cls_results = [None] if class_name_cls_results is None: class_name_cls_results = [None] for _seg_info, SEG_cls_result, class_name_cls_result, mask_pred_result, input_per_image, image_size in zip( seg_info, SEG_cls_results, class_name_cls_results, mask_pred_results, seg_info, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) padding_mask = input_per_image.get("padding_mask") non_padding_indices = np.where(~ np.array(padding_mask)) min_y, max_y = np.min(non_padding_indices[0]), np.max(non_padding_indices[0]) min_x, max_x = np.min(non_padding_indices[1]), np.max(non_padding_indices[1]) original_height = max_y - min_y + 1 original_width = max_x - min_x + 1 processed_results.append({}) # gt = _seg_info['instances'].gt_masks if self.sem_seg_postprocess_before_inference: mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, [original_height, original_width], height, width ) if SEG_cls_result is not None: SEG_cls_result = SEG_cls_result.to(mask_pred_result) if self.semantic_on: semantic_r = retry_if_cuda_oom(self.class_name_semantic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) if not self.sem_seg_postprocess_before_inference: semantic_r = retry_if_cuda_oom(sem_seg_postprocess)( semantic_r, [original_height, original_width], height, width ) processed_results[-1]["sem_seg"] = semantic_r if self.instance_on: instance_r = retry_if_cuda_oom(self.class_name_instance_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.panoptic_on: panoptic_r = retry_if_cuda_oom(self.class_name_panoptic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["panoptic_seg"] = panoptic_r if self.referring_on: instance_r = retry_if_cuda_oom(self.SEG_instance_inference)(SEG_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.region_on: gt = _seg_info['instances'].gt_masks gt_result = retry_if_cuda_oom(sem_seg_postprocess)( gt, [original_height, original_width], height, width ) region_cls_results = region_cls_results[0].to(mask_pred_result) instance_r = retry_if_cuda_oom(self.region_inference)(region_cls_results.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r processed_results[-1]["gt"] = gt_result return processed_results class PSALMForDAVISEval(PSALM): def eval_seg( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, seg_info=None, class_name_ids=None, class_name_embedding_indices=None, cls_indices=None, token_refer_id=None, refer_embedding_indices=None, is_thing_list=None, vp_images=None ): if self.panoptic_on: assert is_thing_list is not None, 'is_thing_list need to be given' self.is_thing_list = is_thing_list output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids == REGION_TOKEN_INDEX).sum() != 0: instances = [i['instances'] for i in seg_info] else: instances = None input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images,vp_images, class_name_embedding_indices, class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs.last_hidden_state seg_query = self.get_seg_query(hidden_states, seg_query_mask) seg_query = self.seg_query_projector(seg_query) image_features = self.get_vision_tower_feature(images) mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features( image_features) if refer_embedding_indices is not None: SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices) SEG_embedding = self.SEG_token_projector(SEG_embedding) else: SEG_embedding = None if class_name_embedding_indices is not None: class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices) class_name_embedding = self.class_name_projector(class_name_embedding) else: class_name_embedding = None if region_embedding_masks is not None: region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks) region_embedding_list = [self.region_projector(region_embedding) for region_embedding in region_embedding_list] else: region_embedding_list = None mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding, class_name_embedding, region_embedding_list) SEG_cls_results = mask_outputs['pred_SEG_logits'] class_name_cls_results = mask_outputs['pred_class_name_logits'] mask_pred_results = mask_outputs["pred_masks"] region_cls_results = mask_outputs['pred_region_logits'] images = [x for x in images] images = ImageList.from_tensors(images, self.size_divisibility) mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) del mask_outputs processed_results = [] if SEG_cls_results is None: SEG_cls_results = [None] if class_name_cls_results is None: class_name_cls_results = [None] for _seg_info, SEG_cls_result, class_name_cls_result, mask_pred_result, input_per_image, image_size in zip( seg_info, SEG_cls_results, class_name_cls_results, mask_pred_results, seg_info, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) padding_mask = input_per_image.get("padding_mask") non_padding_indices = np.where(~ np.array(padding_mask)) min_y, max_y = np.min(non_padding_indices[0]), np.max(non_padding_indices[0]) min_x, max_x = np.min(non_padding_indices[1]), np.max(non_padding_indices[1]) original_height = max_y - min_y + 1 original_width = max_x - min_x + 1 processed_results.append({}) # gt = _seg_info['instances'].gt_masks if self.sem_seg_postprocess_before_inference: mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, [original_height, original_width], height, width ) # gt_result = retry_if_cuda_oom(sem_seg_postprocess)( # gt, [original_height, original_width], height, width # ) if SEG_cls_result is not None: SEG_cls_result = SEG_cls_result.to(mask_pred_result) if self.semantic_on: semantic_r = retry_if_cuda_oom(self.class_name_semantic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) if not self.sem_seg_postprocess_before_inference: semantic_r = retry_if_cuda_oom(sem_seg_postprocess)( semantic_r, [original_height, original_width], height, width ) processed_results[-1]["sem_seg"] = semantic_r if self.instance_on: instance_r = retry_if_cuda_oom(self.class_name_instance_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.panoptic_on: panoptic_r = retry_if_cuda_oom(self.class_name_panoptic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["panoptic_seg"] = panoptic_r if self.referring_on: instance_r = retry_if_cuda_oom(self.SEG_instance_inference)(SEG_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.region_on: gt = _seg_info['instances'].gt_masks gt_result = retry_if_cuda_oom(sem_seg_postprocess)( gt, [original_height, original_width], height, width ) region_cls_results = region_cls_results[0].to(mask_pred_result) instance_r = retry_if_cuda_oom(self.region_inference)(region_cls_results.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r processed_results[-1]["gt"] = gt_result return processed_results def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images, vp_images=None, class_name_embedding_indices=None, class_name_ids=None, cls_indices=None, instances=None, token_refer_id=None, refer_embedding_indices=None ): vision_tower = self.get_vision_tower() seg_query_mask = torch.zeros_like(input_ids) if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[ 1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels, seg_query_mask if type(images) is list or images.ndim == 5: concat_images = torch.cat([image for image in images], dim=0) image_features = self.encode_images(concat_images) split_sizes = [image.shape[0] for image in images] image_features = torch.split(image_features, split_sizes, dim=0) image_features = [x.flatten(0, 1) for x in image_features] else: image_features = self.encode_images(images) expanded_seg_query = self.seg_query.unsqueeze(0).expand(input_ids.shape[0], -1, -1) if (input_ids == REGION_TOKEN_INDEX).sum() != 0 and instances is not None: region_masks_list = [instance.vp_region_masks.tensor for instance in instances] vp_image_features = self.encode_images(vp_images) # [region_features_per_batch: [num_region, 1, dims]], len(region_features) = batch_size region_features = self.region_sampler(vp_image_features, region_masks_list, original_dtype=vp_image_features.dtype, return_dtype=vp_image_features.dtype) region_embedding_masks = torch.zeros_like(input_ids) else: region_features = None region_embedding_masks = None new_input_embeds = [] new_labels = [] if labels is not None else None new_seg_query_masks = [] new_class_name_embedding_indices = [] if class_name_embedding_indices is not None else None new_refer_embedding_indices = [] if refer_embedding_indices is not None else None new_region_embedding_masks = [] if region_features is not None else None for batch_idx, cur_input_ids in enumerate(input_ids): cur_seg_query_mask = seg_query_mask[batch_idx] cur_seg_query = expanded_seg_query[batch_idx] cur_image_feature = image_features[batch_idx] cur_class_name_embedding_indices = class_name_embedding_indices[batch_idx] if class_name_embedding_indices is not None else None cur_refer_embedding_indices = refer_embedding_indices[batch_idx] if refer_embedding_indices is not None else None cur_region_feature_list = region_features[batch_idx] if region_features is not None else None cur_region_embedding_mask = region_embedding_masks[batch_idx] if region_features is not None else None if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # multimodal LLM, but the current sample is not multimodal cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) # ensure gradients back propagation, not changing cur_input_embeds cur_input_embeds = cur_input_embeds + ( 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) new_seg_query_masks.append(cur_seg_query_mask) # cur_image_idx += 1 continue if labels is not None: cur_label = labels[batch_idx] else: cur_label = None if class_name_ids is not None: cur_class_name_ids = class_name_ids[batch_idx] cur_cls_indices = cls_indices[batch_idx] else: cur_class_name_ids = None cur_cls_indices = None if token_refer_id is not None: cur_token_refer_id = token_refer_id[batch_idx] else: cur_token_refer_id = None cur_class_name_embedding = self.embed_class_ids(cur_class_name_ids, cur_cls_indices) cur_refer_embedding = self.embed_refer_ids(cur_token_refer_id) cur_input_embeds, cur_label, cur_seg_query_mask, cur_class_name_embedding_indices, cur_region_embedding_mask, cur_refer_embedding_indices = self.concat_image_seg_cls_embeds( input_id=cur_input_ids, img_feature=cur_image_feature, label=cur_label, seg_query=cur_seg_query, seg_query_mask=cur_seg_query_mask, class_embed=cur_class_name_embedding, class_name_embedding_indices=cur_class_name_embedding_indices, region_embedding_mask=cur_region_embedding_mask, region_feature_list=cur_region_feature_list, refer_embedding_indices=cur_refer_embedding_indices, refer_embedding=cur_refer_embedding ) assert cur_input_embeds.shape[0] == cur_seg_query_mask.shape[0] new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(cur_label) new_seg_query_masks.append(cur_seg_query_mask) if class_name_embedding_indices is not None: new_class_name_embedding_indices.append(cur_class_name_embedding_indices) if refer_embedding_indices is not None: new_refer_embedding_indices.append(cur_refer_embedding_indices) if new_region_embedding_masks is not None: new_region_embedding_masks.append(cur_region_embedding_mask) if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): max_len = max(x.shape[0] for x in new_input_embeds) new_input_embeds_align = [] for cur_new_embed in new_input_embeds: cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) new_input_embeds_align.append(cur_new_embed) new_input_embeds = torch.stack(new_input_embeds_align, dim=0) if labels is not None: new_labels_align = [] _new_labels = new_labels for cur_new_label in new_labels: cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) new_labels_align.append(cur_new_label) new_labels = torch.stack(new_labels_align, dim=0) new_seg_query_masks_align = [] for new_seg_query_mask in new_seg_query_masks: new_seg_query_mask = torch.cat( (new_seg_query_mask, torch.zeros((max_len - new_seg_query_mask.shape[0]),dtype=new_seg_query_mask.dtype, device=new_seg_query_mask.device)), dim=0) new_seg_query_masks_align.append(new_seg_query_mask) new_seg_query_masks = torch.stack(new_seg_query_masks_align, dim=0) new_class_name_embedding_indices_align = [] if class_name_embedding_indices is not None: for new_class_name_embedding_indice in new_class_name_embedding_indices: new_class_name_embedding_indice = torch.cat( (new_class_name_embedding_indice, torch.zeros((max_len - new_class_name_embedding_indice.shape[0]),dtype=new_class_name_embedding_indice.dtype, device=new_class_name_embedding_indice.device)), dim=0) new_class_name_embedding_indices_align.append(new_class_name_embedding_indice) new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices_align, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices_align = [] for new_refer_embedding_indice in new_refer_embedding_indices: new_refer_embedding_indice = torch.cat( (new_refer_embedding_indice, torch.zeros((max_len - new_refer_embedding_indice.shape[0]),dtype=new_refer_embedding_indice.dtype, device=new_refer_embedding_indice.device)), dim=0) new_refer_embedding_indices_align.append(new_refer_embedding_indice) new_refer_embedding_indices = torch.stack(new_refer_embedding_indices_align, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks_align = [] for new_region_embedding_mask in new_region_embedding_masks: new_region_embedding_mask = torch.cat( (new_region_embedding_mask, torch.zeros((max_len - new_region_embedding_mask.shape[0]),dtype=new_region_embedding_mask.dtype, device=new_region_embedding_mask.device)), dim=0) new_region_embedding_masks_align.append(new_region_embedding_mask) new_region_embedding_masks = torch.stack(new_region_embedding_masks_align, dim=0) if attention_mask is not None: new_attention_mask = [] for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) cur_new_attention_mask = torch.cat( (new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) new_attention_mask.append(cur_new_attention_mask) attention_mask = torch.stack(new_attention_mask, dim=0) assert attention_mask.shape == new_labels.shape else: new_input_embeds = torch.stack(new_input_embeds, dim=0) if labels is not None: new_labels = torch.stack(new_labels, dim=0) new_seg_query_masks = torch.stack(new_seg_query_masks, dim=0) if class_name_embedding_indices is not None: new_class_name_embedding_indices = torch.stack(new_class_name_embedding_indices, dim=0) if refer_embedding_indices is not None: new_refer_embedding_indices = torch.stack(new_refer_embedding_indices, dim=0) if new_region_embedding_masks is not None: new_region_embedding_masks = torch.stack(new_region_embedding_masks, dim=0) if attention_mask is not None: new_attn_mask_pad_left = torch.full( (attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) assert attention_mask.shape == new_input_embeds.shape[:2] return None, attention_mask, past_key_values, new_input_embeds, new_labels, new_seg_query_masks, new_class_name_embedding_indices, new_region_embedding_masks, new_refer_embedding_indices def eval_video( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, images: Optional[torch.FloatTensor] = None, vp_images: Optional[torch.FloatTensor] = None, return_dict: Optional[bool] = None, seg_info=None, class_name_ids=None, class_name_embedding_indices=None, cls_indices=None, token_refer_id=None, refer_embedding_indices=None, is_thing_list=None ): if self.panoptic_on: assert is_thing_list is not None, 'is_thing_list need to be given' self.is_thing_list = is_thing_list output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids == REGION_TOKEN_INDEX).sum() != 0: instances = [i['instances'] for i in seg_info] else: instances = None input_ids, attention_mask, past_key_values, inputs_embeds, labels, seg_query_mask, class_name_embedding_indices, region_embedding_masks, refer_embedding_indices = self.prepare_inputs_labels_for_multimodal( input_ids, attention_mask, past_key_values, labels, images,vp_images, class_name_embedding_indices, class_name_ids, cls_indices, instances, token_refer_id, refer_embedding_indices) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs.last_hidden_state seg_query = self.get_seg_query(hidden_states, seg_query_mask) seg_query = self.seg_query_projector(seg_query) image_features = self.get_vision_tower_feature(images) mask_features, transformer_encoder_features, multi_scale_features = self.pixel_decoder.forward_features( image_features) if refer_embedding_indices is not None: SEG_embedding = self.get_SEG_embedding(hidden_states, refer_embedding_indices) SEG_embedding = self.SEG_token_projector(SEG_embedding) else: SEG_embedding = None if class_name_embedding_indices is not None: class_name_embedding = self.get_class_name_embedding(hidden_states, class_name_embedding_indices) class_name_embedding = self.class_name_projector(class_name_embedding) else: class_name_embedding = None if region_embedding_masks is not None: region_embedding_list = self.get_region_embedding(hidden_states, region_embedding_masks) region_embedding_list = [self.region_projector(region_embedding) for region_embedding in region_embedding_list] else: region_embedding_list = None mask_outputs = self.predictor(multi_scale_features, mask_features, None, seg_query, SEG_embedding, class_name_embedding, region_embedding_list) SEG_cls_results = mask_outputs['pred_SEG_logits'] class_name_cls_results = mask_outputs['pred_class_name_logits'] mask_pred_results = mask_outputs["pred_masks"] region_cls_results = mask_outputs['pred_region_logits'] images = [x for x in images] images = ImageList.from_tensors(images, self.size_divisibility) mask_pred_results = F.interpolate( mask_pred_results, size=(images.tensor.shape[-2], images.tensor.shape[-1]), mode="bilinear", align_corners=False, ) del mask_outputs processed_results = [] if SEG_cls_results is None: SEG_cls_results = [None] if class_name_cls_results is None: class_name_cls_results = [None] for _seg_info, SEG_cls_result, class_name_cls_result, mask_pred_result, input_per_image, image_size in zip( seg_info, SEG_cls_results, class_name_cls_results, mask_pred_results, seg_info, images.image_sizes ): height = input_per_image.get("height", image_size[0]) width = input_per_image.get("width", image_size[1]) padding_mask = input_per_image.get("padding_mask") non_padding_indices = np.where(~ np.array(padding_mask)) min_y, max_y = np.min(non_padding_indices[0]), np.max(non_padding_indices[0]) min_x, max_x = np.min(non_padding_indices[1]), np.max(non_padding_indices[1]) original_height = max_y - min_y + 1 original_width = max_x - min_x + 1 processed_results.append({}) if self.sem_seg_postprocess_before_inference: mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( mask_pred_result, [original_height, original_width], height, width ) if SEG_cls_result is not None: SEG_cls_result = SEG_cls_result.to(mask_pred_result) if self.semantic_on: semantic_r = retry_if_cuda_oom(self.class_name_semantic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) if not self.sem_seg_postprocess_before_inference: semantic_r = retry_if_cuda_oom(sem_seg_postprocess)( semantic_r, [original_height, original_width], height, width ) processed_results[-1]["sem_seg"] = semantic_r if self.instance_on: instance_r = retry_if_cuda_oom(self.class_name_instance_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.panoptic_on: panoptic_r = retry_if_cuda_oom(self.class_name_panoptic_inference)(None, class_name_cls_result.float(), mask_pred_result.float()) processed_results[-1]["panoptic_seg"] = panoptic_r if self.referring_on: instance_r = retry_if_cuda_oom(self.SEG_instance_inference)(SEG_cls_result.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r if self.region_on: gt = _seg_info['instances'].gt_masks gt_result = retry_if_cuda_oom(sem_seg_postprocess)( gt, [original_height, original_width], height, width ) region_cls_results = region_cls_results[0].to(mask_pred_result) instance_r = retry_if_cuda_oom(self.region_inference)(region_cls_results.float(), mask_pred_result.float()) processed_results[-1]["instances"] = instance_r processed_results[-1]["gt"] = gt_result return processed_results AutoConfig.register("llava_phi", LlavaConfig) AutoModelForCausalLM.register(LlavaConfig, PSALMModel)