| | import yaml |
| | import torch |
| | import random |
| | import numpy as np |
| | import os |
| | import sys |
| | import matplotlib.pyplot as plt |
| | from einops import repeat |
| | import cv2 |
| | import time |
| | import torch.nn.functional as F |
| |
|
| |
|
| | __all__ = ["decode_mask_to_onehot", |
| | "encode_onehot_to_mask", |
| | 'Logger', |
| | 'get_coords_grid', |
| | 'get_coords_grid_float', |
| | 'draw_bboxes', |
| | 'Infos', |
| | 'inv_normalize_img', |
| | 'make_numpy_img', |
| | 'get_metrics' |
| | ] |
| |
|
| |
|
| | class Infos(object): |
| | def __init__(self, phase, class_names=None): |
| | assert phase in ['od'], "Error in Infos" |
| | self.phase = phase |
| | self.class_names = class_names |
| | self.register() |
| | self.pattern = 'train' |
| | self.epoch_id = 0 |
| | self.max_epoch = 0 |
| | self.batch_id = 0 |
| | self.batch_num = 0 |
| | self.lr = 0 |
| | self.fps_data_load = 0 |
| | self.fps = 0 |
| | self.val_metric = 0 |
| |
|
| | |
| | |
| | |
| | def set_epoch_training_time(self, data): |
| | self.epoch_training_time = data |
| |
|
| | def set_pattern(self, data): |
| | self.pattern = data |
| | def set_epoch_id(self, data): |
| | self.epoch_id = data |
| | def set_max_epoch(self, data): |
| | self.max_epoch = data |
| | def set_batch_id(self, data): |
| | self.batch_id = data |
| | def set_batch_num(self, data): |
| | self.batch_num = data |
| | def set_lr(self, data): |
| | self.lr = data |
| | def set_fps_data_load(self, data): |
| | self.fps_data_load = data |
| | def set_fps(self, data): |
| | self.fps = data |
| | def clear_cache(self): |
| | self.register() |
| |
|
| | def get_val_metric(self): |
| | return self.val_metric |
| |
|
| | def cal_metrics(self): |
| | if self.phase == 'od': |
| | coco_api_gt = COCO() |
| | coco_api_gt.dataset['images'] = [] |
| | coco_api_gt.dataset['annotations'] = [] |
| | ann_id = 0 |
| | for i, targets_per_image in enumerate(self.result_all['target_all']): |
| | for j in range(targets_per_image.shape[0]): |
| | coco_api_gt.dataset['images'].append({'id': i}) |
| | coco_api_gt.dataset['annotations'].append({ |
| | 'image_id': i, |
| | "category_id": int(targets_per_image[j, 0]), |
| | "bbox": np.hstack([targets_per_image[j, 1:3], targets_per_image[j, 3:5] - targets_per_image[j, 1:3]]), |
| | "area": np.prod(targets_per_image[j, 3:5] - targets_per_image[j, 1:3]), |
| | "id": ann_id, |
| | "iscrowd": 0 |
| | }) |
| | ann_id += 1 |
| | coco_api_gt.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in |
| | enumerate(self.class_names)] |
| | coco_api_gt.createIndex() |
| |
|
| | coco_api_pred = COCO() |
| | coco_api_pred.dataset['images'] = [] |
| | coco_api_pred.dataset['annotations'] = [] |
| | ann_id = 0 |
| | for i, preds_per_image in enumerate(self.result_all['pred_all']): |
| | for j in range(preds_per_image.shape[0]): |
| | coco_api_pred.dataset['images'].append({'id': i}) |
| | coco_api_pred.dataset['annotations'].append({ |
| | 'image_id': i, |
| | "category_id": int(preds_per_image[j, 0]), |
| | 'score': preds_per_image[j, 1], |
| | "bbox": np.hstack( |
| | [preds_per_image[j, 2:4], preds_per_image[j, 4:6] - preds_per_image[j, 2:4]]), |
| | "area": np.prod(preds_per_image[j, 4:6] - preds_per_image[j, 2:4]), |
| | "id": ann_id, |
| | "iscrowd": 0 |
| | }) |
| | ann_id += 1 |
| | coco_api_pred.dataset['categories'] = [{"id": i, "supercategory": c, "name": c} for i, c in |
| | enumerate(self.class_names)] |
| | coco_api_pred.createIndex() |
| |
|
| | coco_eval = COCOeval(coco_api_gt, coco_api_pred, "bbox") |
| | coco_eval.params.imgIds = coco_api_gt.getImgIds() |
| | coco_eval.evaluate() |
| | coco_eval.accumulate() |
| | self.metrics = coco_eval.summarize() |
| | self.val_metric = self.metrics[1] |
| |
|
| | def print_epoch_state_infos(self, logger): |
| | infos_str = 'Pattern: %s Epoch [%d,%d], time: %d loss: %.4f' % \ |
| | (self.pattern, self.epoch_id, self.max_epoch, self.epoch_training_time, np.mean(self.loss_all['loss'])) |
| | logger.write(infos_str + '\n') |
| | time_start = time.time() |
| | self.cal_metrics() |
| | time_end = time.time() |
| | logger.write('Pattern: %s Epoch Eval_time: %d\n' % (self.pattern, (time_end - time_start))) |
| |
|
| | if self.phase == 'od': |
| | titleStr = 6 * ['Average Precision'] + 6 * ['Average Recall'] |
| | typeStr = 6 * ['(AP)'] + 6 * ['(AR)'] |
| | iouStr = 12 * ['0.50:0.95'] |
| | iouStr[1] = '0.50' |
| | iouStr[2] = '0.75' |
| | areaRng = 3 * ['all'] + ['small', 'medium', 'large'] + 3 * ['all'] + ['small', 'medium', 'large'] |
| | maxDets = 6 * [100] + [1, 10, 100] + 3 * [100] |
| | for i in range(12): |
| | infos_str = '{:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}\n' |
| | logger.write(infos_str.format(titleStr[i], typeStr[i], iouStr[i], areaRng[i], maxDets[i], self.metrics[i])) |
| |
|
| |
|
| | def save_epoch_state_infos(self, writer): |
| | iter = self.epoch_id |
| | keys = [ |
| | 'AP_m_all_100', |
| | 'AP_50_all_100', |
| | 'AP_75_all_100', |
| | 'AP_m_small_100', |
| | 'AP_m_medium_100', |
| | 'AP_m_large_100', |
| | 'AR_m_all_1', |
| | 'AR_m_all_10', |
| | 'AR_m_all_100', |
| | 'AR_m_small_100', |
| | 'AR_m_medium_100', |
| | 'AR_m_large_100', |
| | ] |
| | for i, key in enumerate(keys): |
| | writer.add_scalar(f'%s/epoch/%s' % (self.pattern, key), self.metrics[i], iter) |
| |
|
| | def print_batch_state_infos(self, logger): |
| | infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % \ |
| | (self.pattern, self.epoch_id, self.max_epoch, self.batch_id, |
| | self.batch_num, self.lr, self.fps_data_load, self.fps) |
| | |
| | infos_str += ', loss: %.4f' % self.loss_all['loss'][-1] |
| | logger.write(infos_str + '\n') |
| |
|
| | def save_batch_state_infos(self, writer): |
| | iter = self.epoch_id * self.batch_num + self.batch_id |
| | writer.add_scalar('%s/lr' % self.pattern, self.lr, iter) |
| | for key, value in self.loss_all.items(): |
| | writer.add_scalar(f'%s/%s' % (self.pattern, key), value[-1], iter) |
| |
|
| | def save_results(self, img_batch, prior_mean, prior_std, vis_dir, *args, **kwargs): |
| | batch_size = img_batch.size(0) |
| | k = np.clip(int(0.3 * batch_size), a_min=1, a_max=batch_size) |
| | ids = np.random.choice(range(batch_size), k, replace=False) |
| | for img_id in ids: |
| | img = img_batch[img_id].detach().cpu() |
| | pred = self.result_all['pred_all'][img_id - batch_size] |
| | target = self.result_all['target_all'][img_id - batch_size] |
| |
|
| | img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) |
| | pred_draw = draw_bboxes(img, pred, self.class_names, (255, 0, 0)) |
| | target_draw = draw_bboxes(img, target, self.class_names, (0, 255, 0)) |
| | |
| | |
| |
|
| | vis = np.concatenate([img/255., pred_draw/255., target_draw/255.], axis=0) |
| | vis = np.clip(vis, a_min=0, a_max=1) |
| | file_name = os.path.join(vis_dir, self.pattern, f'{self.epoch_id}_{self.batch_id}_{img_id}.png') |
| | plt.imsave(file_name, vis) |
| |
|
| | def register(self): |
| | self.is_registered_result = False |
| | self.result_all = {} |
| |
|
| | self.is_registered_loss = False |
| | self.loss_all = {} |
| |
|
| | def register_result(self, data: dict): |
| | for key in data.keys(): |
| | self.result_all[key] = [] |
| | self.is_registered_result = True |
| |
|
| | def append_result(self, data: dict): |
| | if not self.is_registered_result: |
| | self.register_result(data) |
| | for key, value in data.items(): |
| | self.result_all[key] += value |
| |
|
| | def register_loss(self, data: dict): |
| | for key in data.keys(): |
| | self.loss_all[key] = [] |
| | self.is_registered_loss = True |
| |
|
| | def append_loss(self, data: dict): |
| | if not self.is_registered_loss: |
| | self.register_loss(data) |
| | for key, value in data.items(): |
| | self.loss_all[key].append(value.detach().cpu().numpy()) |
| |
|
| |
|
| | |
| | def draw_bboxes(img, bboxes, color=(255, 0, 0), class_names=None, is_show_score=True): |
| | ''' |
| | Args: |
| | img: |
| | bboxes: [n, 5], class_idx, l, t, r, b |
| | [n, 6], class_idx, score, l, t, r, b |
| | Returns: |
| | ''' |
| | assert img is not None, "In draw_bboxes, img is None" |
| | if torch.is_tensor(img): |
| | img = img.cpu().numpy() |
| | img = img.astype(np.uint8).copy() |
| |
|
| | if torch.is_tensor(bboxes): |
| | bboxes = bboxes.cpu().numpy() |
| | for bbox in bboxes: |
| | if class_names: |
| | class_name = class_names[int(bbox[0])] |
| | bbox_coordinate = bbox[1:] |
| | if len(bbox) == 6: |
| | score = bbox[1] |
| | bbox_coordinate = bbox[2:] |
| | bbox_coordinate = bbox_coordinate.astype(np.int) |
| | if is_show_score: |
| | cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2] - np.array([2, 15])), |
| | pt2=tuple(bbox_coordinate[0:2] + np.array([15, 1])), color=(0, 0, 255), thickness=-1) |
| | if len(bbox) == 6: |
| | cv2.putText(img, text='%s:%.2f' % (class_name, score), |
| | org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, |
| | fontScale=0.2, color=(255, 255, 255), thickness=1) |
| | else: |
| | cv2.putText(img, text='%s' % class_name, |
| | org=tuple(bbox_coordinate[0:2] - np.array([1, 7])), fontFace=cv2.FONT_HERSHEY_SIMPLEX, |
| | fontScale=0.2, color=(255, 255, 255), thickness=1) |
| | cv2.rectangle(img, pt1=tuple(bbox_coordinate[0:2]), pt2=tuple(bbox_coordinate[2:4]), color=color, thickness=2) |
| | return img |
| |
|
| |
|
| | def get_coords_grid(h_end, w_end, h_start=0, w_start=0, h_steps=None, w_steps=None, is_normalize=False): |
| | if h_steps is None: |
| | h_steps = int(h_end - h_start) + 1 |
| | if w_steps is None: |
| | w_steps = int(w_end - w_start) + 1 |
| |
|
| | y = torch.linspace(h_start, h_end, h_steps) |
| | x = torch.linspace(w_start, w_end, w_steps) |
| | if is_normalize: |
| | y = y / h_end |
| | x = x / w_end |
| | coords = torch.meshgrid(y, x) |
| | coords = torch.stack(coords[::-1], dim=0) |
| | return coords |
| |
|
| |
|
| | def get_coords_grid_float(ht, wd, scale, is_normalize=False): |
| | y = torch.linspace(0, scale, ht + 2) |
| | x = torch.linspace(0, scale, wd + 2) |
| | if is_normalize: |
| | y = y/scale |
| | x = x/scale |
| | coords = torch.meshgrid(y[1:-1], x[1:-1]) |
| | coords = torch.stack(coords[::-1], dim=0) |
| | return coords |
| |
|
| |
|
| | def get_coords_vector_float(len, scale, is_normalize=False): |
| | x = torch.linspace(0, scale, len+2) |
| | if is_normalize: |
| | x = x/scale |
| | coords = torch.meshgrid(x[1:-1], torch.tensor([0.])) |
| | coords = torch.stack(coords[::-1], dim=0) |
| | return coords |
| |
|
| |
|
| | class Logger(object): |
| | def __init__(self, filename="Default.log", is_terminal_show=True): |
| | self.is_terminal_show = is_terminal_show |
| | if self.is_terminal_show: |
| | self.terminal = sys.stdout |
| | self.log = open(filename, "a") |
| |
|
| | def write(self, message): |
| | if self.is_terminal_show: |
| | self.terminal.write(message) |
| | self.log.write(message) |
| | self.flush() |
| |
|
| | def flush(self): |
| | if self.is_terminal_show: |
| | self.terminal.flush() |
| | self.log.flush() |
| |
|
| |
|
| | class ParamsParser: |
| | def __init__(self, project_file): |
| | self.params = yaml.safe_load(open(project_file).read()) |
| |
|
| | def __getattr__(self, item): |
| | return self.params.get(item, None) |
| |
|
| |
|
| | def get_all_dict(dict_infos: dict) -> dict: |
| | return_dict = {} |
| | for key, value in dict_infos.items(): |
| | if not isinstance(value, dict): |
| | return_dict[key] = value |
| | else: |
| | return_dict = dict(return_dict.items(), **get_all_dict(value)) |
| | return return_dict |
| |
|
| |
|
| | def make_numpy_img(tensor_data): |
| | if len(tensor_data.shape) == 2: |
| | tensor_data = tensor_data.unsqueeze(2) |
| | tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) |
| | elif tensor_data.size(0) == 1: |
| | tensor_data = tensor_data.permute((1, 2, 0)) |
| | tensor_data = torch.cat((tensor_data, tensor_data, tensor_data), dim=2) |
| | elif tensor_data.size(0) == 3: |
| | tensor_data = tensor_data.permute((1, 2, 0)) |
| | elif tensor_data.size(2) == 3: |
| | pass |
| | else: |
| | raise Exception('tensor_data apply to make_numpy_img error') |
| | vis_img = tensor_data.detach().cpu().numpy() |
| |
|
| | return vis_img |
| |
|
| |
|
| | def print_infos(logger, writer, infos: dict): |
| | keys = list(infos.keys()) |
| | values = list(infos.values()) |
| | infos_str = 'Pattern: %s [%d,%d][%d,%d], lr: %5f, fps_data_load: %.2f, fps: %.2f' % tuple(values[:8]) |
| | if len(values) > 8: |
| | extra_infos = [f', {x}: {y:.4f}' for x, y in zip(keys[8:], values[8:])] |
| | infos_str = infos_str + ''.join(extra_infos) |
| |
|
| | logger.write(infos_str + '\n') |
| |
|
| | writer.add_scalar('%s/lr' % infos['pattern'], infos['lr'], |
| | infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) |
| | for key, value in zip(keys[8:], values[8:]): |
| | writer.add_scalar(f'%s/%s' % (infos['pattern'], key), value, |
| | infos['epoch_id'] * infos['batch_num'] + infos['batch_id']) |
| |
|
| |
|
| | def invert_affine(origin_imgs, preds, pattern='train'): |
| | if pattern == 'val': |
| | for i in range(len(preds)): |
| | if len(preds[i]['rois']) == 0: |
| | continue |
| | else: |
| | old_h, old_w, _ = origin_imgs[i].shape |
| | preds[i]['rois'][:, [0, 2]] = preds[i]['rois'][:, [0, 2]] / (512 / old_w) |
| | preds[i]['rois'][:, [1, 3]] = preds[i]['rois'][:, [1, 3]] / (512 / old_h) |
| | return preds |
| |
|
| |
|
| | def save_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id): |
| | flows, pf1s, pf2s = output |
| | k = np.clip(int(0.2 * len(flows[0])), a_min=2, a_max=len(flows[0])) |
| | ids = np.random.choice(range(len(flows[0])), k, replace=False) |
| | for img_id in ids: |
| | img1, img2 = input['ori_img1'][img_id:img_id+1].to(flows[0].device), input['ori_img2'][img_id:img_id+1].to(flows[0].device) |
| | |
| | flow = flows[0][img_id:img_id+1] |
| | warps = flow_to_warp(flow) |
| |
|
| | warped_img2 = resample(img2, warps) |
| |
|
| | ori_img1 = make_numpy_img(img1[0]) / 255. |
| | ori_img2 = make_numpy_img(img2[0]) / 255. |
| | warped_img2 = make_numpy_img(warped_img2[0]) / 255. |
| | flow_amplitude = torch.sqrt(flow[0, 0:1, ...] ** 2 + flow[0, 1:2, ...] ** 2) |
| | flow_amplitude = make_numpy_img(flow_amplitude) |
| | flow_amplitude = (flow_amplitude - np.min(flow_amplitude)) / (np.max(flow_amplitude) - np.min(flow_amplitude) + 1e-10) |
| | u = make_numpy_img(flow[0, 0:1, ...]) |
| | v = make_numpy_img(flow[0, 1:2, ...]) |
| |
|
| | vis = np.concatenate([ori_img1, ori_img2, warped_img2, flow_amplitude], axis=0) |
| | vis = np.clip(vis, a_min=0, a_max=1) |
| | file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') |
| | plt.imsave(file_name, vis) |
| |
|
| |
|
| | def inv_normalize_img(img, prior_mean=[0, 0, 0], prior_std=[1, 1, 1]): |
| | prior_mean = torch.tensor(prior_mean, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) |
| | prior_std = torch.tensor(prior_std, dtype=torch.float).to(img.device).view(img.size(0), 1, 1) |
| | img = img * prior_std + prior_mean |
| | img = img * 255. |
| | img = torch.clamp(img, min=0, max=255) |
| | return img |
| |
|
| |
|
| | def save_seg_output_infos(input, output, vis_dir, pattern, epoch_id, batch_id, prior_mean, prior_std): |
| | pred_label = torch.argmax(output, 1) |
| | k = np.clip(int(0.2 * len(pred_label)), a_min=1, a_max=len(pred_label[0])) |
| | ids = np.random.choice(range(len(pred_label)), k, replace=False) |
| | for img_id in ids: |
| | img = input['img'][img_id].to(pred_label.device) |
| | target = input['label'][img_id].to(pred_label.device) |
| |
|
| | img = make_numpy_img(inv_normalize_img(img, prior_mean, prior_std)) / 255. |
| | target = make_numpy_img(encode_onehot_to_mask(target)) |
| | pred = make_numpy_img(pred_label[img_id]) |
| |
|
| | vis = np.concatenate([img, pred, target], axis=0) |
| | vis = np.clip(vis, a_min=0, a_max=1) |
| | file_name = os.path.join(vis_dir, pattern, str(epoch_id) + '_' + str(batch_id) + '.jpg') |
| | plt.imsave(file_name, vis) |
| |
|
| |
|
| | def set_requires_grad(nets, requires_grad=False): |
| | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations |
| | Parameters: |
| | nets (network list) -- a list of networks |
| | requires_grad (bool) -- whether the networks require gradients or not |
| | """ |
| | if not isinstance(nets, list): |
| | nets = [nets] |
| | for net in nets: |
| | if net is not None: |
| | for param in net.parameters(): |
| | param.requires_grad = requires_grad |
| |
|
| |
|
| | def boolean_string(s): |
| | if s not in {'False', 'True'}: |
| | raise ValueError('Not a valid boolean string') |
| | return s == 'True' |
| |
|
| |
|
| | def cpt_pxl_cls_acc(pred_idx, target): |
| | pred_idx = torch.reshape(pred_idx, [-1]) |
| | target = torch.reshape(target, [-1]) |
| | return torch.mean((pred_idx.int() == target.int()).float()) |
| |
|
| |
|
| | def cpt_batch_psnr(img, img_gt, PIXEL_MAX): |
| | mse = torch.mean((img - img_gt) ** 2, dim=[1, 2, 3]) |
| | psnr = 20 * torch.log10(PIXEL_MAX / torch.sqrt(mse)) |
| | return torch.mean(psnr) |
| |
|
| |
|
| | def cpt_psnr(img, img_gt, PIXEL_MAX): |
| | mse = np.mean((img - img_gt) ** 2) |
| | psnr = 20 * np.log10(PIXEL_MAX / np.sqrt(mse)) |
| | return psnr |
| |
|
| |
|
| | def cpt_rgb_ssim(img, img_gt): |
| | img = clip_01(img) |
| | img_gt = clip_01(img_gt) |
| | SSIM = 0 |
| | for i in range(3): |
| | tmp = img[:, :, i] |
| | tmp_gt = img_gt[:, :, i] |
| | ssim = sk_cpt_ssim(tmp, tmp_gt) |
| | SSIM = SSIM + ssim |
| | return SSIM / 3.0 |
| |
|
| |
|
| | def cpt_ssim(img, img_gt): |
| | img = clip_01(img) |
| | img_gt = clip_01(img_gt) |
| | return sk_cpt_ssim(img, img_gt) |
| |
|
| |
|
| | def decode_mask_to_onehot(mask, n_class): |
| | ''' |
| | mask : BxWxH or WxH |
| | n_class : n |
| | return : BxnxWxH or nxWxH |
| | ''' |
| | assert len(mask.shape) in [2, 3], "decode_mask_to_onehot error!" |
| | if len(mask.shape) == 2: |
| | mask = mask.unsqueeze(0) |
| | onehot = torch.zeros((mask.size(0), n_class, mask.size(1), mask.size(2))).to(mask.device) |
| | for i in range(n_class): |
| | onehot[:, i, ...] = mask == i |
| | if len(mask.shape) == 2: |
| | onehot = onehot.squeeze(0) |
| | return onehot |
| |
|
| |
|
| | def encode_onehot_to_mask(onehot): |
| | ''' |
| | onehot: tensor, BxnxWxH or nxWxH |
| | output: tensor, BxWxH or WxH |
| | ''' |
| | assert len(onehot.shape) in [3, 4], "encode_onehot_to_mask error!" |
| | mask = torch.argmax(onehot, dim=len(onehot.shape)-3) |
| | return mask |
| |
|
| |
|
| | def decode(pred, target=None, *args, **kwargs): |
| | """ |
| | |
| | Args: |
| | phase: 'od' |
| | pred: big_cls_1(0), big_reg_1, small_cls_1(2), small_reg_1, big_cls_2(4), big_reg_2, small_cls_2(6), small_reg_2 |
| | target: [[n,5], [n,5]] list of tensor |
| | |
| | Returns: |
| | |
| | """ |
| | phase = kwargs['phase'] |
| | img_size = kwargs['img_size'] |
| | if phase == 'od': |
| | prior_box_wh = kwargs['prior_box_wh'] |
| | conf_thres = kwargs['conf_thres'] |
| | iou_thres = kwargs['iou_thres'] |
| | conf_type = kwargs['conf_type'] |
| | pred_conf_32_2 = F.softmax(pred[4], dim=1)[:, 1, ...] |
| | pred_conf_64_2 = F.softmax(pred[6], dim=1)[:, 1, ...] |
| | obj_mask_32_2 = pred_conf_32_2 > conf_thres |
| | obj_mask_64_2 = pred_conf_64_2 > conf_thres |
| |
|
| | pre_loc_32_2 = pred[1] + pred[5] |
| | pre_loc_32_2[:, 0::2, ...] *= prior_box_wh[0] |
| | pre_loc_32_2[:, 1::2, ...] *= prior_box_wh[1] |
| | x_y_grid = get_coords_grid(31, 31, 0, 0) |
| | x_y_grid *= 8 |
| | x_y_grid = torch.cat([x_y_grid, x_y_grid], dim=0) |
| | pre_loc_32_2 += x_y_grid.to(pre_loc_32_2.device) |
| |
|
| | pre_loc_64_2 = pred[3] + pred[7] |
| | pre_loc_64_2[:, 0::2, ...] *= prior_box_wh[0] |
| | pre_loc_64_2[:, 1::2, ...] *= prior_box_wh[1] |
| | x_y_grid_2 = get_coords_grid(63, 63, 0, 0) |
| | x_y_grid_2 *= 4 |
| | x_y_grid_2 = torch.cat([x_y_grid_2, x_y_grid_2], dim=0) |
| | pre_loc_64_2 += x_y_grid_2.to(pre_loc_32_2.device) |
| |
|
| | pred_all = [] |
| | for i in range(pre_loc_32_2.size(0)): |
| | score_32 = pred_conf_32_2[i][obj_mask_32_2[i]] |
| | score_64 = pred_conf_64_2[i][obj_mask_64_2[i]] |
| |
|
| | loc_32 = pre_loc_32_2[i].permute((1, 2, 0))[obj_mask_32_2[i]] |
| | loc_64 = pre_loc_64_2[i].permute((1, 2, 0))[obj_mask_64_2[i]] |
| |
|
| | score_list = torch.cat((score_32, score_64), dim=0).detach().cpu().numpy() |
| | boxes_list = torch.cat((loc_32, loc_64), dim=0).detach().cpu().numpy() |
| | boxes_list[:, 0::2] /= img_size[0] |
| | boxes_list[:, 1::2] /= img_size[1] |
| | label_list = np.ones_like(score_list) |
| | |
| | boxes_list = boxes_list[:150, :] |
| | score_list = score_list[:150] |
| | label_list = label_list[:150] |
| | boxes, scores, labels = weighted_boxes_fusion([boxes_list], [score_list], [label_list], weights=None, |
| | iou_thr=iou_thres, conf_type=conf_type) |
| | boxes[:, 0::2] *= img_size[0] |
| | boxes[:, 1::2] *= img_size[1] |
| | pred_boxes = np.concatenate((labels.reshape(-1, 1), scores.reshape(-1, 1), boxes), axis=1) |
| | pred_all.append(pred_boxes) |
| | if target is not None: |
| | target_all = [x.cpu().numpy() for x in target] |
| | else: |
| | target_all = None |
| | return {"pred_all": pred_all, "target_all": target_all} |
| |
|
| |
|
| |
|
| | def get_metrics(phase, pred, target): |
| |
|
| | ''' |
| | pred: logits, tensor, nBatch*nClass*W*H |
| | target: labels, tensor, nBatch*nClass*W*H |
| | ''' |
| | if phase == 'seg': |
| | pred = torch.argmax(pred.detach(), dim=1) |
| | pred = decode_mask_to_onehot(pred, target.size(1)) |
| | |
| | gt_pos_sum = torch.sum(target == 1, dim=(0, 2, 3)) |
| | |
| | pred_pos_sum = torch.sum(pred == 1, dim=(0, 2, 3)) |
| | |
| | true_pos_sum = torch.sum((target == 1) * (pred == 1), dim=(0, 2, 3)) |
| | |
| | precision = true_pos_sum / (pred_pos_sum + 1e-15) |
| | |
| | recall = true_pos_sum / (gt_pos_sum + 1e-15) |
| | |
| | IoU = true_pos_sum / (pred_pos_sum + gt_pos_sum - true_pos_sum + 1e-15) |
| | |
| | OA = 1 - (pred_pos_sum + gt_pos_sum - 2 * true_pos_sum) / torch.sum(target >= 0, dim=(0, 2, 3)) |
| | |
| | F1_score = 2 * precision * recall / (precision + recall + 1e-15) |
| | return IoU, OA, F1_score |
| |
|
| |
|