Spaces:
Runtime error
Runtime error
| from typing import * | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from torchvision import transforms | |
| from PIL import Image | |
| import trimesh | |
| import os | |
| import random | |
| import trellis.modules.sparse as sp | |
| from trellis.models.sparse_structure_vae import * | |
| from contextlib import contextmanager | |
| import sys | |
| sys.path.append("wheels/vggt") | |
| from wheels.vggt.vggt.models.vggt import VGGT | |
| from typing import * | |
| from scipy.spatial.transform import Rotation | |
| from transformers import AutoModelForImageSegmentation | |
| import rembg | |
| def export_point_cloud(xyz, color): | |
| # Convert tensors to numpy arrays if needed | |
| if isinstance(xyz, torch.Tensor): | |
| xyz = xyz.detach().cpu().numpy() | |
| if isinstance(color, torch.Tensor): | |
| color = color.detach().cpu().numpy() | |
| color = (color * 255).astype(np.uint8) | |
| # Create point cloud using trimesh | |
| point_cloud = trimesh.PointCloud(vertices=xyz, colors=color) | |
| return point_cloud | |
| def normalize_trimesh(mesh): | |
| # Calculate the mesh centroid and bounding box extents | |
| centroid = mesh.centroid | |
| # Determine the scale based on the largest extent to fit into unit cube | |
| # Normalizing: Center and scale the vertices | |
| mesh.vertices -= centroid | |
| extents = mesh.extents | |
| scale = max(extents) | |
| mesh.vertices /= scale | |
| return mesh | |
| def random_sample_rotation(rotation_factor: float = 1.0) -> np.ndarray: | |
| # angle_z, angle_y, angle_x | |
| euler = np.random.rand(3) * np.pi * 2 / rotation_factor # (0, 2 * pi / rotation_range) | |
| rotation = Rotation.from_euler('zyx', euler).as_matrix() | |
| return rotation | |
| from scipy.ndimage import binary_dilation | |
| def voxelize_trimesh(mesh, resolution=(64, 64, 64), stride=4): | |
| """ | |
| Voxelize a given trimesh object with the specified resolution, incorporating 4x anti-aliasing. | |
| First voxelizes at a 4x resolution and then downsamples to the target resolution. | |
| Args: | |
| mesh (trimesh.Trimesh): The input trimesh object to be voxelized. | |
| resolution (tuple): The voxel grid resolution as (x, y, z). Default is (64, 64, 64). | |
| Returns: | |
| np.ndarray: A boolean numpy array representing the voxel grid where True indicates | |
| the presence of the mesh in that voxel and False otherwise. | |
| """ | |
| target_density = max(resolution) | |
| target_edge_length = 1.0 / target_density | |
| max_edge_for_subdivision = target_edge_length / 2 | |
| # Calculate the higher resolution for 4x anti-aliasing | |
| anti_aliasing_density = target_density * stride | |
| anti_aliasing_edge_length = 1.0 / anti_aliasing_density | |
| anti_aliasing_max_edge_for_subdivision = anti_aliasing_edge_length / 2 | |
| # Get the vertices and faces of the mesh | |
| vertices = mesh.vertices | |
| faces = mesh.faces | |
| # Subdivide the mesh for the higher resolution voxelization | |
| try: | |
| new_vertices, new_faces = trimesh.remesh.subdivide_to_size( | |
| vertices, faces, anti_aliasing_max_edge_for_subdivision | |
| ) | |
| subdivided_mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces) | |
| except Exception as e: | |
| print(f"Unexpected error during mesh subdivision for anti-aliasing: {e}") | |
| raise | |
| # Voxelize the subdivided mesh at the higher resolution | |
| try: | |
| high_res_voxel_grid = subdivided_mesh.voxelized( | |
| pitch=anti_aliasing_edge_length, method="binvox", exact=True | |
| ) | |
| except: | |
| print("Voxelization using 'binvox' method failed for anti-aliasing") | |
| high_res_voxel_grid = subdivided_mesh.voxelized(pitch=anti_aliasing_edge_length) | |
| print("Falling back to default voxelization method for anti-aliasing.") | |
| high_res_boolean_array = high_res_voxel_grid.matrix.astype(bool) | |
| x_stride, y_stride, z_stride = [int(anti_aliasing_density / target_density)] * 3 | |
| downsampled_shape = ( | |
| high_res_boolean_array.shape[0] // x_stride, | |
| high_res_boolean_array.shape[1] // y_stride, | |
| high_res_boolean_array.shape[2] // z_stride | |
| ) | |
| downsampled_array = np.zeros(downsampled_shape, dtype=bool) | |
| # Use NumPy's strided tricks to efficiently access sub-cubes for downsampling | |
| shape = (downsampled_shape[0], downsampled_shape[1], downsampled_shape[2], x_stride, y_stride, z_stride) | |
| strides = (x_stride * high_res_boolean_array.strides[0], | |
| y_stride * high_res_boolean_array.strides[1], | |
| z_stride * high_res_boolean_array.strides[2], | |
| high_res_boolean_array.strides[0], | |
| high_res_boolean_array.strides[1], | |
| high_res_boolean_array.strides[2]) | |
| sub_cubes = np.lib.stride_tricks.as_strided(high_res_boolean_array, shape=shape, strides=strides) | |
| downsampled_array = np.any(sub_cubes, axis=(3, 4, 5)) | |
| return downsampled_array | |
| def get_occupied_coordinates(voxel_grid): | |
| # Find the indices of occupied voxels | |
| occupied_indices = np.argwhere(voxel_grid) | |
| coords = torch.tensor(occupied_indices, dtype=torch.int8) # Use float for scaling operations | |
| # Add a leading dimension for batch size or any additional data associations | |
| coords = torch.cat([torch.zeros(coords.shape[0], 1, dtype=torch.int32), coords + 1], dim=1) | |
| # Move to GPU if required | |
| coords = coords.to('cuda:0') | |
| return coords | |
| from .base import Pipeline | |
| from . import samplers | |
| from ..modules import sparse as sp | |
| class TrellisImageTo3DPipeline(Pipeline): | |
| """ | |
| Pipeline for inferring Trellis image-to-3D models. | |
| Args: | |
| models (dict[str, nn.Module]): The models to use in the pipeline. | |
| sparse_structure_sampler (samplers.Sampler): The sampler for the sparse structure. | |
| slat_sampler (samplers.Sampler): The sampler for the structured latent. | |
| slat_normalization (dict): The normalization parameters for the structured latent. | |
| image_cond_model (str): The name of the image conditioning model. | |
| """ | |
| default_image_resolution = 518 | |
| def __init__( | |
| self, | |
| models: dict[str, nn.Module] = None, | |
| sparse_structure_sampler: samplers.Sampler = None, | |
| slat_sampler: samplers.Sampler = None, | |
| slat_normalization: dict = None, | |
| image_cond_model: str = None, | |
| ): | |
| if models is None: | |
| return | |
| super().__init__(models) | |
| self.sparse_structure_sampler = sparse_structure_sampler | |
| self.slat_sampler = slat_sampler | |
| self.sparse_structure_sampler_params = {} | |
| self.slat_sampler_params = {} | |
| self.slat_normalization = slat_normalization | |
| self._init_image_cond_model(image_cond_model) | |
| def from_pretrained(path: str) -> "TrellisImageTo3DPipeline": | |
| """ | |
| Load a pretrained model. | |
| Args: | |
| path (str): The path to the model. Can be either local path or a Hugging Face repository. | |
| """ | |
| pipeline = super(TrellisImageTo3DPipeline, TrellisImageTo3DPipeline).from_pretrained(path) | |
| new_pipeline = TrellisImageTo3DPipeline() | |
| new_pipeline.__dict__ = pipeline.__dict__ | |
| args = pipeline._pretrained_args | |
| new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) | |
| new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] | |
| new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) | |
| new_pipeline.slat_sampler_params = args['slat_sampler']['params'] | |
| new_pipeline.slat_normalization = args['slat_normalization'] | |
| new_pipeline._init_image_cond_model(args['image_cond_model']) | |
| return new_pipeline | |
| def _init_image_cond_model(self, name: str): | |
| """ | |
| Initialize the image conditioning model. | |
| """ | |
| try: | |
| dinov2_model = torch.hub.load(os.path.join(torch.hub.get_dir(), 'facebookresearch_dinov2_main'), name, source='local',pretrained=True) | |
| except: | |
| dinov2_model = torch.hub.load('facebookresearch/dinov2', name, pretrained=True) | |
| dinov2_model.eval() | |
| self.models['image_cond_model'] = dinov2_model | |
| transform = transforms.Compose([ | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| self.image_cond_model_transform = transform | |
| def preprocess_image(self, input: Image.Image, resolution=518, no_background=True, recenter=True) -> Image.Image: | |
| """ | |
| Preprocess the input image using BiRefNet for background removal. | |
| Includes padding to maintain aspect ratio when resizing to 518x518. | |
| """ | |
| # if has alpha channel, use it directly | |
| has_alpha = False | |
| if input.mode == 'RGBA': | |
| alpha = np.array(input)[:, :, -1] | |
| if not np.all(alpha == 255): | |
| has_alpha = True | |
| if has_alpha: | |
| output = input | |
| else: | |
| input = input.convert('RGB') | |
| max_size = max(input.size) | |
| scale = min(1, 1024 / max_size) | |
| if scale < 1: | |
| input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) | |
| # Get mask using BiRefNet | |
| mask = self._get_birefnet_mask(input) | |
| # Convert input to RGBA and apply mask | |
| input_rgba = input.convert('RGBA') | |
| input_array = np.array(input_rgba) | |
| input_array[:, :, 3] = mask * 255 # Apply mask to alpha channel | |
| output = Image.fromarray(input_array) | |
| # Process the output image | |
| output_np = np.array(output) | |
| alpha = output_np[:, :, 3] | |
| # Find bounding box of non-transparent pixels | |
| bbox = np.argwhere(alpha > 0.8 * 255) | |
| if len(bbox) == 0: # Handle case where no foreground is detected | |
| return input.convert('RGB') | |
| bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) | |
| center = [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2] | |
| size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) | |
| size = int(size * 1.1) | |
| height, width = alpha.shape | |
| if not recenter: | |
| center = [width / 2, height / 2] | |
| size = max(bbox[2] - bbox[0], | |
| bbox[3] - bbox[1], | |
| (bbox[2] - width / 2) * 2, | |
| (width / 2 - bbox[0]) * 2, | |
| (height / 2 - bbox[1]) * 2, | |
| (bbox[3] - height / 2) * 2) | |
| # Calculate and apply crop bbox | |
| if not no_background: | |
| if height > width: | |
| center[0] = width / 2 | |
| if center[1] < width / 2: | |
| center[1] = width / 2 | |
| elif center[1] > height - width / 2: | |
| center[1] = height - width / 2 | |
| else: | |
| center[1] = height / 2 | |
| if center[0] < height / 2: | |
| center[0] = height / 2 | |
| elif center[0] > width - height / 2: | |
| center[0] = width - height / 2 | |
| size = min(center[0], center[1], input.width - center[0], input.height - center[1], size) * 2 | |
| bbox = ( | |
| int(center[0] - size // 2), | |
| int(center[1] - size // 2), | |
| int(center[0] + size // 2), | |
| int(center[1] + size // 2) | |
| ) | |
| # Ensure bbox is within image bounds | |
| bbox = ( | |
| max(0, bbox[0]), | |
| max(0, bbox[1]), | |
| min(output.width, bbox[2]), | |
| min(output.height, bbox[3]) | |
| ) | |
| output = output.crop(bbox) | |
| # Add padding to maintain aspect ratio | |
| width, height = output.size | |
| if width > height: | |
| new_height = width | |
| padding = (width - height) // 2 | |
| padded_output = Image.new('RGBA', (width, new_height), (0, 0, 0, 0)) | |
| padded_output.paste(output, (0, padding)) | |
| else: | |
| new_width = height | |
| padding = (height - width) // 2 | |
| padded_output = Image.new('RGBA', (new_width, height), (0, 0, 0, 0)) | |
| padded_output.paste(output, (padding, 0)) | |
| # Resize padded image to target size | |
| # padded_output = padded_output.resize((resolution, resolution), Image.Resampling.LANCZOS) | |
| padded_output = torch.from_numpy(np.array(padded_output).astype(np.float32)) / 255 | |
| padded_output = F.interpolate(padded_output.unsqueeze(0).permute(0, 3, 1, 2), (resolution, resolution), mode='bilinear', align_corners=False)[0].permute(1, 2, 0) | |
| # Final processing | |
| output = padded_output.cpu().numpy() | |
| if no_background: | |
| output = np.dstack(( | |
| output[:, :, :3] * (output[:, :, 3:4] > 0.8), # RGB channels premultiplied by alpha | |
| output[:, :, 3] # Original alpha channel | |
| )) | |
| output = Image.fromarray((output * 255).astype(np.uint8), mode='RGBA') | |
| return output | |
| def _get_birefnet_mask(self, image: Image.Image) -> np.ndarray: | |
| """Get object mask using BiRefNet""" | |
| image_size = (1024, 1024) | |
| transform_image = transforms.Compose([ | |
| transforms.Resize(image_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| input_images = transform_image(image).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| preds = self.birefnet_model(input_images)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| pred_pil = transforms.ToPILImage()(pred) | |
| mask = pred_pil.resize(image.size) | |
| mask_np = np.array(mask) | |
| return (mask_np > 128).astype(np.uint8) | |
| def encode_image(self, image: Union[torch.Tensor, list[Image.Image]], w_layernorm=True) -> torch.Tensor: | |
| """ | |
| Encode the image. | |
| Args: | |
| image (Union[torch.Tensor, list[Image.Image]]): The image to encode | |
| Returns: | |
| torch.Tensor: The encoded features. | |
| """ | |
| if isinstance(image, torch.Tensor): | |
| assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" | |
| image = F.interpolate(image, self.default_image_resolution, mode='bilinear', align_corners=False) | |
| elif isinstance(image, list): | |
| assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" | |
| image = [i.resize((self.default_image_resolution, self.default_image_resolution), Image.LANCZOS) for i in image] | |
| image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] | |
| image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] | |
| image = torch.stack(image).to(self.device) | |
| else: | |
| raise ValueError(f"Unsupported type of image: {type(image)}") | |
| image = self.image_cond_model_transform(image).to(self.device) | |
| features = self.models['image_cond_model'](image, is_training=True)['x_prenorm'] | |
| if w_layernorm: | |
| features = F.layer_norm(features, features.shape[-1:]) | |
| return features | |
| def get_cond(self, image: Union[torch.Tensor, list[Image.Image]]) -> dict: | |
| """ | |
| Get the conditioning information for the model. | |
| Args: | |
| image (Union[torch.Tensor, list[Image.Image]]): The image prompts. | |
| Returns: | |
| dict: The conditioning information | |
| """ | |
| cond = self.encode_image(image) | |
| neg_cond = torch.zeros_like(cond) | |
| return { | |
| 'cond': cond, | |
| 'neg_cond': neg_cond, | |
| } | |
| def sample_sparse_structure( | |
| self, | |
| cond: dict, | |
| num_samples: int = 1, | |
| sampler_params: dict = {}, | |
| noise: torch.Tensor = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Sample sparse structures with the given conditioning. | |
| Args: | |
| cond (dict): The conditioning information. | |
| num_samples (int): The number of samples to generate. | |
| sampler_params (dict): Additional parameters for the sampler. | |
| """ | |
| # Sample occupancy latent | |
| flow_model = self.models['sparse_structure_flow_model'] | |
| reso = flow_model.resolution | |
| if noise is None: | |
| noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) | |
| sampler_params = {**self.sparse_structure_sampler_params, **sampler_params} | |
| z_s = self.sparse_structure_sampler.sample( | |
| flow_model, | |
| noise, | |
| **cond, | |
| **sampler_params, | |
| verbose=True | |
| ).samples | |
| # Decode occupancy latent | |
| decoder = self.models['sparse_structure_decoder'] | |
| coords = torch.argwhere(decoder(z_s)>0)[:, [0, 2, 3, 4]].int() | |
| return coords | |
| def encode_slat( | |
| self, | |
| slat: sp.SparseTensor, | |
| ): | |
| ret = {} | |
| slat = self.models['slat_encoder'](slat, sample_posterior=False) | |
| ret['slat'] = slat | |
| return ret | |
| def decode_slat( | |
| self, | |
| slat: sp.SparseTensor, | |
| formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], | |
| ) -> dict: | |
| """ | |
| Decode the structured latent. | |
| Args: | |
| slat (sp.SparseTensor): The structured latent. | |
| formats (List[str]): The formats to decode the structured latent to. | |
| Returns: | |
| dict: The decoded structured latent. | |
| """ | |
| ret = {} | |
| ret['slat'] = slat | |
| if 'gaussian' in formats: | |
| ret['gaussian'] = self.models['slat_decoder_gs'](slat) | |
| if 'mesh' in formats: | |
| ret['mesh'] = self.models['slat_decoder_mesh'](slat) | |
| if 'radiance_field' in formats: | |
| ret['radiance_field'] = self.models['slat_decoder_rf'](slat) | |
| return ret | |
| def sample_slat( | |
| self, | |
| cond: dict, | |
| coords: torch.Tensor, | |
| sampler_params: dict = {}, | |
| ) -> sp.SparseTensor: | |
| """ | |
| Sample structured latent with the given conditioning. | |
| Args: | |
| cond (dict): The conditioning information. | |
| coords (torch.Tensor): The coordinates of the sparse structure. | |
| sampler_params (dict): Additional parameters for the sampler. | |
| """ | |
| # Sample structured latent | |
| flow_model = self.models['slat_flow_model'] | |
| noise = sp.SparseTensor( | |
| feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device), | |
| coords=coords, | |
| ) | |
| sampler_params = {**self.slat_sampler_params, **sampler_params} | |
| slat = self.slat_sampler.sample( | |
| flow_model, | |
| noise, | |
| **cond, | |
| **sampler_params, | |
| verbose=True | |
| ).samples | |
| std = torch.tensor(self.slat_normalization['std'])[None].to(slat.device) | |
| mean = torch.tensor(self.slat_normalization['mean'])[None].to(slat.device) | |
| slat = slat * std + mean | |
| return slat | |
| def get_input(self, batch_data): | |
| std = torch.tensor(self.slat_normalization['std'])[None].to(self.device) | |
| mean = torch.tensor(self.slat_normalization['mean'])[None].to(self.device) | |
| images = batch_data['source_image'] | |
| cond = self.encode_image(images) | |
| if random.random() > 0.5: | |
| cond = torch.zeros_like(cond) | |
| target_feats = batch_data['target_feats'] | |
| target_coords = batch_data['target_coords'] | |
| targets = sp.SparseTensor(target_feats, target_coords).to(self.device) | |
| targets = (targets - mean) / std | |
| noise = sp.SparseTensor( | |
| feats=torch.randn_like(target_feats).to(self.device), | |
| coords=target_coords.to(self.device), | |
| ) | |
| return targets, cond, noise | |
| def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: | |
| return self.slat_flow_model(x, t, cond) | |
| def inject_sampler_multi_image( | |
| self, | |
| sampler_name: str, | |
| num_images: int, | |
| num_steps: int, | |
| mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', | |
| ): | |
| """ | |
| Inject a sampler with multiple images as condition. | |
| Args: | |
| sampler_name (str): The name of the sampler to inject. | |
| num_images (int): The number of images to condition on. | |
| num_steps (int): The number of steps to run the sampler for. | |
| """ | |
| sampler = getattr(self, sampler_name) | |
| setattr(sampler, f'_old_inference_model', sampler._inference_model) | |
| if mode == 'stochastic': | |
| if num_images > num_steps: | |
| print(f"\033[93mWarning: number of conditioning images is greater than number of steps for {sampler_name}. " | |
| "This may lead to performance degradation.\033[0m") | |
| cond_indices = (np.arange(num_steps) % num_images).tolist() | |
| def _new_inference_model(self, model, x_t, t, cond, **kwargs): | |
| cond_idx = cond_indices.pop(0) | |
| cond_i = cond[cond_idx:cond_idx+1] | |
| return self._old_inference_model(model, x_t, t, cond=cond_i, **kwargs) | |
| elif mode =='multidiffusion': | |
| from .samplers import FlowEulerSampler | |
| def _new_inference_model(self, model, x_t, t, cond, neg_cond, cfg_strength, cfg_interval, **kwargs): | |
| if cfg_interval[0] <= t <= cfg_interval[1]: | |
| preds = [] | |
| for i in range(len(cond)): | |
| preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) | |
| pred = sum(preds) / len(preds) | |
| neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs) | |
| return (1 + cfg_strength) * pred - cfg_strength * neg_pred | |
| else: | |
| preds = [] | |
| for i in range(len(cond)): | |
| preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i+1], **kwargs)) | |
| pred = sum(preds) / len(preds) | |
| return pred | |
| else: | |
| raise ValueError(f"Unsupported mode: {mode}") | |
| sampler._inference_model = _new_inference_model.__get__(sampler, type(sampler)) | |
| yield | |
| sampler._inference_model = sampler._old_inference_model | |
| delattr(sampler, f'_old_inference_model') | |
| def run_multi_image( | |
| self, | |
| images: List[Image.Image], | |
| num_samples: int = 1, | |
| seed: int = 42, | |
| sparse_structure_sampler_params: dict = {}, | |
| slat_sampler_params: dict = {}, | |
| formats: List[str] = ['mesh', 'gaussian', 'radiance_field'], | |
| preprocess_image: bool = True, | |
| mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', | |
| ): | |
| """ | |
| Run the pipeline with multiple images as condition | |
| Args: | |
| images (List[Image.Image]): The multi-view images of the assets | |
| num_samples (int): The number of samples to generate. | |
| sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. | |
| slat_sampler_params (dict): Additional parameters for the structured latent sampler. | |
| preprocess_image (bool): Whether to preprocess the image. | |
| """ | |
| if preprocess_image: | |
| images = [self.preprocess_image(image) for image in images] | |
| cond = self.get_cond(images) | |
| cond['neg_cond'] = cond['neg_cond'][:1] | |
| torch.manual_seed(seed) | |
| flow_model = self.models['sparse_structure_flow_model'] | |
| reso = flow_model.resolution | |
| noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device) | |
| ss_steps = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params}.get('steps') | |
| with self.inject_sampler_multi_image('sparse_structure_sampler', len(images), ss_steps, mode=mode): | |
| coords = self.sample_sparse_structure(cond, num_samples, sparse_structure_sampler_params, noise) | |
| slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') | |
| with self.inject_sampler_multi_image('slat_sampler', len(images), slat_steps, mode=mode): | |
| slat = self.sample_slat(cond, coords, slat_sampler_params) | |
| return self.decode_slat(slat, formats) | |
| def run( | |
| self, | |
| image: Image.Image, | |
| ref_image: Image.Image = None, | |
| num_samples: int = 1, | |
| seed: int = 42, | |
| sparse_structure_sampler_params: dict = {}, | |
| slat_sampler_params: dict = {}, | |
| formats: List[str] = ['mesh'], | |
| preprocess_image: bool = True, | |
| init_mesh: trimesh.Trimesh = None, | |
| coords: torch.Tensor = None, | |
| normalize_init_mesh: bool = False, | |
| init_resolution: int = 62, | |
| init_stride: int = 4 | |
| ) -> dict: | |
| """ | |
| Run the pipeline. | |
| Args: | |
| image (Image.Image): The image prompt. | |
| num_samples (int): The number of samples to generate. | |
| sparse_structure_sampler_params (dict): Additional parameters for the sparse structure sampler. | |
| slat_sampler_params (dict): Additional parameters for the structured latent sampler. | |
| preprocess_image (bool): Whether to preprocess the image. | |
| """ | |
| if preprocess_image: | |
| image = self.preprocess_image(image) | |
| if ref_image is not None: | |
| cond = self.encode_image([image, ref_image]) | |
| neg_cond = torch.zeros_like(cond[0:1]) | |
| sparse_cond = slat_cond = { | |
| 'cond': 0.5 * cond[0:1] + 0.5 * cond[1:2], | |
| 'neg_cond': neg_cond, | |
| } | |
| else: | |
| sparse_cond = slat_cond = self.get_cond([image]) | |
| torch.manual_seed(seed) | |
| if coords is not None: | |
| coords = coords | |
| else: | |
| coords = self.sample_sparse_structure(sparse_cond, num_samples, sparse_structure_sampler_params) | |
| slat = self.sample_slat(slat_cond, coords, slat_sampler_params) | |
| return self.decode_slat(slat, formats) | |
| def configure_optimizers(self): | |
| params = list(self.slat_flow_model.parameters()) | |
| opt = torch.optim.AdamW(params, lr=1e-4, weight_decay=0.0) | |
| return opt | |
| def zero_module(module): | |
| """ | |
| Zero out the parameters of a module and return it. | |
| """ | |
| for p in module.parameters(): | |
| p.detach().zero_() | |
| return module | |
| class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline): | |
| def get_ss_cond(self, image_cond: torch.Tensor, aggregated_tokens_list: List, num_samples: int) -> dict: | |
| """ | |
| Get the conditioning information for the model. | |
| Args: | |
| image (Union[torch.Tensor, list[Image.Image]]): The image prompts. | |
| Returns: | |
| dict: The conditioning information | |
| """ | |
| cond = self.sparse_structure_vggt_cond(aggregated_tokens_list, image_cond) | |
| neg_cond = torch.zeros_like(cond) | |
| return { | |
| 'cond': cond, | |
| 'neg_cond': neg_cond, | |
| } | |
| def vggt_feat(self, image: Union[torch.Tensor, list[Image.Image]]) -> List: | |
| """ | |
| Encode the image. | |
| Args: | |
| image (Union[torch.Tensor, list[Image.Image]]): The image to encode | |
| Returns: | |
| torch.Tensor: The encoded features. | |
| """ | |
| if isinstance(image, torch.Tensor): | |
| assert image.ndim == 4, "Image tensor should be batched (B, C, H, W)" | |
| image = F.interpolate(image, self.default_image_resolution, mode='bilinear', align_corners=False) | |
| elif isinstance(image, list): | |
| assert all(isinstance(i, Image.Image) for i in image), "Image list should be list of PIL images" | |
| image = [i.resize((self.default_image_resolution, self.default_image_resolution), Image.LANCZOS) for i in image] | |
| image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] | |
| image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] | |
| image = torch.stack(image).to(self.device) | |
| else: | |
| raise ValueError(f"Unsupported type of image: {type(image)}") | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(dtype=self.VGGT_dtype): | |
| # Predict attributes including cameras, depth maps, and point maps. | |
| aggregated_tokens_list, _ = self.VGGT_model.aggregator(image[None]) | |
| return aggregated_tokens_list, image | |
| def run( | |
| self, | |
| image: Union[torch.Tensor, list[Image.Image]], | |
| coords: torch.Tensor = None, | |
| num_samples: int = 1, | |
| seed: int = 42, | |
| sparse_structure_sampler_params: dict = {}, | |
| slat_sampler_params: dict = {}, | |
| formats: List[str] = ['mesh'], | |
| preprocess_image: bool = True, | |
| mode: Literal['stochastic', 'multidiffusion'] = 'stochastic', | |
| ): | |
| torch.manual_seed(seed) | |
| aggregated_tokens_list, _ = self.vggt_feat(image) | |
| b, n, _, _ = aggregated_tokens_list[0].shape | |
| image_cond = self.encode_image(image).reshape(b, n, -1, 1024) | |
| # if coords is None: | |
| ss_flow_model = self.models['sparse_structure_flow_model'] | |
| ss_cond = self.get_ss_cond(image_cond[:, :, 5:], aggregated_tokens_list, num_samples) | |
| # Sample structured latent | |
| ss_sampler_params = {**self.sparse_structure_sampler_params, **sparse_structure_sampler_params} | |
| reso = ss_flow_model.resolution | |
| ss_noise = torch.randn(num_samples, ss_flow_model.in_channels, reso, reso, reso).to(self.device) | |
| ss_slat = self.sparse_structure_sampler.sample( | |
| ss_flow_model, | |
| ss_noise, | |
| **ss_cond, | |
| **ss_sampler_params, | |
| verbose=True | |
| ).samples | |
| decoder = self.models['sparse_structure_decoder'] | |
| coords = torch.argwhere(decoder(ss_slat)>0)[:, [0, 2, 3, 4]].int() | |
| cond = { | |
| 'cond': image_cond.reshape(n, -1, 1024), | |
| 'neg_cond': torch.zeros_like(image_cond.reshape(n, -1, 1024))[:1], | |
| } | |
| slat_steps = {**self.slat_sampler_params, **slat_sampler_params}.get('steps') | |
| with self.inject_sampler_multi_image('slat_sampler', len(image), slat_steps, mode=mode): | |
| slat = self.sample_slat(cond, coords, slat_sampler_params) | |
| return self.decode_slat(slat, formats) | |
| def from_pretrained(path: str) -> "TrellisVGGTTo3DPipeline": | |
| """ | |
| Load a pretrained model. | |
| Args: | |
| path (str): The path to the model. Can be either local path or a Hugging Face repository. | |
| """ | |
| pipeline = super(TrellisVGGTTo3DPipeline, TrellisVGGTTo3DPipeline).from_pretrained(path) | |
| new_pipeline = TrellisVGGTTo3DPipeline() | |
| new_pipeline.__dict__ = pipeline.__dict__ | |
| args = pipeline._pretrained_args | |
| new_pipeline.VGGT_dtype = torch.float32 | |
| VGGT_model = VGGT.from_pretrained("Stable-X/vggt-object-v0-1") | |
| new_pipeline.VGGT_model = VGGT_model.to(new_pipeline.device) | |
| del new_pipeline.VGGT_model.depth_head | |
| del new_pipeline.VGGT_model.track_head | |
| del new_pipeline.VGGT_model.camera_head | |
| del new_pipeline.VGGT_model.point_head | |
| new_pipeline.VGGT_model.eval() | |
| new_pipeline.birefnet_model = AutoModelForImageSegmentation.from_pretrained( | |
| 'ZhengPeng7/BiRefNet', | |
| trust_remote_code=True | |
| ).to(new_pipeline.device) | |
| new_pipeline.birefnet_model.eval() | |
| new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args']) | |
| new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params'] | |
| new_pipeline.slat_sampler = getattr(samplers, args['slat_sampler']['name'])(**args['slat_sampler']['args']) | |
| new_pipeline.slat_sampler_params = args['slat_sampler']['params'] | |
| new_pipeline.slat_normalization = args['slat_normalization'] | |
| new_pipeline._init_image_cond_model(args['image_cond_model']) | |
| return new_pipeline |