| | import argparse |
| | import math |
| | import sys |
| | sys.path.append("..") |
| | import numpy as np |
| | import os |
| | import torch |
| |
|
| | import trimesh |
| |
|
| | from datasets import Object_Occ,Scale_Shift_Rotate |
| | from models import get_model |
| | from pathlib import Path |
| | import open3d as o3d |
| | from configs.config_utils import CONFIG |
| | import tqdm |
| | from util import misc |
| | from datasets.taxonomy import synthetic_arkit_category_combined |
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser('', add_help=False) |
| | parser.add_argument('--configs',type=str,required=True) |
| | parser.add_argument('--ae-pth',type=str) |
| | parser.add_argument("--category",nargs='+', type=str) |
| | parser.add_argument('--world_size', default=1, type=int, |
| | help='number of distributed processes') |
| | parser.add_argument('--local_rank', default=-1, type=int) |
| | parser.add_argument('--dist_on_itp', action='store_true') |
| | parser.add_argument('--dist_url', default='env://', |
| | help='url used to set up distributed training') |
| | parser.add_argument('--device', default='cuda', |
| | help='device to use for training / testing') |
| | parser.add_argument("--batch_size", default=1, type=int) |
| | parser.add_argument("--data-pth",default="../data",type=str) |
| |
|
| | args = parser.parse_args() |
| | misc.init_distributed_mode(args) |
| | device = torch.device(args.device) |
| |
|
| | config_path=args.configs |
| | config=CONFIG(config_path) |
| | dataset_config=config.config['dataset'] |
| | dataset_config['data_path']=args.data_pth |
| | |
| | transform=Scale_Shift_Rotate(rot_shift_surface=True,use_scale=True) |
| | if len(args.category)==1 and args.category[0]=="all": |
| | category=synthetic_arkit_category_combined["all"] |
| | else: |
| | category=args.category |
| | train_dataset = Object_Occ(dataset_config['data_path'], split="train", |
| | categories=category, |
| | transform=transform, sampling=True, |
| | num_samples=1024, return_surface=True, |
| | surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1) |
| | val_dataset = Object_Occ(dataset_config['data_path'], split="val", |
| | categories=category, |
| | transform=transform, sampling=True, |
| | num_samples=1024, return_surface=True, |
| | surface_sampling=True, surface_size=dataset_config['surface_size'],replica=1) |
| | num_tasks = misc.get_world_size() |
| | global_rank = misc.get_rank() |
| | train_sampler = torch.utils.data.DistributedSampler( |
| | train_dataset, num_replicas=num_tasks, rank=global_rank, |
| | shuffle=False) |
| | val_sampler=torch.utils.data.DistributedSampler( |
| | val_dataset, num_replicas=num_tasks, rank=global_rank, |
| | shuffle=False) |
| | |
| | batch_size=args.batch_size |
| | train_dataloader=torch.utils.data.DataLoader( |
| | train_dataset,sampler=train_sampler, |
| | batch_size=batch_size, |
| | num_workers=10, |
| | shuffle=False, |
| | drop_last=False, |
| | ) |
| | val_dataloader = torch.utils.data.DataLoader( |
| | val_dataset, sampler=val_sampler, |
| | batch_size=batch_size, |
| | num_workers=10, |
| | shuffle=False, |
| | drop_last=False, |
| | ) |
| | dataloader_list=[train_dataloader,val_dataloader] |
| | |
| | output_dir=os.path.join(dataset_config['data_path'],"other_data") |
| | |
| |
|
| | model_config=config.config['model'] |
| | model=get_model(model_config) |
| | model.load_state_dict(torch.load(args.ae_pth)['model']) |
| | model.eval().float().to(device) |
| | |
| |
|
| | with torch.no_grad(): |
| | for e in range(5): |
| | for dataloader in dataloader_list: |
| | for data_iter_step, data_batch in tqdm.tqdm(enumerate(dataloader)): |
| | surface = data_batch['surface'].to(device, non_blocking=True) |
| | model_ids=data_batch['model_id'] |
| | tran_mats=data_batch['tran_mat'] |
| | categories=data_batch['category'] |
| | with torch.no_grad(): |
| | plane_feat,_,means,logvars=model.encode(surface) |
| | plane_feat=torch.nn.functional.interpolate(plane_feat,scale_factor=0.5,mode='bilinear') |
| | vars=torch.exp(logvars) |
| | means=torch.nn.functional.interpolate(means,scale_factor=0.5,mode="bilinear") |
| | vars=torch.nn.functional.interpolate(vars,scale_factor=0.5,mode="bilinear")/4 |
| | sample_logvars=torch.log(vars) |
| |
|
| | for j in range(means.shape[0]): |
| | |
| | mean=means[j].float().cpu().numpy() |
| | logvar=sample_logvars[j].float().cpu().numpy() |
| | tran_mat=tran_mats[j].float().cpu().numpy() |
| |
|
| | output_folder=os.path.join(output_dir,categories[j],'9_triplane_kl25_64',model_ids[j]) |
| | Path(output_folder).mkdir(parents=True, exist_ok=True) |
| | exist_len=len(os.listdir(output_folder)) |
| | save_filepath=os.path.join(output_folder,"triplane_feat_%d.npz"%(exist_len)) |
| | np.savez_compressed(save_filepath,mean=mean,logvar=logvar,tran_mat=tran_mat) |
| |
|