Stable-X commited on
Commit
4170e92
·
verified ·
1 Parent(s): a541280

Upload 74 files

Browse files
app_fine.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ from gradio_litmodel3d import LitModel3D
4
+
5
+ import os
6
+ import shutil
7
+ os.environ['SPCONV_ALGO'] = 'native'
8
+ from typing import *
9
+ import torch
10
+ import numpy as np
11
+ import imageio
12
+ from easydict import EasyDict as edict
13
+ from PIL import Image
14
+ from trellis.pipelines import TrellisVGGTTo3DPipeline
15
+ from trellis.representations import Gaussian, MeshExtractResult
16
+ from trellis.utils import render_utils, postprocessing_utils
17
+
18
+ from wheels.vggt.vggt.utils.load_fn import load_and_preprocess_images
19
+ from wheels.vggt.vggt.utils.pose_enc import pose_encoding_to_extri_intri
20
+ import open3d as o3d
21
+ from torchvision import transforms as TF
22
+ from PIL import Image
23
+ import sys
24
+ sys.path.append("wheels")
25
+ from wheels.mast3r.model import AsymmetricMASt3R
26
+ from wheels.mast3r.fast_nn import fast_reciprocal_NNs
27
+ from wheels.dust3r.dust3r.inference import inference
28
+ from wheels.dust3r.dust3r.utils.image import load_images_new
29
+ from trellis.utils.general_utils import *
30
+ import copy
31
+
32
+ MAX_SEED = np.iinfo(np.int32).max
33
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
34
+ # TMP_DIR = "tmp/Trellis-demo"
35
+ # os.environ['GRADIO_TEMP_DIR'] = 'tmp'
36
+ os.makedirs(TMP_DIR, exist_ok=True)
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ def start_session(req: gr.Request):
40
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
41
+ os.makedirs(user_dir, exist_ok=True)
42
+
43
+
44
+ def end_session(req: gr.Request):
45
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
46
+ shutil.rmtree(user_dir)
47
+
48
+ @spaces.GPU
49
+ def preprocess_image(image: Image.Image) -> Image.Image:
50
+ """
51
+ Preprocess the input image for 3D generation.
52
+
53
+ This function is called when a user uploads an image or selects an example.
54
+ It applies background removal and other preprocessing steps necessary for
55
+ optimal 3D model generation.
56
+
57
+ Args:
58
+ image (Image.Image): The input image from the user
59
+
60
+ Returns:
61
+ Image.Image: The preprocessed image ready for 3D generation
62
+ """
63
+ processed_image = pipeline.preprocess_image(image)
64
+ return processed_image
65
+
66
+ @spaces.GPU
67
+ def preprocess_videos(video: str) -> List[Tuple[Image.Image, str]]:
68
+ """
69
+ Preprocess the input video for multi-image 3D generation.
70
+
71
+ This function is called when a user uploads a video.
72
+ It extracts frames from the video and processes each frame to prepare them
73
+ for the multi-image 3D generation pipeline.
74
+
75
+ Args:
76
+ video (str): The path to the input video file
77
+
78
+ Returns:
79
+ List[Tuple[Image.Image, str]]: The list of preprocessed images ready for 3D generation
80
+ """
81
+ vid = imageio.get_reader(video, 'ffmpeg')
82
+ fps = vid.get_meta_data()['fps']
83
+ images = []
84
+ for i, frame in enumerate(vid):
85
+ if i % max(int(fps * 1), 1) == 0:
86
+ img = Image.fromarray(frame)
87
+ W, H = img.size
88
+ img = img.resize((int(W / H * 512), 512))
89
+ images.append(img)
90
+ vid.close()
91
+ processed_images = [pipeline.preprocess_image(image) for image in images]
92
+ return processed_images
93
+
94
+ @spaces.GPU
95
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
96
+ """
97
+ Preprocess a list of input images for multi-image 3D generation.
98
+
99
+ This function is called when users upload multiple images in the gallery.
100
+ It processes each image to prepare them for the multi-image 3D generation pipeline.
101
+
102
+ Args:
103
+ images (List[Tuple[Image.Image, str]]): The input images from the gallery
104
+
105
+ Returns:
106
+ List[Image.Image]: The preprocessed images ready for 3D generation
107
+ """
108
+ images = [image[0] for image in images]
109
+ processed_images = [pipeline.preprocess_image(image) for image in images]
110
+ return processed_images
111
+
112
+
113
+ def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
114
+ return {
115
+ 'gaussian': {
116
+ **gs.init_params,
117
+ '_xyz': gs._xyz.cpu().numpy(),
118
+ '_features_dc': gs._features_dc.cpu().numpy(),
119
+ '_scaling': gs._scaling.cpu().numpy(),
120
+ '_rotation': gs._rotation.cpu().numpy(),
121
+ '_opacity': gs._opacity.cpu().numpy(),
122
+ },
123
+ 'mesh': {
124
+ 'vertices': mesh.vertices.cpu().numpy(),
125
+ 'faces': mesh.faces.cpu().numpy(),
126
+ },
127
+ }
128
+
129
+
130
+ def unpack_state(state: dict) -> Tuple[Gaussian, edict, str]:
131
+ gs = Gaussian(
132
+ aabb=state['gaussian']['aabb'],
133
+ sh_degree=state['gaussian']['sh_degree'],
134
+ mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
135
+ scaling_bias=state['gaussian']['scaling_bias'],
136
+ opacity_bias=state['gaussian']['opacity_bias'],
137
+ scaling_activation=state['gaussian']['scaling_activation'],
138
+ )
139
+ gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
140
+ gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
141
+ gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
142
+ gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
143
+ gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
144
+
145
+ mesh = edict(
146
+ vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
147
+ faces=torch.tensor(state['mesh']['faces'], device='cuda'),
148
+ )
149
+
150
+ return gs, mesh
151
+
152
+
153
+ def get_seed(randomize_seed: bool, seed: int) -> int:
154
+ """
155
+ Get the random seed for generation.
156
+
157
+ This function is called by the generate button to determine whether to use
158
+ a random seed or the user-specified seed value.
159
+
160
+ Args:
161
+ randomize_seed (bool): Whether to generate a random seed
162
+ seed (int): The user-specified seed value
163
+
164
+ Returns:
165
+ int: The seed to use for generation
166
+ """
167
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
168
+
169
+ def align_camera(num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics):
170
+
171
+ extrinsic_tmp = extrinsic.clone()
172
+ camera_relative = torch.matmul(extrinsic_tmp[:num_frames,:3,:3].permute(0,2,1), extrinsic_tmp[num_frames:,:3,:3])
173
+ camera_relative_angle = torch.acos(((camera_relative[:,0,0] + camera_relative[:,1,1] + camera_relative[:,2,2] - 1) / 2).clamp(-1, 1))
174
+ idx = torch.argmin(camera_relative_angle)
175
+ target_extrinsic = rend_extrinsics[idx:idx+1].clone()
176
+
177
+ focal_x = intrinsic[:num_frames,0,0].mean()
178
+ focal_y = intrinsic[:num_frames,1,1].mean()
179
+ focal = (focal_x + focal_y) / 2
180
+ rend_focal = (rend_intrinsics[0][0,0] + rend_intrinsics[0][1,1]) * 518 / 2
181
+ focal_scale = rend_focal / focal
182
+ target_intrinsic = intrinsic[num_frames:].clone()
183
+ fxy = (target_intrinsic[:,0,0] + target_intrinsic[:,1,1]) / 2 * focal_scale
184
+ target_intrinsic[:,0,0] = fxy
185
+ target_intrinsic[:,1,1] = fxy
186
+ return target_extrinsic, target_intrinsic
187
+
188
+ def refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy, target_extrinsic, rend_depth):
189
+ images_mast3r = load_images_new([rend_image_pil, target_image_pil], size=512, square_ok=True)
190
+ with torch.no_grad():
191
+ output = inference([tuple(images_mast3r)], mast3r_model, device, batch_size=1, verbose=False)
192
+ view1, pred1 = output['view1'], output['pred1']
193
+ view2, pred2 = output['view2'], output['pred2']
194
+ del output
195
+ desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
196
+
197
+ # find 2D-2D matches between the two images
198
+ matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
199
+ device=device, dist='dot', block_size=2**13)
200
+
201
+ # ignore small border around the edge
202
+ H0, W0 = view1['true_shape'][0]
203
+
204
+ valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
205
+ matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
206
+
207
+ H1, W1 = view2['true_shape'][0]
208
+ valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & (
209
+ matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3)
210
+
211
+ valid_matches = valid_matches_im0 & valid_matches_im1
212
+ matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches]
213
+ scale_x = original_size[1] / W0.item()
214
+ scale_y = original_size[0] / H0.item()
215
+ for pixel in matches_im1:
216
+ pixel[0] *= scale_x
217
+ pixel[1] *= scale_y
218
+ for pixel in matches_im0:
219
+ pixel[0] *= scale_x
220
+ pixel[1] *= scale_y
221
+ depth_map = rend_depth[0]
222
+ fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2 # Example values for focal lengths and principal point
223
+ K = np.array([
224
+ [fx, 0, cx],
225
+ [0, fy, cy],
226
+ [0, 0, 1]
227
+ ])
228
+ dist_eff = np.array([0,0,0,0], dtype=np.float32)
229
+ predict_c2w_ini = np.linalg.inv(target_extrinsic[0].cpu().numpy())
230
+ predict_w2c_ini = target_extrinsic[0].cpu().numpy()
231
+ initial_rvec, _ = cv2.Rodrigues(predict_c2w_ini[:3,:3].astype(np.float32))
232
+ initial_tvec = predict_c2w_ini[:3,3].astype(np.float32)
233
+ K_inv = np.linalg.inv(K)
234
+ height, width = depth_map.shape
235
+ x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
236
+ x_flat = x_coords.flatten()
237
+ y_flat = y_coords.flatten()
238
+ depth_flat = depth_map.flatten()
239
+ x_normalized = (x_flat - K[0, 2]) / K[0, 0]
240
+ y_normalized = (y_flat - K[1, 2]) / K[1, 1]
241
+ X_camera = depth_flat * x_normalized
242
+ Y_camera = depth_flat * y_normalized
243
+ Z_camera = depth_flat
244
+ points_camera = np.vstack((X_camera, Y_camera, Z_camera, np.ones_like(X_camera)))
245
+ points_world = predict_c2w_ini @ points_camera
246
+ X_world = points_world[0, :]
247
+ Y_world = points_world[1, :]
248
+ Z_world = points_world[2, :]
249
+ points_3D = np.vstack((X_world, Y_world, Z_world))
250
+ scene_coordinates_gs = points_3D.reshape(3, original_size[0], original_size[1])
251
+ points_3D_at_pixels = np.zeros((matches_im0.shape[0], 3))
252
+ for i, (x, y) in enumerate(matches_im0):
253
+ points_3D_at_pixels[i] = scene_coordinates_gs[:, y, x]
254
+
255
+ success, rvec, tvec, inliers = cv2.solvePnPRansac(points_3D_at_pixels.astype(np.float32), matches_im1.astype(np.float32), K, \
256
+ dist_eff,rvec=initial_rvec,tvec=initial_tvec, useExtrinsicGuess=True, reprojectionError=1.0,\
257
+ iterationsCount=2000,flags=cv2.SOLVEPNP_EPNP)
258
+ R = perform_rodrigues_transformation(rvec)
259
+ trans = -R.T @ np.matrix(tvec)
260
+ predict_c2w_refine = np.eye(4)
261
+ predict_c2w_refine[:3,:3] = R.T
262
+ predict_c2w_refine[:3,3] = trans.reshape(3)
263
+ target_extrinsic_final = torch.tensor(predict_c2w_refine).inverse().cuda()[None].float()
264
+ return target_extrinsic_final
265
+
266
+ def pointcloud_registration(rend_image_pil, target_image_pil, original_size, fxy, target_extrinsic, rend_depth, target_pointmap):
267
+ images_mast3r = load_images_new([rend_image_pil, target_image_pil], size=512, square_ok=True)
268
+ with torch.no_grad():
269
+ output = inference([tuple(images_mast3r)], mast3r_model, device, batch_size=1, verbose=False)
270
+ view1, pred1 = output['view1'], output['pred1']
271
+ view2, pred2 = output['view2'], output['pred2']
272
+ del output
273
+ desc1, desc2 = pred1['desc'].squeeze(0).detach(), pred2['desc'].squeeze(0).detach()
274
+
275
+ # find 2D-2D matches between the two images
276
+ matches_im0, matches_im1 = fast_reciprocal_NNs(desc1, desc2, subsample_or_initxy1=8,
277
+ device=device, dist='dot', block_size=2**13)
278
+
279
+ # ignore small border around the edge
280
+ H0, W0 = view1['true_shape'][0]
281
+
282
+ valid_matches_im0 = (matches_im0[:, 0] >= 3) & (matches_im0[:, 0] < int(W0) - 3) & (
283
+ matches_im0[:, 1] >= 3) & (matches_im0[:, 1] < int(H0) - 3)
284
+
285
+ H1, W1 = view2['true_shape'][0]
286
+ valid_matches_im1 = (matches_im1[:, 0] >= 3) & (matches_im1[:, 0] < int(W1) - 3) & (
287
+ matches_im1[:, 1] >= 3) & (matches_im1[:, 1] < int(H1) - 3)
288
+
289
+ valid_matches = valid_matches_im0 & valid_matches_im1
290
+ matches_im0, matches_im1 = matches_im0[valid_matches], matches_im1[valid_matches]
291
+ scale_x = original_size[1] / W0.item()
292
+ scale_y = original_size[0] / H0.item()
293
+ for pixel in matches_im1:
294
+ pixel[0] *= scale_x
295
+ pixel[1] *= scale_y
296
+ for pixel in matches_im0:
297
+ pixel[0] *= scale_x
298
+ pixel[1] *= scale_y
299
+ depth_map = rend_depth[0]
300
+ fx, fy, cx, cy = fxy.item(), fxy.item(), original_size[1]/2, original_size[0]/2 # Example values for focal lengths and principal point
301
+ K = np.array([
302
+ [fx, 0, cx],
303
+ [0, fy, cy],
304
+ [0, 0, 1]
305
+ ])
306
+ dist_eff = np.array([0,0,0,0], dtype=np.float32)
307
+ predict_c2w_ini = np.linalg.inv(target_extrinsic[0].cpu().numpy())
308
+ predict_w2c_ini = target_extrinsic[0].cpu().numpy()
309
+ initial_rvec, _ = cv2.Rodrigues(predict_c2w_ini[:3,:3].astype(np.float32))
310
+ initial_tvec = predict_c2w_ini[:3,3].astype(np.float32)
311
+ K_inv = np.linalg.inv(K)
312
+ height, width = depth_map.shape
313
+ x_coords, y_coords = np.meshgrid(np.arange(width), np.arange(height))
314
+ x_flat = x_coords.flatten()
315
+ y_flat = y_coords.flatten()
316
+ depth_flat = depth_map.flatten()
317
+ x_normalized = (x_flat - K[0, 2]) / K[0, 0]
318
+ y_normalized = (y_flat - K[1, 2]) / K[1, 1]
319
+ X_camera = depth_flat * x_normalized
320
+ Y_camera = depth_flat * y_normalized
321
+ Z_camera = depth_flat
322
+ points_camera = np.vstack((X_camera, Y_camera, Z_camera, np.ones_like(X_camera)))
323
+ points_world = predict_c2w_ini @ points_camera
324
+ X_world = points_world[0, :]
325
+ Y_world = points_world[1, :]
326
+ Z_world = points_world[2, :]
327
+ points_3D = np.vstack((X_world, Y_world, Z_world))
328
+ scene_coordinates_gs = points_3D.reshape(3, original_size[0], original_size[1])
329
+ points_3D_at_pixels = np.zeros((matches_im0.shape[0], 3))
330
+ for i, (x, y) in enumerate(matches_im0):
331
+ points_3D_at_pixels[i] = scene_coordinates_gs[:, y, x]
332
+
333
+ points_3D_at_pixels_2 = np.zeros((matches_im1.shape[0], 3))
334
+ for i, (x, y) in enumerate(matches_im1):
335
+ points_3D_at_pixels_2[i] = target_pointmap[:, y, x]
336
+
337
+ dist_1 = np.linalg.norm(points_3D_at_pixels - points_3D_at_pixels.mean(axis=0), axis=1)
338
+ scale_1 = dist_1[dist_1 < np.percentile(dist_1, 99)].mean()
339
+ dist_2 = np.linalg.norm(points_3D_at_pixels_2 - points_3D_at_pixels_2.mean(axis=0), axis=1)
340
+ scale_2 = dist_2[dist_2 < np.percentile(dist_2, 99)].mean()
341
+ # scale_1 = np.linalg.norm(points_3D_at_pixels - points_3D_at_pixels.mean(axis=0), axis=1).mean()
342
+ # scale_2 = np.linalg.norm(points_3D_at_pixels_2 - points_3D_at_pixels_2.mean(axis=0), axis=1).mean()
343
+ points_3D_at_pixels_2 = points_3D_at_pixels_2 * (scale_1 / scale_2)
344
+ pcd_1 = o3d.geometry.PointCloud()
345
+ pcd_1.points = o3d.utility.Vector3dVector(points_3D_at_pixels)
346
+ pcd_2 = o3d.geometry.PointCloud()
347
+ pcd_2.points = o3d.utility.Vector3dVector(points_3D_at_pixels_2)
348
+ indices = np.arange(points_3D_at_pixels.shape[0])
349
+ correspondences = np.stack([indices, indices], axis=1)
350
+ correspondences = o3d.utility.Vector2iVector(correspondences)
351
+ result = o3d.pipelines.registration.registration_ransac_based_on_correspondence(
352
+ pcd_2,
353
+ pcd_1,
354
+ correspondences,
355
+ 0.03,
356
+ estimation_method=o3d.pipelines.registration.TransformationEstimationPointToPoint(False),
357
+ ransac_n=5,
358
+ criteria=o3d.pipelines.registration.RANSACConvergenceCriteria(10000, 10000),
359
+ )
360
+ transformation_matrix = result.transformation.copy()
361
+ transformation_matrix[:3,:3] = transformation_matrix[:3,:3] * (scale_1 / scale_2)
362
+ return transformation_matrix, result.fitness
363
+
364
+ @spaces.GPU(duration=120)
365
+ def generate_and_extract_glb(
366
+ multiimages: List[Tuple[Image.Image, str]],
367
+ seed: int,
368
+ ss_guidance_strength: float,
369
+ ss_sampling_steps: int,
370
+ slat_guidance_strength: float,
371
+ slat_sampling_steps: int,
372
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
373
+ mesh_simplify: float,
374
+ texture_size: int,
375
+ refine: Literal["Yes", "No"],
376
+ ss_refine: Literal["noise", "deltav", "No"],
377
+ registration_num_frames: int,
378
+ trellis_stage1_lr: float,
379
+ trellis_stage1_start_t: float,
380
+ trellis_stage2_lr: float,
381
+ trellis_stage2_start_t: float,
382
+ req: gr.Request,
383
+ ) -> Tuple[dict, str, str, str]:
384
+ """
385
+ Convert an image to a 3D model and extract GLB file.
386
+
387
+ Args:
388
+ image (Image.Image): The input image.
389
+ multiimages (List[Tuple[Image.Image, str]]): The input images in multi-image mode.
390
+ is_multiimage (bool): Whether is in multi-image mode.
391
+ seed (int): The random seed.
392
+ ss_guidance_strength (float): The guidance strength for sparse structure generation.
393
+ ss_sampling_steps (int): The number of sampling steps for sparse structure generation.
394
+ slat_guidance_strength (float): The guidance strength for structured latent generation.
395
+ slat_sampling_steps (int): The number of sampling steps for structured latent generation.
396
+ multiimage_algo (Literal["multidiffusion", "stochastic"]): The algorithm for multi-image generation.
397
+ mesh_simplify (float): The mesh simplification factor.
398
+ texture_size (int): The texture resolution.
399
+
400
+ Returns:
401
+ dict: The information of the generated 3D model.
402
+ str: The path to the video of the 3D model.
403
+ str: The path to the extracted GLB file.
404
+ str: The path to the extracted GLB file (for download).
405
+ """
406
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
407
+ image_files = [image[0] for image in multiimages]
408
+
409
+ # Generate 3D model
410
+ outputs, coords, ss_noise = pipeline.run(
411
+ image=image_files,
412
+ seed=seed,
413
+ formats=["gaussian", "mesh"],
414
+ preprocess_image=False,
415
+ sparse_structure_sampler_params={
416
+ "steps": ss_sampling_steps,
417
+ "cfg_strength": ss_guidance_strength,
418
+ },
419
+ slat_sampler_params={
420
+ "steps": slat_sampling_steps,
421
+ "cfg_strength": slat_guidance_strength,
422
+ },
423
+ mode=multiimage_algo,
424
+ )
425
+ if refine == "Yes":
426
+ try:
427
+ images, alphas = load_and_preprocess_images(multiimages)
428
+ images, alphas = images.to(device), alphas.to(device)
429
+ with torch.no_grad():
430
+ with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
431
+ images = images[None]
432
+ aggregated_tokens_list, ps_idx = pipeline.VGGT_model.aggregator(images)
433
+ # Predict Cameras
434
+ pose_enc = pipeline.VGGT_model.camera_head(aggregated_tokens_list)[-1]
435
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
436
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
437
+ # Predict Point Cloud
438
+ point_map, point_conf = pipeline.VGGT_model.point_head(aggregated_tokens_list, images, ps_idx)
439
+ del aggregated_tokens_list
440
+ mask = (alphas[:,0,...][...,None] > 0.8)
441
+ conf_threshold = np.percentile(point_conf.cpu().numpy(), 50)
442
+ confidence_mask = (point_conf[0] > conf_threshold) & (point_conf[0] > 1e-5)
443
+ mask = mask & confidence_mask[...,None]
444
+ point_map_by_unprojection = point_map[0]
445
+ point_map_clean = point_map_by_unprojection[mask[...,0]]
446
+ center_point = point_map_clean.mean(0)
447
+ scale = np.percentile((point_map_clean - center_point[None]).norm(dim=-1).cpu().numpy(), 98)
448
+ outlier_mask = (point_map_by_unprojection - center_point[None]).norm(dim=-1) <= scale
449
+ final_mask = mask & outlier_mask[...,None]
450
+ point_map_perframe = (point_map_by_unprojection - center_point[None, None, None]) / (2 * scale)
451
+ point_map_perframe[~final_mask[...,0]] = 127/255
452
+ point_map_perframe = point_map_perframe.permute(0,3,1,2)
453
+ images = images[0].permute(0,2,3,1)
454
+ images[~(alphas[:,0,...][...,None] > 0.8)[...,0]] = 0.
455
+ input_images = images.permute(0,3,1,2).clone()
456
+ vggt_extrinsic = extrinsic[0]
457
+ vggt_extrinsic = torch.cat([vggt_extrinsic, torch.tensor([[[0,0,0,1]]]).repeat(vggt_extrinsic.shape[0], 1, 1).to(vggt_extrinsic)], dim=1)
458
+ vggt_intrinsic = intrinsic[0]
459
+ vggt_intrinsic[:,:2] = vggt_intrinsic[:,:2] / 518
460
+ vggt_extrinsic[:,:3,3] = (torch.matmul(vggt_extrinsic[:,:3,:3], center_point[None,:,None].float())[...,0] + vggt_extrinsic[:,:3,3]) / (2 * scale)
461
+ pointcloud = point_map_perframe.permute(0,2,3,1)[final_mask[...,0]]
462
+ idxs = torch.randperm(pointcloud.shape[0])[:min(50000, pointcloud.shape[0])]
463
+ pcd = o3d.geometry.PointCloud()
464
+ pcd.points = o3d.utility.Vector3dVector(pointcloud[idxs].cpu().numpy())
465
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=30, std_ratio=3.0)
466
+ inlier_cloud = pcd.select_by_index(ind)
467
+ outlier_cloud = pcd.select_by_index(ind, invert=True)
468
+ voxel_size = 1/64
469
+ down_pcd = inlier_cloud.voxel_down_sample(voxel_size)
470
+ torch.cuda.empty_cache()
471
+
472
+ video, rend_extrinsics, rend_intrinsics = render_utils.render_multiview(outputs['gaussian'][0], num_frames=registration_num_frames)
473
+ rend_extrinsics = torch.stack(rend_extrinsics, dim=0)
474
+ rend_intrinsics = torch.stack(rend_intrinsics, dim=0)
475
+ target_extrinsics = []
476
+ target_intrinsics = []
477
+ target_transforms = []
478
+ target_fitnesses = []
479
+ for k in range(len(image_files)):
480
+ images = torch.stack([TF.ToTensor()(render_image) for render_image in video['color']] + [TF.ToTensor()(image_files[k].convert("RGB"))], dim=0)
481
+ # if len(images) == 0:
482
+ with torch.no_grad():
483
+ with torch.cuda.amp.autocast(dtype=pipeline.VGGT_dtype):
484
+ # predictions = vggt_model(images.cuda())
485
+ aggregated_tokens_list, ps_idx = pipeline.VGGT_model.aggregator(images[None].cuda())
486
+ pose_enc = pipeline.VGGT_model.camera_head(aggregated_tokens_list)[-1]
487
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
488
+ extrinsic, intrinsic = extrinsic[0], intrinsic[0]
489
+ extrinsic = torch.cat([extrinsic, torch.tensor([0,0,0,1])[None,None].repeat(extrinsic.shape[0], 1, 1).to(extrinsic.device)], dim=1)
490
+ del aggregated_tokens_list, ps_idx
491
+
492
+ target_extrinsic, target_intrinsic = align_camera(registration_num_frames, extrinsic, intrinsic, rend_extrinsics, rend_intrinsics)
493
+ fxy = target_intrinsic[:,0,0]
494
+ target_intrinsic_tmp = target_intrinsic.clone()
495
+ target_intrinsic_tmp[:,:2] = target_intrinsic_tmp[:,:2] / 518
496
+
497
+ target_extrinsic_list = [target_extrinsic]
498
+ iou_list = []
499
+ iterations = 3
500
+ for i in range(iterations + 1):
501
+ j = 0
502
+ rend = render_utils.render_frames(outputs['gaussian'][0], target_extrinsic, target_intrinsic_tmp, {'resolution': 518, 'bg_color': (0, 0, 0)}, need_depth=True)
503
+ rend_image = rend['color'][j] # (518, 518, 3)
504
+ rend_depth = rend['depth'][j] # (3, 518, 518)
505
+
506
+ depth_single = rend_depth[0].astype(np.float32) # (H, W)
507
+ mask = (depth_single != 0).astype(np.uint8) #
508
+ kernel = np.ones((3, 3), np.uint8)
509
+ mask_eroded = cv2.erode(mask, kernel, iterations=3)
510
+ depth_eroded = depth_single * mask_eroded
511
+ rend_depth_eroded = np.stack([depth_eroded]*3, axis=0)
512
+
513
+ rend_image = torch.tensor(rend_image).permute(2,0,1) / 255
514
+ target_image = images[registration_num_frames:].to(target_extrinsic.device)[j]
515
+ original_size = (rend_image.shape[1], rend_image.shape[2])
516
+
517
+ import torchvision
518
+ torchvision.utils.save_image(rend_image, 'rend_image_{}.png'.format(k))
519
+ torchvision.utils.save_image(target_image, 'target_image_{}.png'.format(k))
520
+
521
+ mask_rend = (rend_image.detach().cpu() > 0).any(dim=0)
522
+ mask_target = (target_image.detach().cpu() > 0).any(dim=0)
523
+ intersection = (mask_rend & mask_target).sum().item()
524
+ union = (mask_rend | mask_target).sum().item()
525
+ iou = intersection / union if union > 0 else 0.0
526
+ iou_list.append(iou)
527
+
528
+ if i == iterations:
529
+ break
530
+
531
+ rend_image = rend_image * torch.from_numpy(mask_eroded[None]).to(rend_image.device)
532
+ rend_image_pil = Image.fromarray((rend_image.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8))
533
+ target_image_pil = Image.fromarray((target_image.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8))
534
+ target_extrinsic[j:j+1] = refine_pose_mast3r(rend_image_pil, target_image_pil, original_size, fxy[j:j+1], target_extrinsic[j:j+1], rend_depth_eroded)
535
+ target_extrinsic_list.append(target_extrinsic[j:j+1])
536
+
537
+ idx = iou_list.index(max(iou_list))
538
+ target_extrinsic[j:j+1] = target_extrinsic_list[idx]
539
+ target_transform, fitness = pointcloud_registration(rend_image_pil, target_image_pil, original_size, fxy[j:j+1], target_extrinsic[j:j+1], rend_depth_eroded, point_map_perframe[k].cpu().numpy())
540
+ target_transforms.append(target_transform)
541
+ target_fitnesses.append(fitness)
542
+
543
+ target_extrinsics.append(target_extrinsic[j:j+1])
544
+ target_intrinsics.append(target_intrinsic_tmp[j:j+1])
545
+ target_extrinsics = torch.cat(target_extrinsics, dim=0)
546
+ target_intrinsics = torch.cat(target_intrinsics, dim=0)
547
+
548
+ target_fitnesses_filtered = [x for x in target_fitnesses if x < 1]
549
+ idx = target_fitnesses.index(max(target_fitnesses_filtered))
550
+ target_transform = target_transforms[idx]
551
+ down_pcd_align = copy.deepcopy(down_pcd).transform(target_transform)
552
+ pcd = o3d.geometry.PointCloud()
553
+ pcd.points = o3d.utility.Vector3dVector(coords[:,1:].cpu().numpy() / 64 - 0.5)
554
+ reg_p2p = o3d.pipelines.registration.registration_icp(
555
+ down_pcd_align, pcd, 0.01, np.eye(4),
556
+ o3d.pipelines.registration.TransformationEstimationPointToPoint(with_scaling=True),
557
+ o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration = 10000))
558
+ down_pcd_align_2 = copy.deepcopy(down_pcd_align).transform(reg_p2p.transformation)
559
+ input_points = torch.tensor(np.asarray(down_pcd_align_2.points)).to(extrinsic.device).float()
560
+ input_points = ((input_points + 0.5).clip(0, 1) * 63).to(torch.int32)
561
+
562
+ outputs = pipeline.run_refine(
563
+ image=image_files,
564
+ ss_learning_rate=trellis_stage1_lr,
565
+ ss_start_t=trellis_stage1_start_t,
566
+ apperance_learning_rate=trellis_stage2_lr,
567
+ apperance_start_t=trellis_stage2_start_t,
568
+ extrinsics=target_extrinsics,
569
+ intrinsics=target_intrinsics,
570
+ ss_noise=ss_noise,
571
+ input_points=input_points,
572
+ ss_refine_type = ss_refine,
573
+ coords=coords if ss_refine == "No" else None,
574
+ seed=seed,
575
+ formats=["mesh", "gaussian"],
576
+ sparse_structure_sampler_params={
577
+ "steps": ss_sampling_steps,
578
+ "cfg_strength": ss_guidance_strength,
579
+ },
580
+ slat_sampler_params={
581
+ "steps": slat_sampling_steps,
582
+ "cfg_strength": slat_guidance_strength,
583
+ },
584
+ mode=multiimage_algo,
585
+ )
586
+ except Exception as e:
587
+ print(f"Error during refinement: {e}")
588
+ # Render video
589
+ # import uuid
590
+ # output_id = str(uuid.uuid4())
591
+ # os.makedirs(f"{TMP_DIR}/{output_id}", exist_ok=True)
592
+ # video_path = f"{TMP_DIR}/{output_id}/preview.mp4"
593
+ # glb_path = f"{TMP_DIR}/{output_id}/mesh.glb"
594
+ video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
595
+ video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
596
+ video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
597
+ video_path = os.path.join(user_dir, 'sample.mp4')
598
+ imageio.mimsave(video_path, video, fps=15)
599
+
600
+ # Extract GLB
601
+ gs = outputs['gaussian'][0]
602
+ mesh = outputs['mesh'][0]
603
+ glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
604
+ glb_path = os.path.join(user_dir, 'sample.glb')
605
+ glb.export(glb_path)
606
+
607
+ # Pack state for optional Gaussian extraction
608
+ state = pack_state(gs, mesh)
609
+
610
+ torch.cuda.empty_cache()
611
+ return state, video_path, glb_path, glb_path
612
+
613
+
614
+ @spaces.GPU
615
+ def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
616
+ """
617
+ Extract a Gaussian splatting file from the generated 3D model.
618
+
619
+ This function is called when the user clicks "Extract Gaussian" button.
620
+ It converts the 3D model state into a .ply file format containing
621
+ Gaussian splatting data for advanced 3D applications.
622
+
623
+ Args:
624
+ state (dict): The state of the generated 3D model containing Gaussian data
625
+ req (gr.Request): Gradio request object for session management
626
+
627
+ Returns:
628
+ Tuple[str, str]: Paths to the extracted Gaussian file (for display and download)
629
+ """
630
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
631
+ gs, _ = unpack_state(state)
632
+ gaussian_path = os.path.join(user_dir, 'sample.ply')
633
+ gs.save_ply(gaussian_path)
634
+ torch.cuda.empty_cache()
635
+ return gaussian_path, gaussian_path
636
+
637
+
638
+ def prepare_multi_example() -> List[Image.Image]:
639
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
640
+ images = []
641
+ for case in multi_case:
642
+ _images = []
643
+ for i in range(1, 9):
644
+ if os.path.exists(f'assets/example_multi_image/{case}_{i}.png'):
645
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
646
+ W, H = img.size
647
+ img = img.resize((int(W / H * 512), 512))
648
+ _images.append(np.array(img))
649
+ if len(_images) > 0:
650
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
651
+ return images
652
+
653
+
654
+ def split_image(image: Image.Image) -> List[Image.Image]:
655
+ """
656
+ Split a multi-view image into separate view images.
657
+
658
+ This function is called when users select multi-image examples that contain
659
+ multiple views in a single concatenated image. It automatically splits them
660
+ based on alpha channel boundaries and preprocesses each view.
661
+
662
+ Args:
663
+ image (Image.Image): A concatenated image containing multiple views
664
+
665
+ Returns:
666
+ List[Image.Image]: List of individual preprocessed view images
667
+ """
668
+ image = np.array(image)
669
+ alpha = image[..., 3]
670
+ alpha = np.any(alpha>0, axis=0)
671
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
672
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
673
+ images = []
674
+ for s, e in zip(start_pos, end_pos):
675
+ images.append(Image.fromarray(image[:, s:e+1]))
676
+ return [preprocess_image(image) for image in images]
677
+
678
+ # Create interface
679
+ demo = gr.Blocks(
680
+ title="ReconViaGen",
681
+ css="""
682
+ .slider .inner { width: 5px; background: #FFF; }
683
+ .viewport { aspect-ratio: 4/3; }
684
+ .tabs button.selected { font-size: 20px !important; color: crimson !important; }
685
+ h1, h2, h3 { text-align: center; display: block; }
686
+ .md_feedback li { margin-bottom: 0px !important; }
687
+ """
688
+ )
689
+ with demo:
690
+ gr.Markdown("""
691
+ # 💻 ReconViaGen
692
+ <p align="center">
693
+ <a title="Github" href="https://github.com/GAP-LAB-CUHK-SZ/ReconViaGen" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
694
+ <img src="https://img.shields.io/github/stars/GAP-LAB-CUHK-SZ/ReconViaGen?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars">
695
+ </a>
696
+ <a title="Website" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
697
+ <img src="https://www.obukhov.ai/img/badges/badge-website.svg">
698
+ </a>
699
+ <a title="arXiv" href="https://jiahao620.github.io/reconviagen/" target="_blank" rel="noopener noreferrer" style="display: inline-block;">
700
+ <img src="https://www.obukhov.ai/img/badges/badge-pdf.svg">
701
+ </a>
702
+ </p>
703
+
704
+ ✨This demo is partial. We will release the whole model later. Stay tuned!✨
705
+ """)
706
+
707
+ with gr.Row():
708
+ with gr.Column():
709
+ with gr.Tabs() as input_tabs:
710
+ with gr.Tab(label="Input Video or Images", id=0) as multiimage_input_tab:
711
+ input_video = gr.Video(label="Upload Video", interactive=True, height=300)
712
+ image_prompt = gr.Image(label="Image Prompt", format="png", visible=False, image_mode="RGBA", type="pil", height=300)
713
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=300, columns=3)
714
+ gr.Markdown("""
715
+ Input different views of the object in separate images.
716
+ """)
717
+
718
+ with gr.Accordion(label="Generation Settings", open=False):
719
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
720
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
721
+ gr.Markdown("Stage 1: Sparse Structure Generation")
722
+ with gr.Row():
723
+ ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
724
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=30, step=1)
725
+ gr.Markdown("Stage 2: Structured Latent Generation")
726
+ with gr.Row():
727
+ slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
728
+ slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
729
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="multidiffusion")
730
+ refine = gr.Radio(["Yes", "No"], label="Refinement of Not", value="Yes")
731
+ ss_refine = gr.Radio(["noise", "deltav", "No"], label="Sparse Structure refinement of not", value="No")
732
+ registration_num_frames = gr.Slider(20, 50, label="Number of frames in registration", value=30, step=1)
733
+ trellis_stage1_lr = gr.Slider(1e-4, 1., label="trellis_stage1_lr", value=1e-1, step=5e-4)
734
+ trellis_stage1_start_t = gr.Slider(0., 1., label="trellis_stage1_start_t", value=0.5, step=0.01)
735
+ trellis_stage2_lr = gr.Slider(1e-4, 1., label="trellis_stage2_lr", value=1e-1, step=5e-4)
736
+ trellis_stage2_start_t = gr.Slider(0., 1., label="trellis_stage2_start_t", value=0.5, step=0.01)
737
+
738
+ with gr.Accordion(label="GLB Extraction Settings", open=False):
739
+ mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
740
+ texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
741
+
742
+ generate_btn = gr.Button("Generate & Extract GLB", variant="primary")
743
+ extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
744
+ gr.Markdown("""
745
+ *NOTE: Gaussian file can be very large (~50MB), it will take a while to display and download.*
746
+ """)
747
+
748
+ with gr.Column():
749
+ video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
750
+ model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=10.0, height=300)
751
+
752
+ with gr.Row():
753
+ download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
754
+ download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
755
+
756
+ output_buf = gr.State()
757
+
758
+ # Example images at the bottom of the page
759
+ with gr.Row() as multiimage_example:
760
+ examples_multi = gr.Examples(
761
+ examples=prepare_multi_example(),
762
+ inputs=[image_prompt],
763
+ fn=split_image,
764
+ outputs=[multiimage_prompt],
765
+ run_on_click=True,
766
+ examples_per_page=8,
767
+ )
768
+
769
+ # Handlers
770
+ demo.load(start_session)
771
+ demo.unload(end_session)
772
+
773
+ input_video.upload(
774
+ preprocess_videos,
775
+ inputs=[input_video],
776
+ outputs=[multiimage_prompt],
777
+ )
778
+ input_video.clear(
779
+ lambda: tuple([None, None]),
780
+ outputs=[input_video, multiimage_prompt],
781
+ )
782
+ multiimage_prompt.upload(
783
+ preprocess_images,
784
+ inputs=[multiimage_prompt],
785
+ outputs=[multiimage_prompt],
786
+ )
787
+
788
+ generate_btn.click(
789
+ get_seed,
790
+ inputs=[randomize_seed, seed],
791
+ outputs=[seed],
792
+ ).then(
793
+ generate_and_extract_glb,
794
+ inputs=[multiimage_prompt, seed, ss_guidance_strength, ss_sampling_steps,
795
+ slat_guidance_strength, slat_sampling_steps, multiimage_algo,
796
+ mesh_simplify, texture_size, refine, ss_refine, registration_num_frames,
797
+ trellis_stage1_lr, trellis_stage1_start_t, trellis_stage2_lr,
798
+ trellis_stage2_start_t],
799
+ outputs=[output_buf, video_output, model_output, download_glb],
800
+ ).then(
801
+ lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
802
+ outputs=[extract_gs_btn, download_glb],
803
+ )
804
+
805
+ video_output.clear(
806
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False), gr.Button(interactive=False)]),
807
+ outputs=[extract_gs_btn, download_glb, download_gs],
808
+ )
809
+
810
+ extract_gs_btn.click(
811
+ extract_gaussian,
812
+ inputs=[output_buf],
813
+ outputs=[model_output, download_gs],
814
+ ).then(
815
+ lambda: gr.Button(interactive=True),
816
+ outputs=[download_gs],
817
+ )
818
+
819
+ model_output.clear(
820
+ lambda: tuple([gr.Button(interactive=False), gr.Button(interactive=False)]),
821
+ outputs=[download_glb, download_gs],
822
+ )
823
+
824
+
825
+ # Launch the Gradio app
826
+ if __name__ == "__main__":
827
+ pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2")
828
+ # pipeline = TrellisVGGTTo3DPipeline.from_pretrained("weights/trellis-vggt-v0-1")
829
+ pipeline.cuda()
830
+ pipeline.VGGT_model.cuda()
831
+ pipeline.birefnet_model.cuda()
832
+ pipeline.dreamsim_model.cuda()
833
+ mast3r_model = AsymmetricMASt3R.from_pretrained("naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric").cuda().eval()
834
+ # mast3r_model = AsymmetricMASt3R.from_pretrained("weights/MAST3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth").cuda().eval()
835
+ demo.launch()
trellis/pipelines/trellis_image_to_3d.py CHANGED
@@ -19,7 +19,8 @@ from typing import *
19
  from scipy.spatial.transform import Rotation
