Spaces:
Configuration error
Configuration error
| import argparse | |
| import os | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| from collections import OrderedDict | |
| from datetime import datetime | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| from einops import rearrange | |
| from monai.inferers import sliding_window_inference | |
| from monai.losses import DiceCELoss | |
| from monai.transforms import AsDiscrete | |
| from PIL import Image | |
| from skimage import io | |
| from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score | |
| from tensorboardX import SummaryWriter | |
| #from dataset import * | |
| from torch.autograd import Variable | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import cfg | |
| import models.sam.utils.transforms as samtrans | |
| import pytorch_ssim | |
| #from models.discriminatorlayer import discriminator | |
| from conf import settings | |
| from utils import * | |
| # from lucent.modelzoo.util import get_model_layers | |
| # from lucent.optvis import render, param, transform, objectives | |
| # from lucent.modelzoo import inceptionv1 | |
| args = cfg.parse_args() | |
| GPUdevice = torch.device('cuda', args.gpu_device) | |
| pos_weight = torch.ones([1]).cuda(device=GPUdevice)*2 | |
| criterion_G = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) | |
| seed = torch.randint(1,11,(args.b,7)) | |
| torch.backends.cudnn.benchmark = True | |
| loss_function = DiceCELoss(to_onehot_y=True, softmax=True) | |
| scaler = torch.cuda.amp.GradScaler() | |
| max_iterations = settings.EPOCH | |
| post_label = AsDiscrete(to_onehot=14) | |
| post_pred = AsDiscrete(argmax=True, to_onehot=14) | |
| dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) | |
| dice_val_best = 0.0 | |
| global_step_best = 0 | |
| epoch_loss_values = [] | |
| metric_values = [] | |
| def train_sam(args, net: nn.Module, optimizer, train_loader, | |
| epoch, writer, schedulers=None, vis = 50): | |
| hard = 0 | |
| epoch_loss = 0 | |
| ind = 0 | |
| # train mode | |
| net.train() | |
| optimizer.zero_grad() | |
| epoch_loss = 0 | |
| GPUdevice = torch.device('cuda:' + str(args.gpu_device)) | |
| if args.thd: | |
| lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') | |
| else: | |
| lossfunc = criterion_G | |
| with tqdm(total=len(train_loader), desc=f'Epoch {epoch}', unit='img') as pbar: | |
| for pack in train_loader: | |
| # torch.cuda.empty_cache() | |
| imgs = pack['image'].to(dtype = torch.float32, device = GPUdevice) | |
| masks = pack['label'].to(dtype = torch.float32, device = GPUdevice) | |
| # for k,v in pack['image_meta_dict'].items(): | |
| # print(k) | |
| if 'pt' not in pack: | |
| imgs, pt, masks = generate_click_prompt(imgs, masks) | |
| else: | |
| pt = pack['pt'] | |
| point_labels = pack['p_label'] | |
| name = pack['image_meta_dict']['filename_or_obj'] | |
| if args.thd: | |
| imgs, pt, masks = generate_click_prompt(imgs, masks) | |
| pt = rearrange(pt, 'b n d -> (b d) n') | |
| imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ') | |
| masks = rearrange(masks, 'b c h w d -> (b d) c h w ') | |
| imgs = imgs.repeat(1,3,1,1) | |
| point_labels = torch.ones(imgs.size(0)) | |
| imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs) | |
| masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks) | |
| showp = pt | |
| mask_type = torch.float32 | |
| ind += 1 | |
| b_size,c,w,h = imgs.size() | |
| longsize = w if w >=h else h | |
| if point_labels.clone().flatten()[0] != -1: | |
| # point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w)) | |
| point_coords = pt | |
| coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice) | |
| labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice) | |
| if(len(point_labels.shape)==1): # only one point prompt | |
| coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :] | |
| pt = (coords_torch, labels_torch) | |
| '''init''' | |
| if hard: | |
| true_mask_ave = (true_mask_ave > 0.5).float() | |
| #true_mask_ave = cons_tensor(true_mask_ave) | |
| # imgs = imgs.to(dtype = mask_type,device = GPUdevice) | |
| '''Train''' | |
| if args.mod == 'sam_adpt': | |
| for n, value in net.image_encoder.named_parameters(): | |
| if "Adapter" not in n: | |
| value.requires_grad = False | |
| else: | |
| value.requires_grad = True | |
| elif args.mod == 'sam_lora' or args.mod == 'sam_adalora': | |
| from models.common import loralib as lora | |
| lora.mark_only_lora_as_trainable(net.image_encoder) | |
| if args.mod == 'sam_adalora': | |
| # Initialize the RankAllocator | |
| rankallocator = lora.RankAllocator( | |
| net.image_encoder, lora_r=4, target_rank=8, | |
| init_warmup=500, final_warmup=1500, mask_interval=10, | |
| total_step=3000, beta1=0.85, beta2=0.85, | |
| ) | |
| else: | |
| for n, value in net.image_encoder.named_parameters(): | |
| value.requires_grad = True | |
| imge= net.image_encoder(imgs) | |
| with torch.no_grad(): | |
| if args.net == 'sam' or args.net == 'mobile_sam': | |
| se, de = net.prompt_encoder( | |
| points=pt, | |
| boxes=None, | |
| masks=None, | |
| ) | |
| elif args.net == "efficient_sam": | |
| coords_torch,labels_torch = transform_prompt(coords_torch,labels_torch,h,w) | |
| se = net.prompt_encoder( | |
| coords=coords_torch, | |
| labels=labels_torch, | |
| ) | |
| if args.net == 'sam': | |
| pred, _ = net.mask_decoder( | |
| image_embeddings=imge, | |
| image_pe=net.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=se, | |
| dense_prompt_embeddings=de, | |
| multimask_output=(args.multimask_output > 1), | |
| ) | |
| elif args.net == 'mobile_sam': | |
| pred, _ = net.mask_decoder( | |
| image_embeddings=imge, | |
| image_pe=net.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=se, | |
| dense_prompt_embeddings=de, | |
| multimask_output=False, | |
| ) | |
| elif args.net == "efficient_sam": | |
| se = se.view( | |
| se.shape[0], | |
| 1, | |
| se.shape[1], | |
| se.shape[2], | |
| ) | |
| pred, _ = net.mask_decoder( | |
| image_embeddings=imge, | |
| image_pe=net.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=se, | |
| multimask_output=False, | |
| ) | |
| # Resize to the ordered output size | |
| pred = F.interpolate(pred,size=(args.out_size,args.out_size)) | |
| loss = lossfunc(pred, masks) | |
| pbar.set_postfix(**{'loss (batch)': loss.item()}) | |
| epoch_loss += loss.item() | |
| # nn.utils.clip_grad_value_(net.parameters(), 0.1) | |
| if args.mod == 'sam_adalora': | |
| (loss+lora.compute_orth_regu(net, regu_weight=0.1)).backward() | |
| optimizer.step() | |
| rankallocator.update_and_mask(net, ind) | |
| else: | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| '''vis images''' | |
| if vis: | |
| if ind % vis == 0: | |
| namecat = 'Train' | |
| for na in name[:2]: | |
| namecat = namecat + na.split('/')[-1].split('.')[0] + '+' | |
| vis_image(imgs,pred,masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp) | |
| pbar.update() | |
| return loss | |
| def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True): | |
| # eval mode | |
| net.eval() | |
| mask_type = torch.float32 | |
| n_val = len(val_loader) # the number of batch | |
| ave_res, mix_res = (0,0,0,0), (0,)*args.multimask_output*2 | |
| rater_res = [(0,0,0,0) for _ in range(6)] | |
| tot = 0 | |
| hard = 0 | |
| threshold = (0.1, 0.3, 0.5, 0.7, 0.9) | |
| GPUdevice = torch.device('cuda:' + str(args.gpu_device)) | |
| device = GPUdevice | |
| if args.thd: | |
| lossfunc = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean') | |
| else: | |
| lossfunc = criterion_G | |
| with tqdm(total=n_val, desc='Validation round', unit='batch', leave=False) as pbar: | |
| for ind, pack in enumerate(val_loader): | |
| imgsw = pack['image'].to(dtype = torch.float32, device = GPUdevice) | |
| masksw = pack['label'].to(dtype = torch.float32, device = GPUdevice) | |
| # for k,v in pack['image_meta_dict'].items(): | |
| # print(k) | |
| if 'pt' not in pack or args.thd: | |
| imgsw, ptw, masksw = generate_click_prompt(imgsw, masksw) | |
| else: | |
| ptw = pack['pt'] | |
| point_labels = pack['p_label'] | |
| name = pack['image_meta_dict']['filename_or_obj'] | |
| buoy = 0 | |
| if args.evl_chunk: | |
| evl_ch = int(args.evl_chunk) | |
| else: | |
| evl_ch = int(imgsw.size(-1)) | |
| while (buoy + evl_ch) <= imgsw.size(-1): | |
| if args.thd: | |
| pt = ptw[:,:,buoy: buoy + evl_ch] | |
| else: | |
| pt = ptw | |
| imgs = imgsw[...,buoy:buoy + evl_ch] | |
| masks = masksw[...,buoy:buoy + evl_ch] | |
| buoy += evl_ch | |
| if args.thd: | |
| pt = rearrange(pt, 'b n d -> (b d) n') | |
| imgs = rearrange(imgs, 'b c h w d -> (b d) c h w ') | |
| masks = rearrange(masks, 'b c h w d -> (b d) c h w ') | |
| imgs = imgs.repeat(1,3,1,1) | |
| point_labels = torch.ones(imgs.size(0)) | |
| imgs = torchvision.transforms.Resize((args.image_size,args.image_size))(imgs) | |
| masks = torchvision.transforms.Resize((args.out_size,args.out_size))(masks) | |
| showp = pt | |
| mask_type = torch.float32 | |
| ind += 1 | |
| b_size,c,w,h = imgs.size() | |
| longsize = w if w >=h else h | |
| if point_labels.clone().flatten()[0] != -1: | |
| # point_coords = samtrans.ResizeLongestSide(longsize).apply_coords(pt, (h, w)) | |
| point_coords = pt | |
| coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=GPUdevice) | |
| labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=GPUdevice) | |
| if(len(point_labels.shape)==1): # only one point prompt | |
| coords_torch, labels_torch, showp = coords_torch[None, :, :], labels_torch[None, :], showp[None, :, :] | |
| pt = (coords_torch, labels_torch) | |
| '''init''' | |
| if hard: | |
| true_mask_ave = (true_mask_ave > 0.5).float() | |
| #true_mask_ave = cons_tensor(true_mask_ave) | |
| imgs = imgs.to(dtype = mask_type,device = GPUdevice) | |
| '''test''' | |
| with torch.no_grad(): | |
| imge= net.image_encoder(imgs) | |
| if args.net == 'sam' or args.net == 'mobile_sam': | |
| se, de = net.prompt_encoder( | |
| points=pt, | |
| boxes=None, | |
| masks=None, | |
| ) | |
| elif args.net == "efficient_sam": | |
| coords_torch,labels_torch = transform_prompt(coords_torch,labels_torch,h,w) | |
| se = net.prompt_encoder( | |
| coords=coords_torch, | |
| labels=labels_torch, | |
| ) | |
| if args.net == 'sam': | |
| pred, _ = net.mask_decoder( | |
| image_embeddings=imge, | |
| image_pe=net.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=se, | |
| dense_prompt_embeddings=de, | |
| multimask_output=(args.multimask_output > 1), | |
| ) | |
| elif args.net == 'mobile_sam': | |
| pred, _ = net.mask_decoder( | |
| image_embeddings=imge, | |
| image_pe=net.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=se, | |
| dense_prompt_embeddings=de, | |
| multimask_output=False, | |
| ) | |
| elif args.net == "efficient_sam": | |
| se = se.view( | |
| se.shape[0], | |
| 1, | |
| se.shape[1], | |
| se.shape[2], | |
| ) | |
| pred, _ = net.mask_decoder( | |
| image_embeddings=imge, | |
| image_pe=net.prompt_encoder.get_dense_pe(), | |
| sparse_prompt_embeddings=se, | |
| multimask_output=False, | |
| ) | |
| # Resize to the ordered output size | |
| pred = F.interpolate(pred,size=(args.out_size,args.out_size)) | |
| tot += lossfunc(pred, masks) | |
| '''vis images''' | |
| if ind % args.vis == 0: | |
| namecat = 'Test' | |
| for na in name[:2 | |
| ]: | |
| img_name = na.split('/')[-1].split('.')[0] | |
| namecat = namecat + img_name + '+' | |
| vis_image(imgs,pred, masks, os.path.join(args.path_helper['sample_path'], namecat+'epoch+' +str(epoch) + '.jpg'), reverse=False, points=showp) | |
| temp = eval_seg(pred, masks, threshold) | |
| mix_res = tuple([sum(a) for a in zip(mix_res, temp)]) | |
| pbar.update() | |
| if args.evl_chunk: | |
| n_val = n_val * (imgsw.size(-1) // evl_ch) | |
| return tot/ n_val , tuple([a/n_val for a in mix_res]) | |
| def transform_prompt(coord,label,h,w): | |
| coord = coord.transpose(0,1) | |
| label = label.transpose(0,1) | |
| coord = coord.unsqueeze(1) | |
| label = label.unsqueeze(1) | |
| batch_size, max_num_queries, num_pts, _ = coord.shape | |
| num_pts = coord.shape[2] | |
| rescaled_batched_points = get_rescaled_pts(coord, h, w) | |
| decoder_max_num_input_points = 6 | |
| if num_pts > decoder_max_num_input_points: | |
| rescaled_batched_points = rescaled_batched_points[ | |
| :, :, : decoder_max_num_input_points, : | |
| ] | |
| label = label[ | |
| :, :, : decoder_max_num_input_points | |
| ] | |
| elif num_pts < decoder_max_num_input_points: | |
| rescaled_batched_points = F.pad( | |
| rescaled_batched_points, | |
| (0, 0, 0, decoder_max_num_input_points - num_pts), | |
| value=-1.0, | |
| ) | |
| label = F.pad( | |
| label, | |
| (0, decoder_max_num_input_points - num_pts), | |
| value=-1.0, | |
| ) | |
| rescaled_batched_points = rescaled_batched_points.reshape( | |
| batch_size * max_num_queries, decoder_max_num_input_points, 2 | |
| ) | |
| label = label.reshape( | |
| batch_size * max_num_queries, decoder_max_num_input_points | |
| ) | |
| return rescaled_batched_points,label | |
| def get_rescaled_pts(batched_points: torch.Tensor, input_h: int, input_w: int): | |
| return torch.stack( | |
| [ | |
| torch.where( | |
| batched_points[..., 0] >= 0, | |
| batched_points[..., 0] * 1024 / input_w, | |
| -1.0, | |
| ), | |
| torch.where( | |
| batched_points[..., 1] >= 0, | |
| batched_points[..., 1] * 1024 / input_h, | |
| -1.0, | |
| ), | |
| ], | |
| dim=-1, | |
| ) |