Spaces:
Paused
Paused
| import argparse | |
| import glob | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.data as data | |
| import yaml | |
| from PIL import Image | |
| from tqdm import tqdm | |
| from torchvision import transforms, utils | |
| from utils.datasets import * | |
| from utils.functions import * | |
| from trainer import * | |
| torch.backends.cudnn.enabled = True | |
| torch.backends.cudnn.deterministic = True | |
| torch.backends.cudnn.benchmark = True | |
| torch.autograd.set_detect_anomaly(True) | |
| Image.MAX_IMAGE_PIXELS = None | |
| device = torch.device('cuda') | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--config', type=str, default='001', help='Path to the config file.') | |
| parser.add_argument('--pretrained_model_path', type=str, default='./pretrained_models/143_enc.pth', help='pretrained stylegan2 model') | |
| parser.add_argument('--stylegan_model_path', type=str, default='./pixel2style2pixel/pretrained_models/psp_ffhq_encode.pt', help='pretrained stylegan2 model') | |
| parser.add_argument('--arcface_model_path', type=str, default='./pretrained_models/backbone.pth', help='pretrained ArcFace model') | |
| parser.add_argument('--parsing_model_path', type=str, default='./pretrained_models/79999_iter.pth', help='pretrained parsing model') | |
| parser.add_argument('--log_path', type=str, default='./logs/', help='log file path') | |
| parser.add_argument('--resume', action='store_true', help='resume from checkpoint') | |
| parser.add_argument('--checkpoint', type=str, default='', help='checkpoint file path') | |
| parser.add_argument('--checkpoint_noiser', type=str, default='', help='checkpoint file path') | |
| parser.add_argument('--multigpu', type=bool, default=False, help='use multiple gpus') | |
| parser.add_argument('--input_path', type=str, default='./test/', help='evaluation data file path') | |
| parser.add_argument('--save_path', type=str, default='./output/image/', help='output data save path') | |
| opts = parser.parse_args() | |
| log_dir = os.path.join(opts.log_path, opts.config) + '/' | |
| config = yaml.load(open('./configs/' + opts.config + '.yaml', 'r'), Loader=yaml.FullLoader) | |
| # Initialize trainer | |
| trainer = Trainer(config, opts) | |
| trainer.initialize(opts.stylegan_model_path, opts.arcface_model_path, opts.parsing_model_path) | |
| trainer.to(device) | |
| state_dict = torch.load(opts.pretrained_model_path)#os.path.join(opts.log_path, opts.config + '/checkpoint.pth')) | |
| trainer.enc.load_state_dict(torch.load(opts.pretrained_model_path)) | |
| trainer.enc.eval() | |
| img_to_tensor = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| # simple inference | |
| image_dir = opts.input_path | |
| save_dir = opts.save_path | |
| os.makedirs(save_dir, exist_ok=True) | |
| with torch.no_grad(): | |
| img_list = [glob.glob1(image_dir, ext) for ext in ['*jpg','*png']] | |
| img_list = [item for sublist in img_list for item in sublist] | |
| img_list.sort() | |
| for i, img_name in enumerate(img_list): | |
| #print(i, img_name) | |
| image_A = img_to_tensor(Image.open(image_dir + img_name)).unsqueeze(0).to(device) | |
| output = trainer.test(img=image_A, return_latent=True) | |
| feature = output.pop() | |
| latent = output.pop() | |
| #np.save(save_dir + 'latent_code_%d.npy'%i, latent.cpu().numpy()) | |
| utils.save_image(clip_img(output[1]), save_dir + img_name) | |
| if i > 1000: | |
| break | |