Spaces:
Paused
Paused
| import sys | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.data as data | |
| from PIL import Image | |
| from torch.autograd import grad | |
| from torchvision import transforms, utils | |
| import face_alignment | |
| import lpips | |
| current_dir = os.path.abspath(os.path.dirname(__file__)) | |
| sys.path.insert(0, current_dir) | |
| from pixel2style2pixel.models.stylegan2.model import Generator, get_keys | |
| from nets.feature_style_encoder import * | |
| from arcface.iresnet import * | |
| from face_parsing.model import BiSeNet | |
| from ranger import Ranger | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.data as data | |
| from PIL import Image | |
| from torch.autograd import grad | |
| def clip_img(x): | |
| """Clip stylegan generated image to range(0,1)""" | |
| img_tmp = x.clone()[0] | |
| img_tmp = (img_tmp + 1) / 2 | |
| img_tmp = torch.clamp(img_tmp, 0, 1) | |
| return [img_tmp.detach().cpu()] | |
| def tensor_byte(x): | |
| return x.element_size()*x.nelement() | |
| def count_parameters(net): | |
| s = sum([np.prod(list(mm.size())) for mm in net.parameters()]) | |
| print(s) | |
| def stylegan_to_classifier(x, out_size=(224, 224)): | |
| """Clip image to range(0,1)""" | |
| img_tmp = x.clone() | |
| img_tmp = torch.clamp((0.5*img_tmp + 0.5), 0, 1) | |
| img_tmp = F.interpolate(img_tmp, size=out_size, mode='bilinear') | |
| img_tmp[:,0] = (img_tmp[:,0] - 0.485)/0.229 | |
| img_tmp[:,1] = (img_tmp[:,1] - 0.456)/0.224 | |
| img_tmp[:,2] = (img_tmp[:,2] - 0.406)/0.225 | |
| #img_tmp = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img_tmp) | |
| return img_tmp | |
| def downscale(x, scale_times=1, mode='bilinear'): | |
| for i in range(scale_times): | |
| x = F.interpolate(x, scale_factor=0.5, mode=mode) | |
| return x | |
| def upscale(x, scale_times=1, mode='bilinear'): | |
| for i in range(scale_times): | |
| x = F.interpolate(x, scale_factor=2, mode=mode) | |
| return x | |
| def hist_transform(source_tensor, target_tensor): | |
| """Histogram transformation""" | |
| c, h, w = source_tensor.size() | |
| s_t = source_tensor.view(c, -1) | |
| t_t = target_tensor.view(c, -1) | |
| s_t_sorted, s_t_indices = torch.sort(s_t) | |
| t_t_sorted, t_t_indices = torch.sort(t_t) | |
| for i in range(c): | |
| s_t[i, s_t_indices[i]] = t_t_sorted[i] | |
| return s_t.view(c, h, w) | |
| def init_weights(m): | |
| """Initialize layers with Xavier uniform distribution""" | |
| if type(m) == nn.Conv2d: | |
| nn.init.xavier_uniform_(m.weight) | |
| elif type(m) == nn.Linear: | |
| nn.init.uniform_(m.weight, 0.0, 1.0) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0.01) | |
| def total_variation(x, delta=1): | |
| """Total variation, x: tensor of size (B, C, H, W)""" | |
| out = torch.mean(torch.abs(x[:, :, :, :-delta] - x[:, :, :, delta:]))\ | |
| + torch.mean(torch.abs(x[:, :, :-delta, :] - x[:, :, delta:, :])) | |
| return out | |
| def vgg_transform(x): | |
| """Adapt image for vgg network, x: image of range(0,1) subtracting ImageNet mean""" | |
| r, g, b = torch.split(x, 1, 1) | |
| out = torch.cat((b, g, r), dim = 1) | |
| out = F.interpolate(out, size=(224, 224), mode='bilinear') | |
| out = out*255. | |
| return out | |
| # warp image with flow | |
| def normalize_axis(x,L): | |
| return (x-1-(L-1)/2)*2/(L-1) | |
| def unnormalize_axis(x,L): | |
| return x*(L-1)/2+1+(L-1)/2 | |
| def torch_flow_to_th_sampling_grid(flow,h_src,w_src,use_cuda=False): | |
| b,c,h_tgt,w_tgt=flow.size() | |
| grid_y, grid_x = torch.meshgrid(torch.tensor(range(1,w_tgt+1)),torch.tensor(range(1,h_tgt+1))) | |
| disp_x=flow[:,0,:,:] | |
| disp_y=flow[:,1,:,:] | |
| source_x=grid_x.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_x | |
| source_y=grid_y.unsqueeze(0).repeat(b,1,1).type_as(flow)+disp_y | |
| source_x_norm=normalize_axis(source_x,w_src) | |
| source_y_norm=normalize_axis(source_y,h_src) | |
| sampling_grid=torch.cat((source_x_norm.unsqueeze(3), source_y_norm.unsqueeze(3)), dim=3) | |
| if use_cuda: | |
| sampling_grid = sampling_grid.cuda() | |
| return sampling_grid | |
| def warp_image_torch(image, flow): | |
| """ | |
| Warp image (tensor, shape=[b, 3, h_src, w_src]) with flow (tensor, shape=[b, h_tgt, w_tgt, 2]) | |
| """ | |
| b,c,h_src,w_src=image.size() | |
| sampling_grid_torch = torch_flow_to_th_sampling_grid(flow, h_src, w_src) | |
| warped_image_torch = F.grid_sample(image, sampling_grid_torch) | |
| return warped_image_torch | |
| class Trainer(nn.Module): | |
| def __init__(self, config, opts): | |
| super(Trainer, self).__init__() | |
| # Load Hyperparameters | |
| self.config = config | |
| self.device = torch.device(self.config['device']) | |
| self.scale = int(np.log2(config['resolution']/config['enc_resolution'])) | |
| self.scale_mode = 'bilinear' | |
| self.opts = opts | |
| self.n_styles = 2 * int(np.log2(config['resolution'])) - 2 | |
| self.idx_k = 5 | |
| if 'idx_k' in self.config: | |
| self.idx_k = self.config['idx_k'] | |
| if 'stylegan_version' in self.config and self.config['stylegan_version'] == 3: | |
| self.n_styles = 16 | |
| # Networks | |
| in_channels = 256 | |
| if 'in_c' in self.config: | |
| in_channels = config['in_c'] | |
| enc_residual = False | |
| if 'enc_residual' in self.config: | |
| enc_residual = self.config['enc_residual'] | |
| enc_residual_coeff = False | |
| if 'enc_residual_coeff' in self.config: | |
| enc_residual_coeff = self.config['enc_residual_coeff'] | |
| resnet_layers = [4,5,6] | |
| if 'enc_start_layer' in self.config: | |
| st_l = self.config['enc_start_layer'] | |
| resnet_layers = [st_l, st_l+1, st_l+2] | |
| if 'scale_mode' in self.config: | |
| self.scale_mode = self.config['scale_mode'] | |
| # Load encoder | |
| self.stride = (self.config['fs_stride'], self.config['fs_stride']) | |
| self.enc = fs_encoder_v2(n_styles=self.n_styles, opts=opts, residual=enc_residual, use_coeff=enc_residual_coeff, resnet_layer=resnet_layers, stride=self.stride) | |
| ########################## | |
| # Other nets | |
| self.StyleGAN = self.init_stylegan(config) | |
| self.Arcface = iresnet50() | |
| self.parsing_net = BiSeNet(n_classes=19) | |
| # Optimizers | |
| # Latent encoder | |
| self.enc_params = list(self.enc.parameters()) | |
| if 'freeze_iresnet' in self.config and self.config['freeze_iresnet']: | |
| self.enc_params = list(self.enc.styles.parameters()) | |
| if 'optimizer' in self.config and self.config['optimizer'] == 'ranger': | |
| self.enc_opt = Ranger(self.enc_params, lr=config['lr'], betas=(config['beta_1'], config['beta_2']), weight_decay=config['weight_decay']) | |
| else: | |
| self.enc_opt = torch.optim.Adam(self.enc_params, lr=config['lr'], betas=(config['beta_1'], config['beta_2']), weight_decay=config['weight_decay']) | |
| self.enc_scheduler = torch.optim.lr_scheduler.StepLR(self.enc_opt, step_size=config['step_size'], gamma=config['gamma']) | |
| self.fea_avg = None | |
| def initialize(self, stylegan_model_path, arcface_model_path, parsing_model_path): | |
| # load StyleGAN model | |
| stylegan_state_dict = torch.load(stylegan_model_path, map_location='cpu') | |
| self.StyleGAN.load_state_dict(get_keys(stylegan_state_dict, 'decoder'), strict=True) | |
| self.StyleGAN.to(self.device) | |
| # get StyleGAN average latent in w space and the noise inputs | |
| self.dlatent_avg = stylegan_state_dict['latent_avg'].to(self.device) | |
| self.noise_inputs = [getattr(self.StyleGAN.noises, f'noise_{i}').to(self.device) for i in range(self.StyleGAN.num_layers)] | |
| # load Arcface weight | |
| self.Arcface.load_state_dict(torch.load(self.opts.arcface_model_path)) | |
| self.Arcface.eval() | |
| # load face parsing net weight | |
| self.parsing_net.load_state_dict(torch.load(self.opts.parsing_model_path)) | |
| self.parsing_net.eval() | |
| # load lpips net weight | |
| # self.loss_fn = lpips.LPIPS(net='alex', spatial=False) | |
| # self.loss_fn.to(self.device) | |
| def init_stylegan(self, config): | |
| """StyleGAN = G_main( | |
| truncation_psi=config['truncation_psi'], | |
| resolution=config['resolution'], | |
| use_noise=config['use_noise'], | |
| randomize_noise=config['randomize_noise'] | |
| )""" | |
| StyleGAN = Generator(1024, 512, 8) | |
| return StyleGAN | |
| def mapping(self, z): | |
| return self.StyleGAN.get_latent(z).detach() | |
| def L1loss(self, input, target): | |
| return nn.L1Loss()(input,target) | |
| def L2loss(self, input, target): | |
| return nn.MSELoss()(input,target) | |
| def CEloss(self, x, target_age): | |
| return nn.CrossEntropyLoss()(x, target_age) | |
| def LPIPS(self, input, target, multi_scale=False): | |
| if multi_scale: | |
| out = 0 | |
| for k in range(3): | |
| out += self.loss_fn.forward(downscale(input, k, self.scale_mode), downscale(target, k, self.scale_mode)).mean() | |
| else: | |
| out = self.loss_fn.forward(downscale(input, self.scale, self.scale_mode), downscale(target, self.scale, self.scale_mode)).mean() | |
| return out | |
| def IDloss(self, input, target): | |
| x_1 = F.interpolate(input, (112,112)) | |
| x_2 = F.interpolate(target, (112,112)) | |
| cos = nn.CosineSimilarity(dim=1, eps=1e-6) | |
| if 'multi_layer_idloss' in self.config and self.config['multi_layer_idloss']: | |
| id_1 = self.Arcface(x_1, return_features=True) | |
| id_2 = self.Arcface(x_2, return_features=True) | |
| return sum([1 - cos(id_1[i].flatten(start_dim=1), id_2[i].flatten(start_dim=1)) for i in range(len(id_1))]) | |
| else: | |
| id_1 = self.Arcface(x_1) | |
| id_2 = self.Arcface(x_2) | |
| return 1 - cos(id_1, id_2) | |
| def landmarkloss(self, input, target): | |
| cos = nn.CosineSimilarity(dim=1, eps=1e-6) | |
| x_1 = stylegan_to_classifier(input, out_size=(512, 512)) | |
| x_2 = stylegan_to_classifier(target, out_size=(512,512)) | |
| out_1 = self.parsing_net(x_1) | |
| out_2 = self.parsing_net(x_2) | |
| parsing_loss = sum([1 - cos(out_1[i].flatten(start_dim=1), out_2[i].flatten(start_dim=1)) for i in range(len(out_1))]) | |
| return parsing_loss.mean() | |
| def feature_match(self, enc_feat, dec_feat, layer_idx=None): | |
| loss = [] | |
| if layer_idx is None: | |
| layer_idx = [i for i in range(len(enc_feat))] | |
| for i in layer_idx: | |
| loss.append(self.L1loss(enc_feat[i], dec_feat[i])) | |
| return loss | |
| def encode(self, img): | |
| w_recon, fea = self.enc(downscale(img, self.scale, self.scale_mode)) | |
| w_recon = w_recon + self.dlatent_avg | |
| return w_recon, fea | |
| def get_image(self, w=None, img=None, noise=None, zero_noise_input=True, training_mode=True): | |
| x_1, n_1 = img, noise | |
| if x_1 is None: | |
| x_1, _ = self.StyleGAN([w], input_is_latent=True, noise = n_1) | |
| w_delta = None | |
| fea = None | |
| features = None | |
| return_features = False | |
| # Reconstruction | |
| k = 0 | |
| if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']: | |
| return_features = True | |
| k = self.idx_k | |
| w_recon, fea = self.enc(downscale(x_1, self.scale, self.scale_mode)) | |
| w_recon = w_recon + self.dlatent_avg | |
| features = [None]*k + [fea] + [None]*(17-k) | |
| else: | |
| w_recon = self.enc(downscale(x_1, self.scale, self.scale_mode)) + self.dlatent_avg | |
| # generate image | |
| x_1_recon, fea_recon = self.StyleGAN([w_recon], input_is_latent=True, return_features=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter)) | |
| fea_recon = fea_recon[k].detach() | |
| return [x_1_recon, x_1[:,:3,:,:], w_recon, w_delta, n_1, fea, fea_recon] | |
| def compute_loss(self, w=None, img=None, noise=None, real_img=None): | |
| return self.compute_loss_stylegan2(w=w, img=img, noise=noise, real_img=real_img) | |
| def compute_loss_stylegan2(self, w=None, img=None, noise=None, real_img=None): | |
| if img is None: | |
| # generate synthetic images | |
| if noise is None: | |
| noise = [torch.randn(w.size()[:1] + ee.size()[1:]).to(self.device) for ee in self.noise_inputs] | |
| img, _ = self.StyleGAN([w], input_is_latent=True, noise = noise) | |
| img = img.detach() | |
| if img is not None and real_img is not None: | |
| # concat synthetic and real data | |
| img = torch.cat([img, real_img], dim=0) | |
| noise = [torch.cat([ee, ee], dim=0) for ee in noise] | |
| out = self.get_image(w=w, img=img, noise=noise) | |
| x_1_recon, x_1, w_recon, w_delta, n_1, fea_1, fea_recon = out | |
| # Loss setting | |
| w_l2, w_lpips, w_id = self.config['w']['l2'], self.config['w']['lpips'], self.config['w']['id'] | |
| b = x_1.size(0)//2 | |
| if 'l2loss_on_real_image' in self.config and self.config['l2loss_on_real_image']: | |
| b = x_1.size(0) | |
| self.l2_loss = self.L2loss(x_1_recon[:b], x_1[:b]) if w_l2 > 0 else torch.tensor(0) # l2 loss only on synthetic data | |
| # LPIPS | |
| multiscale_lpips=False if 'multiscale_lpips' not in self.config else self.config['multiscale_lpips'] | |
| self.lpips_loss = self.LPIPS(x_1_recon, x_1, multi_scale=multiscale_lpips).mean() if w_lpips > 0 else torch.tensor(0) | |
| self.id_loss = self.IDloss(x_1_recon, x_1).mean() if w_id > 0 else torch.tensor(0) | |
| self.landmark_loss = self.landmarkloss(x_1_recon, x_1) if self.config['w']['landmark'] > 0 else torch.tensor(0) | |
| if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']: | |
| k = self.idx_k | |
| features = [None]*k + [fea_1] + [None]*(17-k) | |
| x_1_recon_2, _ = self.StyleGAN([w_recon], noise=n_1, input_is_latent=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter)) | |
| self.lpips_loss += self.LPIPS(x_1_recon_2, x_1, multi_scale=multiscale_lpips).mean() if w_lpips > 0 else torch.tensor(0) | |
| self.id_loss += self.IDloss(x_1_recon_2, x_1).mean() if w_id > 0 else torch.tensor(0) | |
| self.landmark_loss += self.landmarkloss(x_1_recon_2, x_1) if self.config['w']['landmark'] > 0 else torch.tensor(0) | |
| # downscale image | |
| x_1 = downscale(x_1, self.scale, self.scale_mode) | |
| x_1_recon = downscale(x_1_recon, self.scale, self.scale_mode) | |
| # Total loss | |
| w_l2, w_lpips, w_id = self.config['w']['l2'], self.config['w']['lpips'], self.config['w']['id'] | |
| self.loss = w_l2*self.l2_loss + w_lpips*self.lpips_loss + w_id*self.id_loss | |
| if 'f_recon' in self.config['w']: | |
| self.feature_recon_loss = self.L2loss(fea_1, fea_recon) | |
| self.loss += self.config['w']['f_recon']*self.feature_recon_loss | |
| if 'l1' in self.config['w'] and self.config['w']['l1']>0: | |
| self.l1_loss = self.L1loss(x_1_recon, x_1) | |
| self.loss += self.config['w']['l1']*self.l1_loss | |
| if 'landmark' in self.config['w']: | |
| self.loss += self.config['w']['landmark']*self.landmark_loss | |
| return self.loss | |
| def test(self, w=None, img=None, noise=None, zero_noise_input=True, return_latent=False, training_mode=False): | |
| if 'n_iter' not in self.__dict__.keys(): | |
| self.n_iter = 1e5 | |
| out = self.get_image(w=w, img=img, noise=noise, training_mode=training_mode) | |
| x_1_recon, x_1, w_recon, w_delta, n_1, fea_1 = out[:6] | |
| output = [x_1, x_1_recon] | |
| if return_latent: | |
| output += [w_recon, fea_1] | |
| return output | |
| def log_loss(self, logger, n_iter, prefix='scripts'): | |
| logger.log_value(prefix + '/l2_loss', self.l2_loss.item(), n_iter + 1) | |
| logger.log_value(prefix + '/lpips_loss', self.lpips_loss.item(), n_iter + 1) | |
| logger.log_value(prefix + '/id_loss', self.id_loss.item(), n_iter + 1) | |
| logger.log_value(prefix + '/total_loss', self.loss.item(), n_iter + 1) | |
| if 'f_recon' in self.config['w']: | |
| logger.log_value(prefix + '/feature_recon_loss', self.feature_recon_loss.item(), n_iter + 1) | |
| if 'l1' in self.config['w'] and self.config['w']['l1']>0: | |
| logger.log_value(prefix + '/l1_loss', self.l1_loss.item(), n_iter + 1) | |
| if 'landmark' in self.config['w']: | |
| logger.log_value(prefix + '/landmark_loss', self.landmark_loss.item(), n_iter + 1) | |
| def save_image(self, log_dir, n_epoch, n_iter, prefix='/scripts/', w=None, img=None, noise=None, training_mode=True): | |
| return self.save_image_stylegan2(log_dir=log_dir, n_epoch=n_epoch, n_iter=n_iter, prefix=prefix, w=w, img=img, noise=noise, training_mode=training_mode) | |
| def save_image_stylegan2(self, log_dir, n_epoch, n_iter, prefix='/scripts/', w=None, img=None, noise=None, training_mode=True): | |
| os.makedirs(log_dir + prefix, exist_ok=True) | |
| with torch.no_grad(): | |
| out = self.get_image(w=w, img=img, noise=noise, training_mode=training_mode) | |
| x_1_recon, x_1, w_recon, w_delta, n_1, fea_1 = out[:6] | |
| x_1 = downscale(x_1, self.scale, self.scale_mode) | |
| x_1_recon = downscale(x_1_recon, self.scale, self.scale_mode) | |
| out_img = torch.cat((x_1, x_1_recon), dim=3) | |
| #fs | |
| if 'use_fs_encoder' in self.config and self.config['use_fs_encoder']: | |
| k = self.idx_k | |
| features = [None]*k + [fea_1] + [None]*(17-k) | |
| x_1_recon_2, _ = self.StyleGAN([w_recon], noise=n_1, input_is_latent=True, features_in=features, feature_scale=min(1.0, 0.0001*self.n_iter)) | |
| x_1_recon_2 = downscale(x_1_recon_2, self.scale, self.scale_mode) | |
| out_img = torch.cat((x_1, x_1_recon, x_1_recon_2), dim=3) | |
| utils.save_image(clip_img(out_img[:1]), log_dir + prefix + 'epoch_' +str(n_epoch+1) + '_iter_' + str(n_iter+1) + '_0.jpg') | |
| if out_img.size(0)>1: | |
| utils.save_image(clip_img(out_img[1:]), log_dir + prefix + 'epoch_' +str(n_epoch+1) + '_iter_' + str(n_iter+1) + '_1.jpg') | |
| def save_model(self, log_dir): | |
| torch.save(self.enc.state_dict(),'{:s}/enc.pth.tar'.format(log_dir)) | |
| def save_checkpoint(self, n_epoch, log_dir): | |
| checkpoint_state = { | |
| 'n_epoch': n_epoch, | |
| 'enc_state_dict': self.enc.state_dict(), | |
| 'enc_opt_state_dict': self.enc_opt.state_dict(), | |
| 'enc_scheduler_state_dict': self.enc_scheduler.state_dict() | |
| } | |
| torch.save(checkpoint_state, '{:s}/checkpoint.pth'.format(log_dir)) | |
| if (n_epoch+1)%10 == 0 : | |
| torch.save(checkpoint_state, '{:s}/checkpoint'.format(log_dir)+'_'+str(n_epoch+1)+'.pth') | |
| def load_model(self, log_dir): | |
| self.enc.load_state_dict(torch.load('{:s}/enc.pth.tar'.format(log_dir))) | |
| def load_checkpoint(self, checkpoint_path): | |
| state_dict = torch.load(checkpoint_path) | |
| self.enc.load_state_dict(state_dict['enc_state_dict']) | |
| self.enc_opt.load_state_dict(state_dict['enc_opt_state_dict']) | |
| self.enc_scheduler.load_state_dict(state_dict['enc_scheduler_state_dict']) | |
| return state_dict['n_epoch'] + 1 | |
| def update(self, w=None, img=None, noise=None, real_img=None, n_iter=0): | |
| self.n_iter = n_iter | |
| self.enc_opt.zero_grad() | |
| self.compute_loss(w=w, img=img, noise=noise, real_img=real_img).backward() | |
| self.enc_opt.step() | |