| | import torch |
| | import os |
| | from collections import OrderedDict |
| |
|
| | def freeze(model): |
| | for p in model.parameters(): |
| | p.requires_grad=False |
| |
|
| | def unfreeze(model): |
| | for p in model.parameters(): |
| | p.requires_grad=True |
| |
|
| | def is_frozen(model): |
| | x = [p.requires_grad for p in model.parameters()] |
| | return not all(x) |
| |
|
| | def save_checkpoint(model_dir, state, session): |
| | epoch = state['epoch'] |
| | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) |
| | torch.save(state, model_out_path) |
| |
|
| | def load_checkpoint(model, weights, strict=True): |
| | checkpoint = torch.load(weights, map_location=torch.device('cpu')) |
| | try: |
| | state_dict = checkpoint["state_dict"] |
| | new_state_dict = OrderedDict() |
| | for k, v in state_dict.items(): |
| | new_state_dict[k] = v |
| | model.load_state_dict(new_state_dict, strict=strict) |
| | except: |
| | state_dict = checkpoint["state_dict"] |
| | new_state_dict = OrderedDict() |
| | for k, v in state_dict.items(): |
| | name = k[7:] if 'module.' in k else k |
| | new_state_dict[name] = v |
| | model.load_state_dict(new_state_dict, strict=strict) |
| |
|
| | def load_checkpoint_multigpu(model, weights): |
| | checkpoint = torch.load(weights) |
| | state_dict = checkpoint["state_dict"] |
| | new_state_dict = OrderedDict() |
| | for k, v in state_dict.items(): |
| | name = k[7:] |
| | new_state_dict[name] = v |
| | model.load_state_dict(new_state_dict) |
| |
|
| | def load_start_epoch(weights): |
| | checkpoint = torch.load(weights, map_location=torch.device('cpu')) |
| | epoch = checkpoint["epoch"] |
| | return epoch |
| |
|
| | def load_optim(optimizer, weights): |
| | checkpoint = torch.load(weights, map_location=torch.device('cpu')) |
| | optimizer.load_state_dict(checkpoint['optimizer']) |
| | for p in optimizer.param_groups: lr = p['lr'] |
| | return lr |
| |
|
| | def get_arch(opt): |
| | from model import ShadowFormer, DenseSR |
| | arch = opt.arch |
| |
|
| | print('You choose '+arch+'...') |
| | if arch == 'ShadowFormer': |
| | model_restoration = ShadowFormer(img_size=opt.train_ps,embed_dim=opt.embed_dim, |
| | win_size=opt.win_size,token_projection=opt.token_projection, |
| | token_mlp=opt.token_mlp) |
| | elif arch == 'DenseSR': |
| | model_restoration = DenseSR(img_size=opt.train_ps,embed_dim=opt.embed_dim, |
| | win_size=opt.win_size,token_projection=opt.token_projection, |
| | token_mlp=opt.token_mlp) |
| | else: |
| | raise Exception("Arch error!") |
| |
|
| | return model_restoration |
| |
|
| |
|
| | def window_partition(x, win_size): |
| | B, C, H, W = x.shape |
| | x = x.permute(0,2,3,1) |
| | x = x.reshape(B, H // win_size, win_size, W // win_size, win_size, C) |
| | x = x.permute(0, 1, 3, 2, 4, 5).reshape(-1, win_size, win_size, C) |
| | return x.permute(0,3,1,2) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | def distributed_concat(var, num_total): |
| | |
| | var = var.view(1) if var.dim() == 0 else var |
| |
|
| | var_list = [torch.zeros_like(var).cuda() for _ in range(torch.distributed.get_world_size())] |
| | torch.distributed.all_gather(var_list, var) |
| |
|
| | |
| | return var_list[:num_total] |
| |
|