20
  from transformers import AutoModelForImageSegmentation
21
  import rembg
22
- # from dreamsim import dreamsim
 
23
 
24
  def export_point_cloud(xyz, color):
25
  # Convert tensors to numpy arrays if needed
@@ -475,6 +476,76 @@ class TrellisImageTo3DPipeline(Pipeline):
475
 
476
  return coords
477
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  def encode_slat(
479
  self,
480
  slat: sp.SparseTensor,
@@ -907,6 +978,7 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline):
907
  intrinsics: torch.Tensor,
908
  ss_noise: torch.Tensor,
909
  input_points: torch.Tensor,
 
910
  coords: torch.Tensor = None,
911
  num_samples: int = 1,
912
  seed: int = 42,
@@ -928,7 +1000,10 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline):
928
  ss = ss[None, None]
929
  torch.cuda.empty_cache()
930
  # Sample structured latent
931
- coords = self.sample_sparse_structure_opt(ss_cond, ss, ss_learning_rate, ss_start_t, num_samples, sparse_structure_sampler_params)
 
 
 
932
  torch.cuda.empty_cache()
933
 
934
  # pcd = o3d.geometry.PointCloud()
@@ -987,8 +1062,8 @@ class TrellisVGGTTo3DPipeline(TrellisImageTo3DPipeline):
987
 
