Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import skimage | |
| import torch | |
| from hydra import compose | |
| from hydra.utils import instantiate | |
| from omegaconf import OmegaConf | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torchvision.ops import roi_align | |
| from utils.box_ops import boxes_with_scores | |
| from .box_corr import Box_correction | |
| from .prompt_encoder import PromptEncoder | |
| from .query_generator import C_base | |
| class CNT(nn.Module): | |
| def __init__( | |
| self, | |
| image_size: int, | |
| num_objects: int, | |
| emb_dim: int, | |
| kernel_dim: int, | |
| reduction: int, | |
| zero_shot: bool, | |
| training: bool, | |
| ): | |
| super(CNT, self).__init__() | |
| self.validate = not training | |
| self.emb_dim = emb_dim | |
| self.num_objects = num_objects | |
| self.reduction = reduction | |
| self.kernel_dim = kernel_dim | |
| self.image_size = image_size | |
| self.zero_shot = zero_shot | |
| self.pretrain = False | |
| # torch.hub.set_dir('/d/hpc/projects/FRI/pelhanj/CNT_SAM2/models/') | |
| self.class_embed = nn.Sequential(nn.Linear(emb_dim, 1), nn.LeakyReLU()) | |
| self.bbox_embed = MLP(emb_dim, emb_dim, 4, 3) | |
| if not self.pretrain: | |
| self.class_embed_aux = nn.Sequential(nn.Linear(emb_dim, 1), nn.LeakyReLU()) | |
| self.bbox_embed_aux = MLP(emb_dim, emb_dim, 4, 3) | |
| self.adapt_features = C_base( | |
| transformer_dim=self.emb_dim, | |
| num_prototype_attn_steps=3, | |
| num_image_attn_steps=2, | |
| ) | |
| self.sam_prompt_encoder = PromptEncoder( | |
| embed_dim=self.emb_dim, | |
| image_embedding_size=( | |
| self.image_size // self.reduction, | |
| self.image_size // self.reduction, | |
| ), | |
| input_image_size=(self.image_size, self.image_size), | |
| mask_in_chans=16, | |
| ) | |
| config_name = '../configs/sam2_hiera_base_plus.yaml' | |
| cfg = compose(config_name=config_name) | |
| OmegaConf.resolve(cfg) | |
| self.backbone = instantiate(cfg.backbone, _recursive_=True) | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| 'https://dl.fbaipublicfiles.com/segment_anything_2/072824/' + config_name.split('/')[-1].replace('.yaml', | |
| '.pt'), | |
| map_location="cpu" | |
| )['model'] | |
| state_dict = {k.replace("image_encoder.", ""): v for k, v in checkpoint.items()} | |
| self.backbone.load_state_dict(state_dict, strict=False) | |
| self.shape_or_objectness = nn.Sequential( | |
| nn.Linear(2, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, emb_dim), | |
| nn.ReLU(), | |
| nn.Linear(emb_dim, 1 ** 2 * emb_dim) | |
| ) | |
| if self.validate: | |
| self.box_correction = Box_correction(reduction,image_size,emb_dim) | |
| def forward(self, x, bboxes, tiled=False): | |
| num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects | |
| with torch.no_grad(): | |
| feats = self.backbone(x) | |
| src = feats['vision_features'] | |
| bs, c, w, h = src.shape | |
| self.reduction = 1024 / w | |
| bboxes_roi = torch.cat([ | |
| torch.arange( | |
| bs, requires_grad=False | |
| ).to(bboxes.device).repeat_interleave(num_objects).reshape(-1, 1), | |
| bboxes.flatten(0, 1), | |
| ], dim=1) | |
| self.kernel_dim = 1 | |
| # # NORMAL | |
| exemplars = roi_align( | |
| src, | |
| boxes=bboxes_roi, output_size=self.kernel_dim, | |
| spatial_scale=1.0 / self.reduction, aligned=True | |
| ).permute(0, 2, 3, 1).reshape(bs, num_objects * self.kernel_dim ** 2, self.emb_dim) | |
| l1 = feats['backbone_fpn'][0] | |
| l2 = feats['backbone_fpn'][1] | |
| exemplars_l1 = roi_align( | |
| l1, | |
| boxes=bboxes_roi, output_size=self.kernel_dim, | |
| spatial_scale=1.0 / self.reduction * 2 * 2, aligned=True | |
| ).permute(0, 2, 3, 1).reshape(bs, num_objects * self.kernel_dim ** 2, self.emb_dim) | |
| exemplars_l2 = roi_align( | |
| l2, | |
| boxes=bboxes_roi, output_size=self.kernel_dim, | |
| spatial_scale=1.0 / self.reduction * 2, aligned=True | |
| ).permute(0, 2, 3, 1).reshape(bs, num_objects * self.kernel_dim ** 2, self.emb_dim) | |
| box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device) | |
| box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0] | |
| box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1] | |
| # Encode shape | |
| shape = self.shape_or_objectness(box_hw).reshape( | |
| bs, -1, self.emb_dim | |
| ) | |
| prototype_embeddings = torch.cat([exemplars, shape], dim=1) | |
| prototype_embeddings_l1 = torch.cat([exemplars_l1, shape], dim=1) | |
| prototype_embeddings_l2 = torch.cat([exemplars_l2, shape], dim=1) | |
| hq_prototype_embeddings = [prototype_embeddings_l1, prototype_embeddings_l2] | |
| # adapt image feature with prototypes | |
| adapted_f, adapted_f_aux = self.adapt_features( | |
| image_embeddings=src, | |
| image_pe=self.sam_prompt_encoder.get_dense_pe(), | |
| prototype_embeddings=prototype_embeddings, | |
| hq_features=feats['backbone_fpn'], | |
| hq_prototypes=hq_prototype_embeddings, | |
| hq_pos=feats['vision_pos_enc'], | |
| ) | |
| # Predict class [fg, bg] and l,r,t,b | |
| bs, c, w, h = adapted_f.shape | |
| adapted_f = adapted_f.view(bs, self.emb_dim, -1).permute(0, 2, 1) | |
| centerness = self.class_embed(adapted_f).view(bs, w, h, 1).permute(0, 3, 1, 2) | |
| outputs_coord = self.bbox_embed(adapted_f).sigmoid().view(bs, w, h, 4).permute(0, 3, 1, 2) | |
| outputs, ref_points = boxes_with_scores(centerness, outputs_coord, sort=False, validate=self.validate) | |
| if not self.pretrain: | |
| adapted_f_aux = adapted_f_aux.view(bs, self.emb_dim, -1).permute(0, 2, 1) | |
| centerness_aux = self.class_embed_aux(adapted_f_aux).view(bs, w, h, 1).permute(0, 3, 1, 2) | |
| outputs_coord_aux = self.bbox_embed_aux(adapted_f_aux).sigmoid().view(bs, w, h, 4).permute(0, 3, 1, 2) | |
| outputs_aux, ref_points_aux = boxes_with_scores(centerness_aux, outputs_coord_aux, sort=False, validate=self.validate) | |
| if self.validate: | |
| outputs = self.box_correction(feats, outputs, x) | |
| else: | |
| for i in range(len(outputs)): | |
| outputs[i]["scores"] = outputs[i]["box_v"] | |
| if self.pretrain: | |
| return outputs, ref_points, centerness, outputs_coord | |
| else: | |
| return outputs, ref_points, centerness, outputs_coord, (outputs_aux, ref_points_aux, centerness_aux, outputs_coord_aux) | |
| class MLP(nn.Module): | |
| """ Very simple multi-layer perceptron (also called FFN)""" | |
| def __init__(self, input_dim, hidden_dim, output_dim, num_layers): | |
| super().__init__() | |
| self.num_layers = num_layers | |
| h = [hidden_dim] * (num_layers - 1) | |
| self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) | |
| def forward(self, x): | |
| for i, layer in enumerate(self.layers): | |
| x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) | |
| return x | |
| def build_model(args): | |
| assert args.reduction in [4, 8, 16] | |
| return CNT( | |
| image_size=args.image_size, | |
| num_objects=args.num_objects, | |
| zero_shot=args.zero_shot, | |
| emb_dim=args.emb_dim, | |
| reduction=args.reduction, | |
| kernel_dim=args.kernel_dim, | |
| training=args.training | |
| ) | |