Spaces:
Paused
Paused
| import os | |
| import os.path as osp | |
| import gradio as gr | |
| import spaces | |
| import gc | |
| import trimesh | |
| from PIL import Image | |
| import logging as log | |
| from omegaconf import OmegaConf | |
| import random | |
| import numpy as np | |
| import hashlib | |
| import shutil | |
| from typing import Optional | |
| import torch | |
| from torchvision import transforms | |
| from pycg import vis, image | |
| from pycg import render as pycg_render | |
| import sys | |
| sys.path.append('.') | |
| from lib.util.render import BLENDER_PATH | |
| from third_party.PartField.partfield.model_trainer_pvcnn_only_demo import Model | |
| from lib.opt import appearance, self_similarity | |
| from lib.util import generation, common, pointcloud | |
| import third_party.TRELLIS.trellis.models as models | |
| from demos.custom_utils import render_all_views | |
| # Define project root | |
| PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) | |
| # Example Data Mappings (Manual Sync with Generator/Demo) | |
| APP_MESH_EXAMPLES = { | |
| "B01DA8LC0A": "example_data/appearance_mesh/B01DA8LC0A.glb", | |
| "B01DJH73Y6": "example_data/appearance_mesh/B01DJH73Y6.glb", | |
| "B0728KSP33": "example_data/appearance_mesh/B0728KSP33.glb", | |
| "B07B4YXNR8": "example_data/appearance_mesh/B07B4YXNR8.glb", | |
| "B07QC84LP1": "example_data/appearance_mesh/B07QC84LP1.glb", | |
| "B07QFRSC8M": "example_data/appearance_mesh/B07QFRSC8M_zup.glb", | |
| "B082QC7YKR": "example_data/appearance_mesh/B082QC7YKR_zup.glb", | |
| } | |
| APP_MESH_ABS_TO_NAME = { | |
| os.path.abspath(os.path.join(PROJECT_ROOT, v)): k | |
| for k, v in APP_MESH_EXAMPLES.items() | |
| } | |
| # Set BLENDER_HOME for pycg if not set | |
| if "BLENDER_HOME" not in os.environ: | |
| if osp.exists(BLENDER_PATH): | |
| os.environ["BLENDER_HOME"] = BLENDER_PATH | |
| else: | |
| os.environ["BLENDER_HOME"] = "blender" | |
| log.getLogger().setLevel(log.INFO) | |
| log.basicConfig(level=log.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| partfield_config = 'third_party/PartField/config.yaml' | |
| partfield_cfg = OmegaConf.load(partfield_config) | |
| # Helper to calc hash | |
| def file_sha256(path: str, chunk_size: int = 1 << 20) -> str: | |
| h = hashlib.sha256() | |
| if not osp.exists(path): return "nocontent" | |
| with open(path, "rb") as f: | |
| for chunk in iter(lambda: f.read(chunk_size), b""): | |
| h.update(chunk) | |
| return h.hexdigest() | |
| # @spaces.GPU() | |
| def init_partfield(obj_path): | |
| torch.manual_seed(0) | |
| random.seed(0) | |
| np.random.seed(0) | |
| partfield_model = Model(partfield_cfg, obj_path) | |
| partfield_model = partfield_model.to(device) | |
| ckpt = torch.load(partfield_cfg.continue_ckpt, map_location=device, weights_only=False) | |
| state_dict = ckpt.get("state_dict", ckpt) | |
| state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} | |
| missing, unexpected = partfield_model.load_state_dict(state_dict, strict=False) | |
| if missing: | |
| print("[load_partfield_model] Missing keys:", missing) | |
| if unexpected: | |
| print("[load_partfield_model] Unexpected keys:", unexpected) | |
| partfield_model.eval() | |
| return partfield_model | |
| def partfield_pipeline_predict(obj_path, output_dir, uid_tag): | |
| log.info(f"Extracting PartField feature planes for {uid_tag}...") | |
| gr.Info(f"Extracting PartField feature planes for {uid_tag}...") | |
| seed = int(partfield_cfg.seed) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| partfield_model = init_partfield(obj_path) | |
| dataloader = partfield_model.predict_dataloader() | |
| batch = next(iter(dataloader)) | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| batch = { | |
| k: (v.to(device) if torch.is_tensor(v) else v) | |
| for k, v in batch.items() | |
| } | |
| part_planes, _ = partfield_model.predict_step(batch, batch_idx=0) | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Use the explicit uid_tag instead of the one from the model | |
| partfield_save_path = f'{output_dir}/part_feat_{uid_tag}_batch_part_plane.npy' | |
| print(f"SAVING PART FIELD TO: {partfield_save_path}") | |
| np.save(partfield_save_path, part_planes) | |
| del partfield_model | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return partfield_save_path | |
| class GuideFlow3dPipeline: | |
| def __init__(self): | |
| self.cfg = None | |
| def from_pretrained(self, config): | |
| self.cfg = config | |
| return self | |
| def preprocess( | |
| self, | |
| structure_mesh: str, | |
| convert_yup_to_zup: bool, | |
| output_dir: str, | |
| ) -> None: | |
| log.info("Loading structure mesh...") | |
| gr.Info("Loading structure mesh...") | |
| if not structure_mesh.endswith('.glb'): | |
| log.error("Meshes must be in .glb format") | |
| return | |
| current_struct_hash = file_sha256(structure_mesh) | |
| cached_input_copy_path = osp.join(output_dir, "struct_mesh_input.glb") | |
| cached_struct_hash = None | |
| if osp.exists(cached_input_copy_path): | |
| cached_struct_hash = file_sha256(cached_input_copy_path) | |
| use_struct_cache = (cached_struct_hash == current_struct_hash) | |
| print( | |
| f"Use struct cache: {use_struct_cache}", | |
| f"Current struct hash: {current_struct_hash}", | |
| f"Cached input copy hash: {cached_struct_hash}", | |
| f"Input Structure mesh: {structure_mesh}", | |
| f"Checking Structure hash path at: {cached_input_copy_path}" | |
| ) | |
| struct_mesh_zup_path = osp.join(output_dir, "struct_mesh_zup.glb") | |
| if use_struct_cache and osp.exists(struct_mesh_zup_path): | |
| log.info("Using cached structure mesh (z-up).") | |
| struct_mesh = trimesh.load(struct_mesh_zup_path, force="mesh") | |
| else: | |
| # Cache miss or mismatch: Regenerate | |
| log.info("Cache miss or mismatch. Regenerating structure mesh...") | |
| # 1. Save the exact input copy for future hash checks | |
| shutil.copy2(structure_mesh, cached_input_copy_path) | |
| # Save hash file for folder scanning | |
| with open(osp.join(output_dir, "struct_mesh.hash"), "w") as f: | |
| f.write(current_struct_hash) | |
| # 2. Process | |
| struct_mesh = trimesh.load(structure_mesh, force='mesh') | |
| if convert_yup_to_zup: | |
| struct_mesh = pointcloud.convert_mesh_yup_to_zup(struct_mesh) | |
| struct_mesh.export(struct_mesh_zup_path) | |
| log.info(f"Rendering structure mesh for {self.cfg.num_views // 10} views...") | |
| gr.Info(f"Rendering structure mesh for {self.cfg.num_views // 10} views...") | |
| struct_render_dir = osp.join(output_dir, 'struct_renders') | |
| common.ensure_dir(struct_render_dir) | |
| struct_mesh_ply_path = osp.join(struct_render_dir, "mesh.ply") | |
| struct_transforms_path = osp.join(struct_render_dir, "transforms.json") | |
| if use_struct_cache and osp.exists(struct_mesh_ply_path) and osp.exists(struct_transforms_path): | |
| log.info("Using cached structure renders.") | |
| out_renderviews = sorted( | |
| [ | |
| osp.join(struct_render_dir, f) | |
| for f in os.listdir(struct_render_dir) | |
| if f.lower().endswith((".png", ".jpg", ".jpeg")) | |
| ] | |
| ) | |
| else: | |
| out_renderviews = render_all_views( | |
| struct_mesh_zup_path, | |
| struct_render_dir, | |
| num_views=self.cfg.num_views // 10, | |
| num_workers=None | |
| ) | |
| if not out_renderviews: | |
| log.error("Structure rendering failed! Aborting pipeline.") | |
| return None | |
| voxel_dir = osp.join(output_dir, 'voxels') | |
| common.ensure_dir(voxel_dir) | |
| log.info("Voxelizing structure mesh...") | |
| gr.Info("Voxelizing structure mesh...") | |
| struct_voxels_path = osp.join(voxel_dir, "struct_voxels.ply") | |
| if use_struct_cache and osp.exists(struct_voxels_path): | |
| log.info("Using cached structure voxels.") | |
| else: | |
| pointcloud.voxelize_mesh( | |
| struct_mesh_ply_path, | |
| save_path=struct_voxels_path, | |
| ) | |
| log.info("Extracting Structure Mesh PartField feature planes...") | |
| gr.Info("Extracting Structure Mesh PartField feature planes...") | |
| partfield_dir = osp.join(output_dir, 'partfield') | |
| common.ensure_dir(partfield_dir) | |
| existing = [ | |
| f for f in os.listdir(partfield_dir) | |
| if f.startswith("part_feat_struct_mesh_zup") and f.endswith("_batch_part_plane.npy") | |
| ] | |
| if use_struct_cache and existing: | |
| partfield_save_path = osp.join(partfield_dir, existing[0]) | |
| log.info(f"Using cached Structure PartField at {partfield_save_path}") | |
| else: | |
| print("PREDICTING STRUCTURE PART FIELD...") | |
| partfield_save_path = partfield_pipeline_predict( | |
| struct_mesh_zup_path, | |
| partfield_dir, | |
| "struct_mesh_zup" | |
| ) | |
| if not out_renderviews: | |
| log.info("Structure rendering failed!") | |
| gr.Warning("Structure rendering failed!") | |
| return { | |
| "struct_mesh": struct_mesh, | |
| "render_out": out_renderviews, | |
| "partfield_structure_predictions_save_path": partfield_save_path, | |
| "voxel_dir": voxel_dir | |
| } | |
| def run_appearance( | |
| self, | |
| structure_mesh: str, | |
| convert_target_yup_to_zup: bool, | |
| convert_appearance_yup_to_zup: bool, | |
| output_dir: str, | |
| appearance_mesh: str, | |
| appearance_image: str, | |
| ) -> Optional[str]: | |
| _ = self.preprocess( | |
| structure_mesh=structure_mesh, | |
| convert_yup_to_zup=convert_target_yup_to_zup, | |
| output_dir=output_dir, | |
| ) | |
| blender_cache_dir = osp.join(output_dir, "blender_cache") | |
| os.makedirs(blender_cache_dir, exist_ok=True) | |
| os.environ["XDG_CACHE_HOME"] = blender_cache_dir | |
| log.info("Running appearance-guided optimization...") | |
| gr.Info("Running appearance-guided optimization...") | |
| # Load appearance mesh | |
| log.info("Loading appearance mesh...") | |
| gr.Info("Loading appearance mesh...") | |
| if not appearance_mesh.endswith('.glb'): | |
| log.error("Meshes must be in .glb format") | |
| return None | |
| if not osp.exists(appearance_mesh): | |
| log.error(f"Appearance mesh not found: {appearance_mesh}") | |
| return None | |
| # --- HYDRATE FROM CACHE IF EXAMPLE --- | |
| abs_app_mesh = os.path.abspath(appearance_mesh) | |
| if abs_app_mesh in APP_MESH_ABS_TO_NAME: | |
| example_name = APP_MESH_ABS_TO_NAME[abs_app_mesh] | |
| cache_src = os.path.join(PROJECT_ROOT, "all_outputs", example_name) | |
| if os.path.exists(cache_src): | |
| log.info(f"Hydrating appearance data from cache: {example_name}") | |
| # Copy key folders/files if they don't match current input | |
| # We force copy to ensure we have the correct appearance data in this folder | |
| # 1. App Renders | |
| src_renders = os.path.join(cache_src, "app_renders") | |
| dst_renders = os.path.join(output_dir, "app_renders") | |
| if os.path.exists(src_renders): | |
| shutil.copytree(src_renders, dst_renders, dirs_exist_ok=True) | |
| # 2. Voxels (Merge) | |
| src_voxels = os.path.join(cache_src, "voxels") | |
| dst_voxels = os.path.join(output_dir, "voxels") | |
| if os.path.exists(src_voxels): | |
| shutil.copytree(src_voxels, dst_voxels, dirs_exist_ok=True) | |
| # 3. Features | |
| src_features = os.path.join(cache_src, "features") | |
| dst_features = os.path.join(output_dir, "features") | |
| if os.path.exists(src_features): | |
| shutil.copytree(src_features, dst_features, dirs_exist_ok=True) | |
| # 4. Latents | |
| src_latents = os.path.join(cache_src, "latents") | |
| dst_latents = os.path.join(output_dir, "latents") | |
| if os.path.exists(src_latents): | |
| shutil.copytree(src_latents, dst_latents, dirs_exist_ok=True) | |
| # 5. Partfield (App only ideally, but merge is safe due to naming) | |
| src_partfield = os.path.join(cache_src, "partfield") | |
| dst_partfield = os.path.join(output_dir, "partfield") | |
| if os.path.exists(src_partfield): | |
| shutil.copytree(src_partfield, dst_partfield, dirs_exist_ok=True) | |
| # 6. Input Copy (Tricks the hash check below) | |
| src_input = os.path.join(cache_src, "app_mesh_input.glb") | |
| if os.path.exists(src_input): | |
| shutil.copy2(src_input, os.path.join(output_dir, "app_mesh_input.glb")) | |
| # --- STRICT HASH CHECK START (APPEARANCE) --- | |
| current_app_hash = file_sha256(appearance_mesh) | |
| # Similar strategy: verify against a saved copy of the input | |
| cached_app_input_path = osp.join(output_dir, "app_mesh_input.glb") | |
| cached_app_hash = None | |
| if osp.exists(cached_app_input_path): | |
| cached_app_hash = file_sha256(cached_app_input_path) | |
| use_app_cache = (cached_app_hash == current_app_hash) | |
| print(f"Current app hash: {current_app_hash}") | |
| print(f"Cached app input hash: {cached_app_hash}") | |
| print(f"Use app cache: {use_app_cache}") | |
| app_mesh_path = osp.join(output_dir, "app_mesh.glb") | |
| app_mesh_zup_path = osp.join(output_dir, "app_mesh_zup.glb") | |
| if use_app_cache and osp.exists(app_mesh_zup_path): | |
| log.info("Using cached appearance mesh (z-up).") | |
| app_mesh = trimesh.load(app_mesh_zup_path, force="mesh") | |
| else: | |
| # Cache miss: Save input copy and process | |
| shutil.copy2(appearance_mesh, cached_app_input_path) | |
| # Save hash file | |
| with open(osp.join(output_dir, "app_mesh.hash"), "w") as f: | |
| f.write(current_app_hash) | |
| app_mesh = trimesh.load(appearance_mesh, force="mesh") | |
| app_mesh.export(app_mesh_path) | |
| if convert_appearance_yup_to_zup: | |
| app_mesh = pointcloud.convert_mesh_yup_to_zup(app_mesh) | |
| app_mesh.export(app_mesh_zup_path) | |
| # Load appearance image | |
| log.info("Loading appearance image...") | |
| gr.Info("Loading appearance image...") | |
| if appearance_image: | |
| app_image = Image.open(appearance_image).convert('RGB') | |
| app_image.save(osp.join(output_dir, 'app_image.png')) | |
| else: | |
| # If cached, app_image.png should exist | |
| if not osp.exists(osp.join(output_dir, 'app_image.png')): | |
| mesh = vis.from_file(osp.join(output_dir, 'app_mesh.glb'), load_obj_textures=True) | |
| mesh.paint_uniform_color([0.5, 0.5, 0.5]) | |
| scene = pycg_render.Scene(up_axis='+Y') | |
| scene.add_object(mesh) | |
| scene.quick_camera(w=512, h=512, pitch_angle=30, plane_angle=-45.0, fov=40) | |
| pycg_render.ThemeDiffuseShadow(None, sun_tilt_right=0.0, sun_tilt_back=0.0, sun_angle=60.0).apply_to(scene) | |
| rendering = scene.render_blender(quality=512) | |
| rendering = image.alpha_compositing(rendering, image.solid(rendering.shape[1], rendering.shape[0])) | |
| image.write(osp.join(output_dir, 'app_image.png'), rendering) | |
| # --- CHECK FOR EXISTING FEATURES TO SKIP RENDERING --- | |
| features_dir = osp.join(output_dir, "features", self.cfg.feature_name) | |
| has_dinov2_features = osp.exists(features_dir) and len(os.listdir(features_dir)) > 0 | |
| app_render_dir = osp.join(output_dir, 'app_renders') | |
| common.ensure_dir(app_render_dir) | |
| app_mesh_ply_path = osp.join(app_render_dir, "mesh.ply") | |
| app_transforms_path = osp.join(app_render_dir, "transforms.json") | |
| if has_dinov2_features: | |
| log.info("DinoV2 features found. Skipping appearance rendering.") | |
| gr.Info("DinoV2 features found. Skipping appearance rendering.") | |
| # Ensure mesh.ply exists for voxelization if it wasn't generated by rendering | |
| if not osp.exists(app_mesh_ply_path) and 'app_mesh' in locals(): | |
| app_mesh.export(app_mesh_ply_path) | |
| else: | |
| # Render views for DinoV2 feature extraction | |
| log.info(f"Rendering appearance mesh for {self.cfg.num_views} views...") | |
| gr.Info(f"Rendering appearance mesh for {self.cfg.num_views} views...") | |
| if use_app_cache and osp.exists(app_mesh_ply_path) and osp.exists(app_transforms_path): | |
| log.info("Using cached appearance renders.") | |
| out_renderviews = sorted( | |
| [ | |
| osp.join(app_render_dir, f) | |
| for f in os.listdir(app_render_dir) | |
| if f.lower().endswith((".png", ".jpg", ".jpeg")) | |
| ] | |
| ) | |
| else: | |
| out_renderviews = render_all_views( | |
| app_mesh_zup_path, | |
| app_render_dir, | |
| num_views=self.cfg.num_views, | |
| num_workers=None | |
| ) | |
| if not out_renderviews: | |
| log.info("Appearance rendering failed!") | |
| gr.Warning("Appearance rendering failed!") | |
| return None | |
| # Voxelise mesh | |
| log.info("Voxelizing appearance mesh...") | |
| gr.Info("Voxelizing appearance mesh...") | |
| app_voxel_dir = osp.join(output_dir, "voxels") | |
| common.ensure_dir(app_voxel_dir) | |
| app_voxels_path = osp.join(app_voxel_dir, "app_voxels.ply") | |
| if use_app_cache and osp.exists(app_voxels_path): | |
| log.info("Using cached appearance voxels.") | |
| else: | |
| pointcloud.voxelize_mesh( | |
| app_mesh_ply_path, | |
| save_path=app_voxels_path, | |
| ) | |
| # Extract DinoV2 Features | |
| log.info("Extracting DinoV2 features...") | |
| gr.Info("Extracting DinoV2 features...") | |
| # features_dir already defined above | |
| common.ensure_dir(features_dir) | |
| if has_dinov2_features or (use_app_cache and os.listdir(features_dir)): | |
| log.info("Using cached DINOv2 features.") | |
| else: | |
| log.info("Extracting DinoV2 features...") | |
| gr.Info("Extracting DinoV2 features...") | |
| dinov2_model = torch.hub.load(self.cfg.dinov2_repo, self.cfg.feature_name) | |
| dinov2_model.eval().cuda() | |
| transform = transforms.Compose([transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) | |
| generation.extract_feature(output_dir, dinov2_model, transform) | |
| torch.cuda.empty_cache() | |
| del dinov2_model | |
| gc.collect() | |
| # Extract SLAT Latent | |
| log.info("Extracting SLAT latent...") | |
| gr.Info("Extracting SLAT latent...") | |
| latents_dir = osp.join(output_dir, "latents", self.cfg.latent_name) | |
| common.ensure_dir(latents_dir) | |
| if use_app_cache and os.listdir(latents_dir): | |
| log.info("Using cached SLAT latent.") | |
| else: | |
| log.info("Extracting SLAT latent...") | |
| gr.Info("Extracting SLAT latent...") | |
| encoder = models.from_pretrained(self.cfg.enc_pretrained).eval().cuda() | |
| generation.get_latent(output_dir, self.cfg.feature_name, self.cfg.latent_name, encoder) | |
| del encoder | |
| gc.collect() | |
| # Extract PartField features for appearance mesh | |
| log.info("Extracting Appearance Mesh PartField feature planes...") | |
| gr.Info("Extracting Appearance Mesh PartField feature planes...") | |
| app_partfield_dir = osp.join(output_dir, "partfield") | |
| common.ensure_dir(app_partfield_dir) | |
| existing_app_pf = [ | |
| f for f in os.listdir(app_partfield_dir) | |
| if f.startswith("part_feat_app_mesh_zup") and f.endswith("_batch_part_plane.npy") | |
| ] | |
| if use_app_cache and existing_app_pf: | |
| appearance_partfield_save_path = osp.join( | |
| app_partfield_dir, existing_app_pf[0] | |
| ) | |
| log.info( | |
| f"Using cached Appearance PartField at {appearance_partfield_save_path}" | |
| ) | |
| else: | |
| appearance_partfield_save_path = partfield_pipeline_predict( | |
| app_mesh_zup_path, | |
| app_partfield_dir, | |
| "app_mesh_zup" | |
| ) | |
| # Appearance Optimization | |
| appearance.optimize_appearance(self.cfg, output_dir) | |
| # Return the output mesh path | |
| output_mesh_path = osp.join(output_dir, 'out_app.glb') | |
| output_video_path = osp.join(output_dir, 'out_gaussian_app.mp4') | |
| if not osp.exists(output_mesh_path) or not osp.exists(output_video_path): | |
| log.error(f"Output mesh or video not found at {output_mesh_path} or {output_video_path}") | |
| return None, None | |
| return output_mesh_path, output_video_path | |
| def run_self_similarity( | |
| self, | |
| structure_mesh: str, | |
| convert_target_yup_to_zup: bool, | |
| output_dir: str, | |
| app_type: str, | |
| appearance_text: Optional[str] = None, | |
| appearance_image: Optional[str] = None, | |
| ) -> Optional[str]: | |
| _ = self.preprocess( | |
| structure_mesh=structure_mesh, | |
| convert_yup_to_zup=convert_target_yup_to_zup, | |
| output_dir=output_dir, | |
| ) | |
| log.info("Running similarity-guided optimization...") | |
| gr.Info("Running similarity-guided optimization...") | |
| if app_type == 'image' and appearance_image: | |
| img = Image.open(appearance_image).convert('RGB') | |
| img.save(osp.join(output_dir, 'app_image.png')) | |
| app = appearance_text if app_type == 'text' else appearance_image | |
| # Self-Similarity Optimization | |
| self_similarity.optimize_self_similarity(self.cfg, app, app_type, output_dir) | |
| # Return the output mesh path | |
| output_mesh_path = osp.join(output_dir, 'out_sim.glb') | |
| output_video_path = osp.join(output_dir, 'out_gaussian_sim.mp4') | |
| if not osp.exists(output_mesh_path) or not osp.exists(output_video_path): | |
| log.error(f"Output mesh or video not found at {output_mesh_path} or {output_video_path}") | |
| return None, None | |
| return output_mesh_path, output_video_path | |
| def main(): | |
| pass | |