Colaman-segmap / app.py
ColamanAI's picture
Upload app.py
d7a13ed verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
MapAnything V2 - 3D重建系统(中文版)
- 多视图 3D 重建
- 深度估计与法线计算
- 距离测量功能
"""
import gc
import os
import shutil
import sys
import time
from datetime import datetime
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from PIL import Image
from pillow_heif import register_heif_opener
register_heif_opener()
sys.path.append("mapanything/")
from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
from mapanything.utils.hf_utils.css_and_html import (
GRADIO_CSS,
MEASURE_INSTRUCTIONS_HTML,
get_acknowledgements_html,
get_description_html,
get_gradio_theme,
get_header_html,
)
from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
from mapanything.utils.hf_utils.visual_util import predictions_to_glb
from mapanything.utils.image import load_images, rgb
def get_logo_base64():
"""Convert WAI logo to base64 for embedding in HTML"""
import base64
logo_path = "examples/WAI-Logo/wai_logo.png"
try:
with open(logo_path, "rb") as img_file:
img_data = img_file.read()
base64_str = base64.b64encode(img_data).decode()
return f"data:image/png;base64,{base64_str}"
except FileNotFoundError:
return None
# MapAnything Configuration
high_level_config = {
"path": "configs/train.yaml",
"hf_model_name": "facebook/map-anything",
"model_str": "mapanything",
"config_overrides": [
"machine=aws",
"model=mapanything",
"model/task=images_only",
"model.encoder.uses_torch_hub=false",
],
"checkpoint_name": "model.safetensors",
"config_name": "config.json",
"trained_with_amp": True,
"trained_with_amp_dtype": "bf16",
"data_norm_type": "dinov2",
"patch_size": 14,
"resolution": 518,
}
# Initialize model - this will be done on GPU when needed
model = None
# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def run_model(
target_dir,
apply_mask=True,
mask_edges=True,
filter_black_bg=False,
filter_white_bg=False,
progress=gr.Progress(),
):
"""
Run the MapAnything model on images in the 'target_dir/images' folder and return predictions.
"""
global model
import torch # Ensure torch is available in function scope
start_time = time.time()
print(f"Processing images from {target_dir}")
# Device check
progress(0, desc="🔧 初始化设备...")
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Initialize model if not already done
progress(0.05, desc="📥 加载模型... (~5秒)")
if model is None:
model = initialize_mapanything_model(high_level_config, device)
else:
model = model.to(device)
model.eval()
# Load images using MapAnything's load_images function
progress(0.15, desc="📷 加载图片... (~2秒)")
print("Loading images...")
image_folder_path = os.path.join(target_dir, "images")
views = load_images(image_folder_path)
print(f"Loaded {len(views)} images")
if len(views) == 0:
raise ValueError("No images found. Check your upload.")
# Run model inference
num_images = len(views)
estimated_time = num_images * 3 # 估计每张图片3秒
progress(0.2, desc=f"🚀 运行3D重建... ({num_images}张图片,预计{estimated_time}秒)")
print("Running inference...")
inference_start = time.time()
outputs = model.infer(
views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
)
inference_time = time.time() - inference_start
# Convert predictions to format expected by visualization
progress(0.6, desc=f"🔄 处理预测结果... (推理耗时: {inference_time:.1f}秒)")
predictions = {}
# Initialize lists for the required keys
extrinsic_list = []
intrinsic_list = []
world_points_list = []
depth_maps_list = []
images_list = []
final_mask_list = []
# Loop through the outputs
for i, pred in enumerate(outputs):
if i % max(1, len(outputs) // 5) == 0:
progress(0.6 + (i / len(outputs)) * 0.25, desc=f"🔄 处理视图 {i+1}/{len(outputs)}...")
# Extract data from predictions
depthmap_torch = pred["depth_z"][0].squeeze(-1) # (H, W)
intrinsics_torch = pred["intrinsics"][0] # (3, 3)
camera_pose_torch = pred["camera_poses"][0] # (4, 4)
# Compute new pts3d using depth, intrinsics, and camera pose
pts3d_computed, valid_mask = depthmap_to_world_frame(
depthmap_torch, intrinsics_torch, camera_pose_torch
)
# Convert to numpy arrays for visualization
# Check if mask key exists in pred, if not, fill with boolean trues in the size of depthmap_torch
if "mask" in pred:
mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
else:
# Fill with boolean trues in the size of depthmap_torch
mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
# Combine with valid depth mask
mask = mask & valid_mask.cpu().numpy()
image = pred["img_no_norm"][0].cpu().numpy()
# Append to lists
extrinsic_list.append(camera_pose_torch.cpu().numpy())
intrinsic_list.append(intrinsics_torch.cpu().numpy())
world_points_list.append(pts3d_computed.cpu().numpy())
depth_maps_list.append(depthmap_torch.cpu().numpy())
images_list.append(image) # Add image to list
final_mask_list.append(mask) # Add final_mask to list
# Convert lists to numpy arrays with required shapes
# extrinsic: (S, 3, 4) - batch of camera extrinsic matrices
predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
# intrinsic: (S, 3, 3) - batch of camera intrinsic matrices
predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
# world_points: (S, H, W, 3) - batch of 3D world points
predictions["world_points"] = np.stack(world_points_list, axis=0)
# depth: (S, H, W, 1) or (S, H, W) - batch of depth maps
depth_maps = np.stack(depth_maps_list, axis=0)
# Add channel dimension if needed to match (S, H, W, 1) format
if len(depth_maps.shape) == 3:
depth_maps = depth_maps[..., np.newaxis]
predictions["depth"] = depth_maps
# images: (S, H, W, 3) - batch of input images
predictions["images"] = np.stack(images_list, axis=0)
# final_mask: (S, H, W) - batch of final masks for filtering
predictions["final_mask"] = np.stack(final_mask_list, axis=0)
# Process data for visualization tabs (depth, normal, measure)
progress(0.85, desc="🎨 生成深度图与法线图...")
processed_data = process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg, filter_white_bg
)
# Clean up
progress(0.95, desc="🧹 清理内存...")
torch.cuda.empty_cache()
total_time = time.time() - start_time
progress(1.0, desc=f"✅ 完成!总耗时: {total_time:.1f}秒")
print(f"Total processing time: {total_time:.2f} seconds")
return predictions, processed_data
def update_view_selectors(processed_data):
"""Update view selector dropdowns based on available views"""
if processed_data is None or len(processed_data) == 0:
choices = ["View 1"]
else:
num_views = len(processed_data)
choices = [f"View {i + 1}" for i in range(num_views)]
return (
gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
gr.Dropdown(choices=choices, value=choices[0]), # normal_view_selector
gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
)
def get_view_data_by_index(processed_data, view_index):
"""Get view data by index, handling bounds"""
if processed_data is None or len(processed_data) == 0:
return None
view_keys = list(processed_data.keys())
if view_index < 0 or view_index >= len(view_keys):
view_index = 0
return processed_data[view_keys[view_index]]
def update_depth_view(processed_data, view_index):
"""Update depth view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["depth"] is None:
return None
return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
def update_normal_view(processed_data, view_index):
"""Update normal view for a specific view index"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None or view_data["normal"] is None:
return None
return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
def update_measure_view(processed_data, view_index):
"""Update measure view for a specific view index with mask overlay"""
view_data = get_view_data_by_index(processed_data, view_index)
if view_data is None:
return None, [] # image, measure_points
# Get the base image
image = view_data["image"].copy()
# Ensure image is in uint8 format
if image.dtype != np.uint8:
if image.max() <= 1.0:
image = (image * 255).astype(np.uint8)
else:
image = image.astype(np.uint8)
# Apply mask overlay if mask is available
if view_data["mask"] is not None:
mask = view_data["mask"]
# Create light grey overlay for masked areas
# Masked areas (False values) will be overlaid with light grey
invalid_mask = ~mask # Areas where mask is False
if invalid_mask.any():
# Create a light grey overlay (RGB: 192, 192, 192)
overlay_color = np.array([255, 220, 220], dtype=np.uint8)
# Apply overlay with some transparency
alpha = 0.5 # Transparency level
for c in range(3): # RGB channels
image[:, :, c] = np.where(
invalid_mask,
(1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
image[:, :, c],
).astype(np.uint8)
return image, []
def navigate_depth_view(processed_data, current_selector_value, direction):
"""Navigate depth view (direction: -1 for previous, +1 for next)"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except:
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
depth_vis = update_depth_view(processed_data, new_view)
return new_selector_value, depth_vis
def navigate_normal_view(processed_data, current_selector_value, direction):
"""Navigate normal view (direction: -1 for previous, +1 for next)"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except:
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
normal_vis = update_normal_view(processed_data, new_view)
return new_selector_value, normal_vis
def navigate_measure_view(processed_data, current_selector_value, direction):
"""Navigate measure view (direction: -1 for previous, +1 for next)"""
if processed_data is None or len(processed_data) == 0:
return "View 1", None, []
# Parse current view number
try:
current_view = int(current_selector_value.split()[1]) - 1
except:
current_view = 0
num_views = len(processed_data)
new_view = (current_view + direction) % num_views
new_selector_value = f"View {new_view + 1}"
measure_image, measure_points = update_measure_view(processed_data, new_view)
return new_selector_value, measure_image, measure_points
def populate_visualization_tabs(processed_data):
"""Populate the depth, normal, and measure tabs with processed data"""
if processed_data is None or len(processed_data) == 0:
return None, None, None, []
# Use update functions to ensure confidence filtering is applied from the start
depth_vis = update_depth_view(processed_data, 0)
normal_vis = update_normal_view(processed_data, 0)
measure_img, _ = update_measure_view(processed_data, 0)
return depth_vis, normal_vis, measure_img, []
# -------------------------------------------------------------------------
# 2) Handle uploaded video/images --> produce target_dir + images
# -------------------------------------------------------------------------
def handle_uploads(unified_upload, s_time_interval=1.0):
"""
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
images or extracted frames from video into it. Return (target_dir, image_paths).
"""
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Create a unique folder name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"input_images_{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
# Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
# --- Handle uploaded files (both images and videos) ---
if unified_upload is not None:
for file_data in unified_upload:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = str(file_data)
file_ext = os.path.splitext(file_path)[1].lower()
# Check if it's a video file
video_extensions = [
".mp4",
".avi",
".mov",
".mkv",
".wmv",
".flv",
".webm",
".m4v",
".3gp",
]
if file_ext in video_extensions:
# Handle as video
vs = cv2.VideoCapture(file_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * s_time_interval) # frames per interval
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
# Use original filename as prefix for frames
base_name = os.path.splitext(os.path.basename(file_path))[0]
image_path = os.path.join(
target_dir_images, f"{base_name}_{video_frame_num:06}.png"
)
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
vs.release()
print(
f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}"
)
else:
# Handle as image
# Check if the file is a HEIC image
if file_ext in [".heic", ".heif"]:
# Convert HEIC to JPEG for better gallery compatibility
try:
with Image.open(file_path) as img:
# Convert to RGB if necessary (HEIC can have different color modes)
if img.mode not in ("RGB", "L"):
img = img.convert("RGB")
# Create JPEG filename
base_name = os.path.splitext(os.path.basename(file_path))[0]
dst_path = os.path.join(
target_dir_images, f"{base_name}.jpg"
)
# Save as JPEG with high quality
img.save(dst_path, "JPEG", quality=95)
image_paths.append(dst_path)
print(
f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}"
)
except Exception as e:
print(f"Error converting HEIC file {file_path}: {e}")
# Fall back to copying as is
dst_path = os.path.join(
target_dir_images, os.path.basename(file_path)
)
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
else:
# Regular image files - copy as is
dst_path = os.path.join(
target_dir_images, os.path.basename(file_path)
)
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(
f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds"
)
return target_dir, image_paths
# -------------------------------------------------------------------------
# 3) Update gallery on upload
# -------------------------------------------------------------------------
def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
"""
Whenever user uploads or changes files, immediately handle them
and show in the gallery. Return (target_dir, image_paths).
If nothing is uploaded, returns "None" and empty list.
"""
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
return (
None,
target_dir,
image_paths,
"上传完成。点击「开始重建」进行3D处理",
)
# -------------------------------------------------------------------------
# 4) Reconstruction: uses the target_dir plus any viz parameters
# -------------------------------------------------------------------------
@spaces.GPU(duration=120)
def gradio_demo(
target_dir,
frame_filter="All",
show_cam=True,
filter_black_bg=False,
filter_white_bg=False,
apply_mask=True,
show_mesh=True,
progress=gr.Progress(),
):
"""
Perform reconstruction using the already-created target_dir/images.
"""
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "❌ 未找到有效的目标目录,请先上传文件", None, None, None, None, None, None, None, None, None
progress(0, desc="🔄 准备重建...")
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Prepare frame_filter dropdown
target_dir_images = os.path.join(target_dir, "images")
all_files = (
sorted(os.listdir(target_dir_images))
if os.path.isdir(target_dir_images)
else []
)
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
frame_filter_choices = ["All"] + all_files
progress(0.05, desc=f"🚀 运行 MapAnything 模型... ({len(all_files)}张图片)")
print("Running MapAnything model...")
with torch.no_grad():
predictions, processed_data = run_model(
target_dir, apply_mask, True, filter_black_bg, filter_white_bg, progress
)
# Save predictions
progress(0.92, desc="💾 保存预测结果...")
prediction_save_path = os.path.join(target_dir, "predictions.npz")
np.savez(prediction_save_path, **predictions)
# Handle None frame_filter
if frame_filter is None:
frame_filter = "All"
# Build a GLB file name
progress(0.93, desc="🏗️ 生成3D模型文件...")
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
)
# Convert predictions to GLB
glbscene = predictions_to_glb(
predictions,
filter_by_frames=frame_filter,
show_cam=show_cam,
mask_black_bg=filter_black_bg,
mask_white_bg=filter_white_bg,
as_mesh=show_mesh, # Use the show_mesh parameter
)
glbscene.export(file_obj=glbfile)
# Cleanup
progress(0.96, desc="🧹 清理内存...")
del predictions
gc.collect()
torch.cuda.empty_cache()
end_time = time.time()
total_time = end_time - start_time
print(f"总耗时: {total_time:.2f}秒")
log_msg = f"✅ 重建成功 ({len(all_files)} 帧,耗时 {total_time:.1f}秒)"
# Populate visualization tabs with processed data
progress(0.98, desc="🎨 生成可视化...")
depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
processed_data
)
# Update view selectors based on available views
depth_selector, normal_selector, measure_selector = update_view_selectors(
processed_data
)
progress(1.0, desc="✅ 全部完成!")
return (
glbfile,
log_msg,
gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
processed_data,
depth_vis,
normal_vis,
measure_img,
"", # measure_text (empty initially)
depth_selector,
normal_selector,
measure_selector,
)
# -------------------------------------------------------------------------
# 5) Helper functions for UI resets + re-visualization
# -------------------------------------------------------------------------
def colorize_depth(depth_map, mask=None):
"""Convert depth map to colorized visualization with optional mask"""
if depth_map is None:
return None
# Normalize depth to 0-1 range
depth_normalized = depth_map.copy()
valid_mask = depth_normalized > 0
# Apply additional mask if provided (for background filtering)
if mask is not None:
valid_mask = valid_mask & mask
if valid_mask.sum() > 0:
valid_depths = depth_normalized[valid_mask]
p5 = np.percentile(valid_depths, 5)
p95 = np.percentile(valid_depths, 95)
depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
# Apply colormap
import matplotlib.pyplot as plt
colormap = plt.cm.turbo_r
colored = colormap(depth_normalized)
colored = (colored[:, :, :3] * 255).astype(np.uint8)
# Set invalid pixels to white
colored[~valid_mask] = [255, 255, 255]
return colored
def colorize_normal(normal_map, mask=None):
"""Convert normal map to colorized visualization with optional mask"""
if normal_map is None:
return None
# Create a copy for modification
normal_vis = normal_map.copy()
# Apply mask if provided (set masked areas to [0, 0, 0] which becomes grey after normalization)
if mask is not None:
invalid_mask = ~mask
normal_vis[invalid_mask] = [0, 0, 0] # Set invalid areas to zero
# Normalize normals to [0, 1] range for visualization
normal_vis = (normal_vis + 1.0) / 2.0
normal_vis = (normal_vis * 255).astype(np.uint8)
return normal_vis
def process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False
):
"""Extract depth, normal, and 3D points from predictions for visualization"""
processed_data = {}
# Process each view
for view_idx, view in enumerate(views):
# Get image
image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
# Get predicted points
pred_pts3d = predictions["world_points"][view_idx]
# Initialize data for this view
view_data = {
"image": image[0],
"points3d": pred_pts3d,
"depth": None,
"normal": None,
"mask": None,
}
# Start with the final mask from predictions
mask = predictions["final_mask"][view_idx].copy()
# Apply black background filtering if enabled
if filter_black_bg:
# Get the image colors (ensure they're in 0-255 range)
view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
# Filter out black background pixels (sum of RGB < 16)
black_bg_mask = view_colors.sum(axis=2) >= 16
mask = mask & black_bg_mask
# Apply white background filtering if enabled
if filter_white_bg:
# Get the image colors (ensure they're in 0-255 range)
view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
# Filter out white background pixels (all RGB > 240)
white_bg_mask = ~(
(view_colors[:, :, 0] > 240)
& (view_colors[:, :, 1] > 240)
& (view_colors[:, :, 2] > 240)
)
mask = mask & white_bg_mask
view_data["mask"] = mask
view_data["depth"] = predictions["depth"][view_idx].squeeze()
normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
view_data["normal"] = normals
processed_data[view_idx] = view_data
return processed_data
def reset_measure(processed_data):
"""Reset measure points"""
if processed_data is None or len(processed_data) == 0:
return None, [], ""
# Return the first view image
first_view = list(processed_data.values())[0]
return first_view["image"], [], ""
def measure(
processed_data, measure_points, current_view_selector, event: gr.SelectData
):
"""Handle measurement on images"""
try:
print(f"Measure function called with selector: {current_view_selector}")
if processed_data is None or len(processed_data) == 0:
return None, [], "No data available"
# Use the currently selected view instead of always using the first view
try:
current_view_index = int(current_view_selector.split()[1]) - 1
except:
current_view_index = 0
print(f"Using view index: {current_view_index}")
# Get view data safely
if current_view_index < 0 or current_view_index >= len(processed_data):
current_view_index = 0
view_keys = list(processed_data.keys())
current_view = processed_data[view_keys[current_view_index]]
if current_view is None:
return None, [], "No view data available"
point2d = event.index[0], event.index[1]
print(f"Clicked point: {point2d}")
# Check if the clicked point is in a masked area (prevent interaction)
if (
current_view["mask"] is not None
and 0 <= point2d[1] < current_view["mask"].shape[0]
and 0 <= point2d[0] < current_view["mask"].shape[1]
):
# Check if the point is in a masked (invalid) area
if not current_view["mask"][point2d[1], point2d[0]]:
print(f"Clicked point {point2d} is in masked area, ignoring click")
# Always return image with mask overlay
masked_image, _ = update_measure_view(
processed_data, current_view_index
)
return (
masked_image,
measure_points,
'<span style="color: red; font-weight: bold;">Cannot measure on masked areas (shown in grey)</span>',
)
measure_points.append(point2d)
# Get image with mask overlay and ensure it's valid
image, _ = update_measure_view(processed_data, current_view_index)
if image is None:
return None, [], "No image available"
image = image.copy()
points3d = current_view["points3d"]
# Ensure image is in uint8 format for proper cv2 operations
try:
if image.dtype != np.uint8:
if image.max() <= 1.0:
# Image is in [0, 1] range, convert to [0, 255]
image = (image * 255).astype(np.uint8)
else:
# Image is already in [0, 255] range
image = image.astype(np.uint8)
except Exception as e:
print(f"Image conversion error: {e}")
return None, [], f"Image conversion error: {e}"
# Draw circles for points
try:
for p in measure_points:
if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
image = cv2.circle(
image, p, radius=5, color=(255, 0, 0), thickness=2
)
except Exception as e:
print(f"Drawing error: {e}")
return None, [], f"Drawing error: {e}"
depth_text = ""
try:
for i, p in enumerate(measure_points):
if (
current_view["depth"] is not None
and 0 <= p[1] < current_view["depth"].shape[0]
and 0 <= p[0] < current_view["depth"].shape[1]
):
d = current_view["depth"][p[1], p[0]]
depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n"
else:
# Use Z coordinate of 3D points if depth not available
if (
points3d is not None
and 0 <= p[1] < points3d.shape[0]
and 0 <= p[0] < points3d.shape[1]
):
z = points3d[p[1], p[0], 2]
depth_text += f"- **P{i + 1} Z-coord: {z:.2f}m.**\n"
except Exception as e:
print(f"Depth text error: {e}")
depth_text = f"Error computing depth: {e}\n"
if len(measure_points) == 2:
try:
point1, point2 = measure_points
# Draw line
if (
0 <= point1[0] < image.shape[1]
and 0 <= point1[1] < image.shape[0]
and 0 <= point2[0] < image.shape[1]
and 0 <= point2[1] < image.shape[0]
):
image = cv2.line(
image, point1, point2, color=(255, 0, 0), thickness=2
)
# Compute 3D distance
distance_text = "- **Distance: Unable to compute**"
if (
points3d is not None
and 0 <= point1[1] < points3d.shape[0]
and 0 <= point1[0] < points3d.shape[1]
and 0 <= point2[1] < points3d.shape[0]
and 0 <= point2[0] < points3d.shape[1]
):
try:
p1_3d = points3d[point1[1], point1[0]]
p2_3d = points3d[point2[1], point2[0]]
distance = np.linalg.norm(p1_3d - p2_3d)
distance_text = f"- **Distance: {distance:.2f}m**"
except Exception as e:
print(f"Distance computation error: {e}")
distance_text = f"- **Distance computation error: {e}**"
measure_points = []
text = depth_text + distance_text
print(f"Measurement complete: {text}")
return [image, measure_points, text]
except Exception as e:
print(f"Final measurement error: {e}")
return None, [], f"Measurement error: {e}"
else:
print(f"Single point measurement: {depth_text}")
return [image, measure_points, depth_text]
except Exception as e:
print(f"Overall measure function error: {e}")
return None, [], f"Measure function error: {e}"
def clear_fields():
"""
Clears the 3D viewer, the stored target_dir, and empties the gallery.
"""
return None
def update_log():
"""
Display a quick log message while waiting.
"""
return "加载和重建中..."
def update_visualization(
target_dir,
frame_filter,
show_cam,
is_example,
filter_black_bg=False,
filter_white_bg=False,
show_mesh=True,
):
"""
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
and return it for the 3D viewer. If is_example == "True", skip.
"""
# If it's an example click, skip as requested
if is_example == "True":
return (
gr.update(),
"没有可用的重建。请先点击重建按钮。",
)
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return (
gr.update(),
"没有可用的重建。请先点击重建按钮。",
)
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return (
gr.update(),
f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
)
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
glbfile = os.path.join(
target_dir,
f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
)
if not os.path.exists(glbfile):
glbscene = predictions_to_glb(
predictions,
filter_by_frames=frame_filter,
show_cam=show_cam,
mask_black_bg=filter_black_bg,
mask_white_bg=filter_white_bg,
as_mesh=show_mesh,
)
glbscene.export(file_obj=glbfile)
return (
glbfile,
"可视化已更新",
)
def update_all_views_on_filter_change(
target_dir,
filter_black_bg,
filter_white_bg,
processed_data,
depth_view_selector,
normal_view_selector,
measure_view_selector,
):
"""
Update all individual view tabs when background filtering checkboxes change.
This regenerates the processed data with new filtering and updates all views.
"""
# Check if we have a valid target directory and predictions
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return processed_data, None, None, None, []
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return processed_data, None, None, None, []
try:
# Load the original predictions and views
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: loaded[key] for key in loaded.keys()}
# Load images using MapAnything's load_images function
image_folder_path = os.path.join(target_dir, "images")
views = load_images(image_folder_path)
# Regenerate processed data with new filtering settings
new_processed_data = process_predictions_for_visualization(
predictions, views, high_level_config, filter_black_bg, filter_white_bg
)
# Get current view indices
try:
depth_view_idx = (
int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
)
except:
depth_view_idx = 0
try:
normal_view_idx = (
int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
)
except:
normal_view_idx = 0
try:
measure_view_idx = (
int(measure_view_selector.split()[1]) - 1
if measure_view_selector
else 0
)
except:
measure_view_idx = 0
# Update all views with new filtered data
depth_vis = update_depth_view(new_processed_data, depth_view_idx)
normal_vis = update_normal_view(new_processed_data, normal_view_idx)
measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
return new_processed_data, depth_vis, normal_vis, measure_img, []
except Exception as e:
print(f"Error updating views on filter change: {e}")
return processed_data, None, None, None, []
# -------------------------------------------------------------------------
# Example scene functions
# -------------------------------------------------------------------------
def get_scene_info(examples_dir):
"""Get information about scenes in the examples directory"""
import glob
scenes = []
if not os.path.exists(examples_dir):
return scenes
for scene_folder in sorted(os.listdir(examples_dir)):
scene_path = os.path.join(examples_dir, scene_folder)
if os.path.isdir(scene_path):
# Find all image files in the scene folder
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
image_files = []
for ext in image_extensions:
image_files.extend(glob.glob(os.path.join(scene_path, ext)))
image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
if image_files:
# Sort images and get the first one for thumbnail
image_files = sorted(image_files)
first_image = image_files[0]
num_images = len(image_files)
scenes.append(
{
"name": scene_folder,
"path": scene_path,
"thumbnail": first_image,
"num_images": num_images,
"image_files": image_files,
}
)
return scenes
def load_example_scene(scene_name, examples_dir="examples"):
"""Load a scene from examples directory"""
scenes = get_scene_info(examples_dir)
# Find the selected scene
selected_scene = None
for scene in scenes:
if scene["name"] == scene_name:
selected_scene = scene
break
if selected_scene is None:
return None, None, None, "Scene not found"
# Create file-like objects for the unified upload system
# Convert image file paths to the format expected by unified_upload
file_objects = []
for image_path in selected_scene["image_files"]:
file_objects.append(image_path)
# Create target directory and copy images using the unified upload system
target_dir, image_paths = handle_uploads(file_objects, 1.0)
return (
None, # Clear reconstruction output
target_dir, # Set target directory
image_paths, # Set gallery
f"已加载场景 '{scene_name}'({selected_scene['num_images']} 张图片)。点击「开始重建」进行3D处理。",
)
# -------------------------------------------------------------------------
# 6) Build Gradio UI
# -------------------------------------------------------------------------
theme = get_gradio_theme()
# 自定义CSS防止UI抖动
CUSTOM_CSS = GRADIO_CSS + """
/* 防止组件撑开布局 */
.gradio-container {
max-width: 100% !important;
}
/* 固定Gallery高度 */
.gallery-container {
max-height: 350px !important;
overflow-y: auto !important;
}
/* 固定File组件高度 */
.file-preview {
max-height: 200px !important;
overflow-y: auto !important;
}
/* 固定Video组件高度 */
.video-container {
max-height: 300px !important;
}
/* 防止Textbox无限扩展 */
.textbox-container {
max-height: 100px !important;
}
/* 保持Tabs内容区域稳定 */
.tab-content {
min-height: 550px !important;
}
/* 增强文件上传区域 */
.file-upload-enhanced {
position: relative;
}
/* 减少Accordion和后续内容的间距 */
.accordion {
margin-bottom: 10px !important;
padding-bottom: 0 !important;
}
/* 示例场景区域紧凑样式 */
.accordion > .label-wrap {
margin-bottom: 5px !important;
}
/* 信息提示框 */
.info-box {
background-color: #E3F2FD !important;
border-left: 4px solid #2196F3 !important;
padding: 15px !important;
margin-bottom: 15px !important;
border-radius: 4px !important;
}
"""
with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything - 3D重建系统") as demo:
# State variables for the tabbed interface
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
processed_data_state = gr.State(value=None)
measure_points_state = gr.State(value=[])
current_view_index = gr.State(value=0) # Track current view index for navigation
# 添加粘贴板支持的 JavaScript
PASTE_JS = """
<script>
// 添加粘贴板支持
document.addEventListener('paste', function(e) {
const items = e.clipboardData.items;
for (let i = 0; i < items.length; i++) {
if (items[i].type.indexOf('image') !== -1) {
const blob = items[i].getAsFile();
const fileInput = document.querySelector('input[type="file"][multiple]');
if (fileInput) {
const dataTransfer = new DataTransfer();
dataTransfer.items.add(blob);
fileInput.files = dataTransfer.files;
fileInput.dispatchEvent(new Event('change', { bubbles: true }));
console.log('✅ 图片已从剪贴板粘贴');
}
}
}
});
console.log('💡 粘贴板功能已启用:使用 Ctrl+V 可直接粘贴截图');
</script>
"""
gr.HTML(PASTE_JS)
# 美化的顶部标题
gr.HTML("""
<div style="text-align: center; margin: 20px 0;">
<h2 style="color: #1976D2; margin-bottom: 10px;">MapAnything - 3D重建系统</h2>
<p style="color: #666; font-size: 16px;">多视图3D重建 | 深度估计 | 法线计算 | 距离测量</p>
</div>
""")
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
with gr.Row(equal_height=False):
# 左侧:输入区域
with gr.Column(scale=1, min_width=300):
gr.Markdown("### 📤 输入")
# 统一上传组件(支持文件、拖拽、粘贴板)
unified_upload = gr.File(
file_count="multiple",
label="上传视频或图片(支持拖拽、粘贴Ctrl+V📋)",
interactive=True,
file_types=["image", "video"],
)
# 摄像头输入(折叠式)
with gr.Accordion("📷 使用摄像头拍照", open=False):
camera_input = gr.Image(
label="拍照后自动添加",
sources=["webcam"],
type="filepath",
interactive=True,
)
with gr.Row():
s_time_interval = gr.Slider(
minimum=0.1, maximum=5.0, value=1.0, step=0.1,
label="视频采样时间间隔(每x秒取一帧)",
interactive=True,
visible=True,
scale=3,
)
resample_btn = gr.Button(
"重新采样视频",
visible=False,
variant="secondary",
scale=1,
)
image_gallery = gr.Gallery(
label="图片预览", columns=3, height=350,
show_download_button=True, object_fit="contain", preview=True
)
clear_uploads_btn = gr.ClearButton(
[unified_upload, camera_input, image_gallery],
value="清空上传",
variant="secondary",
size="sm",
)
with gr.Row():
submit_btn = gr.Button("🚀 开始重建", variant="primary", scale=2)
clear_btn = gr.ClearButton(
[unified_upload, camera_input, target_dir_output, image_gallery],
value="🗑️ 清空", scale=1
)
# 右侧:输出区域
with gr.Column(scale=2, min_width=600):
gr.Markdown("### 🎯 输出")
with gr.Tabs():
with gr.Tab("🏗️ 原始3D"):
reconstruction_output = gr.Model3D(
height=550, zoom_speed=0.5, pan_speed=0.5,
clear_color=[0.0, 0.0, 0.0, 0.0]
)
with gr.Tab("📊 深度图"):
with gr.Row(elem_classes=["navigation-row"]):
prev_depth_btn = gr.Button("◀", size="sm", scale=1)
depth_view_selector = gr.Dropdown(
choices=["View 1"], value="View 1",
label="视图", scale=3, interactive=True
)
next_depth_btn = gr.Button("▶", size="sm", scale=1)
depth_map = gr.Image(
type="numpy", label="", format="png", interactive=False,
height=500
)
with gr.Tab("🧭 法线图"):
with gr.Row(elem_classes=["navigation-row"]):
prev_normal_btn = gr.Button("◀", size="sm", scale=1)
normal_view_selector = gr.Dropdown(
choices=["View 1"], value="View 1",
label="视图", scale=3, interactive=True
)
next_normal_btn = gr.Button("▶", size="sm", scale=1)
normal_map = gr.Image(
type="numpy", label="", format="png", interactive=False,
height=500
)
with gr.Tab("📏 测量"):
gr.Markdown("**点击图片两次进行距离测量**")
with gr.Row(elem_classes=["navigation-row"]):
prev_measure_btn = gr.Button("◀", size="sm", scale=1)
measure_view_selector = gr.Dropdown(
choices=["View 1"], value="View 1",
label="视图", scale=3, interactive=True
)
next_measure_btn = gr.Button("▶", size="sm", scale=1)
measure_image = gr.Image(
type="numpy", show_label=False,
format="webp", interactive=False, sources=[],
height=500
)
measure_text = gr.Markdown("")
log_output = gr.Textbox(
value="📌 请上传图片或视频,然后点击「开始重建」",
label="状态信息",
interactive=False,
lines=1,
max_lines=1
)
# 高级选项(默认折叠)
with gr.Accordion("⚙️ 高级选项", open=False):
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=300):
gr.Markdown("#### 可视化参数")
frame_filter = gr.Dropdown(
choices=["All"], value="All", label="显示帧"
)
show_cam = gr.Checkbox(label="显示相机", value=True)
show_mesh = gr.Checkbox(label="显示网格", value=True)
filter_black_bg = gr.Checkbox(label="过滤黑色背景", value=False)
filter_white_bg = gr.Checkbox(label="过滤白色背景", value=False)
with gr.Column(scale=1, min_width=300):
gr.Markdown("#### 重建参数")
apply_mask_checkbox = gr.Checkbox(
label="应用深度掩码", value=True
)
# 示例场景(可折叠)
with gr.Accordion("🖼️ 示例场景", open=False):
gr.Markdown("点击缩略图加载场景进行重建")
scenes = get_scene_info("examples")
if scenes:
for i in range(0, len(scenes), 4): # Process 4 scenes per row
with gr.Row(equal_height=True):
for j in range(4):
scene_idx = i + j
if scene_idx < len(scenes):
scene = scenes[scene_idx]
with gr.Column(scale=1, min_width=150):
scene_img = gr.Image(
value=scene["thumbnail"],
height=150,
interactive=False,
show_label=False,
sources=[],
container=False
)
gr.Markdown(
f"**{scene['name']}** ({scene['num_images']}张)",
elem_classes=["text-center"]
)
scene_img.select(
fn=lambda name=scene["name"]: load_example_scene(name),
outputs=[
reconstruction_output,
target_dir_output,
image_gallery,
log_output,
],
)
# === 事件绑定 ===
# 上传文件自动更新
def update_gallery_on_unified_upload(files, interval):
if not files:
return None, None, None
target_dir, image_paths = handle_uploads(files, interval)
return (
target_dir,
image_paths,
"✅ 上传完成,点击「开始重建」进行 3D 处理",
)
# 处理摄像头拍照
def update_gallery_on_camera(image):
if image is None:
return None, None, None
# 将单张图片包装成列表
target_dir, image_paths = handle_uploads([image], 1.0)
return (
target_dir,
image_paths,
"✅ 摄像头照片已添加,点击「开始重建」进行 3D 处理",
)
def show_resample_button(files):
"""仅当上传的文件包含视频时显示重新采样按钮"""
if not files:
return gr.update(visible=False)
# 检查是否有视频文件
video_extensions = [
".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp",
]
has_video = False
for file_data in files:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = str(file_data)
file_ext = os.path.splitext(file_path)[1].lower()
if file_ext in video_extensions:
has_video = True
break
return gr.update(visible=has_video)
def resample_video_with_new_interval(files, new_interval, current_target_dir):
"""使用新的滑块值重新采样视频"""
if not files:
return (
current_target_dir,
None,
"没有可重新采样的文件。",
gr.update(visible=False),
)
# 检查是否有视频需要重新采样
video_extensions = [
".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp",
]
has_video = any(
os.path.splitext(
str(file_data["name"] if isinstance(file_data, dict) else file_data)
)[1].lower()
in video_extensions
for file_data in files
)
if not has_video:
return (
current_target_dir,
None,
"未找到视频进行重新采样。",
gr.update(visible=False),
)
# 清理旧的目标目录
if (
current_target_dir
and current_target_dir != "None"
and os.path.exists(current_target_dir)
):
shutil.rmtree(current_target_dir)
# 使用新间隔处理文件
target_dir, image_paths = handle_uploads(files, new_interval)
return (
target_dir,
image_paths,
f"视频已使用 {new_interval}秒 间隔重新采样。点击「开始重建」进行 3D 处理。",
gr.update(visible=False),
)
unified_upload.change(
fn=update_gallery_on_unified_upload,
inputs=[unified_upload, s_time_interval],
outputs=[target_dir_output, image_gallery, log_output]
).then(
fn=show_resample_button,
inputs=[unified_upload],
outputs=[resample_btn],
)
# 摄像头拍照事件
camera_input.change(
fn=update_gallery_on_camera,
inputs=[camera_input],
outputs=[target_dir_output, image_gallery, log_output]
)
# 滑块改变时显示重新采样按钮(仅当已上传文件时)
s_time_interval.change(
fn=show_resample_button,
inputs=[unified_upload],
outputs=[resample_btn],
)
# 处理重新采样按钮点击
resample_btn.click(
fn=resample_video_with_new_interval,
inputs=[unified_upload, s_time_interval, target_dir_output],
outputs=[target_dir_output, image_gallery, log_output, resample_btn],
)
# 重建按钮
submit_btn.click(
fn=clear_fields,
outputs=[reconstruction_output]
).then(
fn=update_log,
outputs=[log_output]
).then(
fn=gradio_demo,
inputs=[
target_dir_output, frame_filter, show_cam,
filter_black_bg, filter_white_bg,
apply_mask_checkbox, show_mesh
],
outputs=[
reconstruction_output, log_output, frame_filter,
processed_data_state, depth_map, normal_map, measure_image,
measure_text, depth_view_selector, normal_view_selector, measure_view_selector
]
).then(
fn=lambda: "False",
outputs=[is_example]
)
# 清空按钮
clear_btn.add([reconstruction_output, log_output])
# 可视化参数实时更新
for component in [frame_filter, show_cam, show_mesh]:
component.change(
fn=update_visualization,
inputs=[
target_dir_output, frame_filter, show_cam, is_example,
filter_black_bg, filter_white_bg, show_mesh
],
outputs=[reconstruction_output, log_output]
)
# 背景过滤器更新所有视图
for bg_filter in [filter_black_bg, filter_white_bg]:
bg_filter.change(
fn=update_all_views_on_filter_change,
inputs=[
target_dir_output, filter_black_bg, filter_white_bg, processed_data_state,
depth_view_selector, normal_view_selector, measure_view_selector
],
outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state]
)
# 深度图导航
prev_depth_btn.click(
fn=lambda pd, cs: navigate_depth_view(pd, cs, -1),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_view_selector, depth_map]
)
next_depth_btn.click(
fn=lambda pd, cs: navigate_depth_view(pd, cs, 1),
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_view_selector, depth_map]
)
depth_view_selector.change(
fn=lambda pd, sv: update_depth_view(pd, int(sv.split()[1]) - 1) if sv else None,
inputs=[processed_data_state, depth_view_selector],
outputs=[depth_map]
)
# 法线图导航
prev_normal_btn.click(
fn=lambda pd, cs: navigate_normal_view(pd, cs, -1),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_view_selector, normal_map]
)
next_normal_btn.click(
fn=lambda pd, cs: navigate_normal_view(pd, cs, 1),
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_view_selector, normal_map]
)
normal_view_selector.change(
fn=lambda pd, sv: update_normal_view(pd, int(sv.split()[1]) - 1) if sv else None,
inputs=[processed_data_state, normal_view_selector],
outputs=[normal_map]
)
# 测量功能
measure_image.select(
fn=measure,
inputs=[processed_data_state, measure_points_state, measure_view_selector],
outputs=[measure_image, measure_points_state, measure_text]
)
prev_measure_btn.click(
fn=lambda pd, cs: navigate_measure_view(pd, cs, -1),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_view_selector, measure_image, measure_points_state]
)
next_measure_btn.click(
fn=lambda pd, cs: navigate_measure_view(pd, cs, 1),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_view_selector, measure_image, measure_points_state]
)
measure_view_selector.change(
fn=lambda pd, sv: update_measure_view(pd, int(sv.split()[1]) - 1) if sv else (None, []),
inputs=[processed_data_state, measure_view_selector],
outputs=[measure_image, measure_points_state]
)
demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)