Spaces:
Runtime error
Runtime error
| import logging | |
| import numpy as np | |
| import torch | |
| import torch.distributions as dists | |
| import torch.nn.functional as F | |
| from torchvision.utils import save_image | |
| from models.archs.fcn_arch import FCNHead, MultiHeadFCNHead | |
| from models.archs.shape_attr_embedding_arch import ShapeAttrEmbedding | |
| from models.archs.transformer_arch import TransformerMultiHead | |
| from models.archs.unet_arch import ShapeUNet, UNet | |
| from models.archs.vqgan_arch import (Decoder, DecoderRes, Encoder, | |
| VectorQuantizer, | |
| VectorQuantizerSpatialTextureAware, | |
| VectorQuantizerTexture) | |
| logger = logging.getLogger('base') | |
| class BaseSampleModel(): | |
| """Base Model""" | |
| def __init__(self, opt): | |
| self.opt = opt | |
| self.device = torch.device(opt['device']) | |
| # hierarchical VQVAE | |
| self.decoder = Decoder( | |
| in_channels=opt['top_in_channels'], | |
| resolution=opt['top_resolution'], | |
| z_channels=opt['top_z_channels'], | |
| ch=opt['top_ch'], | |
| out_ch=opt['top_out_ch'], | |
| num_res_blocks=opt['top_num_res_blocks'], | |
| attn_resolutions=opt['top_attn_resolutions'], | |
| ch_mult=opt['top_ch_mult'], | |
| dropout=opt['top_dropout'], | |
| resamp_with_conv=True, | |
| give_pre_end=False).to(self.device) | |
| self.top_quantize = VectorQuantizerTexture( | |
| 1024, opt['embed_dim'], beta=0.25).to(self.device) | |
| self.top_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], | |
| opt["top_z_channels"], | |
| 1).to(self.device) | |
| self.load_top_pretrain_models() | |
| self.bot_decoder_res = DecoderRes( | |
| in_channels=opt['bot_in_channels'], | |
| resolution=opt['bot_resolution'], | |
| z_channels=opt['bot_z_channels'], | |
| ch=opt['bot_ch'], | |
| num_res_blocks=opt['bot_num_res_blocks'], | |
| ch_mult=opt['bot_ch_mult'], | |
| dropout=opt['bot_dropout'], | |
| give_pre_end=False).to(self.device) | |
| self.bot_quantize = VectorQuantizerSpatialTextureAware( | |
| opt['bot_n_embed'], | |
| opt['embed_dim'], | |
| beta=0.25, | |
| spatial_size=opt['bot_codebook_spatial_size']).to(self.device) | |
| self.bot_post_quant_conv = torch.nn.Conv2d(opt['embed_dim'], | |
| opt["bot_z_channels"], | |
| 1).to(self.device) | |
| self.load_bot_pretrain_network() | |
| # top -> bot prediction | |
| self.index_pred_guidance_encoder = UNet( | |
| in_channels=opt['index_pred_encoder_in_channels']).to(self.device) | |
| self.index_pred_decoder = MultiHeadFCNHead( | |
| in_channels=opt['index_pred_fc_in_channels'], | |
| in_index=opt['index_pred_fc_in_index'], | |
| channels=opt['index_pred_fc_channels'], | |
| num_convs=opt['index_pred_fc_num_convs'], | |
| concat_input=opt['index_pred_fc_concat_input'], | |
| dropout_ratio=opt['index_pred_fc_dropout_ratio'], | |
| num_classes=opt['index_pred_fc_num_classes'], | |
| align_corners=opt['index_pred_fc_align_corners'], | |
| num_head=18).to(self.device) | |
| self.load_index_pred_network() | |
| # VAE for segmentation mask | |
| self.segm_encoder = Encoder( | |
| ch=opt['segm_ch'], | |
| num_res_blocks=opt['segm_num_res_blocks'], | |
| attn_resolutions=opt['segm_attn_resolutions'], | |
| ch_mult=opt['segm_ch_mult'], | |
| in_channels=opt['segm_in_channels'], | |
| resolution=opt['segm_resolution'], | |
| z_channels=opt['segm_z_channels'], | |
| double_z=opt['segm_double_z'], | |
| dropout=opt['segm_dropout']).to(self.device) | |
| self.segm_quantizer = VectorQuantizer( | |
| opt['segm_n_embed'], | |
| opt['segm_embed_dim'], | |
| beta=0.25, | |
| sane_index_shape=True).to(self.device) | |
| self.segm_quant_conv = torch.nn.Conv2d(opt["segm_z_channels"], | |
| opt['segm_embed_dim'], | |
| 1).to(self.device) | |
| self.load_pretrained_segm_token() | |
| # define sampler | |
| self.sampler_fn = TransformerMultiHead( | |
| codebook_size=opt['codebook_size'], | |
| segm_codebook_size=opt['segm_codebook_size'], | |
| texture_codebook_size=opt['texture_codebook_size'], | |
| bert_n_emb=opt['bert_n_emb'], | |
| bert_n_layers=opt['bert_n_layers'], | |
| bert_n_head=opt['bert_n_head'], | |
| block_size=opt['block_size'], | |
| latent_shape=opt['latent_shape'], | |
| embd_pdrop=opt['embd_pdrop'], | |
| resid_pdrop=opt['resid_pdrop'], | |
| attn_pdrop=opt['attn_pdrop'], | |
| num_head=opt['num_head']).to(self.device) | |
| self.load_sampler_pretrained_network() | |
| self.shape = tuple(opt['latent_shape']) | |
| self.mask_id = opt['codebook_size'] | |
| self.sample_steps = opt['sample_steps'] | |
| def load_top_pretrain_models(self): | |
| # load pretrained vqgan | |
| top_vae_checkpoint = torch.load(self.opt['top_vae_path'],map_location=torch.device('cpu')) | |
| self.decoder.load_state_dict( | |
| top_vae_checkpoint['decoder'], strict=True) | |
| self.top_quantize.load_state_dict( | |
| top_vae_checkpoint['quantize'], strict=True) | |
| self.top_post_quant_conv.load_state_dict( | |
| top_vae_checkpoint['post_quant_conv'], strict=True) | |
| self.decoder.eval() | |
| self.top_quantize.eval() | |
| self.top_post_quant_conv.eval() | |
| def load_bot_pretrain_network(self): | |
| checkpoint = torch.load(self.opt['bot_vae_path'],map_location=torch.device('cpu')) | |
| self.bot_decoder_res.load_state_dict( | |
| checkpoint['bot_decoder_res'], strict=True) | |
| self.decoder.load_state_dict(checkpoint['decoder'], strict=True) | |
| self.bot_quantize.load_state_dict( | |
| checkpoint['bot_quantize'], strict=True) | |
| self.bot_post_quant_conv.load_state_dict( | |
| checkpoint['bot_post_quant_conv'], strict=True) | |
| self.bot_decoder_res.eval() | |
| self.decoder.eval() | |
| self.bot_quantize.eval() | |
| self.bot_post_quant_conv.eval() | |
| def load_pretrained_segm_token(self): | |
| # load pretrained vqgan for segmentation mask | |
| segm_token_checkpoint = torch.load(self.opt['segm_token_path'],map_location=torch.device('cpu')) | |
| self.segm_encoder.load_state_dict( | |
| segm_token_checkpoint['encoder'], strict=True) | |
| self.segm_quantizer.load_state_dict( | |
| segm_token_checkpoint['quantize'], strict=True) | |
| self.segm_quant_conv.load_state_dict( | |
| segm_token_checkpoint['quant_conv'], strict=True) | |
| self.segm_encoder.eval() | |
| self.segm_quantizer.eval() | |
| self.segm_quant_conv.eval() | |
| def load_index_pred_network(self): | |
| checkpoint = torch.load(self.opt['pretrained_index_network'],map_location=torch.device('cpu')) | |
| self.index_pred_guidance_encoder.load_state_dict( | |
| checkpoint['guidance_encoder'], strict=True) | |
| self.index_pred_decoder.load_state_dict( | |
| checkpoint['index_decoder'], strict=True) | |
| self.index_pred_guidance_encoder.eval() | |
| self.index_pred_decoder.eval() | |
| def load_sampler_pretrained_network(self): | |
| checkpoint = torch.load(self.opt['pretrained_sampler'],map_location=torch.device('cpu')) | |
| self.sampler_fn.load_state_dict(checkpoint, strict=True) | |
| self.sampler_fn.eval() | |
| def bot_index_prediction(self, feature_top, texture_mask): | |
| self.index_pred_guidance_encoder.eval() | |
| self.index_pred_decoder.eval() | |
| texture_tokens = F.interpolate( | |
| texture_mask, (32, 16), mode='nearest').view(self.batch_size, | |
| -1).long() | |
| texture_mask_flatten = texture_tokens.view(-1) | |
| min_encodings_indices_list = [ | |
| torch.full( | |
| texture_mask_flatten.size(), | |
| fill_value=-1, | |
| dtype=torch.long, | |
| device=texture_mask_flatten.device) for _ in range(18) | |
| ] | |
| with torch.no_grad(): | |
| feature_enc = self.index_pred_guidance_encoder(feature_top) | |
| memory_logits_list = self.index_pred_decoder(feature_enc) | |
| for codebook_idx, memory_logits in enumerate(memory_logits_list): | |
| region_of_interest = texture_mask_flatten == codebook_idx | |
| if torch.sum(region_of_interest) > 0: | |
| memory_indices_pred = memory_logits.argmax(dim=1).view(-1) | |
| memory_indices_pred = memory_indices_pred | |
| min_encodings_indices_list[codebook_idx][ | |
| region_of_interest] = memory_indices_pred[ | |
| region_of_interest] | |
| min_encodings_indices_return_list = [ | |
| min_encodings_indices.view((1, 32, 16)) | |
| for min_encodings_indices in min_encodings_indices_list | |
| ] | |
| return min_encodings_indices_return_list | |
| def sample_and_refine(self, save_dir=None, img_name=None): | |
| # sample 32x16 features indices | |
| sampled_top_indices_list = self.sample_fn( | |
| temp=1, sample_steps=self.sample_steps) | |
| for sample_idx in range(self.batch_size): | |
| sample_indices = [ | |
| sampled_indices_cur[sample_idx:sample_idx + 1] | |
| for sampled_indices_cur in sampled_top_indices_list | |
| ] | |
| top_quant = self.top_quantize.get_codebook_entry( | |
| sample_indices, self.texture_mask[sample_idx:sample_idx + 1], | |
| (sample_indices[0].size(0), self.shape[0], self.shape[1], | |
| self.opt["top_z_channels"])) | |
| top_quant = self.top_post_quant_conv(top_quant) | |
| bot_indices_list = self.bot_index_prediction( | |
| top_quant, self.texture_mask[sample_idx:sample_idx + 1]) | |
| quant_bot = self.bot_quantize.get_codebook_entry( | |
| bot_indices_list, self.texture_mask[sample_idx:sample_idx + 1], | |
| (bot_indices_list[0].size(0), bot_indices_list[0].size(1), | |
| bot_indices_list[0].size(2), | |
| self.opt["bot_z_channels"])) #.permute(0, 3, 1, 2) | |
| quant_bot = self.bot_post_quant_conv(quant_bot) | |
| bot_dec_res = self.bot_decoder_res(quant_bot) | |
| dec = self.decoder(top_quant, bot_h=bot_dec_res) | |
| dec = ((dec + 1) / 2) | |
| dec = dec.clamp_(0, 1) | |
| if save_dir is None and img_name is None: | |
| return dec | |
| else: | |
| save_image( | |
| dec, | |
| f'{save_dir}/{img_name[sample_idx]}', | |
| nrow=1, | |
| padding=4) | |
| def sample_fn(self, temp=1.0, sample_steps=None): | |
| self.sampler_fn.eval() | |
| x_t = torch.ones((self.batch_size, np.prod(self.shape)), | |
| device=self.device).long() * self.mask_id | |
| unmasked = torch.zeros_like(x_t, device=self.device).bool() | |
| sample_steps = list(range(1, sample_steps + 1)) | |
| texture_tokens = F.interpolate( | |
| self.texture_mask, (32, 16), | |
| mode='nearest').view(self.batch_size, -1).long() | |
| texture_mask_flatten = texture_tokens.view(-1) | |
| # min_encodings_indices_list would be used to visualize the image | |
| min_encodings_indices_list = [ | |
| torch.full( | |
| texture_mask_flatten.size(), | |
| fill_value=-1, | |
| dtype=torch.long, | |
| device=texture_mask_flatten.device) for _ in range(18) | |
| ] | |
| for t in reversed(sample_steps): | |
| t = torch.full((self.batch_size, ), | |
| t, | |
| device=self.device, | |
| dtype=torch.long) | |
| # where to unmask | |
| changes = torch.rand( | |
| x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1) | |
| # don't unmask somewhere already unmasked | |
| changes = torch.bitwise_xor(changes, | |
| torch.bitwise_and(changes, unmasked)) | |
| # update mask with changes | |
| unmasked = torch.bitwise_or(unmasked, changes) | |
| x_0_logits_list = self.sampler_fn( | |
| x_t, self.segm_tokens, texture_tokens, t=t) | |
| changes_flatten = changes.view(-1) | |
| ori_shape = x_t.shape # [b, h*w] | |
| x_t = x_t.view(-1) # [b*h*w] | |
| for codebook_idx, x_0_logits in enumerate(x_0_logits_list): | |
| if torch.sum(texture_mask_flatten[changes_flatten] == | |
| codebook_idx) > 0: | |
| # scale by temperature | |
| x_0_logits = x_0_logits / temp | |
| x_0_dist = dists.Categorical(logits=x_0_logits) | |
| x_0_hat = x_0_dist.sample().long() | |
| x_0_hat = x_0_hat.view(-1) | |
| # only replace the changed indices with corresponding codebook_idx | |
| changes_segm = torch.bitwise_and( | |
| changes_flatten, texture_mask_flatten == codebook_idx) | |
| # x_t would be the input to the transformer, so the index range should be continual one | |
| x_t[changes_segm] = x_0_hat[ | |
| changes_segm] + 1024 * codebook_idx | |
| min_encodings_indices_list[codebook_idx][ | |
| changes_segm] = x_0_hat[changes_segm] | |
| x_t = x_t.view(ori_shape) # [b, h*w] | |
| min_encodings_indices_return_list = [ | |
| min_encodings_indices.view(ori_shape) | |
| for min_encodings_indices in min_encodings_indices_list | |
| ] | |
| self.sampler_fn.train() | |
| return min_encodings_indices_return_list | |
| def get_quantized_segm(self, segm): | |
| segm_one_hot = F.one_hot( | |
| segm.squeeze(1).long(), | |
| num_classes=self.opt['segm_num_segm_classes']).permute( | |
| 0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() | |
| encoded_segm_mask = self.segm_encoder(segm_one_hot) | |
| encoded_segm_mask = self.segm_quant_conv(encoded_segm_mask) | |
| _, _, [_, _, segm_tokens] = self.segm_quantizer(encoded_segm_mask) | |
| return segm_tokens | |
| class SampleFromParsingModel(BaseSampleModel): | |
| """SampleFromParsing model. | |
| """ | |
| def feed_data(self, data): | |
| self.segm = data['segm'].to(self.device) | |
| self.texture_mask = data['texture_mask'].to(self.device) | |
| self.batch_size = self.segm.size(0) | |
| self.segm_tokens = self.get_quantized_segm(self.segm) | |
| self.segm_tokens = self.segm_tokens.view(self.batch_size, -1) | |
| def inference(self, data_loader, save_dir): | |
| for _, data in enumerate(data_loader): | |
| img_name = data['img_name'] | |
| self.feed_data(data) | |
| with torch.no_grad(): | |
| self.sample_and_refine(save_dir, img_name) | |
| class SampleFromPoseModel(BaseSampleModel): | |
| """SampleFromPose model. | |
| """ | |
| def __init__(self, opt): | |
| super().__init__(opt) | |
| # pose-to-parsing | |
| self.shape_attr_embedder = ShapeAttrEmbedding( | |
| dim=opt['shape_embedder_dim'], | |
| out_dim=opt['shape_embedder_out_dim'], | |
| cls_num_list=opt['shape_attr_class_num']).to(self.device) | |
| self.shape_parsing_encoder = ShapeUNet( | |
| in_channels=opt['shape_encoder_in_channels']).to(self.device) | |
| self.shape_parsing_decoder = FCNHead( | |
| in_channels=opt['shape_fc_in_channels'], | |
| in_index=opt['shape_fc_in_index'], | |
| channels=opt['shape_fc_channels'], | |
| num_convs=opt['shape_fc_num_convs'], | |
| concat_input=opt['shape_fc_concat_input'], | |
| dropout_ratio=opt['shape_fc_dropout_ratio'], | |
| num_classes=opt['shape_fc_num_classes'], | |
| align_corners=opt['shape_fc_align_corners'], | |
| ).to(self.device) | |
| self.load_shape_generation_models() | |
| self.palette = [[0, 0, 0], [255, 250, 250], [220, 220, 220], | |
| [250, 235, 215], [255, 250, 205], [211, 211, 211], | |
| [70, 130, 180], [127, 255, 212], [0, 100, 0], | |
| [50, 205, 50], [255, 255, 0], [245, 222, 179], | |
| [255, 140, 0], [255, 0, 0], [16, 78, 139], | |
| [144, 238, 144], [50, 205, 174], [50, 155, 250], | |
| [160, 140, 88], [213, 140, 88], [90, 140, 90], | |
| [185, 210, 205], [130, 165, 180], [225, 141, 151]] | |
| def load_shape_generation_models(self): | |
| checkpoint = torch.load(self.opt['pretrained_parsing_gen'],map_location=torch.device('cpu')) | |
| self.shape_attr_embedder.load_state_dict( | |
| checkpoint['embedder'], strict=True) | |
| self.shape_attr_embedder.eval() | |
| self.shape_parsing_encoder.load_state_dict( | |
| checkpoint['encoder'], strict=True) | |
| self.shape_parsing_encoder.eval() | |
| self.shape_parsing_decoder.load_state_dict( | |
| checkpoint['decoder'], strict=True) | |
| self.shape_parsing_decoder.eval() | |
| def feed_data(self, data): | |
| self.pose = data['densepose'].to(self.device) | |
| self.batch_size = self.pose.size(0) | |
| self.shape_attr = data['shape_attr'].to(self.device) | |
| self.upper_fused_attr = data['upper_fused_attr'].to(self.device) | |
| self.lower_fused_attr = data['lower_fused_attr'].to(self.device) | |
| self.outer_fused_attr = data['outer_fused_attr'].to(self.device) | |
| def inference(self, data_loader, save_dir): | |
| for _, data in enumerate(data_loader): | |
| img_name = data['img_name'] | |
| self.feed_data(data) | |
| with torch.no_grad(): | |
| self.generate_parsing_map() | |
| self.generate_quantized_segm() | |
| self.generate_texture_map() | |
| self.sample_and_refine(save_dir, img_name) | |
| def generate_parsing_map(self): | |
| with torch.no_grad(): | |
| attr_embedding = self.shape_attr_embedder(self.shape_attr) | |
| pose_enc = self.shape_parsing_encoder(self.pose, attr_embedding) | |
| seg_logits = self.shape_parsing_decoder(pose_enc) | |
| self.segm = seg_logits.argmax(dim=1) | |
| self.segm = self.segm.unsqueeze(1) | |
| def generate_quantized_segm(self): | |
| self.segm_tokens = self.get_quantized_segm(self.segm) | |
| self.segm_tokens = self.segm_tokens.view(self.batch_size, -1) | |
| def generate_texture_map(self): | |
| upper_cls = [1., 4.] | |
| lower_cls = [3., 5., 21.] | |
| outer_cls = [2.] | |
| mask_batch = [] | |
| for idx in range(self.batch_size): | |
| mask = torch.zeros_like(self.segm[idx]) | |
| upper_fused_attr = self.upper_fused_attr[idx] | |
| lower_fused_attr = self.lower_fused_attr[idx] | |
| outer_fused_attr = self.outer_fused_attr[idx] | |
| if upper_fused_attr != 17: | |
| for cls in upper_cls: | |
| mask[self.segm[idx] == cls] = upper_fused_attr + 1 | |
| if lower_fused_attr != 17: | |
| for cls in lower_cls: | |
| mask[self.segm[idx] == cls] = lower_fused_attr + 1 | |
| if outer_fused_attr != 17: | |
| for cls in outer_cls: | |
| mask[self.segm[idx] == cls] = outer_fused_attr + 1 | |
| mask_batch.append(mask) | |
| self.texture_mask = torch.stack(mask_batch, dim=0).to(torch.float32) | |
| def feed_pose_data(self, pose_img): | |
| # for ui demo | |
| self.pose = pose_img.to(self.device) | |
| self.batch_size = self.pose.size(0) | |
| def feed_shape_attributes(self, shape_attr): | |
| # for ui demo | |
| self.shape_attr = shape_attr.to(self.device) | |
| def feed_texture_attributes(self, texture_attr): | |
| # for ui demo | |
| self.upper_fused_attr = texture_attr[0].unsqueeze(0).to(self.device) | |
| self.lower_fused_attr = texture_attr[1].unsqueeze(0).to(self.device) | |
| self.outer_fused_attr = texture_attr[2].unsqueeze(0).to(self.device) | |
| def palette_result(self, result): | |
| seg = result[0] | |
| palette = np.array(self.palette) | |
| assert palette.shape[1] == 3 | |
| assert len(palette.shape) == 2 | |
| color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) | |
| for label, color in enumerate(palette): | |
| color_seg[seg == label, :] = color | |
| # convert to BGR | |
| # color_seg = color_seg[..., ::-1] | |
| return color_seg | |