# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ MapAnything V2 - 3D重建系统(中文版) - 多视图 3D 重建 - 深度估计与法线计算 - 距离测量功能 """ import gc import os import shutil import sys import time from datetime import datetime os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import cv2 import gradio as gr import numpy as np import spaces import torch from PIL import Image from pillow_heif import register_heif_opener register_heif_opener() sys.path.append("mapanything/") from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals from mapanything.utils.hf_utils.css_and_html import ( GRADIO_CSS, MEASURE_INSTRUCTIONS_HTML, get_acknowledgements_html, get_description_html, get_gradio_theme, get_header_html, ) from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model from mapanything.utils.hf_utils.visual_util import predictions_to_glb from mapanything.utils.image import load_images, rgb def get_logo_base64(): """Convert WAI logo to base64 for embedding in HTML""" import base64 logo_path = "examples/WAI-Logo/wai_logo.png" try: with open(logo_path, "rb") as img_file: img_data = img_file.read() base64_str = base64.b64encode(img_data).decode() return f"data:image/png;base64,{base64_str}" except FileNotFoundError: return None # MapAnything Configuration high_level_config = { "path": "configs/train.yaml", "hf_model_name": "facebook/map-anything", "model_str": "mapanything", "config_overrides": [ "machine=aws", "model=mapanything", "model/task=images_only", "model.encoder.uses_torch_hub=false", ], "checkpoint_name": "model.safetensors", "config_name": "config.json", "trained_with_amp": True, "trained_with_amp_dtype": "bf16", "data_norm_type": "dinov2", "patch_size": 14, "resolution": 518, } # Initialize model - this will be done on GPU when needed model = None # ------------------------------------------------------------------------- # 1) Core model inference # ------------------------------------------------------------------------- @spaces.GPU(duration=120) def run_model( target_dir, apply_mask=True, mask_edges=True, filter_black_bg=False, filter_white_bg=False, progress=gr.Progress(), ): """ Run the MapAnything model on images in the 'target_dir/images' folder and return predictions. """ global model import torch # Ensure torch is available in function scope start_time = time.time() print(f"Processing images from {target_dir}") # Device check progress(0, desc="🔧 初始化设备...") device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # Initialize model if not already done progress(0.05, desc="📥 加载模型... (~5秒)") if model is None: model = initialize_mapanything_model(high_level_config, device) else: model = model.to(device) model.eval() # Load images using MapAnything's load_images function progress(0.15, desc="📷 加载图片... (~2秒)") print("Loading images...") image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) print(f"Loaded {len(views)} images") if len(views) == 0: raise ValueError("No images found. Check your upload.") # Run model inference num_images = len(views) estimated_time = num_images * 3 # 估计每张图片3秒 progress(0.2, desc=f"🚀 运行3D重建... ({num_images}张图片,预计{estimated_time}秒)") print("Running inference...") inference_start = time.time() outputs = model.infer( views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False ) inference_time = time.time() - inference_start # Convert predictions to format expected by visualization progress(0.6, desc=f"🔄 处理预测结果... (推理耗时: {inference_time:.1f}秒)") predictions = {} # Initialize lists for the required keys extrinsic_list = [] intrinsic_list = [] world_points_list = [] depth_maps_list = [] images_list = [] final_mask_list = [] # Loop through the outputs for i, pred in enumerate(outputs): if i % max(1, len(outputs) // 5) == 0: progress(0.6 + (i / len(outputs)) * 0.25, desc=f"🔄 处理视图 {i+1}/{len(outputs)}...") # Extract data from predictions depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W) intrinsics_torch = pred["intrinsics"][0] # (3, 3) camera_pose_torch = pred["camera_poses"][0] # (4, 4) # Compute new pts3d using depth, intrinsics, and camera pose pts3d_computed, valid_mask = depthmap_to_world_frame( depthmap_torch, intrinsics_torch, camera_pose_torch ) # Convert to numpy arrays for visualization # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch if "mask" in pred: mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) else: # Fill with boolean trues in the size of depthmap_torch mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) # Combine with valid depth mask mask = mask & valid_mask.cpu().numpy() image = pred["img_no_norm"][0].cpu().numpy() # Append to lists extrinsic_list.append(camera_pose_torch.cpu().numpy()) intrinsic_list.append(intrinsics_torch.cpu().numpy()) world_points_list.append(pts3d_computed.cpu().numpy()) depth_maps_list.append(depthmap_torch.cpu().numpy()) images_list.append(image) # Add image to list final_mask_list.append(mask) # Add final_mask to list # Convert lists to numpy arrays with required shapes # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) # world_points: (S, H, W, 3) - batch of 3D world points predictions["world_points"] = np.stack(world_points_list, axis=0) # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps depth_maps = np.stack(depth_maps_list, axis=0) # Add channel dimension if needed to match (S, H, W, 1) format if len(depth_maps.shape) == 3: depth_maps = depth_maps[..., np.newaxis] predictions["depth"] = depth_maps # images: (S, H, W, 3) - batch of input images predictions["images"] = np.stack(images_list, axis=0) # final_mask: (S, H, W) - batch of final masks for filtering predictions["final_mask"] = np.stack(final_mask_list, axis=0) # Process data for visualization tabs (depth, normal, measure) progress(0.85, desc="🎨 生成深度图与法线图...") processed_data = process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg, filter_white_bg ) # Clean up progress(0.95, desc="🧹 清理内存...") torch.cuda.empty_cache() total_time = time.time() - start_time progress(1.0, desc=f"✅ 完成!总耗时: {total_time:.1f}秒") print(f"Total processing time: {total_time:.2f} seconds") return predictions, processed_data def update_view_selectors(processed_data): """Update view selector dropdowns based on available views""" if processed_data is None or len(processed_data) == 0: choices = ["View 1"] else: num_views = len(processed_data) choices = [f"View {i + 1}" for i in range(num_views)] return ( gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector ) def get_view_data_by_index(processed_data, view_index): """Get view data by index, handling bounds""" if processed_data is None or len(processed_data) == 0: return None view_keys = list(processed_data.keys()) if view_index < 0 or view_index >= len(view_keys): view_index = 0 return processed_data[view_keys[view_index]] def update_depth_view(processed_data, view_index): """Update depth view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["depth"] is None: return None return colorize_depth(view_data["depth"], mask=view_data.get("mask")) def update_normal_view(processed_data, view_index): """Update normal view for a specific view index""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None or view_data["normal"] is None: return None return colorize_normal(view_data["normal"], mask=view_data.get("mask")) def update_measure_view(processed_data, view_index): """Update measure view for a specific view index with mask overlay""" view_data = get_view_data_by_index(processed_data, view_index) if view_data is None: return None, [] # image, measure_points # Get the base image image = view_data["image"].copy() # Ensure image is in uint8 format if image.dtype != np.uint8: if image.max() <= 1.0: image = (image * 255).astype(np.uint8) else: image = image.astype(np.uint8) # Apply mask overlay if mask is available if view_data["mask"] is not None: mask = view_data["mask"] # Create light grey overlay for masked areas # Masked areas (False values) will be overlaid with light grey invalid_mask = ~mask # Areas where mask is False if invalid_mask.any(): # Create a light grey overlay (RGB: 192, 192, 192) overlay_color = np.array([255, 220, 220], dtype=np.uint8) # Apply overlay with some transparency alpha = 0.5 # Transparency level for c in range(3): # RGB channels image[:, :, c] = np.where( invalid_mask, (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], image[:, :, c], ).astype(np.uint8) return image, [] def navigate_depth_view(processed_data, current_selector_value, direction): """Navigate depth view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" depth_vis = update_depth_view(processed_data, new_view) return new_selector_value, depth_vis def navigate_normal_view(processed_data, current_selector_value, direction): """Navigate normal view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" normal_vis = update_normal_view(processed_data, new_view) return new_selector_value, normal_vis def navigate_measure_view(processed_data, current_selector_value, direction): """Navigate measure view (direction: -1 for previous, +1 for next)""" if processed_data is None or len(processed_data) == 0: return "View 1", None, [] # Parse current view number try: current_view = int(current_selector_value.split()[1]) - 1 except: current_view = 0 num_views = len(processed_data) new_view = (current_view + direction) % num_views new_selector_value = f"View {new_view + 1}" measure_image, measure_points = update_measure_view(processed_data, new_view) return new_selector_value, measure_image, measure_points def populate_visualization_tabs(processed_data): """Populate the depth, normal, and measure tabs with processed data""" if processed_data is None or len(processed_data) == 0: return None, None, None, [] # Use update functions to ensure confidence filtering is applied from the start depth_vis = update_depth_view(processed_data, 0) normal_vis = update_normal_view(processed_data, 0) measure_img, _ = update_measure_view(processed_data, 0) return depth_vis, normal_vis, measure_img, [] # ------------------------------------------------------------------------- # 2) Handle uploaded video/images --> produce target_dir + images # ------------------------------------------------------------------------- def handle_uploads(unified_upload, s_time_interval=1.0): """ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded images or extracted frames from video into it. Return (target_dir, image_paths). """ start_time = time.time() gc.collect() torch.cuda.empty_cache() # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") target_dir = f"input_images_{timestamp}" target_dir_images = os.path.join(target_dir, "images") # Clean up if somehow that folder already exists if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir) os.makedirs(target_dir_images) image_paths = [] # --- Handle uploaded files (both images and videos) --- if unified_upload is not None: for file_data in unified_upload: if isinstance(file_data, dict) and "name" in file_data: file_path = file_data["name"] else: file_path = str(file_data) file_ext = os.path.splitext(file_path)[1].lower() # Check if it's a video file video_extensions = [ ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp", ] if file_ext in video_extensions: # Handle as video vs = cv2.VideoCapture(file_path) fps = vs.get(cv2.CAP_PROP_FPS) frame_interval = int(fps * s_time_interval) # frames per interval count = 0 video_frame_num = 0 while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: # Use original filename as prefix for frames base_name = os.path.splitext(os.path.basename(file_path))[0] image_path = os.path.join( target_dir_images, f"{base_name}_{video_frame_num:06}.png" ) cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 vs.release() print( f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}" ) else: # Handle as image # Check if the file is a HEIC image if file_ext in [".heic", ".heif"]: # Convert HEIC to JPEG for better gallery compatibility try: with Image.open(file_path) as img: # Convert to RGB if necessary (HEIC can have different color modes) if img.mode not in ("RGB", "L"): img = img.convert("RGB") # Create JPEG filename base_name = os.path.splitext(os.path.basename(file_path))[0] dst_path = os.path.join( target_dir_images, f"{base_name}.jpg" ) # Save as JPEG with high quality img.save(dst_path, "JPEG", quality=95) image_paths.append(dst_path) print( f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}" ) except Exception as e: print(f"Error converting HEIC file {file_path}: {e}") # Fall back to copying as is dst_path = os.path.join( target_dir_images, os.path.basename(file_path) ) shutil.copy(file_path, dst_path) image_paths.append(dst_path) else: # Regular image files - copy as is dst_path = os.path.join( target_dir_images, os.path.basename(file_path) ) shutil.copy(file_path, dst_path) image_paths.append(dst_path) # Sort final images for gallery image_paths = sorted(image_paths) end_time = time.time() print( f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds" ) return target_dir, image_paths # ------------------------------------------------------------------------- # 3) Update gallery on upload # ------------------------------------------------------------------------- def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0): """ Whenever user uploads or changes files, immediately handle them and show in the gallery. Return (target_dir, image_paths). If nothing is uploaded, returns "None" and empty list. """ if not input_video and not input_images: return None, None, None, None target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval) return ( None, target_dir, image_paths, "上传完成。点击「开始重建」进行3D处理", ) # ------------------------------------------------------------------------- # 4) Reconstruction: uses the target_dir plus any viz parameters # ------------------------------------------------------------------------- @spaces.GPU(duration=120) def gradio_demo( target_dir, frame_filter="All", show_cam=True, filter_black_bg=False, filter_white_bg=False, apply_mask=True, show_mesh=True, progress=gr.Progress(), ): """ Perform reconstruction using the already-created target_dir/images. """ if not os.path.isdir(target_dir) or target_dir == "None": return None, "❌ 未找到有效的目标目录,请先上传文件", None, None, None, None, None, None, None, None, None progress(0, desc="🔄 准备重建...") start_time = time.time() gc.collect() torch.cuda.empty_cache() # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") all_files = ( sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] ) all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] frame_filter_choices = ["All"] + all_files progress(0.05, desc=f"🚀 运行 MapAnything 模型... ({len(all_files)}张图片)") print("Running MapAnything model...") with torch.no_grad(): predictions, processed_data = run_model( target_dir, apply_mask, True, filter_black_bg, filter_white_bg, progress ) # Save predictions progress(0.92, desc="💾 保存预测结果...") prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) # Handle None frame_filter if frame_filter is None: frame_filter = "All" # Build a GLB file name progress(0.93, desc="🏗️ 生成3D模型文件...") glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) # Convert predictions to GLB glbscene = predictions_to_glb( predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh, # Use the show_mesh parameter ) glbscene.export(file_obj=glbfile) # Cleanup progress(0.96, desc="🧹 清理内存...") del predictions gc.collect() torch.cuda.empty_cache() end_time = time.time() total_time = end_time - start_time print(f"总耗时: {total_time:.2f}秒") log_msg = f"✅ 重建成功 ({len(all_files)} 帧,耗时 {total_time:.1f}秒)" # Populate visualization tabs with processed data progress(0.98, desc="🎨 生成可视化...") depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs( processed_data ) # Update view selectors based on available views depth_selector, normal_selector, measure_selector = update_view_selectors( processed_data ) progress(1.0, desc="✅ 全部完成!") return ( glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), processed_data, depth_vis, normal_vis, measure_img, "", # measure_text (empty initially) depth_selector, normal_selector, measure_selector, ) # ------------------------------------------------------------------------- # 5) Helper functions for UI resets + re-visualization # ------------------------------------------------------------------------- def colorize_depth(depth_map, mask=None): """Convert depth map to colorized visualization with optional mask""" if depth_map is None: return None # Normalize depth to 0-1 range depth_normalized = depth_map.copy() valid_mask = depth_normalized > 0 # Apply additional mask if provided (for background filtering) if mask is not None: valid_mask = valid_mask & mask if valid_mask.sum() > 0: valid_depths = depth_normalized[valid_mask] p5 = np.percentile(valid_depths, 5) p95 = np.percentile(valid_depths, 95) depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) # Apply colormap import matplotlib.pyplot as plt colormap = plt.cm.turbo_r colored = colormap(depth_normalized) colored = (colored[:, :, :3] * 255).astype(np.uint8) # Set invalid pixels to white colored[~valid_mask] = [255, 255, 255] return colored def colorize_normal(normal_map, mask=None): """Convert normal map to colorized visualization with optional mask""" if normal_map is None: return None # Create a copy for modification normal_vis = normal_map.copy() # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization) if mask is not None: invalid_mask = ~mask normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero # Normalize normals to [0, 1] range for visualization normal_vis = (normal_vis + 1.0) / 2.0 normal_vis = (normal_vis * 255).astype(np.uint8) return normal_vis def process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False ): """Extract depth, normal, and 3D points from predictions for visualization""" processed_data = {} # Process each view for view_idx, view in enumerate(views): # Get image image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) # Get predicted points pred_pts3d = predictions["world_points"][view_idx] # Initialize data for this view view_data = { "image": image[0], "points3d": pred_pts3d, "depth": None, "normal": None, "mask": None, } # Start with the final mask from predictions mask = predictions["final_mask"][view_idx].copy() # Apply black background filtering if enabled if filter_black_bg: # Get the image colors (ensure they're in 0-255 range) view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] # Filter out black background pixels (sum of RGB < 16) black_bg_mask = view_colors.sum(axis=2) >= 16 mask = mask & black_bg_mask # Apply white background filtering if enabled if filter_white_bg: # Get the image colors (ensure they're in 0-255 range) view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] # Filter out white background pixels (all RGB > 240) white_bg_mask = ~( (view_colors[:, :, 0] > 240) & (view_colors[:, :, 1] > 240) & (view_colors[:, :, 2] > 240) ) mask = mask & white_bg_mask view_data["mask"] = mask view_data["depth"] = predictions["depth"][view_idx].squeeze() normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) view_data["normal"] = normals processed_data[view_idx] = view_data return processed_data def reset_measure(processed_data): """Reset measure points""" if processed_data is None or len(processed_data) == 0: return None, [], "" # Return the first view image first_view = list(processed_data.values())[0] return first_view["image"], [], "" def measure( processed_data, measure_points, current_view_selector, event: gr.SelectData ): """Handle measurement on images""" try: print(f"Measure function called with selector: {current_view_selector}") if processed_data is None or len(processed_data) == 0: return None, [], "No data available" # Use the currently selected view instead of always using the first view try: current_view_index = int(current_view_selector.split()[1]) - 1 except: current_view_index = 0 print(f"Using view index: {current_view_index}") # Get view data safely if current_view_index < 0 or current_view_index >= len(processed_data): current_view_index = 0 view_keys = list(processed_data.keys()) current_view = processed_data[view_keys[current_view_index]] if current_view is None: return None, [], "No view data available" point2d = event.index[0], event.index[1] print(f"Clicked point: {point2d}") # Check if the clicked point is in a masked area (prevent interaction) if ( current_view["mask"] is not None and 0 <= point2d[1] < current_view["mask"].shape[0] and 0 <= point2d[0] < current_view["mask"].shape[1] ): # Check if the point is in a masked (invalid) area if not current_view["mask"][point2d[1], point2d[0]]: print(f"Clicked point {point2d} is in masked area, ignoring click") # Always return image with mask overlay masked_image, _ = update_measure_view( processed_data, current_view_index ) return ( masked_image, measure_points, 'Cannot measure on masked areas (shown in grey)', ) measure_points.append(point2d) # Get image with mask overlay and ensure it's valid image, _ = update_measure_view(processed_data, current_view_index) if image is None: return None, [], "No image available" image = image.copy() points3d = current_view["points3d"] # Ensure image is in uint8 format for proper cv2 operations try: if image.dtype != np.uint8: if image.max() <= 1.0: # Image is in [0, 1] range, convert to [0, 255] image = (image * 255).astype(np.uint8) else: # Image is already in [0, 255] range image = image.astype(np.uint8) except Exception as e: print(f"Image conversion error: {e}") return None, [], f"Image conversion error: {e}" # Draw circles for points try: for p in measure_points: if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: image = cv2.circle( image, p, radius=5, color=(255, 0, 0), thickness=2 ) except Exception as e: print(f"Drawing error: {e}") return None, [], f"Drawing error: {e}" depth_text = "" try: for i, p in enumerate(measure_points): if ( current_view["depth"] is not None and 0 <= p[1] < current_view["depth"].shape[0] and 0 <= p[0] < current_view["depth"].shape[1] ): d = current_view["depth"][p[1], p[0]] depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" else: # Use Z coordinate of 3D points if depth not available if ( points3d is not None and 0 <= p[1] < points3d.shape[0] and 0 <= p[0] < points3d.shape[1] ): z = points3d[p[1], p[0], 2] depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n" except Exception as e: print(f"Depth text error: {e}") depth_text = f"Error computing depth: {e}\n" if len(measure_points) == 2: try: point1, point2 = measure_points # Draw line if ( 0 <= point1[0] < image.shape[1] and 0 <= point1[1] < image.shape[0] and 0 <= point2[0] < image.shape[1] and 0 <= point2[1] < image.shape[0] ): image = cv2.line( image, point1, point2, color=(255, 0, 0), thickness=2 ) # Compute 3D distance distance_text = "- **Distance: Unable to compute**" if ( points3d is not None and 0 <= point1[1] < points3d.shape[0] and 0 <= point1[0] < points3d.shape[1] and 0 <= point2[1] < points3d.shape[0] and 0 <= point2[0] < points3d.shape[1] ): try: p1_3d = points3d[point1[1], point1[0]] p2_3d = points3d[point2[1], point2[0]] distance = np.linalg.norm(p1_3d - p2_3d) distance_text = f"- **Distance: {distance:.2f}m**" except Exception as e: print(f"Distance computation error: {e}") distance_text = f"- **Distance computation error: {e}**" measure_points = [] text = depth_text + distance_text print(f"Measurement complete: {text}") return [image, measure_points, text] except Exception as e: print(f"Final measurement error: {e}") return None, [], f"Measurement error: {e}" else: print(f"Single point measurement: {depth_text}") return [image, measure_points, depth_text] except Exception as e: print(f"Overall measure function error: {e}") return None, [], f"Measure function error: {e}" def clear_fields(): """ Clears the 3D viewer, the stored target_dir, and empties the gallery. """ return None def update_log(): """ Display a quick log message while waiting. """ return "加载和重建中..." def update_visualization( target_dir, frame_filter, show_cam, is_example, filter_black_bg=False, filter_white_bg=False, show_mesh=True, ): """ Reload saved predictions from npz, create (or reuse) the GLB for new parameters, and return it for the 3D viewer. If is_example == "True", skip. """ # If it's an example click, skip as requested if is_example == "True": return ( gr.update(), "没有可用的重建。请先点击重建按钮。", ) if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return ( gr.update(), "没有可用的重建。请先点击重建按钮。", ) predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return ( gr.update(), f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", ) loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} glbfile = os.path.join( target_dir, f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", ) if not os.path.exists(glbfile): glbscene = predictions_to_glb( predictions, filter_by_frames=frame_filter, show_cam=show_cam, mask_black_bg=filter_black_bg, mask_white_bg=filter_white_bg, as_mesh=show_mesh, ) glbscene.export(file_obj=glbfile) return ( glbfile, "可视化已更新", ) def update_all_views_on_filter_change( target_dir, filter_black_bg, filter_white_bg, processed_data, depth_view_selector, normal_view_selector, measure_view_selector, ): """ Update all individual view tabs when background filtering checkboxes change. This regenerates the processed data with new filtering and updates all views. """ # Check if we have a valid target directory and predictions if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return processed_data, None, None, None, [] predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return processed_data, None, None, None, [] try: # Load the original predictions and views loaded = np.load(predictions_path, allow_pickle=True) predictions = {key: loaded[key] for key in loaded.keys()} # Load images using MapAnything's load_images function image_folder_path = os.path.join(target_dir, "images") views = load_images(image_folder_path) # Regenerate processed data with new filtering settings new_processed_data = process_predictions_for_visualization( predictions, views, high_level_config, filter_black_bg, filter_white_bg ) # Get current view indices try: depth_view_idx = ( int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0 ) except: depth_view_idx = 0 try: normal_view_idx = ( int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0 ) except: normal_view_idx = 0 try: measure_view_idx = ( int(measure_view_selector.split()[1]) - 1 if measure_view_selector else 0 ) except: measure_view_idx = 0 # Update all views with new filtered data depth_vis = update_depth_view(new_processed_data, depth_view_idx) normal_vis = update_normal_view(new_processed_data, normal_view_idx) measure_img, _ = update_measure_view(new_processed_data, measure_view_idx) return new_processed_data, depth_vis, normal_vis, measure_img, [] except Exception as e: print(f"Error updating views on filter change: {e}") return processed_data, None, None, None, [] # ------------------------------------------------------------------------- # Example scene functions # ------------------------------------------------------------------------- def get_scene_info(examples_dir): """Get information about scenes in the examples directory""" import glob scenes = [] if not os.path.exists(examples_dir): return scenes for scene_folder in sorted(os.listdir(examples_dir)): scene_path = os.path.join(examples_dir, scene_folder) if os.path.isdir(scene_path): # Find all image files in the scene folder image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] image_files = [] for ext in image_extensions: image_files.extend(glob.glob(os.path.join(scene_path, ext))) image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) if image_files: # Sort images and get the first one for thumbnail image_files = sorted(image_files) first_image = image_files[0] num_images = len(image_files) scenes.append( { "name": scene_folder, "path": scene_path, "thumbnail": first_image, "num_images": num_images, "image_files": image_files, } ) return scenes def load_example_scene(scene_name, examples_dir="examples"): """Load a scene from examples directory""" scenes = get_scene_info(examples_dir) # Find the selected scene selected_scene = None for scene in scenes: if scene["name"] == scene_name: selected_scene = scene break if selected_scene is None: return None, None, None, "Scene not found" # Create file-like objects for the unified upload system # Convert image file paths to the format expected by unified_upload file_objects = [] for image_path in selected_scene["image_files"]: file_objects.append(image_path) # Create target directory and copy images using the unified upload system target_dir, image_paths = handle_uploads(file_objects, 1.0) return ( None, # Clear reconstruction output target_dir, # Set target directory image_paths, # Set gallery f"已加载场景 '{scene_name}'({selected_scene['num_images']} 张图片)。点击「开始重建」进行3D处理。", ) # ------------------------------------------------------------------------- # 6) Build Gradio UI # ------------------------------------------------------------------------- theme = get_gradio_theme() # 自定义CSS防止UI抖动 CUSTOM_CSS = GRADIO_CSS + """ /* 防止组件撑开布局 */ .gradio-container { max-width: 100% !important; } /* 固定Gallery高度 */ .gallery-container { max-height: 350px !important; overflow-y: auto !important; } /* 固定File组件高度 */ .file-preview { max-height: 200px !important; overflow-y: auto !important; } /* 固定Video组件高度 */ .video-container { max-height: 300px !important; } /* 防止Textbox无限扩展 */ .textbox-container { max-height: 100px !important; } /* 保持Tabs内容区域稳定 */ .tab-content { min-height: 550px !important; } /* 增强文件上传区域 */ .file-upload-enhanced { position: relative; } /* 减少Accordion和后续内容的间距 */ .accordion { margin-bottom: 10px !important; padding-bottom: 0 !important; } /* 示例场景区域紧凑样式 */ .accordion > .label-wrap { margin-bottom: 5px !important; } /* 信息提示框 */ .info-box { background-color: #E3F2FD !important; border-left: 4px solid #2196F3 !important; padding: 15px !important; margin-bottom: 15px !important; border-radius: 4px !important; } """ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything - 3D重建系统") as demo: # State variables for the tabbed interface is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") processed_data_state = gr.State(value=None) measure_points_state = gr.State(value=[]) current_view_index = gr.State(value=0) # Track current view index for navigation # 添加粘贴板支持的 JavaScript PASTE_JS = """ """ gr.HTML(PASTE_JS) # 美化的顶部标题 gr.HTML("""
多视图3D重建 | 深度估计 | 法线计算 | 距离测量