Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| import os | |
| from torch.autograd import Variable | |
| from skimage.metrics import peak_signal_noise_ratio as compare_psnr | |
| from skimage.metrics import mean_squared_error as compare_mse | |
| from skimage.metrics import structural_similarity as compare_ssim | |
| import pandas as pd | |
| from model.OneRestore import OneRestore | |
| from model.Embedder import Embedder | |
| def load_embedder_ckpt(device, freeze_model=False, ckpt_name=None, | |
| combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\ | |
| 'low_haze', 'low_rain', 'low_snow', 'haze_rain',\ | |
| 'haze_snow', 'low_haze_rain', 'low_haze_snow']): | |
| if ckpt_name != None: | |
| if torch.cuda.is_available(): | |
| model_info = torch.load(ckpt_name) | |
| else: | |
| model_info = torch.load(ckpt_name, map_location=torch.device('cpu')) | |
| print('==> loading existing Embedder model:', ckpt_name) | |
| model = Embedder(combine_type) | |
| model.load_state_dict(model_info) | |
| model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| print('==> Initialize Embedder model.') | |
| model = Embedder(combine_type) | |
| model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| if freeze_model: | |
| freeze(model) | |
| return model | |
| def load_restore_ckpt(device, freeze_model=False, ckpt_name=None): | |
| if ckpt_name != None: | |
| if torch.cuda.is_available(): | |
| model_info = torch.load(ckpt_name) | |
| else: | |
| model_info = torch.load(ckpt_name, map_location=torch.device('cpu')) | |
| print('==> loading existing OneRestore model:', ckpt_name) | |
| model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") | |
| model.load_state_dict(model_info) | |
| else: | |
| print('==> Initialize OneRestore model.') | |
| model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") | |
| model = torch.nn.DataParallel(model).to("cuda" if torch.cuda.is_available() else "cpu") | |
| if freeze_model: | |
| freeze(model) | |
| total = sum([param.nelement() for param in model.parameters()]) | |
| print("Number of OneRestore parameter: %.2fM" % (total/1e6)) | |
| return model | |
| def load_restore_ckpt_with_optim(device, local_rank=None, freeze_model=False, ckpt_name=None, lr=None): | |
| if ckpt_name != None: | |
| if torch.cuda.is_available(): | |
| model_info = torch.load(ckpt_name) | |
| else: | |
| model_info = torch.load(ckpt_name, map_location=torch.device('cpu')) | |
| print('==> loading existing OneRestore model:', ckpt_name) | |
| model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) if lr != None else None | |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) if local_rank != None else model | |
| if local_rank != None: | |
| model.load_state_dict(model_info['state_dict']) | |
| else: | |
| weights_dict = {} | |
| for k, v in model_info['state_dict'].items(): | |
| new_k = k.replace('module.', '') if 'module' in k else k | |
| weights_dict[new_k] = v | |
| model.load_state_dict(weights_dict) | |
| optimizer = torch.optim.Adam(model.parameters()) | |
| optimizer.load_state_dict(model_info['optimizer']) | |
| cur_epoch = model_info['epoch'] | |
| else: | |
| print('==> Initialize OneRestore model.') | |
| model = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu") | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
| model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, find_unused_parameters=True) if local_rank != None else torch.nn.DataParallel(model) | |
| cur_epoch = 0 | |
| if freeze_model: | |
| freeze(model) | |
| total = sum([param.nelement() for param in model.parameters()]) | |
| print("Number of OneRestore parameter: %.2fM" % (total/1e6)) | |
| return model, optimizer, cur_epoch | |
| def load_embedder_ckpt_with_optim(device, args, combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\ | |
| 'low_haze', 'low_rain', 'low_snow', 'haze_rain', 'haze_snow', 'low_haze_rain', 'low_haze_snow']): | |
| print('Init embedder') | |
| # seed | |
| if args.seed == -1: | |
| args.seed = np.random.randint(1, 10000) | |
| seed = args.seed | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| print('Training embedder seed:', seed) | |
| # embedder model | |
| embedder = Embedder(combine_type).to("cuda" if torch.cuda.is_available() else "cpu") | |
| if args.pre_weight == '': | |
| optimizer = torch.optim.Adam(embedder.parameters(), lr=args.lr) | |
| cur_epoch = 1 | |
| else: | |
| try: | |
| embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}') | |
| if torch.cuda.is_available(): | |
| embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}') | |
| else: | |
| embedder_info = torch.load(f'{args.check_dir}/{args.pre_weight}', map_location=torch.device('cpu')) | |
| embedder.load_state_dict(embedder_info['state_dict']) | |
| optimizer = torch.optim.Adam(embedder.parameters(), lr=args.lr) | |
| optimizer.load_state_dict(embedder_info['optimizer']) | |
| cur_epoch = embedder_info['epoch'] + 1 | |
| except: | |
| print('Pre-trained model loading error!') | |
| return embedder, optimizer, cur_epoch, device | |
| def freeze_text_embedder(m): | |
| """Freezes module m. | |
| """ | |
| m.eval() | |
| for name, para in m.named_parameters(): | |
| if name == 'embedder.weight' or name == 'mlp.0.weight' or name == 'mlp.0.bias': | |
| print(name) | |
| para.requires_grad = False | |
| para.grad = None | |
| class AverageMeter(object): | |
| """Computes and stores the average and current value""" | |
| def __init__(self): | |
| self.reset() | |
| def reset(self): | |
| self.val = 0 | |
| self.avg = 0 | |
| self.sum = 0 | |
| self.count = 0 | |
| def update(self, val, n=1): | |
| self.val = val | |
| self.sum += val * n | |
| self.count += n | |
| self.avg = self.sum / self.count | |
| def data_process(data, args, device): | |
| combine_type = args.degr_type | |
| b,n,c,w,h = data.size() | |
| pos_data = data[:,0,:,:,:] | |
| inp_data = torch.zeros((b,c,w,h)) | |
| inp_class = [] | |
| neg_data = torch.zeros((b,n-2,c,w,h)) | |
| index = np.random.randint(1, n, (b)) | |
| for i in range(b): | |
| k = 0 | |
| for j in range(n): | |
| if j == 0: | |
| continue | |
| elif index[i] == j: | |
| inp_class.append(combine_type[index[i]]) | |
| inp_data[i, :, :, :] = data[i, index[i], :, :,:] | |
| else: | |
| neg_data[i,k,:,:,:] = data[i, j, :, :,:] | |
| k=k+1 | |
| return pos_data.to("cuda" if torch.cuda.is_available() else "cpu"), [inp_data.to("cuda" if torch.cuda.is_available() else "cpu"), inp_class], neg_data.to("cuda" if torch.cuda.is_available() else "cpu") | |
| def print_args(argspar): | |
| print("\nParameter Print") | |
| for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()): | |
| print('\t{}: {}'.format(p, v)) | |
| print('\n') | |
| def adjust_learning_rate(optimizer, epoch, lr_update_freq): | |
| if not epoch % lr_update_freq and epoch: | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = param_group['lr'] /2 | |
| return optimizer | |
| def tensor_metric(img, imclean, model, data_range=1): | |
| img_cpu = img.data.cpu().numpy().astype(np.float32).transpose(0,2,3,1) | |
| imgclean = imclean.data.cpu().numpy().astype(np.float32).transpose(0,2,3,1) | |
| SUM = 0 | |
| for i in range(img_cpu.shape[0]): | |
| if model == 'PSNR': | |
| SUM += compare_psnr(imgclean[i, :, :, :], img_cpu[i, :, :, :],data_range=data_range) | |
| elif model == 'MSE': | |
| SUM += compare_mse(imgclean[i, :, :, :], img_cpu[i, :, :, :]) | |
| elif model == 'SSIM': | |
| SUM += compare_ssim(imgclean[i, :, :, :], img_cpu[i, :, :, :], data_range=data_range, multichannel = True) | |
| # due to the skimage vision problem, you can replace above line by | |
| # SUM += compare_ssim(imgclean[i, :, :, :], img_cpu[i, :, :, :], data_range=data_range, channel_axis=-1) | |
| else: | |
| print('Model False!') | |
| return SUM/img_cpu.shape[0] | |
| def save_checkpoint(stateF, checkpoint, epoch, psnr_t1,ssim_t1,psnr_t2,ssim_t2, filename='model.tar'): | |
| torch.save(stateF, checkpoint + 'OneRestore_model_%d_%.4f_%.4f_%.4f_%.4f.tar'%(epoch,psnr_t1,ssim_t1,psnr_t2,ssim_t2)) | |
| def load_excel(x): | |
| data1 = pd.DataFrame(x) | |
| writer = pd.ExcelWriter('./mertic_result.xlsx') | |
| data1.to_excel(writer, 'PSNR-SSIM', float_format='%.5f') | |
| # writer.save() | |
| writer.close() | |
| def freeze(m): | |
| """Freezes module m. | |
| """ | |
| m.eval() | |
| for p in m.parameters(): | |
| p.requires_grad = False | |
| p.grad = None | |