Spaces:
Build error
Build error
| import sys | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| from .point import Point | |
| from .polygon import Polygon | |
| from .scribble import Scribble | |
| from .circle import Circle | |
| from modeling.utils import configurable | |
| class ShapeSampler(nn.Module): | |
| def __init__(self, max_candidate=1, shape_prob=[], shape_candidate=[], is_train=True): | |
| super().__init__() | |
| self.max_candidate = max_candidate | |
| self.shape_prob = shape_prob | |
| self.shape_candidate = shape_candidate | |
| self.is_train = is_train | |
| def from_config(cls, cfg, is_train=True, mode=None): | |
| max_candidate = cfg['STROKE_SAMPLER']['MAX_CANDIDATE'] | |
| candidate_probs = cfg['STROKE_SAMPLER']['CANDIDATE_PROBS'] | |
| candidate_names = cfg['STROKE_SAMPLER']['CANDIDATE_NAMES'] | |
| if mode == 'hack_train': | |
| candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, True) for class_name in candidate_names] | |
| else: | |
| # overwrite condidate_prob | |
| if not is_train: | |
| candidate_probs = [0.0 for x in range(len(candidate_names))] | |
| candidate_probs[candidate_names.index(mode)] = 1.0 | |
| candidate_classes = [getattr(sys.modules[__name__], class_name)(cfg, is_train) for class_name in candidate_names] | |
| # Build augmentation | |
| return { | |
| "max_candidate": max_candidate, | |
| "shape_prob": candidate_probs, | |
| "shape_candidate": candidate_classes, | |
| "is_train": is_train, | |
| } | |
| def forward(self, instances): | |
| masks = instances.gt_masks.tensor | |
| boxes = instances.gt_boxes.tensor | |
| if len(masks) == 0: | |
| gt_masks = torch.zeros(masks.shape[-2:]).bool() | |
| rand_masks = torch.zeros(masks.shape[-2:]).bool() | |
| return {'gt_masks': gt_masks[None,:], 'rand_shape': torch.stack([rand_masks]), 'types': ['none']} | |
| indices = [x for x in range(len(masks))] | |
| if self.is_train: | |
| random.shuffle(indices) | |
| candidate_mask = masks[indices[:self.max_candidate]] | |
| candidate_box = boxes[indices[:self.max_candidate]] | |
| else: | |
| candidate_mask = masks | |
| candidate_box = boxes | |
| draw_funcs = random.choices(self.shape_candidate, weights=self.shape_prob, k=len(candidate_mask)) | |
| rand_shapes = [d.draw(x,y) for d,x,y in zip(draw_funcs, candidate_mask, candidate_box)] | |
| types = [repr(x) for x in draw_funcs] | |
| for i in range(0, len(rand_shapes)): | |
| if rand_shapes[i].sum() == 0: | |
| candidate_mask[i] = candidate_mask[i] * 0 | |
| types[i] = 'none' | |
| # candidate_mask: (c,h,w), bool. rand_shape: (c, iter, h, w), bool. types: list(c) | |
| return {'gt_masks': candidate_mask, 'rand_shape': torch.stack(rand_shapes).bool(), 'types': types, 'sampler': self} | |
| def build_shape_sampler(cfg, **kwargs): | |
| return ShapeSampler(cfg, **kwargs) |