Spaces:
Build error
Build error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.nn.utils.spectral_norm as spectral_norm | |
| from modules.eg3ds.models.networks_stylegan2 import SynthesisBlock | |
| from modules.eg3ds.models.superresolution import SynthesisBlockNoUp | |
| from modules.eg3ds.models.superresolution import SuperresolutionHybrid8XDC | |
| from modules.real3d.facev2v_warp.model import WarpBasedTorsoModelMediaPipe as torso_model_v1 | |
| from modules.real3d.facev2v_warp.model2 import WarpBasedTorsoModelMediaPipe as torso_model_v2 | |
| from utils.commons.hparams import hparams | |
| from utils.commons.image_utils import dilate, erode | |
| class SuperresolutionHybrid8XDC_Warp(SuperresolutionHybrid8XDC): | |
| def __init__(self, channels, img_resolution, sr_num_fp16_res, sr_antialias, **block_kwargs): | |
| super().__init__(channels, img_resolution, sr_num_fp16_res, sr_antialias, **block_kwargs) | |
| if hparams.get("torso_model_version", "v1") == 'v1': | |
| self.torso_model = torso_model_v1('standard') | |
| elif hparams.get("torso_model_version", "v1") == 'v2': | |
| self.torso_model = torso_model_v2('standard') | |
| else: raise NotImplementedError() | |
| # self.torso_model = WarpBasedTorsoModelMediaPipe('small') | |
| self.torso_encoder = nn.Sequential(*[ | |
| nn.Conv2d(64, 256, 1, 1, padding=0), | |
| ]) | |
| self.bg_encoder = nn.Sequential(*[ | |
| nn.Conv2d(3, 64, 3, 1, padding=1), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(64, 256, 3, 1, padding=1), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(256, 256, 3, 1, padding=1), | |
| ]) | |
| if hparams.get("weight_fuse", True): | |
| if hparams['htbsr_head_weight_fuse_mode'] in ['v1']: | |
| fuse_in_dim = 512 | |
| # elif hparams['htbsr_head_weight_fuse_mode'] in ['v2']: | |
| else: | |
| fuse_in_dim = 512 | |
| self.head_torso_alpha_predictor = nn.Sequential(*[ | |
| nn.Conv2d(3+1+3, 32, 3, 1, padding=1), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(32, 32, 3, 1, padding=1), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(32, 1, 3, 1, padding=1), | |
| nn.Sigmoid(), | |
| ]) | |
| self.fuse_head_torso_convs = nn.Sequential(*[ | |
| nn.Conv2d(256+256, 256, 3, 1, padding=1), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(256, 256, 3, 1, padding=1), | |
| ]) | |
| self.head_torso_block = SynthesisBlockNoUp(256, 256, w_dim=512, resolution=256, | |
| img_channels=3, is_last=False, use_fp16=False, conv_clamp=None, **block_kwargs) | |
| else: | |
| fuse_in_dim = 768 | |
| self.fuse_fg_bg_convs = nn.Sequential(*[ | |
| nn.Conv2d(fuse_in_dim, 64, 1, 1, padding=0), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(64, 256, 3, 1, padding=1), | |
| nn.LeakyReLU(), | |
| nn.Conv2d(256, 256, 3, 1, padding=1), | |
| ]) | |
| def forward(self, rgb, x, ws, ref_torso_rgb, ref_bg_rgb, weights_img, segmap, kp_s, kp_d, target_torso_mask=None, **block_kwargs): | |
| weights_img = weights_img.detach() | |
| ws = ws[:, -1:, :].expand([rgb.shape[0], 3, -1]) | |
| if x.shape[-1] != self.input_resolution: | |
| x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), | |
| mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), | |
| mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| rgb_256 = torch.nn.functional.interpolate(rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| weights_256 = torch.nn.functional.interpolate(weights_img, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| ref_torso_rgb_256 = torch.nn.functional.interpolate(ref_torso_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| ref_bg_rgb_256 = torch.nn.functional.interpolate(ref_bg_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| x, rgb = self.block0(x, rgb, ws, **block_kwargs) # sr branch, 128x128 head img ==> 256x256 head img | |
| if hparams.get("torso_model_version", "v1") == 'v1': | |
| rgb_torso, facev2v_ret = self.torso_model.forward(ref_torso_rgb_256, segmap, kp_s, kp_d, rgb_256.detach(), cal_loss=True, target_torso_mask=target_torso_mask) | |
| elif hparams.get("torso_model_version", "v1") == 'v2': | |
| rgb_torso, facev2v_ret = self.torso_model.forward(ref_torso_rgb_256, segmap, kp_s, kp_d, rgb_256.detach(), weights_256.detach(), cal_loss=True, target_torso_mask=target_torso_mask) | |
| x_torso = self.torso_encoder(facev2v_ret['deformed_torso_hid']) | |
| x_bg = self.bg_encoder(ref_bg_rgb_256) | |
| if hparams.get("weight_fuse", True): | |
| if hparams['htbsr_head_weight_fuse_mode'] == 'v1': | |
| rgb = rgb * weights_256 + rgb_torso * (1-weights_256) # get person img | |
| x = x * weights_256 + x_torso * (1-weights_256) # get person img | |
| head_occlusion = weights_256.clone() | |
| htbsr_head_threshold = hparams['htbsr_head_threshold'] | |
| head_occlusion[head_occlusion > htbsr_head_threshold] = 1. | |
| torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1) | |
| rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6 | |
| x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6 | |
| x = self.fuse_fg_bg_convs(x) | |
| x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
| elif hparams['htbsr_head_weight_fuse_mode'] == 'v2': | |
| # 用alpha-cat实现head torso的x的融合;替代了之前的直接alpha相加 | |
| head_torso_alpha = weights_256.clone() | |
| head_torso_alpha[head_torso_alpha>weights_256] = weights_256[head_torso_alpha>weights_256] | |
| rgb = rgb * head_torso_alpha + rgb_torso * (1-head_torso_alpha) # get person img | |
| x = torch.cat([x * head_torso_alpha, x_torso * (1-head_torso_alpha)], dim=1) | |
| x = self.fuse_head_torso_convs(x) | |
| x, rgb = self.head_torso_block(x, rgb, ws, **block_kwargs) | |
| head_occlusion = head_torso_alpha.clone() | |
| # 鼓励weights与mask逼近后,不再需要手动修改head weights threshold到很小的值了,0.7就行 | |
| htbsr_head_threshold = hparams['htbsr_head_threshold'] | |
| head_occlusion[head_occlusion > htbsr_head_threshold] = 1. | |
| torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1) | |
| rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6 | |
| x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6 | |
| x = self.fuse_fg_bg_convs(x) | |
| x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
| elif hparams['htbsr_head_weight_fuse_mode'] == 'v3': | |
| # v2:用alpha-cat实现head torso的x的融合;替代了之前的直接alpha相加 | |
| # v3: 用nn额外后处理head mask | |
| head_torso_alpha_inp = torch.cat([rgb.clamp(-1,1)/2+0.5, weights_256, rgb_torso.clamp(-1,1)/2+0.5], dim=1) | |
| head_torso_alpha_ = self.head_torso_alpha_predictor(head_torso_alpha_inp) | |
| head_torso_alpha = head_torso_alpha_.clone() | |
| head_torso_alpha[head_torso_alpha>weights_256] = weights_256[head_torso_alpha>weights_256] | |
| rgb = rgb * head_torso_alpha + rgb_torso * (1-head_torso_alpha) # get person img | |
| x = torch.cat([x * head_torso_alpha, x_torso * (1-head_torso_alpha)], dim=1) # run6 | |
| x = self.fuse_head_torso_convs(x) | |
| x, rgb = self.head_torso_block(x, rgb, ws, **block_kwargs) | |
| head_occlusion = head_torso_alpha.clone() | |
| htbsr_head_threshold = hparams['htbsr_head_threshold'] | |
| if not self.training: | |
| head_occlusion_ = head_occlusion[head_occlusion>0.05] | |
| htbsr_head_threshold = max(head_occlusion_.quantile(0.05), htbsr_head_threshold) # 过滤掉比0.05大的最后5% voxels | |
| head_occlusion[head_occlusion > htbsr_head_threshold] = 1. | |
| torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1) | |
| rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6 | |
| # Todo: 修改这里,把cat的occlusion去掉?或者把occlusion截断一下。 | |
| x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6 | |
| x = self.fuse_fg_bg_convs(x) | |
| x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
| else: | |
| # v4 尝试直接用cat处理head-torso的hid的融合,发现不好 | |
| # v5 try1处理x的时候也把cat里的alpha去掉了,但是try1发现导致occlusion直接变1.所以去掉 | |
| # v5 try2给torso也加了threshold让他算rgb的时候更加sharp, 会导致torso周围黑边? | |
| raise NotImplementedError() | |
| else: | |
| x = torch.cat([x, x_torso, x_bg], dim=1) # run6 | |
| x = self.fuse_fg_bg_convs(x) | |
| x, rgb = self.block1(x, None, ws, **block_kwargs) | |
| return rgb, facev2v_ret | |
| def infer_forward_stage1(self, rgb, x, ws, ref_torso_rgb, ref_bg_rgb, weights_img, segmap, kp_s, kp_d, **block_kwargs): | |
| weights_img = weights_img.detach() | |
| ws = ws[:, -1:, :].repeat(1, 3, 1) | |
| if x.shape[-1] != self.input_resolution: | |
| x = torch.nn.functional.interpolate(x, size=(self.input_resolution, self.input_resolution), | |
| mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| rgb = torch.nn.functional.interpolate(rgb, size=(self.input_resolution, self.input_resolution), | |
| mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| rgb_256 = torch.nn.functional.interpolate(rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| weights_256 = torch.nn.functional.interpolate(weights_img, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| ref_torso_rgb_256 = torch.nn.functional.interpolate(ref_torso_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| ref_bg_rgb_256 = torch.nn.functional.interpolate(ref_bg_rgb, size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| x, rgb = self.block0(x, rgb, ws, **block_kwargs) | |
| facev2v_ret = self.torso_model.infer_forward_stage1(ref_torso_rgb_256, segmap, kp_s, kp_d, rgb_256.detach(), cal_loss=True) | |
| facev2v_ret['ref_bg_rgb_256'] = ref_bg_rgb_256 | |
| facev2v_ret['weights_256'] = weights_256 | |
| facev2v_ret['x'] = x | |
| facev2v_ret['ws'] = ws | |
| facev2v_ret['rgb'] = rgb | |
| return facev2v_ret | |
| def infer_forward_stage2(self, facev2v_ret, **block_kwargs): | |
| x = facev2v_ret['x'] | |
| ws = facev2v_ret['ws'] | |
| rgb = facev2v_ret['rgb'] | |
| ref_bg_rgb_256 = facev2v_ret['ref_bg_rgb_256'] | |
| weights_256 = facev2v_ret['weights_256'] | |
| rgb_torso = self.torso_model.infer_forward_stage2(facev2v_ret) | |
| x_torso = self.torso_encoder(facev2v_ret['deformed_torso_hid']) | |
| x_bg = self.bg_encoder(ref_bg_rgb_256) | |
| if hparams.get("weight_fuse", True): | |
| rgb = rgb * weights_256 + rgb_torso * (1-weights_256) # get person img | |
| x = x * weights_256 + x_torso * (1-weights_256) # get person img | |
| head_occlusion = weights_256.clone() | |
| head_occlusion[head_occlusion > 0.5] = 1. | |
| torso_occlusion = torch.nn.functional.interpolate(facev2v_ret['occlusion_2'], size=(256, 256), mode='bilinear', align_corners=False, antialias=self.sr_antialias) | |
| person_occlusion = (torso_occlusion + head_occlusion).clamp_(0,1) | |
| rgb = rgb * person_occlusion + ref_bg_rgb_256 * (1-person_occlusion) # run6 | |
| x = torch.cat([x * person_occlusion, x_bg * (1-person_occlusion)], dim=1) # run6 | |
| x = self.fuse_fg_bg_convs(x) | |
| x, rgb = self.block1(x, rgb, ws, **block_kwargs) | |
| else: | |
| x = torch.cat([x, x_torso, x_bg], dim=1) # run6 | |
| x = self.fuse_fg_bg_convs(x) | |
| x, rgb = self.block1(x, None, ws, **block_kwargs) | |
| return rgb, facev2v_ret | |
| if __name__ == '__main__': | |
| model = SuperresolutionHybrid8XDC_Warp(32,512,0, False) | |
| model.cuda() | |
| rgb = torch.randn([4, 3, 128, 128]).cuda() | |
| x = torch.randn([4, 32, 128, 128]).cuda() | |
| ws = torch.randn([4, 14, 512]).cuda() | |
| ref_rgb = torch.randn([4, 3, 128, 128]).cuda() | |
| ref_torso_rgb = torch.randn([4, 3, 128, 128]).cuda() | |
| y = model(rgb, x, ws, ref_rgb, ref_torso_rgb) | |
| print(" ") |