Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| MapAnything V2 - 3D重建系统(中文版) | |
| - 多视图 3D 重建 | |
| - 深度估计与法线计算 | |
| - 距离测量功能 | |
| """ | |
| import gc | |
| import os | |
| import shutil | |
| import sys | |
| import time | |
| from datetime import datetime | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from pillow_heif import register_heif_opener | |
| register_heif_opener() | |
| sys.path.append("mapanything/") | |
| from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals | |
| from mapanything.utils.hf_utils.css_and_html import ( | |
| GRADIO_CSS, | |
| MEASURE_INSTRUCTIONS_HTML, | |
| get_acknowledgements_html, | |
| get_description_html, | |
| get_gradio_theme, | |
| get_header_html, | |
| ) | |
| from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model | |
| from mapanything.utils.hf_utils.visual_util import predictions_to_glb | |
| from mapanything.utils.image import load_images, rgb | |
| def get_logo_base64(): | |
| """Convert WAI logo to base64 for embedding in HTML""" | |
| import base64 | |
| logo_path = "examples/WAI-Logo/wai_logo.png" | |
| try: | |
| with open(logo_path, "rb") as img_file: | |
| img_data = img_file.read() | |
| base64_str = base64.b64encode(img_data).decode() | |
| return f"data:image/png;base64,{base64_str}" | |
| except FileNotFoundError: | |
| return None | |
| # MapAnything Configuration | |
| high_level_config = { | |
| "path": "configs/train.yaml", | |
| "hf_model_name": "facebook/map-anything", | |
| "model_str": "mapanything", | |
| "config_overrides": [ | |
| "machine=aws", | |
| "model=mapanything", | |
| "model/task=images_only", | |
| "model.encoder.uses_torch_hub=false", | |
| ], | |
| "checkpoint_name": "model.safetensors", | |
| "config_name": "config.json", | |
| "trained_with_amp": True, | |
| "trained_with_amp_dtype": "bf16", | |
| "data_norm_type": "dinov2", | |
| "patch_size": 14, | |
| "resolution": 518, | |
| } | |
| # Initialize model - this will be done on GPU when needed | |
| model = None | |
| # ------------------------------------------------------------------------- | |
| # 1) Core model inference | |
| # ------------------------------------------------------------------------- | |
| def run_model( | |
| target_dir, | |
| apply_mask=True, | |
| mask_edges=True, | |
| filter_black_bg=False, | |
| filter_white_bg=False, | |
| progress=gr.Progress(), | |
| ): | |
| """ | |
| Run the MapAnything model on images in the 'target_dir/images' folder and return predictions. | |
| """ | |
| global model | |
| import torch # Ensure torch is available in function scope | |
| start_time = time.time() | |
| print(f"Processing images from {target_dir}") | |
| # Device check | |
| progress(0, desc="🔧 初始化设备...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| device = torch.device(device) | |
| # Initialize model if not already done | |
| progress(0.05, desc="📥 加载模型... (~5秒)") | |
| if model is None: | |
| model = initialize_mapanything_model(high_level_config, device) | |
| else: | |
| model = model.to(device) | |
| model.eval() | |
| # Load images using MapAnything's load_images function | |
| progress(0.15, desc="📷 加载图片... (~2秒)") | |
| print("Loading images...") | |
| image_folder_path = os.path.join(target_dir, "images") | |
| views = load_images(image_folder_path) | |
| print(f"Loaded {len(views)} images") | |
| if len(views) == 0: | |
| raise ValueError("No images found. Check your upload.") | |
| # Run model inference | |
| num_images = len(views) | |
| estimated_time = num_images * 3 # 估计每张图片3秒 | |
| progress(0.2, desc=f"🚀 运行3D重建... ({num_images}张图片,预计{estimated_time}秒)") | |
| print("Running inference...") | |
| inference_start = time.time() | |
| outputs = model.infer( | |
| views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False | |
| ) | |
| inference_time = time.time() - inference_start | |
| # Convert predictions to format expected by visualization | |
| progress(0.6, desc=f"🔄 处理预测结果... (推理耗时: {inference_time:.1f}秒)") | |
| predictions = {} | |
| # Initialize lists for the required keys | |
| extrinsic_list = [] | |
| intrinsic_list = [] | |
| world_points_list = [] | |
| depth_maps_list = [] | |
| images_list = [] | |
| final_mask_list = [] | |
| # Loop through the outputs | |
| for i, pred in enumerate(outputs): | |
| if i % max(1, len(outputs) // 5) == 0: | |
| progress(0.6 + (i / len(outputs)) * 0.25, desc=f"🔄 处理视图 {i+1}/{len(outputs)}...") | |
| # Extract data from predictions | |
| depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W) | |
| intrinsics_torch = pred["intrinsics"][0] # (3, 3) | |
| camera_pose_torch = pred["camera_poses"][0] # (4, 4) | |
| # Compute new pts3d using depth, intrinsics, and camera pose | |
| pts3d_computed, valid_mask = depthmap_to_world_frame( | |
| depthmap_torch, intrinsics_torch, camera_pose_torch | |
| ) | |
| # Convert to numpy arrays for visualization | |
| # Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch | |
| if "mask" in pred: | |
| mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool) | |
| else: | |
| # Fill with boolean trues in the size of depthmap_torch | |
| mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool) | |
| # Combine with valid depth mask | |
| mask = mask & valid_mask.cpu().numpy() | |
| image = pred["img_no_norm"][0].cpu().numpy() | |
| # Append to lists | |
| extrinsic_list.append(camera_pose_torch.cpu().numpy()) | |
| intrinsic_list.append(intrinsics_torch.cpu().numpy()) | |
| world_points_list.append(pts3d_computed.cpu().numpy()) | |
| depth_maps_list.append(depthmap_torch.cpu().numpy()) | |
| images_list.append(image) # Add image to list | |
| final_mask_list.append(mask) # Add final_mask to list | |
| # Convert lists to numpy arrays with required shapes | |
| # extrinsic: (S, 3, 4) - batch of camera extrinsic matrices | |
| predictions["extrinsic"] = np.stack(extrinsic_list, axis=0) | |
| # intrinsic: (S, 3, 3) - batch of camera intrinsic matrices | |
| predictions["intrinsic"] = np.stack(intrinsic_list, axis=0) | |
| # world_points: (S, H, W, 3) - batch of 3D world points | |
| predictions["world_points"] = np.stack(world_points_list, axis=0) | |
| # depth: (S, H, W, 1) or (S, H, W) - batch of depth maps | |
| depth_maps = np.stack(depth_maps_list, axis=0) | |
| # Add channel dimension if needed to match (S, H, W, 1) format | |
| if len(depth_maps.shape) == 3: | |
| depth_maps = depth_maps[..., np.newaxis] | |
| predictions["depth"] = depth_maps | |
| # images: (S, H, W, 3) - batch of input images | |
| predictions["images"] = np.stack(images_list, axis=0) | |
| # final_mask: (S, H, W) - batch of final masks for filtering | |
| predictions["final_mask"] = np.stack(final_mask_list, axis=0) | |
| # Process data for visualization tabs (depth, normal, measure) | |
| progress(0.85, desc="🎨 生成深度图与法线图...") | |
| processed_data = process_predictions_for_visualization( | |
| predictions, views, high_level_config, filter_black_bg, filter_white_bg | |
| ) | |
| # Clean up | |
| progress(0.95, desc="🧹 清理内存...") | |
| torch.cuda.empty_cache() | |
| total_time = time.time() - start_time | |
| progress(1.0, desc=f"✅ 完成!总耗时: {total_time:.1f}秒") | |
| print(f"Total processing time: {total_time:.2f} seconds") | |
| return predictions, processed_data | |
| def update_view_selectors(processed_data): | |
| """Update view selector dropdowns based on available views""" | |
| if processed_data is None or len(processed_data) == 0: | |
| choices = ["View 1"] | |
| else: | |
| num_views = len(processed_data) | |
| choices = [f"View {i + 1}" for i in range(num_views)] | |
| return ( | |
| gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector | |
| gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector | |
| gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector | |
| ) | |
| def get_view_data_by_index(processed_data, view_index): | |
| """Get view data by index, handling bounds""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return None | |
| view_keys = list(processed_data.keys()) | |
| if view_index < 0 or view_index >= len(view_keys): | |
| view_index = 0 | |
| return processed_data[view_keys[view_index]] | |
| def update_depth_view(processed_data, view_index): | |
| """Update depth view for a specific view index""" | |
| view_data = get_view_data_by_index(processed_data, view_index) | |
| if view_data is None or view_data["depth"] is None: | |
| return None | |
| return colorize_depth(view_data["depth"], mask=view_data.get("mask")) | |
| def update_normal_view(processed_data, view_index): | |
| """Update normal view for a specific view index""" | |
| view_data = get_view_data_by_index(processed_data, view_index) | |
| if view_data is None or view_data["normal"] is None: | |
| return None | |
| return colorize_normal(view_data["normal"], mask=view_data.get("mask")) | |
| def update_measure_view(processed_data, view_index): | |
| """Update measure view for a specific view index with mask overlay""" | |
| view_data = get_view_data_by_index(processed_data, view_index) | |
| if view_data is None: | |
| return None, [] # image, measure_points | |
| # Get the base image | |
| image = view_data["image"].copy() | |
| # Ensure image is in uint8 format | |
| if image.dtype != np.uint8: | |
| if image.max() <= 1.0: | |
| image = (image * 255).astype(np.uint8) | |
| else: | |
| image = image.astype(np.uint8) | |
| # Apply mask overlay if mask is available | |
| if view_data["mask"] is not None: | |
| mask = view_data["mask"] | |
| # Create light grey overlay for masked areas | |
| # Masked areas (False values) will be overlaid with light grey | |
| invalid_mask = ~mask # Areas where mask is False | |
| if invalid_mask.any(): | |
| # Create a light grey overlay (RGB: 192, 192, 192) | |
| overlay_color = np.array([255, 220, 220], dtype=np.uint8) | |
| # Apply overlay with some transparency | |
| alpha = 0.5 # Transparency level | |
| for c in range(3): # RGB channels | |
| image[:, :, c] = np.where( | |
| invalid_mask, | |
| (1 - alpha) * image[:, :, c] + alpha * overlay_color[c], | |
| image[:, :, c], | |
| ).astype(np.uint8) | |
| return image, [] | |
| def navigate_depth_view(processed_data, current_selector_value, direction): | |
| """Navigate depth view (direction: -1 for previous, +1 for next)""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return "View 1", None | |
| # Parse current view number | |
| try: | |
| current_view = int(current_selector_value.split()[1]) - 1 | |
| except: | |
| current_view = 0 | |
| num_views = len(processed_data) | |
| new_view = (current_view + direction) % num_views | |
| new_selector_value = f"View {new_view + 1}" | |
| depth_vis = update_depth_view(processed_data, new_view) | |
| return new_selector_value, depth_vis | |
| def navigate_normal_view(processed_data, current_selector_value, direction): | |
| """Navigate normal view (direction: -1 for previous, +1 for next)""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return "View 1", None | |
| # Parse current view number | |
| try: | |
| current_view = int(current_selector_value.split()[1]) - 1 | |
| except: | |
| current_view = 0 | |
| num_views = len(processed_data) | |
| new_view = (current_view + direction) % num_views | |
| new_selector_value = f"View {new_view + 1}" | |
| normal_vis = update_normal_view(processed_data, new_view) | |
| return new_selector_value, normal_vis | |
| def navigate_measure_view(processed_data, current_selector_value, direction): | |
| """Navigate measure view (direction: -1 for previous, +1 for next)""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return "View 1", None, [] | |
| # Parse current view number | |
| try: | |
| current_view = int(current_selector_value.split()[1]) - 1 | |
| except: | |
| current_view = 0 | |
| num_views = len(processed_data) | |
| new_view = (current_view + direction) % num_views | |
| new_selector_value = f"View {new_view + 1}" | |
| measure_image, measure_points = update_measure_view(processed_data, new_view) | |
| return new_selector_value, measure_image, measure_points | |
| def populate_visualization_tabs(processed_data): | |
| """Populate the depth, normal, and measure tabs with processed data""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return None, None, None, [] | |
| # Use update functions to ensure confidence filtering is applied from the start | |
| depth_vis = update_depth_view(processed_data, 0) | |
| normal_vis = update_normal_view(processed_data, 0) | |
| measure_img, _ = update_measure_view(processed_data, 0) | |
| return depth_vis, normal_vis, measure_img, [] | |
| # ------------------------------------------------------------------------- | |
| # 2) Handle uploaded video/images --> produce target_dir + images | |
| # ------------------------------------------------------------------------- | |
| def handle_uploads(unified_upload, s_time_interval=1.0): | |
| """ | |
| Create a new 'target_dir' + 'images' subfolder, and place user-uploaded | |
| images or extracted frames from video into it. Return (target_dir, image_paths). | |
| """ | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Create a unique folder name | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| target_dir = f"input_images_{timestamp}" | |
| target_dir_images = os.path.join(target_dir, "images") | |
| # Clean up if somehow that folder already exists | |
| if os.path.exists(target_dir): | |
| shutil.rmtree(target_dir) | |
| os.makedirs(target_dir) | |
| os.makedirs(target_dir_images) | |
| image_paths = [] | |
| # --- Handle uploaded files (both images and videos) --- | |
| if unified_upload is not None: | |
| for file_data in unified_upload: | |
| if isinstance(file_data, dict) and "name" in file_data: | |
| file_path = file_data["name"] | |
| else: | |
| file_path = str(file_data) | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| # Check if it's a video file | |
| video_extensions = [ | |
| ".mp4", | |
| ".avi", | |
| ".mov", | |
| ".mkv", | |
| ".wmv", | |
| ".flv", | |
| ".webm", | |
| ".m4v", | |
| ".3gp", | |
| ] | |
| if file_ext in video_extensions: | |
| # Handle as video | |
| vs = cv2.VideoCapture(file_path) | |
| fps = vs.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(fps * s_time_interval) # frames per interval | |
| count = 0 | |
| video_frame_num = 0 | |
| while True: | |
| gotit, frame = vs.read() | |
| if not gotit: | |
| break | |
| count += 1 | |
| if count % frame_interval == 0: | |
| # Use original filename as prefix for frames | |
| base_name = os.path.splitext(os.path.basename(file_path))[0] | |
| image_path = os.path.join( | |
| target_dir_images, f"{base_name}_{video_frame_num:06}.png" | |
| ) | |
| cv2.imwrite(image_path, frame) | |
| image_paths.append(image_path) | |
| video_frame_num += 1 | |
| vs.release() | |
| print( | |
| f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}" | |
| ) | |
| else: | |
| # Handle as image | |
| # Check if the file is a HEIC image | |
| if file_ext in [".heic", ".heif"]: | |
| # Convert HEIC to JPEG for better gallery compatibility | |
| try: | |
| with Image.open(file_path) as img: | |
| # Convert to RGB if necessary (HEIC can have different color modes) | |
| if img.mode not in ("RGB", "L"): | |
| img = img.convert("RGB") | |
| # Create JPEG filename | |
| base_name = os.path.splitext(os.path.basename(file_path))[0] | |
| dst_path = os.path.join( | |
| target_dir_images, f"{base_name}.jpg" | |
| ) | |
| # Save as JPEG with high quality | |
| img.save(dst_path, "JPEG", quality=95) | |
| image_paths.append(dst_path) | |
| print( | |
| f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}" | |
| ) | |
| except Exception as e: | |
| print(f"Error converting HEIC file {file_path}: {e}") | |
| # Fall back to copying as is | |
| dst_path = os.path.join( | |
| target_dir_images, os.path.basename(file_path) | |
| ) | |
| shutil.copy(file_path, dst_path) | |
| image_paths.append(dst_path) | |
| else: | |
| # Regular image files - copy as is | |
| dst_path = os.path.join( | |
| target_dir_images, os.path.basename(file_path) | |
| ) | |
| shutil.copy(file_path, dst_path) | |
| image_paths.append(dst_path) | |
| # Sort final images for gallery | |
| image_paths = sorted(image_paths) | |
| end_time = time.time() | |
| print( | |
| f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds" | |
| ) | |
| return target_dir, image_paths | |
| # ------------------------------------------------------------------------- | |
| # 3) Update gallery on upload | |
| # ------------------------------------------------------------------------- | |
| def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0): | |
| """ | |
| Whenever user uploads or changes files, immediately handle them | |
| and show in the gallery. Return (target_dir, image_paths). | |
| If nothing is uploaded, returns "None" and empty list. | |
| """ | |
| if not input_video and not input_images: | |
| return None, None, None, None | |
| target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval) | |
| return ( | |
| None, | |
| target_dir, | |
| image_paths, | |
| "上传完成。点击「开始重建」进行3D处理", | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # 4) Reconstruction: uses the target_dir plus any viz parameters | |
| # ------------------------------------------------------------------------- | |
| def gradio_demo( | |
| target_dir, | |
| frame_filter="All", | |
| show_cam=True, | |
| filter_black_bg=False, | |
| filter_white_bg=False, | |
| apply_mask=True, | |
| show_mesh=True, | |
| progress=gr.Progress(), | |
| ): | |
| """ | |
| Perform reconstruction using the already-created target_dir/images. | |
| """ | |
| if not os.path.isdir(target_dir) or target_dir == "None": | |
| return None, "❌ 未找到有效的目标目录,请先上传文件", None, None, None, None, None, None, None, None, None | |
| progress(0, desc="🔄 准备重建...") | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Prepare frame_filter dropdown | |
| target_dir_images = os.path.join(target_dir, "images") | |
| all_files = ( | |
| sorted(os.listdir(target_dir_images)) | |
| if os.path.isdir(target_dir_images) | |
| else [] | |
| ) | |
| all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] | |
| frame_filter_choices = ["All"] + all_files | |
| progress(0.05, desc=f"🚀 运行 MapAnything 模型... ({len(all_files)}张图片)") | |
| print("Running MapAnything model...") | |
| with torch.no_grad(): | |
| predictions, processed_data = run_model( | |
| target_dir, apply_mask, True, filter_black_bg, filter_white_bg, progress | |
| ) | |
| # Save predictions | |
| progress(0.92, desc="💾 保存预测结果...") | |
| prediction_save_path = os.path.join(target_dir, "predictions.npz") | |
| np.savez(prediction_save_path, **predictions) | |
| # Handle None frame_filter | |
| if frame_filter is None: | |
| frame_filter = "All" | |
| # Build a GLB file name | |
| progress(0.93, desc="🏗️ 生成3D模型文件...") | |
| glbfile = os.path.join( | |
| target_dir, | |
| f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", | |
| ) | |
| # Convert predictions to GLB | |
| glbscene = predictions_to_glb( | |
| predictions, | |
| filter_by_frames=frame_filter, | |
| show_cam=show_cam, | |
| mask_black_bg=filter_black_bg, | |
| mask_white_bg=filter_white_bg, | |
| as_mesh=show_mesh, # Use the show_mesh parameter | |
| ) | |
| glbscene.export(file_obj=glbfile) | |
| # Cleanup | |
| progress(0.96, desc="🧹 清理内存...") | |
| del predictions | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| end_time = time.time() | |
| total_time = end_time - start_time | |
| print(f"总耗时: {total_time:.2f}秒") | |
| log_msg = f"✅ 重建成功 ({len(all_files)} 帧,耗时 {total_time:.1f}秒)" | |
| # Populate visualization tabs with processed data | |
| progress(0.98, desc="🎨 生成可视化...") | |
| depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs( | |
| processed_data | |
| ) | |
| # Update view selectors based on available views | |
| depth_selector, normal_selector, measure_selector = update_view_selectors( | |
| processed_data | |
| ) | |
| progress(1.0, desc="✅ 全部完成!") | |
| return ( | |
| glbfile, | |
| log_msg, | |
| gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), | |
| processed_data, | |
| depth_vis, | |
| normal_vis, | |
| measure_img, | |
| "", # measure_text (empty initially) | |
| depth_selector, | |
| normal_selector, | |
| measure_selector, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # 5) Helper functions for UI resets + re-visualization | |
| # ------------------------------------------------------------------------- | |
| def colorize_depth(depth_map, mask=None): | |
| """Convert depth map to colorized visualization with optional mask""" | |
| if depth_map is None: | |
| return None | |
| # Normalize depth to 0-1 range | |
| depth_normalized = depth_map.copy() | |
| valid_mask = depth_normalized > 0 | |
| # Apply additional mask if provided (for background filtering) | |
| if mask is not None: | |
| valid_mask = valid_mask & mask | |
| if valid_mask.sum() > 0: | |
| valid_depths = depth_normalized[valid_mask] | |
| p5 = np.percentile(valid_depths, 5) | |
| p95 = np.percentile(valid_depths, 95) | |
| depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5) | |
| # Apply colormap | |
| import matplotlib.pyplot as plt | |
| colormap = plt.cm.turbo_r | |
| colored = colormap(depth_normalized) | |
| colored = (colored[:, :, :3] * 255).astype(np.uint8) | |
| # Set invalid pixels to white | |
| colored[~valid_mask] = [255, 255, 255] | |
| return colored | |
| def colorize_normal(normal_map, mask=None): | |
| """Convert normal map to colorized visualization with optional mask""" | |
| if normal_map is None: | |
| return None | |
| # Create a copy for modification | |
| normal_vis = normal_map.copy() | |
| # Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization) | |
| if mask is not None: | |
| invalid_mask = ~mask | |
| normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero | |
| # Normalize normals to [0, 1] range for visualization | |
| normal_vis = (normal_vis + 1.0) / 2.0 | |
| normal_vis = (normal_vis * 255).astype(np.uint8) | |
| return normal_vis | |
| def process_predictions_for_visualization( | |
| predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False | |
| ): | |
| """Extract depth, normal, and 3D points from predictions for visualization""" | |
| processed_data = {} | |
| # Process each view | |
| for view_idx, view in enumerate(views): | |
| # Get image | |
| image = rgb(view["img"], norm_type=high_level_config["data_norm_type"]) | |
| # Get predicted points | |
| pred_pts3d = predictions["world_points"][view_idx] | |
| # Initialize data for this view | |
| view_data = { | |
| "image": image[0], | |
| "points3d": pred_pts3d, | |
| "depth": None, | |
| "normal": None, | |
| "mask": None, | |
| } | |
| # Start with the final mask from predictions | |
| mask = predictions["final_mask"][view_idx].copy() | |
| # Apply black background filtering if enabled | |
| if filter_black_bg: | |
| # Get the image colors (ensure they're in 0-255 range) | |
| view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] | |
| # Filter out black background pixels (sum of RGB < 16) | |
| black_bg_mask = view_colors.sum(axis=2) >= 16 | |
| mask = mask & black_bg_mask | |
| # Apply white background filtering if enabled | |
| if filter_white_bg: | |
| # Get the image colors (ensure they're in 0-255 range) | |
| view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0] | |
| # Filter out white background pixels (all RGB > 240) | |
| white_bg_mask = ~( | |
| (view_colors[:, :, 0] > 240) | |
| & (view_colors[:, :, 1] > 240) | |
| & (view_colors[:, :, 2] > 240) | |
| ) | |
| mask = mask & white_bg_mask | |
| view_data["mask"] = mask | |
| view_data["depth"] = predictions["depth"][view_idx].squeeze() | |
| normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"]) | |
| view_data["normal"] = normals | |
| processed_data[view_idx] = view_data | |
| return processed_data | |
| def reset_measure(processed_data): | |
| """Reset measure points""" | |
| if processed_data is None or len(processed_data) == 0: | |
| return None, [], "" | |
| # Return the first view image | |
| first_view = list(processed_data.values())[0] | |
| return first_view["image"], [], "" | |
| def measure( | |
| processed_data, measure_points, current_view_selector, event: gr.SelectData | |
| ): | |
| """Handle measurement on images""" | |
| try: | |
| print(f"Measure function called with selector: {current_view_selector}") | |
| if processed_data is None or len(processed_data) == 0: | |
| return None, [], "No data available" | |
| # Use the currently selected view instead of always using the first view | |
| try: | |
| current_view_index = int(current_view_selector.split()[1]) - 1 | |
| except: | |
| current_view_index = 0 | |
| print(f"Using view index: {current_view_index}") | |
| # Get view data safely | |
| if current_view_index < 0 or current_view_index >= len(processed_data): | |
| current_view_index = 0 | |
| view_keys = list(processed_data.keys()) | |
| current_view = processed_data[view_keys[current_view_index]] | |
| if current_view is None: | |
| return None, [], "No view data available" | |
| point2d = event.index[0], event.index[1] | |
| print(f"Clicked point: {point2d}") | |
| # Check if the clicked point is in a masked area (prevent interaction) | |
| if ( | |
| current_view["mask"] is not None | |
| and 0 <= point2d[1] < current_view["mask"].shape[0] | |
| and 0 <= point2d[0] < current_view["mask"].shape[1] | |
| ): | |
| # Check if the point is in a masked (invalid) area | |
| if not current_view["mask"][point2d[1], point2d[0]]: | |
| print(f"Clicked point {point2d} is in masked area, ignoring click") | |
| # Always return image with mask overlay | |
| masked_image, _ = update_measure_view( | |
| processed_data, current_view_index | |
| ) | |
| return ( | |
| masked_image, | |
| measure_points, | |
| '<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>', | |
| ) | |
| measure_points.append(point2d) | |
| # Get image with mask overlay and ensure it's valid | |
| image, _ = update_measure_view(processed_data, current_view_index) | |
| if image is None: | |
| return None, [], "No image available" | |
| image = image.copy() | |
| points3d = current_view["points3d"] | |
| # Ensure image is in uint8 format for proper cv2 operations | |
| try: | |
| if image.dtype != np.uint8: | |
| if image.max() <= 1.0: | |
| # Image is in [0, 1] range, convert to [0, 255] | |
| image = (image * 255).astype(np.uint8) | |
| else: | |
| # Image is already in [0, 255] range | |
| image = image.astype(np.uint8) | |
| except Exception as e: | |
| print(f"Image conversion error: {e}") | |
| return None, [], f"Image conversion error: {e}" | |
| # Draw circles for points | |
| try: | |
| for p in measure_points: | |
| if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]: | |
| image = cv2.circle( | |
| image, p, radius=5, color=(255, 0, 0), thickness=2 | |
| ) | |
| except Exception as e: | |
| print(f"Drawing error: {e}") | |
| return None, [], f"Drawing error: {e}" | |
| depth_text = "" | |
| try: | |
| for i, p in enumerate(measure_points): | |
| if ( | |
| current_view["depth"] is not None | |
| and 0 <= p[1] < current_view["depth"].shape[0] | |
| and 0 <= p[0] < current_view["depth"].shape[1] | |
| ): | |
| d = current_view["depth"][p[1], p[0]] | |
| depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" | |
| else: | |
| # Use Z coordinate of 3D points if depth not available | |
| if ( | |
| points3d is not None | |
| and 0 <= p[1] < points3d.shape[0] | |
| and 0 <= p[0] < points3d.shape[1] | |
| ): | |
| z = points3d[p[1], p[0], 2] | |
| depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n" | |
| except Exception as e: | |
| print(f"Depth text error: {e}") | |
| depth_text = f"Error computing depth: {e}\n" | |
| if len(measure_points) == 2: | |
| try: | |
| point1, point2 = measure_points | |
| # Draw line | |
| if ( | |
| 0 <= point1[0] < image.shape[1] | |
| and 0 <= point1[1] < image.shape[0] | |
| and 0 <= point2[0] < image.shape[1] | |
| and 0 <= point2[1] < image.shape[0] | |
| ): | |
| image = cv2.line( | |
| image, point1, point2, color=(255, 0, 0), thickness=2 | |
| ) | |
| # Compute 3D distance | |
| distance_text = "- **Distance: Unable to compute**" | |
| if ( | |
| points3d is not None | |
| and 0 <= point1[1] < points3d.shape[0] | |
| and 0 <= point1[0] < points3d.shape[1] | |
| and 0 <= point2[1] < points3d.shape[0] | |
| and 0 <= point2[0] < points3d.shape[1] | |
| ): | |
| try: | |
| p1_3d = points3d[point1[1], point1[0]] | |
| p2_3d = points3d[point2[1], point2[0]] | |
| distance = np.linalg.norm(p1_3d - p2_3d) | |
| distance_text = f"- **Distance: {distance:.2f}m**" | |
| except Exception as e: | |
| print(f"Distance computation error: {e}") | |
| distance_text = f"- **Distance computation error: {e}**" | |
| measure_points = [] | |
| text = depth_text + distance_text | |
| print(f"Measurement complete: {text}") | |
| return [image, measure_points, text] | |
| except Exception as e: | |
| print(f"Final measurement error: {e}") | |
| return None, [], f"Measurement error: {e}" | |
| else: | |
| print(f"Single point measurement: {depth_text}") | |
| return [image, measure_points, depth_text] | |
| except Exception as e: | |
| print(f"Overall measure function error: {e}") | |
| return None, [], f"Measure function error: {e}" | |
| def clear_fields(): | |
| """ | |
| Clears the 3D viewer, the stored target_dir, and empties the gallery. | |
| """ | |
| return None | |
| def update_log(): | |
| """ | |
| Display a quick log message while waiting. | |
| """ | |
| return "加载和重建中..." | |
| def update_visualization( | |
| target_dir, | |
| frame_filter, | |
| show_cam, | |
| is_example, | |
| filter_black_bg=False, | |
| filter_white_bg=False, | |
| show_mesh=True, | |
| ): | |
| """ | |
| Reload saved predictions from npz, create (or reuse) the GLB for new parameters, | |
| and return it for the 3D viewer. If is_example == "True", skip. | |
| """ | |
| # If it's an example click, skip as requested | |
| if is_example == "True": | |
| return ( | |
| gr.update(), | |
| "没有可用的重建。请先点击重建按钮。", | |
| ) | |
| if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): | |
| return ( | |
| gr.update(), | |
| "没有可用的重建。请先点击重建按钮。", | |
| ) | |
| predictions_path = os.path.join(target_dir, "predictions.npz") | |
| if not os.path.exists(predictions_path): | |
| return ( | |
| gr.update(), | |
| f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", | |
| ) | |
| loaded = np.load(predictions_path, allow_pickle=True) | |
| predictions = {key: loaded[key] for key in loaded.keys()} | |
| glbfile = os.path.join( | |
| target_dir, | |
| f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb", | |
| ) | |
| if not os.path.exists(glbfile): | |
| glbscene = predictions_to_glb( | |
| predictions, | |
| filter_by_frames=frame_filter, | |
| show_cam=show_cam, | |
| mask_black_bg=filter_black_bg, | |
| mask_white_bg=filter_white_bg, | |
| as_mesh=show_mesh, | |
| ) | |
| glbscene.export(file_obj=glbfile) | |
| return ( | |
| glbfile, | |
| "可视化已更新", | |
| ) | |
| def update_all_views_on_filter_change( | |
| target_dir, | |
| filter_black_bg, | |
| filter_white_bg, | |
| processed_data, | |
| depth_view_selector, | |
| normal_view_selector, | |
| measure_view_selector, | |
| ): | |
| """ | |
| Update all individual view tabs when background filtering checkboxes change. | |
| This regenerates the processed data with new filtering and updates all views. | |
| """ | |
| # Check if we have a valid target directory and predictions | |
| if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): | |
| return processed_data, None, None, None, [] | |
| predictions_path = os.path.join(target_dir, "predictions.npz") | |
| if not os.path.exists(predictions_path): | |
| return processed_data, None, None, None, [] | |
| try: | |
| # Load the original predictions and views | |
| loaded = np.load(predictions_path, allow_pickle=True) | |
| predictions = {key: loaded[key] for key in loaded.keys()} | |
| # Load images using MapAnything's load_images function | |
| image_folder_path = os.path.join(target_dir, "images") | |
| views = load_images(image_folder_path) | |
| # Regenerate processed data with new filtering settings | |
| new_processed_data = process_predictions_for_visualization( | |
| predictions, views, high_level_config, filter_black_bg, filter_white_bg | |
| ) | |
| # Get current view indices | |
| try: | |
| depth_view_idx = ( | |
| int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0 | |
| ) | |
| except: | |
| depth_view_idx = 0 | |
| try: | |
| normal_view_idx = ( | |
| int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0 | |
| ) | |
| except: | |
| normal_view_idx = 0 | |
| try: | |
| measure_view_idx = ( | |
| int(measure_view_selector.split()[1]) - 1 | |
| if measure_view_selector | |
| else 0 | |
| ) | |
| except: | |
| measure_view_idx = 0 | |
| # Update all views with new filtered data | |
| depth_vis = update_depth_view(new_processed_data, depth_view_idx) | |
| normal_vis = update_normal_view(new_processed_data, normal_view_idx) | |
| measure_img, _ = update_measure_view(new_processed_data, measure_view_idx) | |
| return new_processed_data, depth_vis, normal_vis, measure_img, [] | |
| except Exception as e: | |
| print(f"Error updating views on filter change: {e}") | |
| return processed_data, None, None, None, [] | |
| # ------------------------------------------------------------------------- | |
| # Example scene functions | |
| # ------------------------------------------------------------------------- | |
| def get_scene_info(examples_dir): | |
| """Get information about scenes in the examples directory""" | |
| import glob | |
| scenes = [] | |
| if not os.path.exists(examples_dir): | |
| return scenes | |
| for scene_folder in sorted(os.listdir(examples_dir)): | |
| scene_path = os.path.join(examples_dir, scene_folder) | |
| if os.path.isdir(scene_path): | |
| # Find all image files in the scene folder | |
| image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"] | |
| image_files = [] | |
| for ext in image_extensions: | |
| image_files.extend(glob.glob(os.path.join(scene_path, ext))) | |
| image_files.extend(glob.glob(os.path.join(scene_path, ext.upper()))) | |
| if image_files: | |
| # Sort images and get the first one for thumbnail | |
| image_files = sorted(image_files) | |
| first_image = image_files[0] | |
| num_images = len(image_files) | |
| scenes.append( | |
| { | |
| "name": scene_folder, | |
| "path": scene_path, | |
| "thumbnail": first_image, | |
| "num_images": num_images, | |
| "image_files": image_files, | |
| } | |
| ) | |
| return scenes | |
| def load_example_scene(scene_name, examples_dir="examples"): | |
| """Load a scene from examples directory""" | |
| scenes = get_scene_info(examples_dir) | |
| # Find the selected scene | |
| selected_scene = None | |
| for scene in scenes: | |
| if scene["name"] == scene_name: | |
| selected_scene = scene | |
| break | |
| if selected_scene is None: | |
| return None, None, None, "Scene not found" | |
| # Create file-like objects for the unified upload system | |
| # Convert image file paths to the format expected by unified_upload | |
| file_objects = [] | |
| for image_path in selected_scene["image_files"]: | |
| file_objects.append(image_path) | |
| # Create target directory and copy images using the unified upload system | |
| target_dir, image_paths = handle_uploads(file_objects, 1.0) | |
| return ( | |
| None, # Clear reconstruction output | |
| target_dir, # Set target directory | |
| image_paths, # Set gallery | |
| f"已加载场景 '{scene_name}'({selected_scene['num_images']} 张图片)。点击「开始重建」进行3D处理。", | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # 6) Build Gradio UI | |
| # ------------------------------------------------------------------------- | |
| theme = get_gradio_theme() | |
| # 自定义CSS防止UI抖动 | |
| CUSTOM_CSS = GRADIO_CSS + """ | |
| /* 防止组件撑开布局 */ | |
| .gradio-container { | |
| max-width: 100% !important; | |
| } | |
| /* 固定Gallery高度 */ | |
| .gallery-container { | |
| max-height: 350px !important; | |
| overflow-y: auto !important; | |
| } | |
| /* 固定File组件高度 */ | |
| .file-preview { | |
| max-height: 200px !important; | |
| overflow-y: auto !important; | |
| } | |
| /* 固定Video组件高度 */ | |
| .video-container { | |
| max-height: 300px !important; | |
| } | |
| /* 防止Textbox无限扩展 */ | |
| .textbox-container { | |
| max-height: 100px !important; | |
| } | |
| /* 保持Tabs内容区域稳定 */ | |
| .tab-content { | |
| min-height: 550px !important; | |
| } | |
| /* 增强文件上传区域 */ | |
| .file-upload-enhanced { | |
| position: relative; | |
| } | |
| /* 减少Accordion和后续内容的间距 */ | |
| .accordion { | |
| margin-bottom: 10px !important; | |
| padding-bottom: 0 !important; | |
| } | |
| /* 示例场景区域紧凑样式 */ | |
| .accordion > .label-wrap { | |
| margin-bottom: 5px !important; | |
| } | |
| /* 信息提示框 */ | |
| .info-box { | |
| background-color: #E3F2FD !important; | |
| border-left: 4px solid #2196F3 !important; | |
| padding: 15px !important; | |
| margin-bottom: 15px !important; | |
| border-radius: 4px !important; | |
| } | |
| """ | |
| with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything - 3D重建系统") as demo: | |
| # State variables for the tabbed interface | |
| is_example = gr.Textbox(label="is_example", visible=False, value="None") | |
| num_images = gr.Textbox(label="num_images", visible=False, value="None") | |
| processed_data_state = gr.State(value=None) | |
| measure_points_state = gr.State(value=[]) | |
| current_view_index = gr.State(value=0) # Track current view index for navigation | |
| # 添加粘贴板支持的 JavaScript | |
| PASTE_JS = """ | |
| <script> | |
| // 添加粘贴板支持 | |
| document.addEventListener('paste', function(e) { | |
| const items = e.clipboardData.items; | |
| for (let i = 0; i < items.length; i++) { | |
| if (items[i].type.indexOf('image') !== -1) { | |
| const blob = items[i].getAsFile(); | |
| const fileInput = document.querySelector('input[type="file"][multiple]'); | |
| if (fileInput) { | |
| const dataTransfer = new DataTransfer(); | |
| dataTransfer.items.add(blob); | |
| fileInput.files = dataTransfer.files; | |
| fileInput.dispatchEvent(new Event('change', { bubbles: true })); | |
| console.log('✅ 图片已从剪贴板粘贴'); | |
| } | |
| } | |
| } | |
| }); | |
| console.log('💡 粘贴板功能已启用:使用 Ctrl+V 可直接粘贴截图'); | |
| </script> | |
| """ | |
| gr.HTML(PASTE_JS) | |
| # 美化的顶部标题 | |
| gr.HTML(""" | |
| <div style="text-align: center; margin: 20px 0;"> | |
| <h2 style="color: #1976D2; margin-bottom: 10px;">MapAnything - 3D重建系统</h2> | |
| <p style="color: #666; font-size: 16px;">多视图3D重建 | 深度估计 | 法线计算 | 距离测量</p> | |
| </div> | |
| """) | |
| target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") | |
| with gr.Row(equal_height=False): | |
| # 左侧:输入区域 | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("### 📤 输入") | |
| # 统一上传组件(支持文件、拖拽、粘贴板) | |
| unified_upload = gr.File( | |
| file_count="multiple", | |
| label="上传视频或图片(支持拖拽、粘贴Ctrl+V📋)", | |
| interactive=True, | |
| file_types=["image", "video"], | |
| ) | |
| # 摄像头输入(折叠式) | |
| with gr.Accordion("📷 使用摄像头拍照", open=False): | |
| camera_input = gr.Image( | |
| label="拍照后自动添加", | |
| sources=["webcam"], | |
| type="filepath", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| s_time_interval = gr.Slider( | |
| minimum=0.1, maximum=5.0, value=1.0, step=0.1, | |
| label="视频采样时间间隔(每x秒取一帧)", | |
| interactive=True, | |
| visible=True, | |
| scale=3, | |
| ) | |
| resample_btn = gr.Button( | |
| "重新采样视频", | |
| visible=False, | |
| variant="secondary", | |
| scale=1, | |
| ) | |
| image_gallery = gr.Gallery( | |
| label="图片预览", columns=3, height=350, | |
| show_download_button=True, object_fit="contain", preview=True | |
| ) | |
| clear_uploads_btn = gr.ClearButton( | |
| [unified_upload, camera_input, image_gallery], | |
| value="清空上传", | |
| variant="secondary", | |
| size="sm", | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("🚀 开始重建", variant="primary", scale=2) | |
| clear_btn = gr.ClearButton( | |
| [unified_upload, camera_input, target_dir_output, image_gallery], | |
| value="🗑️ 清空", scale=1 | |
| ) | |
| # 右侧:输出区域 | |
| with gr.Column(scale=2, min_width=600): | |
| gr.Markdown("### 🎯 输出") | |
| with gr.Tabs(): | |
| with gr.Tab("🏗️ 原始3D"): | |
| reconstruction_output = gr.Model3D( | |
| height=550, zoom_speed=0.5, pan_speed=0.5, | |
| clear_color=[0.0, 0.0, 0.0, 0.0] | |
| ) | |
| with gr.Tab("📊 深度图"): | |
| with gr.Row(elem_classes=["navigation-row"]): | |
| prev_depth_btn = gr.Button("◀", size="sm", scale=1) | |
| depth_view_selector = gr.Dropdown( | |
| choices=["View 1"], value="View 1", | |
| label="视图", scale=3, interactive=True | |
| ) | |
| next_depth_btn = gr.Button("▶", size="sm", scale=1) | |
| depth_map = gr.Image( | |
| type="numpy", label="", format="png", interactive=False, | |
| height=500 | |
| ) | |
| with gr.Tab("🧭 法线图"): | |
| with gr.Row(elem_classes=["navigation-row"]): | |
| prev_normal_btn = gr.Button("◀", size="sm", scale=1) | |
| normal_view_selector = gr.Dropdown( | |
| choices=["View 1"], value="View 1", | |
| label="视图", scale=3, interactive=True | |
| ) | |
| next_normal_btn = gr.Button("▶", size="sm", scale=1) | |
| normal_map = gr.Image( | |
| type="numpy", label="", format="png", interactive=False, | |
| height=500 | |
| ) | |
| with gr.Tab("📏 测量"): | |
| gr.Markdown("**点击图片两次进行距离测量**") | |
| with gr.Row(elem_classes=["navigation-row"]): | |
| prev_measure_btn = gr.Button("◀", size="sm", scale=1) | |
| measure_view_selector = gr.Dropdown( | |
| choices=["View 1"], value="View 1", | |
| label="视图", scale=3, interactive=True | |
| ) | |
| next_measure_btn = gr.Button("▶", size="sm", scale=1) | |
| measure_image = gr.Image( | |
| type="numpy", show_label=False, | |
| format="webp", interactive=False, sources=[], | |
| height=500 | |
| ) | |
| measure_text = gr.Markdown("") | |
| log_output = gr.Textbox( | |
| value="📌 请上传图片或视频,然后点击「开始重建」", | |
| label="状态信息", | |
| interactive=False, | |
| lines=1, | |
| max_lines=1 | |
| ) | |
| # 高级选项(默认折叠) | |
| with gr.Accordion("⚙️ 高级选项", open=False): | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("#### 可视化参数") | |
| frame_filter = gr.Dropdown( | |
| choices=["All"], value="All", label="显示帧" | |
| ) | |
| show_cam = gr.Checkbox(label="显示相机", value=True) | |
| show_mesh = gr.Checkbox(label="显示网格", value=True) | |
| filter_black_bg = gr.Checkbox(label="过滤黑色背景", value=False) | |
| filter_white_bg = gr.Checkbox(label="过滤白色背景", value=False) | |
| with gr.Column(scale=1, min_width=300): | |
| gr.Markdown("#### 重建参数") | |
| apply_mask_checkbox = gr.Checkbox( | |
| label="应用深度掩码", value=True | |
| ) | |
| # 示例场景(可折叠) | |
| with gr.Accordion("🖼️ 示例场景", open=False): | |
| gr.Markdown("点击缩略图加载场景进行重建") | |
| scenes = get_scene_info("examples") | |
| if scenes: | |
| for i in range(0, len(scenes), 4): # Process 4 scenes per row | |
| with gr.Row(equal_height=True): | |
| for j in range(4): | |
| scene_idx = i + j | |
| if scene_idx < len(scenes): | |
| scene = scenes[scene_idx] | |
| with gr.Column(scale=1, min_width=150): | |
| scene_img = gr.Image( | |
| value=scene["thumbnail"], | |
| height=150, | |
| interactive=False, | |
| show_label=False, | |
| sources=[], | |
| container=False | |
| ) | |
| gr.Markdown( | |
| f"**{scene['name']}** ({scene['num_images']}张)", | |
| elem_classes=["text-center"] | |
| ) | |
| scene_img.select( | |
| fn=lambda name=scene["name"]: load_example_scene(name), | |
| outputs=[ | |
| reconstruction_output, | |
| target_dir_output, | |
| image_gallery, | |
| log_output, | |
| ], | |
| ) | |
| # === 事件绑定 === | |
| # 上传文件自动更新 | |
| def update_gallery_on_unified_upload(files, interval): | |
| if not files: | |
| return None, None, None | |
| target_dir, image_paths = handle_uploads(files, interval) | |
| return ( | |
| target_dir, | |
| image_paths, | |
| "✅ 上传完成,点击「开始重建」进行 3D 处理", | |
| ) | |
| # 处理摄像头拍照 | |
| def update_gallery_on_camera(image): | |
| if image is None: | |
| return None, None, None | |
| # 将单张图片包装成列表 | |
| target_dir, image_paths = handle_uploads([image], 1.0) | |
| return ( | |
| target_dir, | |
| image_paths, | |
| "✅ 摄像头照片已添加,点击「开始重建」进行 3D 处理", | |
| ) | |
| def show_resample_button(files): | |
| """仅当上传的文件包含视频时显示重新采样按钮""" | |
| if not files: | |
| return gr.update(visible=False) | |
| # 检查是否有视频文件 | |
| video_extensions = [ | |
| ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp", | |
| ] | |
| has_video = False | |
| for file_data in files: | |
| if isinstance(file_data, dict) and "name" in file_data: | |
| file_path = file_data["name"] | |
| else: | |
| file_path = str(file_data) | |
| file_ext = os.path.splitext(file_path)[1].lower() | |
| if file_ext in video_extensions: | |
| has_video = True | |
| break | |
| return gr.update(visible=has_video) | |
| def resample_video_with_new_interval(files, new_interval, current_target_dir): | |
| """使用新的滑块值重新采样视频""" | |
| if not files: | |
| return ( | |
| current_target_dir, | |
| None, | |
| "没有可重新采样的文件。", | |
| gr.update(visible=False), | |
| ) | |
| # 检查是否有视频需要重新采样 | |
| video_extensions = [ | |
| ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp", | |
| ] | |
| has_video = any( | |
| os.path.splitext( | |
| str(file_data["name"] if isinstance(file_data, dict) else file_data) | |
| )[1].lower() | |
| in video_extensions | |
| for file_data in files | |
| ) | |
| if not has_video: | |
| return ( | |
| current_target_dir, | |
| None, | |
| "未找到视频进行重新采样。", | |
| gr.update(visible=False), | |
| ) | |
| # 清理旧的目标目录 | |
| if ( | |
| current_target_dir | |
| and current_target_dir != "None" | |
| and os.path.exists(current_target_dir) | |
| ): | |
| shutil.rmtree(current_target_dir) | |
| # 使用新间隔处理文件 | |
| target_dir, image_paths = handle_uploads(files, new_interval) | |
| return ( | |
| target_dir, | |
| image_paths, | |
| f"视频已使用 {new_interval}秒 间隔重新采样。点击「开始重建」进行 3D 处理。", | |
| gr.update(visible=False), | |
| ) | |
| unified_upload.change( | |
| fn=update_gallery_on_unified_upload, | |
| inputs=[unified_upload, s_time_interval], | |
| outputs=[target_dir_output, image_gallery, log_output] | |
| ).then( | |
| fn=show_resample_button, | |
| inputs=[unified_upload], | |
| outputs=[resample_btn], | |
| ) | |
| # 摄像头拍照事件 | |
| camera_input.change( | |
| fn=update_gallery_on_camera, | |
| inputs=[camera_input], | |
| outputs=[target_dir_output, image_gallery, log_output] | |
| ) | |
| # 滑块改变时显示重新采样按钮(仅当已上传文件时) | |
| s_time_interval.change( | |
| fn=show_resample_button, | |
| inputs=[unified_upload], | |
| outputs=[resample_btn], | |
| ) | |
| # 处理重新采样按钮点击 | |
| resample_btn.click( | |
| fn=resample_video_with_new_interval, | |
| inputs=[unified_upload, s_time_interval, target_dir_output], | |
| outputs=[target_dir_output, image_gallery, log_output, resample_btn], | |
| ) | |
| # 重建按钮 | |
| submit_btn.click( | |
| fn=clear_fields, | |
| outputs=[reconstruction_output] | |
| ).then( | |
| fn=update_log, | |
| outputs=[log_output] | |
| ).then( | |
| fn=gradio_demo, | |
| inputs=[ | |
| target_dir_output, frame_filter, show_cam, | |
| filter_black_bg, filter_white_bg, | |
| apply_mask_checkbox, show_mesh | |
| ], | |
| outputs=[ | |
| reconstruction_output, log_output, frame_filter, | |
| processed_data_state, depth_map, normal_map, measure_image, | |
| measure_text, depth_view_selector, normal_view_selector, measure_view_selector | |
| ] | |
| ).then( | |
| fn=lambda: "False", | |
| outputs=[is_example] | |
| ) | |
| # 清空按钮 | |
| clear_btn.add([reconstruction_output, log_output]) | |
| # 可视化参数实时更新 | |
| for component in [frame_filter, show_cam, show_mesh]: | |
| component.change( | |
| fn=update_visualization, | |
| inputs=[ | |
| target_dir_output, frame_filter, show_cam, is_example, | |
| filter_black_bg, filter_white_bg, show_mesh | |
| ], | |
| outputs=[reconstruction_output, log_output] | |
| ) | |
| # 背景过滤器更新所有视图 | |
| for bg_filter in [filter_black_bg, filter_white_bg]: | |
| bg_filter.change( | |
| fn=update_all_views_on_filter_change, | |
| inputs=[ | |
| target_dir_output, filter_black_bg, filter_white_bg, processed_data_state, | |
| depth_view_selector, normal_view_selector, measure_view_selector | |
| ], | |
| outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state] | |
| ) | |
| # 深度图导航 | |
| prev_depth_btn.click( | |
| fn=lambda pd, cs: navigate_depth_view(pd, cs, -1), | |
| inputs=[processed_data_state, depth_view_selector], | |
| outputs=[depth_view_selector, depth_map] | |
| ) | |
| next_depth_btn.click( | |
| fn=lambda pd, cs: navigate_depth_view(pd, cs, 1), | |
| inputs=[processed_data_state, depth_view_selector], | |
| outputs=[depth_view_selector, depth_map] | |
| ) | |
| depth_view_selector.change( | |
| fn=lambda pd, sv: update_depth_view(pd, int(sv.split()[1]) - 1) if sv else None, | |
| inputs=[processed_data_state, depth_view_selector], | |
| outputs=[depth_map] | |
| ) | |
| # 法线图导航 | |
| prev_normal_btn.click( | |
| fn=lambda pd, cs: navigate_normal_view(pd, cs, -1), | |
| inputs=[processed_data_state, normal_view_selector], | |
| outputs=[normal_view_selector, normal_map] | |
| ) | |
| next_normal_btn.click( | |
| fn=lambda pd, cs: navigate_normal_view(pd, cs, 1), | |
| inputs=[processed_data_state, normal_view_selector], | |
| outputs=[normal_view_selector, normal_map] | |
| ) | |
| normal_view_selector.change( | |
| fn=lambda pd, sv: update_normal_view(pd, int(sv.split()[1]) - 1) if sv else None, | |
| inputs=[processed_data_state, normal_view_selector], | |
| outputs=[normal_map] | |
| ) | |
| # 测量功能 | |
| measure_image.select( | |
| fn=measure, | |
| inputs=[processed_data_state, measure_points_state, measure_view_selector], | |
| outputs=[measure_image, measure_points_state, measure_text] | |
| ) | |
| prev_measure_btn.click( | |
| fn=lambda pd, cs: navigate_measure_view(pd, cs, -1), | |
| inputs=[processed_data_state, measure_view_selector], | |
| outputs=[measure_view_selector, measure_image, measure_points_state] | |
| ) | |
| next_measure_btn.click( | |
| fn=lambda pd, cs: navigate_measure_view(pd, cs, 1), | |
| inputs=[processed_data_state, measure_view_selector], | |
| outputs=[measure_view_selector, measure_image, measure_points_state] | |
| ) | |
| measure_view_selector.change( | |
| fn=lambda pd, sv: update_measure_view(pd, int(sv.split()[1]) - 1) if sv else (None, []), | |
| inputs=[processed_data_state, measure_view_selector], | |
| outputs=[measure_image, measure_points_state] | |
| ) | |
| demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False) | |