Spaces:
Paused
Paused
| import argparse | |
| import os | |
| import torch | |
| import trimesh | |
| from cube3d.inference.engine import Engine, EngineFast | |
| from cube3d.inference.utils import normalize_bbox, select_device | |
| from cube3d.mesh_utils.postprocessing import ( | |
| PYMESHLAB_AVAILABLE, | |
| create_pymeshset, | |
| postprocess_mesh, | |
| save_mesh, | |
| ) | |
| from cube3d.renderer import renderer | |
| from cube3d.training.dataset import LegosTestDataset, LegosDataset | |
| from torch.utils.data.dataloader import DataLoader | |
| from cube3d.training.utils import normalize_bboxs | |
| def generate_mesh( | |
| engine, | |
| prompt, | |
| output_dir, | |
| output_name, | |
| resolution_base=8.0, | |
| disable_postprocess=False, | |
| top_p=None, | |
| bounding_box_xyz=None, | |
| ): | |
| #import ipdb; ipdb.set_trace() | |
| mesh_v_f = engine.t2s( | |
| [prompt], | |
| use_kv_cache=True, | |
| resolution_base=resolution_base, | |
| top_p=top_p, | |
| bounding_box_xyz=bounding_box_xyz, | |
| ) | |
| vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] | |
| obj_path = os.path.join(output_dir, f"{output_name}.obj") | |
| if PYMESHLAB_AVAILABLE: | |
| ms = create_pymeshset(vertices, faces) | |
| if not disable_postprocess: | |
| target_face_num = max(10000, int(faces.shape[0] * 0.1)) | |
| print(f"Postprocessing mesh to {target_face_num} faces") | |
| postprocess_mesh(ms, target_face_num, obj_path) | |
| save_mesh(ms, obj_path) | |
| else: | |
| print( | |
| "WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing." | |
| ) | |
| mesh = trimesh.Trimesh(vertices, faces) | |
| mesh.export(obj_path) | |
| return obj_path | |
| def generate_ldr( | |
| engine, | |
| prompt, | |
| inputs_ids, | |
| output_dir, | |
| output_name, | |
| resolution_base=8.0, | |
| disable_postprocess=False, | |
| top_p=None, | |
| bounding_box_xyz=None, | |
| idx=None | |
| ): | |
| #import ipdb; ipdb.set_trace() | |
| ldr = engine.t2l( | |
| #[prompt], | |
| prompt, | |
| inputs_ids=inputs_ids, | |
| use_kv_cache=True, | |
| resolution_base=resolution_base, | |
| top_p=top_p, | |
| bounding_box_xyz=bounding_box_xyz, | |
| idx=idx | |
| ) | |
| # vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] | |
| # obj_path = os.path.join(output_dir, f"{output_name}.obj") | |
| # if PYMESHLAB_AVAILABLE: | |
| # ms = create_pymeshset(vertices, faces) | |
| # if not disable_postprocess: | |
| # target_face_num = max(10000, int(faces.shape[0] * 0.1)) | |
| # print(f"Postprocessing mesh to {target_face_num} faces") | |
| # postprocess_mesh(ms, target_face_num, obj_path) | |
| # save_mesh(ms, obj_path) | |
| # else: | |
| # print( | |
| # "WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing." | |
| # ) | |
| # mesh = trimesh.Trimesh(vertices, faces) | |
| # mesh.export(obj_path) | |
| return ldr | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="cube shape generation script") | |
| parser.add_argument( | |
| "--config-path", | |
| type=str, | |
| default="cube3d/configs/open_model_v0.5.yaml", | |
| help="Path to the configuration YAML file.", | |
| ) | |
| parser.add_argument( | |
| "--data-dir", | |
| type=str, | |
| required=True, | |
| help="Path to the input dataset file.", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default="outputs/", | |
| help="Path to the output directory to store .obj and .gif files", | |
| ) | |
| parser.add_argument( | |
| "--gpt-ckpt-path", | |
| type=str, | |
| required=True, | |
| help="Path to the main GPT checkpoint file.", | |
| ) | |
| parser.add_argument( | |
| "--shape-ckpt-path", | |
| type=str, | |
| required=True, | |
| help="Path to the shape encoder/decoder checkpoint file.", | |
| ) | |
| parser.add_argument( | |
| "--save-gpt-ckpt-path", | |
| type=str, | |
| required=True, | |
| help="Path to the save adaption GPT checkpoint file.", | |
| ) | |
| parser.add_argument( | |
| "--fast-inference", | |
| help="Use optimized inference", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| required=True, | |
| help="Text prompt for generating a 3D mesh", | |
| ) | |
| parser.add_argument( | |
| "--top-p", | |
| type=float, | |
| default=None, | |
| help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.", | |
| ) | |
| parser.add_argument( | |
| "--bounding-box-xyz", | |
| nargs=3, | |
| type=float, | |
| help="Three float values for x, y, z bounding box", | |
| default=None, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--render-gif", | |
| help="Render a turntable gif of the mesh", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--disable-postprocessing", | |
| help="Disable postprocessing on the mesh. This will result in a mesh with more faces.", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--resolution-base", | |
| type=float, | |
| default=8.0, | |
| help="Resolution base for the shape decoder.", | |
| ) | |
| args = parser.parse_args() | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| device = select_device() | |
| print(f"Using device: {device}") | |
| # Initialize engine based on fast_inference flag | |
| if args.fast_inference: | |
| print( | |
| "Using cuda graphs, this will take some time to warmup and capture the graph." | |
| ) | |
| engine = EngineFast( | |
| args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device | |
| ) | |
| print("Compiled the graph.") | |
| else: | |
| engine = Engine( | |
| args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, args.save_gpt_ckpt_path, device=device | |
| ) | |
| if args.bounding_box_xyz is not None: | |
| args.bounding_box_xyz = normalize_bbox(tuple(args.bounding_box_xyz)) | |
| # Generate meshes based on input source | |
| # obj_path = generate_mesh( | |
| # engine, | |
| # args.prompt, | |
| # args.output_dir, | |
| # "output", | |
| # args.resolution_base, | |
| # args.disable_postprocessing, | |
| # args.top_p, | |
| # args.bounding_box_xyz, | |
| # ) | |
| test_dataset = LegosDataset(args) | |
| batch_size = 1 | |
| x_num = 213 | |
| y_num = 217 | |
| z_num = 529 | |
| # setup the dataloader | |
| data_loader = DataLoader( | |
| test_dataset, | |
| shuffle=False, | |
| batch_size=batch_size, | |
| ) | |
| data_iter = iter(data_loader) | |
| for idx in range(len(test_dataset)): | |
| batch = next(data_iter) | |
| prompt, targets, box = batch['prompt'], batch['target'].to(device), batch['bbox'] | |
| ldr = generate_ldr( | |
| engine, | |
| prompt, | |
| targets, | |
| args.output_dir, | |
| "output", | |
| args.resolution_base, | |
| args.disable_postprocessing, | |
| args.top_p, | |
| #args.bounding_box_xyz, | |
| normalize_bboxs(box.float(), [x_num-1, y_num-1, z_num-1]), | |
| idx | |
| ) | |
| # if idx>4: | |
| # break |