Spaces:
Paused
Paused
| import torch | |
| from torch import nn | |
| from models.Encoders import ClipBlendingModel, PostProcessModel | |
| from models.Net import Net | |
| from utils.bicubic import BicubicDownSample | |
| from utils.image_utils import DilateErosion | |
| from utils.save_utils import save_gen_image, save_latents | |
| class Blending(nn.Module): | |
| """ | |
| Module for transferring the desired hair color and post processing | |
| """ | |
| def __init__(self, opts, net=None): | |
| super().__init__() | |
| self.opts = opts | |
| if net is None: | |
| self.net = Net(self.opts) | |
| else: | |
| self.net = net | |
| blending_checkpoint = torch.load(self.opts.blending_checkpoint) | |
| self.blending_encoder = ClipBlendingModel(blending_checkpoint.get('clip', "ViT-B/32")) | |
| self.blending_encoder.load_state_dict(blending_checkpoint['model_state_dict'], strict=False) | |
| self.blending_encoder.to(self.opts.device).eval() | |
| self.post_process = PostProcessModel().to(self.opts.device).eval() | |
| self.post_process.load_state_dict(torch.load(self.opts.pp_checkpoint)['model_state_dict']) | |
| self.dilate_erosion = DilateErosion(dilate_erosion=self.opts.smooth, device=self.opts.device) | |
| self.downsample_256 = BicubicDownSample(factor=4) | |
| def blend_images(self, align_shape, align_color, name_to_embed, **kwargs): | |
| I_1 = name_to_embed['face']['image_norm_256'] | |
| I_2 = name_to_embed['shape']['image_norm_256'] | |
| I_3 = name_to_embed['color']['image_norm_256'] | |
| mask_de = self.dilate_erosion.hair_from_mask( | |
| torch.cat([name_to_embed[x]['mask'] for x in ['face', 'color']], dim=0) | |
| ) | |
| HM_1D, _ = mask_de[0][0].unsqueeze(0), mask_de[1][0].unsqueeze(0) | |
| HM_3D, HM_3E = mask_de[0][1].unsqueeze(0), mask_de[1][1].unsqueeze(0) | |
| latent_S_1, latent_F_align = name_to_embed['face']['S'], align_shape['latent_F_align'] | |
| HM_X = align_color['HM_X'] | |
| latent_S_3 = name_to_embed['color']["S"] | |
| HM_XD, _ = self.dilate_erosion.mask(HM_X) | |
| target_mask = (1 - HM_1D) * (1 - HM_3D) * (1 - HM_XD) | |
| # Blending | |
| if I_1 is not I_3 or I_1 is not I_2: | |
| S_blend_6_18 = self.blending_encoder(latent_S_1[:, 6:], latent_S_3[:, 6:], I_1 * target_mask, I_3 * HM_3E) | |
| S_blend = torch.cat((latent_S_1[:, :6], S_blend_6_18), dim=1) | |
| else: | |
| S_blend = latent_S_1 | |
| I_blend, _ = self.net.generator([S_blend], input_is_latent=True, return_latents=False, start_layer=4, | |
| end_layer=8, layer_in=latent_F_align) | |
| I_blend_256 = self.downsample_256(I_blend) | |
| # Post Process | |
| S_final, F_final = self.post_process(I_1, I_blend_256) | |
| I_final, _ = self.net.generator([S_final], input_is_latent=True, return_latents=False, | |
| start_layer=5, end_layer=8, layer_in=F_final) | |
| if self.opts.save_all: | |
| exp_name = exp_name if (exp_name := kwargs.get('exp_name')) is not None else "" | |
| output_dir = self.opts.save_all_dir / exp_name | |
| save_gen_image(output_dir, 'Blending', 'blending.png', I_blend) | |
| save_latents(output_dir, 'Blending', 'blending.npz', S_blend=S_blend) | |
| save_gen_image(output_dir, 'Final', 'final.png', I_final) | |
| save_latents(output_dir, 'Final', 'final.npz', S_final=S_final, F_final=F_final) | |
| final_image = ((I_final[0] + 1) / 2).clip(0, 1) | |
| return final_image | |