988
  new_pipeline._init_image_cond_model(args['image_cond_model'])
989
 
990
- # model, _ = dreamsim(pretrained=True, device=new_pipeline.device, dreamsim_type="dino_vitb16", cache_dir="weights/dreamsim")
991
- # new_pipeline.dreamsim_model = model
992
- # new_pipeline.dreamsim_model.eval()
993
 
994
- return new_pipeline
 
19
  from scipy.spatial.transform import Rotation
20
  from transformers import AutoModelForImageSegmentation
21
  import rembg
22
+ from dreamsim import dreamsim
23
+ from tqdm import tqdm
24
 
25
  def export_point_cloud(xyz, color):
26
  # Convert tensors to numpy arrays if needed
 
476
 
477
  return coords
478
 
479
+ def sample_sparse_structure_opt_noise(
480
+ self,
481
+ cond: dict,
482
+ ss: torch.Tensor,
483
+ ss_learning_rate: float=1e-3,
484
+ num_samples: int = 1,
485
+ sampler_params: dict = {},
486
+ noise: torch.Tensor = None,
487
+ ) -> torch.Tensor:
488
+ """
489
+ Sample sparse structures with the given conditioning.
490
+
491
+ Args:
492
+ cond (dict): The conditioning information.
493
+ num_samples (int): The number of samples to generate.
494
+ sampler_params (dict): Additional parameters for the sampler.
495
+ """
496
+ # Sample occupancy latent
497
+ flow_model = self.models['sparse_structure_flow_model']
498
+ ss = ss.float()
499
+ reso = flow_model.resolution
500
+ if noise is None:
501
+ noise = torch.randn(num_samples, flow_model.in_channels, reso, reso, reso).to(self.device)
502
+ torch.cuda.empty_cache()
503
+ noise = torch.nn.Parameter(noise.to(self.device))
504
+ optimizer = torch.optim.Adam([noise], betas=(0.5, 0.9), lr=ss_learning_rate)
505
+ total_steps = 5
506
+ def cosine_anealing(step, total_steps, start_lr, end_lr):
507
+ return end_lr + 0.5 * (start_lr - end_lr) * (1 + np.cos(np.pi * step / total_steps))
508
+ sampler_params = {**self.sparse_structure_sampler_params, **sampler_params}
509
+ fix_cond = cond['cond'].clone()
510
+ with tqdm(total=total_steps, disable=False, desc='Geometry (opt): optimizing') as pbar:
511
+ for step in range(total_steps):
512
+ optimizer.zero_grad()
513
+ shuffle_idx = torch.randperm(fix_cond.shape[0])
514
+ cond['cond'] = fix_cond[shuffle_idx]
515
+ norm_noise = (noise - noise.mean()) / noise.std()
516
+ ss_slat = self.sparse_structure_sampler.sample_opt(
517
+ flow_model,
518
+ norm_noise,
519
+ **cond,
520
+ **{**self.sparse_structure_sampler_params, **{"steps": 1, "cfg_strength": sampler_params["cfg_strength"]}},
521
+ verbose=False
522
+ ).samples
523
+ ss_decoder = self.models['sparse_structure_decoder']
524
+ logits = F.sigmoid(ss_decoder(ss_slat))
525
+ loss = 1 - (2 * (logits * ss.float()).sum() + 1) / (logits.sum() + ss.float().sum() + 1)
526
+ # loss.backward()
527
+ # optimizer.step()
528
+ # 仅对 noise 求导,避免保留整个计算图(比 retain_graph=True 更省显存)
529
+ grads = torch.autograd.grad(loss, noise, retain_graph=False, allow_unused=False)[0]
530
+ # 把梯度写回 noise.grad 供 optimizer 使用
531
+ noise.grad = grads
532
+ optimizer.step()
533
+ optimizer.param_groups[0]['lr'] = cosine_anealing(step, total_steps, ss_learning_rate, 1e-5)
534
+ pbar.set_postfix({'loss': loss.item()})
535
+ pbar.update()
536
+
537
+ noise = noise.detach()
538
+ torch.cuda.empty_cache()
539
+ z_s = self.sparse_structure_sampler.sample(
540
+ flow_model,
541
+ noise,
542
+ **cond,
543
+ **sampler_params,
544
+ verbose=True
545
+ ).samples
546
+ coords = torch.argwhere(ss_decoder(z_s)>0)[:, [0, 2, 3, 4]].int()
547
+ return coords
548
+
549
  def encode_slat(
550
  self,
551
  slat: sp.SparseTensor,
 
978
  intrinsics: torch.Tensor,
979
  ss_noise: torch.Tensor,
980
  input_points: torch.Tensor,
981
+ ss_refine_type: str = 'No',
982
  coords: torch.Tensor = None,
983
  num_samples: int = 1,
984
  seed: int = 42,
 
1000
  ss = ss[None, None]
1001
  torch.cuda.empty_cache()
1002
  # Sample structured latent
1003
+ if ss_refine_type == 'noise':
1004
+ coords = self.sample_sparse_structure_opt_noise(ss_cond, ss, ss_learning_rate, num_samples, sparse_structure_sampler_params, ss_noise)
1005
+ elif ss_refine_type == 'deltav':
1006
+ coords = self.sample_sparse_structure_opt(ss_cond, ss, ss_learning_rate, ss_start_t, num_samples, sparse_structure_sampler_params, ss_noise)
1007
  torch.cuda.empty_cache()
1008
 
1009
  # pcd = o3d.geometry.PointCloud()
 
1062
 
1063
  new_pipeline._init_image_cond_model(args['image_cond_model'])
1064
 
1065
+ model, _ = dreamsim(pretrained=True, device=new_pipeline.device, dreamsim_type="dino_vitb16", cache_dir="weights/dreamsim")
1066
+ new_pipeline.dreamsim_model = model
1067
+ new_pipeline.dreamsim_model.eval()
1068
 
1069
+ return new_pipeline