|
|
from statistics import mode |
|
|
from fvcore.common.config import CfgNode |
|
|
import numpy as np |
|
|
import os |
|
|
import cv2 |
|
|
import glob |
|
|
import tqdm |
|
|
from PIL import Image |
|
|
from PIL import ImageOps |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
from modeling.MaskFormerModel import MaskFormerModel |
|
|
from utils.misc import load_parallal_model |
|
|
from utils.misc import ADEVisualize |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Segmentation(): |
|
|
def __init__(self, cfg, model=None): |
|
|
self.cfg = cfg |
|
|
self.num_queries = cfg.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES |
|
|
self.size_divisibility = cfg.MODEL.MASK_FORMER.SIZE_DIVISIBILITY |
|
|
self.num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES |
|
|
self.device = torch.device("cuda", cfg.local_rank) |
|
|
|
|
|
|
|
|
self.padding_constant = 2**5 |
|
|
self.test_dir = cfg.TEST.TEST_DIR |
|
|
self.output_dir = cfg.TEST.SAVE_DIR |
|
|
self.imgMaxSize = cfg.INPUT.CROP.MAX_SIZE |
|
|
self.pixel_mean = np.array(cfg.DATASETS.PIXEL_MEAN) |
|
|
self.pixel_std = np.array(cfg.DATASETS.PIXEL_STD) |
|
|
self.visualize = ADEVisualize() |
|
|
self.model = None |
|
|
|
|
|
pretrain_weights = cfg.MODEL.PRETRAINED_WEIGHTS |
|
|
if model is not None: |
|
|
self.model = model |
|
|
elif os.path.exists(pretrain_weights): |
|
|
self.model = MaskFormerModel(cfg, is_init=False) |
|
|
self.load_model(pretrain_weights) |
|
|
else: |
|
|
print(f'please check weights file: {cfg.MODEL.PRETRAINED_WEIGHTS}') |
|
|
|
|
|
def load_model(self, pretrain_weights): |
|
|
state_dict = torch.load(pretrain_weights, map_location='cuda:0') |
|
|
|
|
|
ckpt_dict = state_dict['model'] |
|
|
self.last_lr = state_dict['lr'] |
|
|
self.start_epoch = state_dict['epoch'] |
|
|
self.model = load_parallal_model(self.model, ckpt_dict) |
|
|
self.model = self.model.to(self.device) |
|
|
self.model.eval() |
|
|
print("loaded pretrain mode:{}".format(pretrain_weights)) |
|
|
|
|
|
def img_transform(self, img): |
|
|
|
|
|
img = np.float32(np.array(img)) / 255. |
|
|
img = (img - self.pixel_mean) / self.pixel_std |
|
|
img = img.transpose((2, 0, 1)) |
|
|
return img |
|
|
|
|
|
|
|
|
def round2nearest_multiple(self, x, p): |
|
|
return ((x - 1) // p + 1) * p |
|
|
|
|
|
def get_img_ratio(self, img_size, target_size): |
|
|
img_rate = np.max(img_size) / np.min(img_size) |
|
|
target_rate = np.max(target_size) / np.min(target_size) |
|
|
if img_rate > target_rate: |
|
|
|
|
|
ratio = max(target_size) / max(img_size) |
|
|
else: |
|
|
ratio = min(target_size) / min(img_size) |
|
|
return ratio |
|
|
|
|
|
def resize_padding(self, img, outsize, Interpolation=Image.BILINEAR): |
|
|
w, h = img.size |
|
|
target_w, target_h = outsize[0], outsize[1] |
|
|
ratio = self.get_img_ratio([w, h], outsize) |
|
|
ow, oh = round(w * ratio), round(h * ratio) |
|
|
img = img.resize((ow, oh), Interpolation) |
|
|
dh, dw = target_h - oh, target_w - ow |
|
|
top, bottom = dh // 2, dh - (dh // 2) |
|
|
left, right = dw // 2, dw - (dw // 2) |
|
|
img = ImageOps.expand(img, border=(left, top, right, bottom), fill=0) |
|
|
return img, [left, top, right, bottom] |
|
|
|
|
|
def get_img_ratio(self, img_size, target_size): |
|
|
img_rate = np.max(img_size) / np.min(img_size) |
|
|
target_rate = np.max(target_size) / np.min(target_size) |
|
|
if img_rate > target_rate: |
|
|
|
|
|
ratio = max(target_size) / max(img_size) |
|
|
else: |
|
|
ratio = min(target_size) / min(img_size) |
|
|
return ratio |
|
|
|
|
|
def image_preprocess(self, img): |
|
|
img_height, img_width = img.shape[0], img.shape[1] |
|
|
this_scale = self.get_img_ratio((img_width, img_height), self.imgMaxSize) |
|
|
target_width = img_width * this_scale |
|
|
target_height = img_height * this_scale |
|
|
input_width = int(self.round2nearest_multiple(target_width, self.padding_constant)) |
|
|
input_height = int(self.round2nearest_multiple(target_height, self.padding_constant)) |
|
|
|
|
|
img, padding_info = self.resize_padding(Image.fromarray(img), (input_width, input_height)) |
|
|
img = self.img_transform(img) |
|
|
|
|
|
transformer_info = {'padding_info': padding_info, 'scale': this_scale, 'input_size':(input_height, input_width)} |
|
|
input_tensor = torch.from_numpy(img).float().unsqueeze(0).to(self.device) |
|
|
return input_tensor, transformer_info |
|
|
|
|
|
def semantic_inference(self, mask_cls, mask_pred): |
|
|
mask_cls = F.softmax(mask_cls, dim=-1)[...,1:] |
|
|
mask_pred = mask_pred.sigmoid() |
|
|
semseg = torch.einsum("bqc,bqhw->bchw", mask_cls, mask_pred) |
|
|
return semseg.cpu().numpy() |
|
|
|
|
|
def postprocess(self, pred_mask, transformer_info, target_size): |
|
|
oh, ow = pred_mask.shape[0], pred_mask.shape[1] |
|
|
padding_info = transformer_info['padding_info'] |
|
|
|
|
|
left, top, right, bottom = padding_info[0], padding_info[1], padding_info[2], padding_info[3] |
|
|
mask = pred_mask[top: oh - bottom, left: ow - right] |
|
|
mask = cv2.resize(mask.astype(np.uint8), dsize=target_size, interpolation=cv2.INTER_NEAREST) |
|
|
return mask |
|
|
|
|
|
@torch.no_grad() |
|
|
def forward(self, img_list=None): |
|
|
if img_list is None or len(img_list) == 0: |
|
|
img_list = glob.glob(self.test_dir + '/*.[jp][pn]g') |
|
|
mask_images = [] |
|
|
for image_path in tqdm.tqdm(img_list): |
|
|
|
|
|
|
|
|
|
|
|
img = Image.open(image_path).convert('RGB') |
|
|
img_height, img_width = img.size[1], img.size[0] |
|
|
inpurt_tensor, transformer_info = self.image_preprocess(np.array(img)) |
|
|
|
|
|
outputs = self.model(inpurt_tensor) |
|
|
mask_cls_results = outputs["pred_logits"] |
|
|
mask_pred_results = outputs["pred_masks"] |
|
|
|
|
|
mask_pred_results = F.interpolate( |
|
|
mask_pred_results, |
|
|
size=(inpurt_tensor.shape[-2], inpurt_tensor.shape[-1]), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
pred_masks = self.semantic_inference(mask_cls_results, mask_pred_results) |
|
|
mask_img = np.argmax(pred_masks, axis=1)[0] |
|
|
mask_img = self.postprocess(mask_img, transformer_info, (img_width, img_height)) |
|
|
mask_images.append(mask_img) |
|
|
return mask_images |
|
|
|
|
|
|
|
|
def render_image(self, img, mask_img, output_path=None): |
|
|
self.visualize.show_result(img, mask_img, output_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|