Spaces:
Runtime error
Runtime error
| import torch | |
| from collections import OrderedDict | |
| from torch.nn import utils, functional as F | |
| from torch.optim import Adam, SGD | |
| from torch.autograd import Variable | |
| from torch.backends import cudnn | |
| from model import build_model, weights_init | |
| import scipy.misc as sm | |
| import numpy as np | |
| import os | |
| import torchvision.utils as vutils | |
| import cv2 | |
| import torch.nn.functional as F | |
| import math | |
| import time | |
| import sys | |
| import PIL.Image | |
| import scipy.io | |
| import os | |
| import logging | |
| EPSILON = 1e-8 | |
| p = OrderedDict() | |
| from dataset import get_loader | |
| base_model_cfg = 'resnet' | |
| p['lr_bone'] = 5e-5 # Learning rate resnet:5e-5, vgg:2e-5 | |
| p['lr_branch'] = 0.025 # Learning rate | |
| p['wd'] = 0.0005 # Weight decay | |
| p['momentum'] = 0.90 # Momentum | |
| lr_decay_epoch = [15, 24] # [6, 9], now x3 #15 | |
| nAveGrad = 10 # Update the weights once in 'nAveGrad' forward passes | |
| showEvery = 50 | |
| tmp_path = 'tmp_see' | |
| class Solver(object): | |
| def __init__(self, train_loader, test_loader, config, save_fold=None): | |
| self.train_loader = train_loader | |
| self.test_loader = test_loader | |
| self.config = config | |
| self.save_fold = save_fold | |
| self.mean = torch.Tensor([123.68, 116.779, 103.939]).view(3, 1, 1) / 255. | |
| # inference: choose the side map (see paper) | |
| if config.visdom: | |
| self.visual = Viz_visdom("trueUnify", 1) | |
| self.build_model() | |
| if self.config.pre_trained: self.net.load_state_dict(torch.load(self.config.pre_trained)) | |
| if config.mode == 'train': | |
| self.log_output = open("%s/logs/log.txt" % config.save_fold, 'w') | |
| else: | |
| print('Loading pre-trained model from %s...' % self.config.model) | |
| self.net_bone.load_state_dict(torch.load(self.config.model)) | |
| self.net_bone.eval() | |
| def print_network(self, model, name): | |
| num_params = 0 | |
| for p in model.parameters(): | |
| num_params += p.numel() | |
| print(name) | |
| print(model) | |
| print("The number of parameters: {}".format(num_params)) | |
| def get_params(self, base_lr): | |
| ml = [] | |
| for name, module in self.net_bone.named_children(): | |
| print(name) | |
| if name == 'loss_weight': | |
| ml.append({'params': module.parameters(), 'lr': p['lr_branch']}) | |
| else: | |
| ml.append({'params': module.parameters()}) | |
| return ml | |
| # build the network | |
| def build_model(self): | |
| self.net_bone = build_model(base_model_cfg) | |
| if self.config.cuda: | |
| self.net_bone = self.net_bone.cuda() | |
| self.net_bone.eval() # use_global_stats = True | |
| self.net_bone.apply(weights_init) | |
| if self.config.mode == 'train': | |
| if self.config.load_bone == '': | |
| if base_model_cfg == 'vgg': | |
| self.net_bone.base.load_pretrained_model(torch.load(self.config.vgg)) | |
| elif base_model_cfg == 'resnet': | |
| self.net_bone.base.load_state_dict(torch.load(self.config.resnet)) | |
| if self.config.load_bone != '': self.net_bone.load_state_dict(torch.load(self.config.load_bone)) | |
| self.lr_bone = p['lr_bone'] | |
| self.lr_branch = p['lr_branch'] | |
| self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone, weight_decay=p['wd']) | |
| self.print_network(self.net_bone, 'trueUnify bone part') | |
| # update the learning rate | |
| def update_lr(self, rate): | |
| for param_group in self.optimizer.param_groups: | |
| param_group['lr'] = param_group['lr'] * rate | |
| def test(self, test_mode=0): | |
| EPSILON = 1e-8 | |
| img_num = len(self.test_loader) | |
| time_t = 0.0 | |
| name_t = 'EGNet_ResNet50/' | |
| if not os.path.exists(os.path.join(self.save_fold, name_t)): | |
| os.mkdir(os.path.join(self.save_fold, name_t)) | |
| for i, data_batch in enumerate(self.test_loader): | |
| self.config.test_fold = self.save_fold | |
| print(self.config.test_fold) | |
| images_, name, im_size = data_batch['image'], data_batch['name'][0], np.asarray(data_batch['size']) | |
| with torch.no_grad(): | |
| images = Variable(images_) | |
| if self.config.cuda: | |
| images = images.cuda() | |
| print(images.size()) | |
| time_start = time.time() | |
| up_edge, up_sal, up_sal_f = self.net_bone(images) | |
| torch.cuda.synchronize() | |
| time_end = time.time() | |
| print(time_end - time_start) | |
| time_t = time_t + time_end - time_start | |
| pred = np.squeeze(torch.sigmoid(up_sal_f[-1]).cpu().data.numpy()) | |
| multi_fuse = 255 * pred | |
| cv2.imwrite(os.path.join(self.config.test_fold,name_t, name[:-4] + '.png'), multi_fuse) | |
| print("--- %s seconds ---" % (time_t)) | |
| print('Test Done!') | |
| # training phase | |
| def train(self): | |
| iter_num = len(self.train_loader.dataset) // self.config.batch_size | |
| aveGrad = 0 | |
| F_v = 0 | |
| if not os.path.exists(tmp_path): | |
| os.mkdir(tmp_path) | |
| for epoch in range(self.config.epoch): | |
| r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 | |
| self.net_bone.zero_grad() | |
| for i, data_batch in enumerate(self.train_loader): | |
| sal_image, sal_label, sal_edge = data_batch['sal_image'], data_batch['sal_label'], data_batch['sal_edge'] | |
| if sal_image.size()[2:] != sal_label.size()[2:]: | |
| print("Skip this batch") | |
| continue | |
| sal_image, sal_label, sal_edge = Variable(sal_image), Variable(sal_label), Variable(sal_edge) | |
| if self.config.cuda: | |
| sal_image, sal_label, sal_edge = sal_image.cuda(), sal_label.cuda(), sal_edge.cuda() | |
| up_edge, up_sal, up_sal_f = self.net_bone(sal_image) | |
| # edge part | |
| edge_loss = [] | |
| for ix in up_edge: | |
| edge_loss.append(bce2d_new(ix, sal_edge, reduction='sum')) | |
| edge_loss = sum(edge_loss) / (nAveGrad * self.config.batch_size) | |
| r_edge_loss += edge_loss.data | |
| # sal part | |
| sal_loss1= [] | |
| sal_loss2 = [] | |
| for ix in up_sal: | |
| sal_loss1.append(F.binary_cross_entropy_with_logits(ix, sal_label, reduction='sum')) | |
| for ix in up_sal_f: | |
| sal_loss2.append(F.binary_cross_entropy_with_logits(ix, sal_label, reduction='sum')) | |
| sal_loss = (sum(sal_loss1) + sum(sal_loss2)) / (nAveGrad * self.config.batch_size) | |
| r_sal_loss += sal_loss.data | |
| loss = sal_loss + edge_loss | |
| r_sum_loss += loss.data | |
| loss.backward() | |
| aveGrad += 1 | |
| if aveGrad % nAveGrad == 0: | |
| self.optimizer_bone.step() | |
| self.optimizer_bone.zero_grad() | |
| aveGrad = 0 | |
| if i % showEvery == 0: | |
| print('epoch: [%2d/%2d], iter: [%5d/%5d] || Edge : %10.4f || Sal : %10.4f || Sum : %10.4f' % ( | |
| epoch, self.config.epoch, i, iter_num, r_edge_loss*(nAveGrad * self.config.batch_size)/showEvery, | |
| r_sal_loss*(nAveGrad * self.config.batch_size)/showEvery, | |
| r_sum_loss*(nAveGrad * self.config.batch_size)/showEvery)) | |
| print('Learning rate: ' + str(self.lr_bone)) | |
| r_edge_loss, r_sal_loss, r_sum_loss= 0,0,0 | |
| if i % 200 == 0: | |
| vutils.save_image(torch.sigmoid(up_sal_f[-1].data), tmp_path+'/iter%d-sal-0.jpg' % i, normalize=True, padding = 0) | |
| vutils.save_image(sal_image.data, tmp_path+'/iter%d-sal-data.jpg' % i, padding = 0) | |
| vutils.save_image(sal_label.data, tmp_path+'/iter%d-sal-target.jpg' % i, padding = 0) | |
| if (epoch + 1) % self.config.epoch_save == 0: | |
| torch.save(self.net_bone.state_dict(), '%s/models/epoch_%d_bone.pth' % (self.config.save_fold, epoch + 1)) | |
| if epoch in lr_decay_epoch: | |
| self.lr_bone = self.lr_bone * 0.1 | |
| self.optimizer_bone = Adam(filter(lambda p: p.requires_grad, self.net_bone.parameters()), lr=self.lr_bone, weight_decay=p['wd']) | |
| torch.save(self.net_bone.state_dict(), '%s/models/final_bone.pth' % self.config.save_fold) | |
| def bce2d_new(input, target, reduction=None): | |
| assert(input.size() == target.size()) | |
| pos = torch.eq(target, 1).float() | |
| neg = torch.eq(target, 0).float() | |
| # ing = ((torch.gt(target, 0) & torch.lt(target, 1))).float() | |
| num_pos = torch.sum(pos) | |
| num_neg = torch.sum(neg) | |
| num_total = num_pos + num_neg | |
| alpha = num_neg / num_total | |
| beta = 1.1 * num_pos / num_total | |
| # target pixel = 1 -> weight beta | |
| # target pixel = 0 -> weight 1-beta | |
| weights = alpha * pos + beta * neg | |
| return F.binary_cross_entropy_with_logits(input, target, weights, reduction=reduction) | |