diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..1b11ea33f9d33b64baf0a497b56b106066b79945 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+examples/**/*.jpg filter=lfs diff=lfs merge=lfs -text
+examples/**/*.jpeg filter=lfs diff=lfs merge=lfs -text
+examples/**/*.png filter=lfs diff=lfs merge=lfs -text
+examples/**/*.bmp filter=lfs diff=lfs merge=lfs -text
+examples/**/*.tiff filter=lfs diff=lfs merge=lfs -text
+examples/**/*.tif filter=lfs diff=lfs merge=lfs -text
+examples/* filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..2099be4577a2bec84191d0a74064332812f9a1fa
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,276 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be added to the global gitignore or merged into this project gitignore. For a PyCharm
+# project, it is recommended to ignore the entire .idea directory, or at least the following:
+# .idea/workspace.xml
+# .idea/tasks.xml
+# .idea/usage.statistics.xml
+# .idea/dictionaries
+# .idea/shelf
+
+# VS Code
+.vscode/
+*.code-workspace
+
+# Local History for Visual Studio Code
+.history/
+
+# Built Visual Studio Code Extensions
+*.vsix
+
+# Hugging Face specific
+# Model files (usually large binary files)
+*.bin
+*.safetensors
+*.h5
+*.onnx
+*.pkl
+*.pth
+*.pt
+*.ckpt
+*.pb
+*.tflite
+*.mlmodel
+
+# Hugging Face cache and tokens
+.cache/
+cache/
+**/cache/
+hf_token*
+.huggingface/
+transformers_cache/
+datasets_cache/
+input_images_*
+
+# Gradio temporary files
+gradio_cached_examples/
+flagged/
+
+# Data directories
+data/
+checkpoints/
+outputs/
+results/
+logs/
+tmp/
+temp/
+# examples/*/
+# /examples*.jpg
+# *.png
+# *.jpeg
+# examples/
+
+# OS generated files
+.DS_Store
+.DS_Store?
+._*
+.Spotlight-V100
+.Trashes
+ehthumbs.db
+Thumbs.db
+desktop.ini
+
+# Backup files
+*.bak
+*.swp
+*.swo
+*~
+
+# Compressed files
+*.7z
+*.dmg
+*.gz
+*.iso
+*.jar
+*.rar
+*.tar
+*.zip
+
+# IDE and editor files
+.idea/
+*.sublime-project
+*.sublime-workspace
+.vscode/settings.json
+.vscode/tasks.json
+.vscode/launch.json
+.vscode/extensions.json
+
+# Node modules (if any frontend components)
+node_modules/
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+
+# Docker
+Dockerfile*
+docker-compose*
+.dockerignore
+
+# MLOps and experiment tracking
+wandb/
+.neptune/
+mlruns/
+.mlflow/
+tensorboard_logs/
+
+# Secrets and configuration
+*.secret
+*.key
+config.ini
+.env.local
+.env.*.local
+secrets.json
diff --git a/.gradio/certificate.pem b/.gradio/certificate.pem
new file mode 100644
index 0000000000000000000000000000000000000000..b85c8037f6b60976b2546fdbae88312c5246d9a3
--- /dev/null
+++ b/.gradio/certificate.pem
@@ -0,0 +1,31 @@
+-----BEGIN CERTIFICATE-----
+MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
+TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
+cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
+WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
+ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
+MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
+h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
+0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
+A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
+T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
+B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
+B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
+KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
+OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
+jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
+qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
+rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
+HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
+hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
+ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
+3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
+NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
+ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
+TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
+jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
+oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
+4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
+mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
+emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
+-----END CERTIFICATE-----
diff --git a/README.md b/README.md
index abc5c5f3d767acdc8548d48cc9b9fc328735ead7..fa42a21854aec982038d3554f04a6bc42805e4cf 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,13 @@
---
-title: Colaman Segmap
-emoji: 📊
-colorFrom: pink
-colorTo: purple
+title: Mapanything Gradio
+emoji: 🐠
+colorFrom: purple
+colorTo: green
sdk: gradio
-sdk_version: 5.49.1
+sdk_version: 5.44.1
app_file: app.py
pinned: false
+license: apache-2.0
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c6f99808af5ec2bc60747c8de2c116c3c6d011d
--- /dev/null
+++ b/app.py
@@ -0,0 +1,1985 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+MapAnything V2: 3D Reconstruction with Object Segmentation
+- Multi-view 3D reconstruction
+- GroundingDINO object detection
+- SAM precise segmentation
+- DBSCAN clustering for cross-view object matching
+"""
+
+import gc
+import os
+import shutil
+import sys
+import time
+from datetime import datetime
+from pathlib import Path
+from collections import defaultdict
+
+os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
+
+import cv2
+import gradio as gr
+import numpy as np
+import spaces
+import torch
+import trimesh
+from PIL import Image
+from pillow_heif import register_heif_opener
+from sklearn.cluster import DBSCAN
+
+register_heif_opener()
+
+sys.path.append("mapanything/")
+
+from mapanything.utils.geometry import depthmap_to_world_frame, points_to_normals
+from mapanything.utils.hf_utils.css_and_html import (
+ GRADIO_CSS,
+ MEASURE_INSTRUCTIONS_HTML,
+ get_acknowledgements_html,
+ get_description_html,
+ get_gradio_theme,
+ get_header_html,
+)
+from mapanything.utils.hf_utils.hf_helpers import initialize_mapanything_model
+from mapanything.utils.hf_utils.visual_util import predictions_to_glb
+from mapanything.utils.image import load_images, rgb
+
+
+def get_logo_base64():
+ """Convert WAI logo to base64 for embedding in HTML"""
+ import base64
+
+ logo_path = "examples/WAI-Logo/wai_logo.png"
+ try:
+ with open(logo_path, "rb") as img_file:
+ img_data = img_file.read()
+ base64_str = base64.b64encode(img_data).decode()
+ return f"data:image/png;base64,{base64_str}"
+ except FileNotFoundError:
+ return None
+
+
+# ============================================================================
+# Configuration
+# ============================================================================
+
+# MapAnything Configuration
+high_level_config = {
+ "path": "configs/train.yaml",
+ "hf_model_name": "facebook/map-anything",
+ "model_str": "mapanything",
+ "config_overrides": [
+ "machine=aws",
+ "model=mapanything",
+ "model/task=images_only",
+ "model.encoder.uses_torch_hub=false",
+ ],
+ "checkpoint_name": "model.safetensors",
+ "config_name": "config.json",
+ "trained_with_amp": True,
+ "trained_with_amp_dtype": "bf16",
+ "data_norm_type": "dinov2",
+ "patch_size": 14,
+ "resolution": 518,
+}
+
+# GroundingDINO and SAM Configuration
+GROUNDING_DINO_MODEL_ID = "IDEA-Research/grounding-dino-tiny"
+GROUNDING_DINO_BOX_THRESHOLD = 0.25
+GROUNDING_DINO_TEXT_THRESHOLD = 0.2
+
+SAM_MODEL_ID = "facebook/sam-vit-huge"
+
+DEFAULT_TEXT_PROMPT = "chair . table . sofa . bed . desk . cabinet"
+
+# Common objects prompt for detection
+COMMON_OBJECTS_PROMPT = (
+ "person . face . hand . "
+ "chair . sofa . couch . bed . table . desk . cabinet . shelf . drawer . "
+ "door . window . wall . floor . ceiling . curtain . "
+ "tv . monitor . screen . computer . laptop . keyboard . mouse . "
+ "phone . tablet . remote . "
+ "lamp . light . chandelier . "
+ "book . magazine . paper . pen . pencil . "
+ "bottle . cup . glass . mug . plate . bowl . fork . knife . spoon . "
+ "vase . plant . flower . pot . "
+ "clock . picture . frame . mirror . "
+ "pillow . cushion . blanket . towel . "
+ "bag . backpack . suitcase . "
+ "box . basket . container . "
+ "shoe . hat . coat . "
+ "toy . ball . "
+ "car . bicycle . motorcycle . bus . truck . "
+ "tree . grass . sky . cloud . sun . "
+ "dog . cat . bird . "
+ "building . house . bridge . road . street . "
+ "sign . pole . bench"
+)
+
+# DBSCAN clustering configuration (eps in meters)
+DBSCAN_EPS_CONFIG = {
+ 'sofa': 1.5,
+ 'bed': 1.5,
+ 'couch': 1.5,
+ 'desk': 0.8,
+ 'table': 0.8,
+ 'chair': 0.6,
+ 'cabinet': 0.8,
+ 'window': 0.5,
+ 'door': 0.6,
+ 'tv': 0.6,
+ 'default': 1.0
+}
+
+DBSCAN_MIN_SAMPLES = 1
+
+# Quality control
+MIN_DETECTION_CONFIDENCE = 0.35
+MIN_MASK_AREA = 100
+
+# Global model variables
+model = None
+grounding_dino_model = None
+grounding_dino_processor = None
+sam_predictor = None
+
+
+# ============================================================================
+# Model Loading Functions
+# ============================================================================
+
+def load_grounding_dino_model(device):
+ """Load GroundingDINO model from HuggingFace"""
+ global grounding_dino_model, grounding_dino_processor
+
+ if grounding_dino_model is not None:
+ print("✅ GroundingDINO already loaded")
+ return
+
+ try:
+ from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
+
+ print(f"📥 Loading GroundingDINO from HuggingFace: {GROUNDING_DINO_MODEL_ID}")
+ grounding_dino_processor = AutoProcessor.from_pretrained(GROUNDING_DINO_MODEL_ID)
+ grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
+ GROUNDING_DINO_MODEL_ID
+ ).to(device).eval()
+
+ print("✅ GroundingDINO loaded successfully")
+
+ except Exception as e:
+ print(f"❌ GroundingDINO loading failed: {e}")
+ import traceback
+ traceback.print_exc()
+
+
+def load_sam_model(device):
+ """Load SAM model from HuggingFace"""
+ global sam_predictor
+
+ if sam_predictor is not None:
+ print("✅ SAM already loaded")
+ return
+
+ try:
+ from transformers import SamModel, SamProcessor
+
+ print(f"📥 Loading SAM from HuggingFace: {SAM_MODEL_ID}")
+ sam_model = SamModel.from_pretrained(SAM_MODEL_ID).to(device).eval()
+ sam_processor = SamProcessor.from_pretrained(SAM_MODEL_ID)
+
+ # Wrap in a predictor-like interface
+ class SAMPredictor:
+ def __init__(self, model, processor, device):
+ self.model = model
+ self.processor = processor
+ self.device = device
+ self.image = None
+
+ def set_image(self, image):
+ """Set image for prediction"""
+ if image.dtype == np.uint8:
+ self.image = Image.fromarray(image)
+ else:
+ self.image = Image.fromarray((image * 255).astype(np.uint8))
+
+ def predict(self, box, multimask_output=False):
+ """Predict mask from box"""
+ inputs = self.processor(
+ self.image,
+ input_boxes=[[[box]]],
+ return_tensors="pt"
+ ).to(self.device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ masks = self.processor.image_processor.post_process_masks(
+ outputs.pred_masks.cpu(),
+ inputs["original_sizes"].cpu(),
+ inputs["reshaped_input_sizes"].cpu()
+ )[0].squeeze().numpy()
+
+ if len(masks.shape) == 2:
+ masks = masks[np.newaxis, ...]
+
+ return masks, None, None
+
+ sam_predictor = SAMPredictor(sam_model, sam_processor, device)
+ print("✅ SAM loaded successfully")
+
+ except Exception as e:
+ print(f"❌ SAM loading failed: {e}")
+ print(" Falling back to bbox-based masks")
+ import traceback
+ traceback.print_exc()
+
+
+# ============================================================================
+# Segmentation Functions
+# ============================================================================
+
+def generate_distinct_colors(n):
+ """Generate N visually distinct colors (RGB, 0-255)"""
+ import colorsys
+ if n == 0:
+ return []
+
+ colors = []
+ for i in range(n):
+ hue = i / max(n, 1)
+ rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
+ rgb_color = tuple(int(c * 255) for c in rgb)
+ colors.append(rgb_color)
+
+ return colors
+
+
+def run_grounding_dino_detection(image_np, text_prompt, device):
+ """Run GroundingDINO detection"""
+ if grounding_dino_model is None or grounding_dino_processor is None:
+ print("⚠️ GroundingDINO not loaded")
+ return []
+
+ try:
+ print(f"🔍 GroundingDINO detection: {text_prompt}")
+
+ # Convert to PIL Image
+ if image_np.dtype == np.uint8:
+ pil_image = Image.fromarray(image_np)
+ else:
+ pil_image = Image.fromarray((image_np * 255).astype(np.uint8))
+
+ # Preprocess
+ inputs = grounding_dino_processor(images=pil_image, text=text_prompt, return_tensors="pt")
+ inputs = {k: v.to(device) for k, v in inputs.items()}
+
+ # Inference
+ with torch.no_grad():
+ outputs = grounding_dino_model(**inputs)
+
+ # Post-process
+ results = grounding_dino_processor.post_process_grounded_object_detection(
+ outputs,
+ inputs["input_ids"],
+ threshold=GROUNDING_DINO_BOX_THRESHOLD,
+ text_threshold=GROUNDING_DINO_TEXT_THRESHOLD,
+ target_sizes=[pil_image.size[::-1]]
+ )[0]
+
+ # Convert to unified format
+ detections = []
+ boxes = results["boxes"].cpu().numpy()
+ scores = results["scores"].cpu().numpy()
+ labels = results["labels"]
+
+ print(f"✅ Detected {len(boxes)} objects")
+
+ for box, score, label in zip(boxes, scores, labels):
+ detection = {
+ 'bbox': box.tolist(), # [x1, y1, x2, y2]
+ 'label': label,
+ 'confidence': float(score)
+ }
+ detections.append(detection)
+ print(f" - {label}: {score:.2f}")
+
+ return detections
+
+ except Exception as e:
+ print(f"❌ GroundingDINO detection failed: {e}")
+ import traceback
+ traceback.print_exc()
+ return []
+
+
+def run_sam_refinement(image_np, boxes):
+ """Run SAM precise segmentation"""
+ if sam_predictor is None:
+ print("⚠️ SAM not loaded, using bbox as mask")
+ # Use bbox to create simple rectangular mask
+ masks = []
+ h, w = image_np.shape[:2]
+ for box in boxes:
+ x1, y1, x2, y2 = map(int, box)
+ mask = np.zeros((h, w), dtype=bool)
+ mask[y1:y2, x1:x2] = True
+ masks.append(mask)
+ return masks
+
+ try:
+ print(f"🎯 SAM precise segmentation for {len(boxes)} regions...")
+ sam_predictor.set_image(image_np)
+
+ masks = []
+ for box in boxes:
+ x1, y1, x2, y2 = map(int, box)
+ box_array = np.array([x1, y1, x2, y2])
+
+ mask_output, _, _ = sam_predictor.predict(
+ box=box_array,
+ multimask_output=False
+ )
+ masks.append(mask_output[0])
+
+ print(f"✅ SAM segmentation complete")
+ return masks
+
+ except Exception as e:
+ print(f"❌ SAM segmentation failed: {e}")
+ # Fallback to bbox masks
+ masks = []
+ h, w = image_np.shape[:2]
+ for box in boxes:
+ x1, y1, x2, y2 = map(int, box)
+ mask = np.zeros((h, w), dtype=bool)
+ mask[y1:y2, x1:x2] = True
+ masks.append(mask)
+ return masks
+
+
+def normalize_label(label):
+ """Normalize label to main category"""
+ label = label.strip().lower()
+
+ priority_labels = ['sofa', 'bed', 'table', 'desk', 'chair', 'cabinet', 'window', 'door']
+
+ for priority in priority_labels:
+ if priority in label:
+ return priority
+
+ first_word = label.split()[0] if label else label
+
+ # Handle plural forms
+ if first_word.endswith('s') and len(first_word) > 1:
+ singular = first_word[:-1]
+ if first_word.endswith('sses'):
+ singular = first_word[:-2]
+ elif first_word.endswith('ies'):
+ singular = first_word[:-3] + 'y'
+ elif first_word.endswith('ves'):
+ singular = first_word[:-3] + 'f'
+ return singular
+
+ return first_word
+
+
+def compute_object_3d_center(points, mask):
+ """Compute 3D center of object"""
+ masked_points = points[mask]
+ if len(masked_points) == 0:
+ return None
+ return np.median(masked_points, axis=0)
+
+
+def compute_adaptive_eps(centers, base_eps):
+ """Adaptively compute eps value based on object distribution"""
+ if len(centers) <= 1:
+ return base_eps
+
+ from scipy.spatial.distance import pdist
+ distances = pdist(centers)
+
+ if len(distances) == 0:
+ return base_eps
+
+ median_dist = np.median(distances)
+
+ if median_dist > base_eps * 2:
+ adaptive_eps = min(median_dist * 0.6, base_eps * 2.5)
+ elif median_dist > base_eps:
+ adaptive_eps = median_dist * 0.5
+ else:
+ adaptive_eps = base_eps
+
+ return adaptive_eps
+
+
+def match_objects_across_views(all_view_detections):
+ """Match objects across views using DBSCAN clustering"""
+ print("\n🔗 Matching objects across views using DBSCAN clustering...")
+
+ objects_by_label = defaultdict(list)
+
+ for view_idx, detections in enumerate(all_view_detections):
+ for det_idx, det in enumerate(detections):
+ if det.get('center_3d') is None:
+ continue
+
+ norm_label = normalize_label(det['label'])
+ objects_by_label[norm_label].append({
+ 'view_idx': view_idx,
+ 'det_idx': det_idx,
+ 'label': det['label'],
+ 'norm_label': norm_label,
+ 'center_3d': det['center_3d'],
+ 'confidence': det['confidence'],
+ })
+
+ if len(objects_by_label) == 0:
+ return {}, []
+
+ object_id_map = defaultdict(dict)
+ unique_objects = []
+ next_global_id = 0
+
+ for norm_label, objects in objects_by_label.items():
+ print(f"\n 📦 Processing {norm_label}: {len(objects)} detections")
+
+ if len(objects) == 1:
+ obj = objects[0]
+ unique_objects.append({
+ 'global_id': next_global_id,
+ 'label': obj['label'],
+ 'views': [(obj['view_idx'], obj['det_idx'])],
+ 'center_3d': obj['center_3d'],
+ })
+ object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
+ next_global_id += 1
+ print(f" → 1 cluster (single detection)")
+ continue
+
+ centers = np.array([obj['center_3d'] for obj in objects])
+
+ base_eps = DBSCAN_EPS_CONFIG.get(norm_label, DBSCAN_EPS_CONFIG.get('default', 1.0))
+ eps = compute_adaptive_eps(centers, base_eps)
+
+ clustering = DBSCAN(eps=eps, min_samples=DBSCAN_MIN_SAMPLES, metric='euclidean')
+ cluster_labels = clustering.fit_predict(centers)
+
+ n_clusters = len(set(cluster_labels)) - (1 if -1 in cluster_labels else 0)
+ n_noise = list(cluster_labels).count(-1)
+
+ if eps != base_eps:
+ print(f" → {n_clusters} clusters (base_eps={base_eps}m → adaptive_eps={eps:.2f}m)")
+ else:
+ print(f" → {n_clusters} clusters (eps={eps}m)")
+ if n_noise > 0:
+ print(f" ⚠️ {n_noise} noise points (isolated detections)")
+
+ for cluster_id in set(cluster_labels):
+ if cluster_id == -1:
+ for i, label in enumerate(cluster_labels):
+ if label == -1:
+ obj = objects[i]
+ unique_objects.append({
+ 'global_id': next_global_id,
+ 'label': obj['label'],
+ 'views': [(obj['view_idx'], obj['det_idx'])],
+ 'center_3d': obj['center_3d'],
+ })
+ object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
+ next_global_id += 1
+ else:
+ cluster_objects = [objects[i] for i, label in enumerate(cluster_labels) if label == cluster_id]
+
+ total_conf = sum(o['confidence'] for o in cluster_objects)
+ weighted_center = sum(o['center_3d'] * o['confidence'] for o in cluster_objects) / total_conf
+
+ unique_objects.append({
+ 'global_id': next_global_id,
+ 'label': cluster_objects[0]['label'],
+ 'views': [(o['view_idx'], o['det_idx']) for o in cluster_objects],
+ 'center_3d': weighted_center,
+ })
+
+ for obj in cluster_objects:
+ object_id_map[obj['view_idx']][obj['det_idx']] = next_global_id
+
+ next_global_id += 1
+
+ print(f"\n 📊 Summary:")
+ print(f" Total detections: {sum(len(objs) for objs in objects_by_label.values())}")
+ print(f" Unique objects: {len(unique_objects)}")
+
+ return object_id_map, unique_objects
+
+
+def create_multi_view_segmented_mesh(processed_data, all_view_detections, all_view_masks,
+ object_id_map, unique_objects, target_dir):
+ """Create multi-view fused segmented mesh"""
+ try:
+ print("\n🎨 Generating multi-view segmented mesh...")
+
+ unique_normalized_labels = sorted(set(normalize_label(obj['label']) for obj in unique_objects))
+ label_colors = {}
+ colors = generate_distinct_colors(len(unique_normalized_labels))
+
+ for i, norm_label in enumerate(unique_normalized_labels):
+ label_colors[norm_label] = colors[i]
+
+ for obj in unique_objects:
+ norm_label = normalize_label(obj['label'])
+ obj['color'] = label_colors[norm_label]
+ obj['normalized_label'] = norm_label
+
+ print(f" Object category color mapping:")
+ for norm_label, color in sorted(label_colors.items()):
+ count = sum(1 for obj in unique_objects if normalize_label(obj['label']) == norm_label)
+ print(f" {norm_label} × {count} → RGB{color}")
+
+ import utils3d
+
+ all_meshes = []
+
+ for view_idx in range(len(processed_data)):
+ view_data = processed_data[view_idx]
+ image = view_data["image"]
+ points3d = view_data["points3d"]
+ mask = view_data.get("mask")
+ normal = view_data.get("normal")
+
+ detections = all_view_detections[view_idx]
+ masks = all_view_masks[view_idx]
+
+ if len(detections) == 0:
+ continue
+
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ image = (image * 255).astype(np.uint8)
+ else:
+ image = image.astype(np.uint8)
+
+ colored_image = image.copy()
+ confidence_map = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
+
+ detections_info = []
+ filtered_count = 0
+ for det_idx, (det, seg_mask) in enumerate(zip(detections, masks)):
+ if det['confidence'] < MIN_DETECTION_CONFIDENCE:
+ filtered_count += 1
+ continue
+
+ mask_area = seg_mask.sum()
+ if mask_area < MIN_MASK_AREA:
+ filtered_count += 1
+ continue
+
+ global_id = object_id_map[view_idx].get(det_idx)
+ if global_id is None:
+ continue
+
+ unique_obj = next((obj for obj in unique_objects if obj['global_id'] == global_id), None)
+ if unique_obj is None:
+ continue
+
+ detections_info.append({
+ 'mask': seg_mask,
+ 'color': unique_obj['color'],
+ 'confidence': det['confidence'],
+ })
+
+ if filtered_count > 0:
+ print(f" View {view_idx + 1}: filtered {filtered_count} low-quality detections")
+
+ detections_info.sort(key=lambda x: x['confidence'])
+
+ for info in detections_info:
+ seg_mask = info['mask']
+ color = info['color']
+ conf = info['confidence']
+
+ update_mask = seg_mask & (conf > confidence_map)
+ colored_image[update_mask] = color
+ confidence_map[update_mask] = conf
+
+ height, width = image.shape[:2]
+
+ if normal is None:
+ faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh(
+ points3d,
+ colored_image.astype(np.float32) / 255,
+ utils3d.numpy.image_uv(width=width, height=height),
+ mask=mask if mask is not None else np.ones((height, width), dtype=bool),
+ tri=True
+ )
+ vertex_normals = None
+ else:
+ faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh(
+ points3d,
+ colored_image.astype(np.float32) / 255,
+ utils3d.numpy.image_uv(width=width, height=height),
+ normal,
+ mask=mask if mask is not None else np.ones((height, width), dtype=bool),
+ tri=True
+ )
+
+ vertices = vertices * np.array([1, -1, -1], dtype=np.float32)
+ if vertex_normals is not None:
+ vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32)
+
+ view_mesh = trimesh.Trimesh(
+ vertices=vertices,
+ faces=faces,
+ vertex_normals=vertex_normals,
+ vertex_colors=(vertex_colors * 255).astype(np.uint8),
+ process=False
+ )
+
+ all_meshes.append(view_mesh)
+ print(f" View {view_idx + 1}: {len(vertices):,} vertices, {len(faces):,} faces")
+
+ if len(all_meshes) == 0:
+ print("⚠️ No mesh generated")
+ return None
+
+ print(" Fusing all views...")
+ combined_mesh = trimesh.util.concatenate(all_meshes)
+
+ glb_path = os.path.join(target_dir, 'segmented_mesh.glb')
+ combined_mesh.export(glb_path)
+
+ print(f"✅ Multi-view segmented mesh saved: {glb_path}")
+ print(f" Total: {len(combined_mesh.vertices):,} vertices, {len(combined_mesh.faces):,} faces")
+ print(f" {len(unique_objects)} unique objects")
+
+ return glb_path
+
+ except Exception as e:
+ print(f"❌ Failed to generate multi-view mesh: {e}")
+ import traceback
+ traceback.print_exc()
+ return None
+
+
+# ============================================================================
+# Core Model Inference
+# ============================================================================
+
+@spaces.GPU(duration=120)
+def run_model(
+ target_dir,
+ apply_mask=True,
+ mask_edges=True,
+ filter_black_bg=False,
+ filter_white_bg=False,
+ enable_segmentation=False,
+ text_prompt=DEFAULT_TEXT_PROMPT,
+):
+ """
+ Run the MapAnything model + optional segmentation
+ """
+ global model
+ import torch
+
+ print(f"Processing images from {target_dir}")
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Initialize MapAnything model
+ if model is None:
+ model = initialize_mapanything_model(high_level_config, device)
+ else:
+ model = model.to(device)
+
+ model.eval()
+
+ # Load segmentation models if enabled
+ if enable_segmentation:
+ load_grounding_dino_model(device)
+ load_sam_model(device)
+
+ # Load images
+ print("Loading images...")
+ image_folder_path = os.path.join(target_dir, "images")
+ views = load_images(image_folder_path)
+
+ print(f"Loaded {len(views)} images")
+ if len(views) == 0:
+ raise ValueError("No images found. Check your upload.")
+
+ # Run model inference
+ print("Running inference...")
+ outputs = model.infer(
+ views, apply_mask=apply_mask, mask_edges=True, memory_efficient_inference=False
+ )
+
+ # Convert predictions
+ predictions = {}
+ extrinsic_list = []
+ intrinsic_list = []
+ world_points_list = []
+ depth_maps_list = []
+ images_list = []
+ final_mask_list = []
+
+ for pred in outputs:
+ depthmap_torch = pred["depth_z"][0].squeeze(-1)
+ intrinsics_torch = pred["intrinsics"][0]
+ camera_pose_torch = pred["camera_poses"][0]
+
+ pts3d_computed, valid_mask = depthmap_to_world_frame(
+ depthmap_torch, intrinsics_torch, camera_pose_torch
+ )
+
+ if "mask" in pred:
+ mask = pred["mask"][0].squeeze(-1).cpu().numpy().astype(bool)
+ else:
+ mask = np.ones_like(depthmap_torch.cpu().numpy(), dtype=bool)
+
+ mask = mask & valid_mask.cpu().numpy()
+ image = pred["img_no_norm"][0].cpu().numpy()
+
+ extrinsic_list.append(camera_pose_torch.cpu().numpy())
+ intrinsic_list.append(intrinsics_torch.cpu().numpy())
+ world_points_list.append(pts3d_computed.cpu().numpy())
+ depth_maps_list.append(depthmap_torch.cpu().numpy())
+ images_list.append(image)
+ final_mask_list.append(mask)
+
+ predictions["extrinsic"] = np.stack(extrinsic_list, axis=0)
+ predictions["intrinsic"] = np.stack(intrinsic_list, axis=0)
+ predictions["world_points"] = np.stack(world_points_list, axis=0)
+
+ depth_maps = np.stack(depth_maps_list, axis=0)
+ if len(depth_maps.shape) == 3:
+ depth_maps = depth_maps[..., np.newaxis]
+ predictions["depth"] = depth_maps
+
+ predictions["images"] = np.stack(images_list, axis=0)
+ predictions["final_mask"] = np.stack(final_mask_list, axis=0)
+
+ # Process visualization data
+ processed_data = process_predictions_for_visualization(
+ predictions, views, high_level_config, filter_black_bg, filter_white_bg
+ )
+
+ # Segmentation processing
+ segmented_glb = None
+ if enable_segmentation and grounding_dino_model is not None:
+ print("\n🎯 Starting segmentation...")
+ print(f"🔍 Detection prompt: {text_prompt[:100]}...")
+
+ all_view_detections = []
+ all_view_masks = []
+
+ for view_idx, ref_image in enumerate(images_list):
+ print(f"\n📸 Processing view {view_idx + 1}/{len(images_list)}...")
+
+ if ref_image.dtype != np.uint8:
+ ref_image_np = (ref_image * 255).astype(np.uint8)
+ else:
+ ref_image_np = ref_image
+
+ detections = run_grounding_dino_detection(ref_image_np, text_prompt, device)
+
+ if len(detections) > 0:
+ boxes = [d['bbox'] for d in detections]
+ masks = run_sam_refinement(ref_image_np, boxes)
+
+ points3d = world_points_list[view_idx]
+
+ for det_idx, (det, mask) in enumerate(zip(detections, masks)):
+ center_3d = compute_object_3d_center(points3d, mask)
+ det['center_3d'] = center_3d
+ det['mask_2d'] = mask
+
+ all_view_detections.append(detections)
+ all_view_masks.append(masks)
+ else:
+ all_view_detections.append([])
+ all_view_masks.append([])
+
+ # Match objects across views
+ if any(len(dets) > 0 for dets in all_view_detections):
+ object_id_map, unique_objects = match_objects_across_views(all_view_detections)
+
+ # Generate segmented mesh
+ segmented_glb = create_multi_view_segmented_mesh(
+ processed_data, all_view_detections, all_view_masks,
+ object_id_map, unique_objects, target_dir
+ )
+
+ # Cleanup
+ torch.cuda.empty_cache()
+
+ return predictions, processed_data, segmented_glb
+
+
+# ============================================================================
+# Helper Functions (from app.py)
+# ============================================================================
+
+def update_view_selectors(processed_data):
+ """Update view selector dropdowns based on available views"""
+ if processed_data is None or len(processed_data) == 0:
+ choices = ["View 1"]
+ else:
+ num_views = len(processed_data)
+ choices = [f"View {i + 1}" for i in range(num_views)]
+
+ return (
+ gr.Dropdown(choices=choices, value=choices[0]),
+ gr.Dropdown(choices=choices, value=choices[0]),
+ gr.Dropdown(choices=choices, value=choices[0]),
+ )
+
+
+def get_view_data_by_index(processed_data, view_index):
+ """Get view data by index, handling bounds"""
+ if processed_data is None or len(processed_data) == 0:
+ return None
+
+ view_keys = list(processed_data.keys())
+ if view_index < 0 or view_index >= len(view_keys):
+ view_index = 0
+
+ return processed_data[view_keys[view_index]]
+
+
+def update_depth_view(processed_data, view_index):
+ """Update depth view for a specific view index"""
+ view_data = get_view_data_by_index(processed_data, view_index)
+ if view_data is None or view_data["depth"] is None:
+ return None
+
+ return colorize_depth(view_data["depth"], mask=view_data.get("mask"))
+
+
+def update_normal_view(processed_data, view_index):
+ """Update normal view for a specific view index"""
+ view_data = get_view_data_by_index(processed_data, view_index)
+ if view_data is None or view_data["normal"] is None:
+ return None
+
+ return colorize_normal(view_data["normal"], mask=view_data.get("mask"))
+
+
+def update_measure_view(processed_data, view_index):
+ """Update measure view for a specific view index with mask overlay"""
+ view_data = get_view_data_by_index(processed_data, view_index)
+ if view_data is None:
+ return None, []
+
+ image = view_data["image"].copy()
+
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ image = (image * 255).astype(np.uint8)
+ else:
+ image = image.astype(np.uint8)
+
+ if view_data["mask"] is not None:
+ mask = view_data["mask"]
+ invalid_mask = ~mask
+
+ if invalid_mask.any():
+ overlay_color = np.array([255, 220, 220], dtype=np.uint8)
+ alpha = 0.5
+ for c in range(3):
+ image[:, :, c] = np.where(
+ invalid_mask,
+ (1 - alpha) * image[:, :, c] + alpha * overlay_color[c],
+ image[:, :, c],
+ ).astype(np.uint8)
+
+ return image, []
+
+
+def navigate_depth_view(processed_data, current_selector_value, direction):
+ """Navigate depth view (direction: -1 for previous, +1 for next)"""
+ if processed_data is None or len(processed_data) == 0:
+ return "View 1", None
+
+ try:
+ current_view = int(current_selector_value.split()[1]) - 1
+ except:
+ current_view = 0
+
+ num_views = len(processed_data)
+ new_view = (current_view + direction) % num_views
+
+ new_selector_value = f"View {new_view + 1}"
+ depth_vis = update_depth_view(processed_data, new_view)
+
+ return new_selector_value, depth_vis
+
+
+def navigate_normal_view(processed_data, current_selector_value, direction):
+ """Navigate normal view (direction: -1 for previous, +1 for next)"""
+ if processed_data is None or len(processed_data) == 0:
+ return "View 1", None
+
+ try:
+ current_view = int(current_selector_value.split()[1]) - 1
+ except:
+ current_view = 0
+
+ num_views = len(processed_data)
+ new_view = (current_view + direction) % num_views
+
+ new_selector_value = f"View {new_view + 1}"
+ normal_vis = update_normal_view(processed_data, new_view)
+
+ return new_selector_value, normal_vis
+
+
+def navigate_measure_view(processed_data, current_selector_value, direction):
+ """Navigate measure view (direction: -1 for previous, +1 for next)"""
+ if processed_data is None or len(processed_data) == 0:
+ return "View 1", None, []
+
+ try:
+ current_view = int(current_selector_value.split()[1]) - 1
+ except:
+ current_view = 0
+
+ num_views = len(processed_data)
+ new_view = (current_view + direction) % num_views
+
+ new_selector_value = f"View {new_view + 1}"
+ measure_image, measure_points = update_measure_view(processed_data, new_view)
+
+ return new_selector_value, measure_image, measure_points
+
+
+def populate_visualization_tabs(processed_data):
+ """Populate the depth, normal, and measure tabs with processed data"""
+ if processed_data is None or len(processed_data) == 0:
+ return None, None, None, []
+
+ depth_vis = update_depth_view(processed_data, 0)
+ normal_vis = update_normal_view(processed_data, 0)
+ measure_img, _ = update_measure_view(processed_data, 0)
+
+ return depth_vis, normal_vis, measure_img, []
+
+
+def handle_uploads(unified_upload, s_time_interval=1.0):
+ """
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
+ images or extracted frames from video into it. Return (target_dir, image_paths).
+ """
+ start_time = time.time()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ target_dir = f"input_images_{timestamp}"
+ target_dir_images = os.path.join(target_dir, "images")
+
+ if os.path.exists(target_dir):
+ shutil.rmtree(target_dir)
+ os.makedirs(target_dir)
+ os.makedirs(target_dir_images)
+
+ image_paths = []
+
+ if unified_upload is not None:
+ for file_data in unified_upload:
+ if isinstance(file_data, dict) and "name" in file_data:
+ file_path = file_data["name"]
+ else:
+ file_path = str(file_data)
+
+ file_ext = os.path.splitext(file_path)[1].lower()
+
+ video_extensions = [
+ ".mp4", ".avi", ".mov", ".mkv", ".wmv", ".flv", ".webm", ".m4v", ".3gp",
+ ]
+ if file_ext in video_extensions:
+ vs = cv2.VideoCapture(file_path)
+ fps = vs.get(cv2.CAP_PROP_FPS)
+ frame_interval = int(fps * s_time_interval)
+
+ count = 0
+ video_frame_num = 0
+ while True:
+ gotit, frame = vs.read()
+ if not gotit:
+ break
+ count += 1
+ if count % frame_interval == 0:
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
+ image_path = os.path.join(
+ target_dir_images, f"{base_name}_{video_frame_num:06}.png"
+ )
+ cv2.imwrite(image_path, frame)
+ image_paths.append(image_path)
+ video_frame_num += 1
+ vs.release()
+ print(f"Extracted {video_frame_num} frames from video: {os.path.basename(file_path)}")
+
+ else:
+ if file_ext in [".heic", ".heif"]:
+ try:
+ with Image.open(file_path) as img:
+ if img.mode not in ("RGB", "L"):
+ img = img.convert("RGB")
+
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
+ dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
+
+ img.save(dst_path, "JPEG", quality=95)
+ image_paths.append(dst_path)
+ print(f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> {os.path.basename(dst_path)}")
+ except Exception as e:
+ print(f"Error converting HEIC file {file_path}: {e}")
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
+ shutil.copy(file_path, dst_path)
+ image_paths.append(dst_path)
+ else:
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
+ shutil.copy(file_path, dst_path)
+ image_paths.append(dst_path)
+
+ image_paths = sorted(image_paths)
+
+ end_time = time.time()
+ print(f"Files processed to {target_dir_images}; took {end_time - start_time:.3f} seconds")
+ return target_dir, image_paths
+
+
+def update_gallery_on_upload(input_video, input_images, s_time_interval=1.0):
+ """Update gallery on upload"""
+ if not input_video and not input_images:
+ return None, None, None, None, None
+ target_dir, image_paths = handle_uploads(input_video, input_images, s_time_interval)
+ return (
+ None,
+ None,
+ target_dir,
+ image_paths,
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
+ )
+
+
+@spaces.GPU(duration=120)
+def gradio_demo(
+ target_dir,
+ frame_filter="All",
+ show_cam=True,
+ filter_black_bg=False,
+ filter_white_bg=False,
+ conf_thres=3.0,
+ apply_mask=True,
+ show_mesh=True,
+ enable_segmentation=False,
+ text_prompt=DEFAULT_TEXT_PROMPT,
+ use_sam=True,
+):
+ """执行重建"""
+ if not os.path.isdir(target_dir) or target_dir == "None":
+ return None, None, "❌ 未找到有效的目标目录,请先上传文件", None, None, None, None, None, None, None, None, None
+
+ start_time = time.time()
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ target_dir_images = os.path.join(target_dir, "images")
+ all_files = (
+ sorted(os.listdir(target_dir_images))
+ if os.path.isdir(target_dir_images)
+ else []
+ )
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
+ frame_filter_choices = ["All"] + all_files
+
+ print("运行 MapAnything 模型...")
+ with torch.no_grad():
+ predictions, processed_data, segmented_glb = run_model(
+ target_dir, apply_mask, True, filter_black_bg, filter_white_bg,
+ enable_segmentation, text_prompt
+ )
+
+ # 保存预测结果
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
+ np.savez(prediction_save_path, **predictions)
+
+ if frame_filter is None:
+ frame_filter = "All"
+
+ # 生成 GLB 文件名
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}.glb",
+ )
+
+ # 转换预测结果为 GLB
+ glbscene = predictions_to_glb(
+ predictions,
+ filter_by_frames=frame_filter,
+ show_cam=show_cam,
+ mask_black_bg=filter_black_bg,
+ mask_white_bg=filter_white_bg,
+ as_mesh=show_mesh,
+ conf_percentile=conf_thres,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ # 清理内存
+ del predictions
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ end_time = time.time()
+ print(f"总耗时: {end_time - start_time:.2f}秒")
+ log_msg = f"✅ 重建成功 ({len(all_files)} 帧)"
+
+ # Populate visualization tabs
+ depth_vis, normal_vis, measure_img, measure_pts = populate_visualization_tabs(
+ processed_data
+ )
+
+ # Update view selectors
+ depth_selector, normal_selector, measure_selector = update_view_selectors(
+ processed_data
+ )
+
+ return (
+ glbfile,
+ segmented_glb,
+ log_msg,
+ gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True),
+ processed_data,
+ depth_vis,
+ normal_vis,
+ measure_img,
+ "",
+ depth_selector,
+ normal_selector,
+ measure_selector,
+ )
+
+
+def colorize_depth(depth_map, mask=None):
+ """Convert depth map to colorized visualization with optional mask"""
+ if depth_map is None:
+ return None
+
+ depth_normalized = depth_map.copy()
+ valid_mask = depth_normalized > 0
+
+ if mask is not None:
+ valid_mask = valid_mask & mask
+
+ if valid_mask.sum() > 0:
+ valid_depths = depth_normalized[valid_mask]
+ p5 = np.percentile(valid_depths, 5)
+ p95 = np.percentile(valid_depths, 95)
+
+ depth_normalized[valid_mask] = (depth_normalized[valid_mask] - p5) / (p95 - p5)
+
+ import matplotlib.pyplot as plt
+
+ colormap = plt.cm.turbo_r
+ colored = colormap(depth_normalized)
+ colored = (colored[:, :, :3] * 255).astype(np.uint8)
+
+ colored[~valid_mask] = [255, 255, 255]
+
+ return colored
+
+
+def colorize_normal(normal_map, mask=None):
+ """Convert normal map to colorized visualization with optional mask"""
+ if normal_map is None:
+ return None
+
+ normal_vis = normal_map.copy()
+
+ if mask is not None:
+ invalid_mask = ~mask
+ normal_vis[invalid_mask] = [0, 0, 0]
+
+ normal_vis = (normal_vis + 1.0) / 2.0
+ normal_vis = (normal_vis * 255).astype(np.uint8)
+
+ return normal_vis
+
+
+def process_predictions_for_visualization(
+ predictions, views, high_level_config, filter_black_bg=False, filter_white_bg=False
+):
+ """Extract depth, normal, and 3D points from predictions for visualization"""
+ processed_data = {}
+
+ for view_idx, view in enumerate(views):
+ image = rgb(view["img"], norm_type=high_level_config["data_norm_type"])
+
+ pred_pts3d = predictions["world_points"][view_idx]
+
+ view_data = {
+ "image": image[0],
+ "points3d": pred_pts3d,
+ "depth": None,
+ "normal": None,
+ "mask": None,
+ }
+
+ mask = predictions["final_mask"][view_idx].copy()
+
+ if filter_black_bg:
+ view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
+ black_bg_mask = view_colors.sum(axis=2) >= 16
+ mask = mask & black_bg_mask
+
+ if filter_white_bg:
+ view_colors = image[0] * 255 if image[0].max() <= 1.0 else image[0]
+ white_bg_mask = ~(
+ (view_colors[:, :, 0] > 240)
+ & (view_colors[:, :, 1] > 240)
+ & (view_colors[:, :, 2] > 240)
+ )
+ mask = mask & white_bg_mask
+
+ view_data["mask"] = mask
+ view_data["depth"] = predictions["depth"][view_idx].squeeze()
+
+ normals, _ = points_to_normals(pred_pts3d, mask=view_data["mask"])
+ view_data["normal"] = normals
+
+ processed_data[view_idx] = view_data
+
+ return processed_data
+
+
+def reset_measure(processed_data):
+ """Reset measure points"""
+ if processed_data is None or len(processed_data) == 0:
+ return None, [], ""
+
+ first_view = list(processed_data.values())[0]
+ return first_view["image"], [], ""
+
+
+def measure(
+ processed_data, measure_points, current_view_selector, event: gr.SelectData
+):
+ """Handle measurement on images"""
+ try:
+ print(f"测量功能调用,选择器: {current_view_selector}")
+
+ if processed_data is None or len(processed_data) == 0:
+ return None, [], "❌ 没有可用数据"
+
+ try:
+ current_view_index = int(current_view_selector.split()[1]) - 1
+ except:
+ current_view_index = 0
+
+ print(f"使用视图索引: {current_view_index}")
+
+ if current_view_index < 0 or current_view_index >= len(processed_data):
+ current_view_index = 0
+
+ view_keys = list(processed_data.keys())
+ current_view = processed_data[view_keys[current_view_index]]
+
+ if current_view is None:
+ return None, [], "❌ 没有视图数据"
+
+ point2d = event.index[0], event.index[1]
+ print(f"点击点: {point2d}")
+
+ if (
+ current_view["mask"] is not None
+ and 0 <= point2d[1] < current_view["mask"].shape[0]
+ and 0 <= point2d[0] < current_view["mask"].shape[1]
+ ):
+ if not current_view["mask"][point2d[1], point2d[0]]:
+ print(f"点击点 {point2d} 在遮罩区域,忽略点击")
+ masked_image, _ = update_measure_view(
+ processed_data, current_view_index
+ )
+ return (
+ masked_image,
+ measure_points,
+ '⚠️ 无法在遮罩区域测量(显示为灰色)',
+ )
+
+ measure_points.append(point2d)
+
+ image, _ = update_measure_view(processed_data, current_view_index)
+ if image is None:
+ return None, [], "❌ 没有可用图像"
+
+ image = image.copy()
+ points3d = current_view["points3d"]
+
+ try:
+ if image.dtype != np.uint8:
+ if image.max() <= 1.0:
+ image = (image * 255).astype(np.uint8)
+ else:
+ image = image.astype(np.uint8)
+ except Exception as e:
+ print(f"图像转换错误: {e}")
+ return None, [], f"❌ 图像转换错误: {e}"
+
+ try:
+ for p in measure_points:
+ if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
+ image = cv2.circle(
+ image, p, radius=5, color=(255, 0, 0), thickness=2
+ )
+ except Exception as e:
+ print(f"绘制错误: {e}")
+ return None, [], f"❌ 绘制错误: {e}"
+
+ depth_text = ""
+ try:
+ for i, p in enumerate(measure_points):
+ if (
+ current_view["depth"] is not None
+ and 0 <= p[1] < current_view["depth"].shape[0]
+ and 0 <= p[0] < current_view["depth"].shape[1]
+ ):
+ d = current_view["depth"][p[1], p[0]]
+ depth_text += f"- **P{i + 1} 深度: {d:.2f}m**\n"
+ else:
+ if (
+ points3d is not None
+ and 0 <= p[1] < points3d.shape[0]
+ and 0 <= p[0] < points3d.shape[1]
+ ):
+ z = points3d[p[1], p[0], 2]
+ depth_text += f"- **P{i + 1} Z坐标: {z:.2f}m**\n"
+ except Exception as e:
+ print(f"深度文本错误: {e}")
+ depth_text = f"❌ 深度计算错误: {e}\n"
+
+ if len(measure_points) == 2:
+ try:
+ point1, point2 = measure_points
+ if (
+ 0 <= point1[0] < image.shape[1]
+ and 0 <= point1[1] < image.shape[0]
+ and 0 <= point2[0] < image.shape[1]
+ and 0 <= point2[1] < image.shape[0]
+ ):
+ image = cv2.line(
+ image, point1, point2, color=(255, 0, 0), thickness=2
+ )
+
+ distance_text = "- **距离: 无法计算**"
+ if (
+ points3d is not None
+ and 0 <= point1[1] < points3d.shape[0]
+ and 0 <= point1[0] < points3d.shape[1]
+ and 0 <= point2[1] < points3d.shape[0]
+ and 0 <= point2[0] < points3d.shape[1]
+ ):
+ try:
+ p1_3d = points3d[point1[1], point1[0]]
+ p2_3d = points3d[point2[1], point2[0]]
+ distance = np.linalg.norm(p1_3d - p2_3d)
+ distance_text = f"- **距离: {distance:.2f}m**"
+ except Exception as e:
+ print(f"距离计算错误: {e}")
+ distance_text = f"- **距离计算错误: {e}**"
+
+ measure_points = []
+ text = depth_text + distance_text
+ print(f"测量完成: {text}")
+ return [image, measure_points, text]
+ except Exception as e:
+ print(f"最终测量错误: {e}")
+ return None, [], f"❌ 测量错误: {e}"
+ else:
+ print(f"单点测量: {depth_text}")
+ return [image, measure_points, depth_text]
+
+ except Exception as e:
+ print(f"整体测量功能错误: {e}")
+ return None, [], f"❌ 测量功能错误: {e}"
+
+
+def clear_fields():
+ """清空 3D 查看器"""
+ return None, None
+
+
+def update_log():
+ """显示日志消息"""
+ return "🔄 加载和重建中..."
+
+
+def update_visualization(
+ target_dir,
+ frame_filter,
+ show_cam,
+ is_example,
+ conf_thres=None,
+ filter_black_bg=False,
+ filter_white_bg=False,
+ show_mesh=True,
+):
+ """更新可视化"""
+ if is_example == "True":
+ return gr.update(), "❌ 没有可用的重建。请先点击重建按钮。"
+
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
+ return gr.update(), "❌ 没有可用的重建。请先点击重建按钮。"
+
+ predictions_path = os.path.join(target_dir, "predictions.npz")
+ if not os.path.exists(predictions_path):
+ return gr.update(), f"❌ 没有可用的重建。请先运行「重建」。"
+
+ loaded = np.load(predictions_path, allow_pickle=True)
+ predictions = {key: loaded[key] for key in loaded.keys()}
+
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_cam{show_cam}_mesh{show_mesh}_black{filter_black_bg}_white{filter_white_bg}.glb",
+ )
+
+ glbscene = predictions_to_glb(
+ predictions,
+ filter_by_frames=frame_filter,
+ show_cam=show_cam,
+ mask_black_bg=filter_black_bg,
+ mask_white_bg=filter_white_bg,
+ as_mesh=show_mesh,
+ conf_percentile=conf_thres,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ return glbfile, "✅ 可视化已更新。"
+
+
+def update_all_views_on_filter_change(
+ target_dir,
+ filter_black_bg,
+ filter_white_bg,
+ processed_data,
+ depth_view_selector,
+ normal_view_selector,
+ measure_view_selector,
+):
+ """Update all individual view tabs when background filtering checkboxes change"""
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
+ return processed_data, None, None, None, []
+
+ predictions_path = os.path.join(target_dir, "predictions.npz")
+ if not os.path.exists(predictions_path):
+ return processed_data, None, None, None, []
+
+ try:
+ loaded = np.load(predictions_path, allow_pickle=True)
+ predictions = {key: loaded[key] for key in loaded.keys()}
+
+ image_folder_path = os.path.join(target_dir, "images")
+ views = load_images(image_folder_path)
+
+ new_processed_data = process_predictions_for_visualization(
+ predictions, views, high_level_config, filter_black_bg, filter_white_bg
+ )
+
+ try:
+ depth_view_idx = (
+ int(depth_view_selector.split()[1]) - 1 if depth_view_selector else 0
+ )
+ except:
+ depth_view_idx = 0
+
+ try:
+ normal_view_idx = (
+ int(normal_view_selector.split()[1]) - 1 if normal_view_selector else 0
+ )
+ except:
+ normal_view_idx = 0
+
+ try:
+ measure_view_idx = (
+ int(measure_view_selector.split()[1]) - 1
+ if measure_view_selector
+ else 0
+ )
+ except:
+ measure_view_idx = 0
+
+ depth_vis = update_depth_view(new_processed_data, depth_view_idx)
+ normal_vis = update_normal_view(new_processed_data, normal_view_idx)
+ measure_img, _ = update_measure_view(new_processed_data, measure_view_idx)
+
+ return new_processed_data, depth_vis, normal_vis, measure_img, []
+
+ except Exception as e:
+ print(f"Error updating views on filter change: {e}")
+ return processed_data, None, None, None, []
+
+
+# ============================================================================
+# Example Scene Functions
+# ============================================================================
+
+def get_scene_info(examples_dir):
+ """Get information about scenes in the examples directory"""
+ import glob
+
+ scenes = []
+ if not os.path.exists(examples_dir):
+ return scenes
+
+ for scene_folder in sorted(os.listdir(examples_dir)):
+ scene_path = os.path.join(examples_dir, scene_folder)
+ if os.path.isdir(scene_path):
+ image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
+ image_files = []
+ for ext in image_extensions:
+ image_files.extend(glob.glob(os.path.join(scene_path, ext)))
+ image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
+
+ if image_files:
+ image_files = sorted(image_files)
+ first_image = image_files[0]
+ num_images = len(image_files)
+
+ scenes.append(
+ {
+ "name": scene_folder,
+ "path": scene_path,
+ "thumbnail": first_image,
+ "num_images": num_images,
+ "image_files": image_files,
+ }
+ )
+
+ return scenes
+
+
+def load_example_scene(scene_name, examples_dir="examples"):
+ """从示例目录加载场景"""
+ scenes = get_scene_info(examples_dir)
+
+ selected_scene = None
+ for scene in scenes:
+ if scene["name"] == scene_name:
+ selected_scene = scene
+ break
+
+ if selected_scene is None:
+ return None, None, None, None, "❌ 场景未找到"
+
+ file_objects = []
+ for image_path in selected_scene["image_files"]:
+ file_objects.append(image_path)
+
+ target_dir, image_paths = handle_uploads(file_objects, 1.0)
+
+ return (
+ None,
+ None,
+ target_dir,
+ image_paths,
+ f"✅ 已加载场景 '{scene_name}' ({selected_scene['num_images']} 张图像)。点击「开始重建」进行 3D 处理。",
+ )
+
+
+# ============================================================================
+# Gradio UI
+# ============================================================================
+
+theme = get_gradio_theme()
+
+# 自定义CSS防止UI抖动
+CUSTOM_CSS = GRADIO_CSS + """
+/* 防止组件撑开布局 */
+.gradio-container {
+ max-width: 100% !important;
+}
+
+/* 固定Gallery高度 */
+.gallery-container {
+ max-height: 350px !important;
+ overflow-y: auto !important;
+}
+
+/* 固定File组件高度 */
+.file-preview {
+ max-height: 200px !important;
+ overflow-y: auto !important;
+}
+
+/* 固定Video组件高度 */
+.video-container {
+ max-height: 300px !important;
+}
+
+/* 防止Textbox无限扩展 */
+.textbox-container {
+ max-height: 100px !important;
+}
+
+/* 保持Tabs内容区域稳定 */
+.tab-content {
+ min-height: 550px !important;
+}
+"""
+
+with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="MapAnything V2 - 3D重建与物体分割") as demo:
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
+ processed_data_state = gr.State(value=None)
+ measure_points_state = gr.State(value=[])
+
+ # 顶部标题
+ gr.HTML("""
+
+
MapAnything V2 - 3D重建与物体分割
+
基于DBSCAN聚类的智能物体识别 | 多视图融合 | 自适应参数调整
+
+ """)
+
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
+
+ with gr.Row(equal_height=False):
+ # 左侧:输入区域
+ with gr.Column(scale=1, min_width=300):
+ gr.Markdown("### 📤 输入")
+
+ with gr.Tabs():
+ with gr.Tab("📷 图片"):
+ input_images = gr.File(
+ file_count="multiple",
+ label="上传多张图片(推荐3-10张)",
+ interactive=True,
+ height=200
+ )
+
+ with gr.Tab("🎥 视频"):
+ input_video = gr.Video(
+ label="上传视频",
+ interactive=True,
+ height=300
+ )
+ s_time_interval = gr.Slider(
+ minimum=0.1, maximum=5.0, value=1.0, step=0.1,
+ label="帧采样间隔(秒)", interactive=True
+ )
+
+ image_gallery = gr.Gallery(
+ label="图片预览", columns=3, height=350,
+ show_download_button=True, object_fit="contain", preview=True
+ )
+
+ with gr.Row():
+ submit_btn = gr.Button("🚀 开始重建", variant="primary", scale=2)
+ clear_btn = gr.ClearButton(
+ [input_video, input_images, target_dir_output, image_gallery],
+ value="🗑️ 清空", scale=1
+ )
+
+ # 右侧:输出区域
+ with gr.Column(scale=2, min_width=600):
+ gr.Markdown("### 🎯 输出")
+
+ with gr.Tabs():
+ with gr.Tab("🏗️ 原始3D"):
+ reconstruction_output = gr.Model3D(
+ height=550, zoom_speed=0.5, pan_speed=0.5,
+ clear_color=[0.0, 0.0, 0.0, 0.0]
+ )
+
+ with gr.Tab("🎨 分割3D"):
+ segmented_output = gr.Model3D(
+ height=550, zoom_speed=0.5, pan_speed=0.5,
+ clear_color=[0.0, 0.0, 0.0, 0.0]
+ )
+
+ with gr.Tab("📊 深度图"):
+ with gr.Row(elem_classes=["navigation-row"]):
+ prev_depth_btn = gr.Button("◀", size="sm", scale=1)
+ depth_view_selector = gr.Dropdown(
+ choices=["View 1"], value="View 1",
+ label="视图", scale=3, interactive=True
+ )
+ next_depth_btn = gr.Button("▶", size="sm", scale=1)
+ depth_map = gr.Image(
+ type="numpy", label="", format="png", interactive=False,
+ height=500
+ )
+
+ with gr.Tab("🧭 法线图"):
+ with gr.Row(elem_classes=["navigation-row"]):
+ prev_normal_btn = gr.Button("◀", size="sm", scale=1)
+ normal_view_selector = gr.Dropdown(
+ choices=["View 1"], value="View 1",
+ label="视图", scale=3, interactive=True
+ )
+ next_normal_btn = gr.Button("▶", size="sm", scale=1)
+ normal_map = gr.Image(
+ type="numpy", label="", format="png", interactive=False,
+ height=500
+ )
+
+ with gr.Tab("📏 测量"):
+ gr.Markdown("**点击图片两次进行距离测量**")
+ with gr.Row(elem_classes=["navigation-row"]):
+ prev_measure_btn = gr.Button("◀", size="sm", scale=1)
+ measure_view_selector = gr.Dropdown(
+ choices=["View 1"], value="View 1",
+ label="视图", scale=3, interactive=True
+ )
+ next_measure_btn = gr.Button("▶", size="sm", scale=1)
+ measure_image = gr.Image(
+ type="numpy", show_label=False,
+ format="webp", interactive=False, sources=[],
+ height=500
+ )
+ measure_text = gr.Markdown("")
+
+ log_output = gr.Textbox(
+ value="📌 请上传图片或视频,然后点击「开始重建」",
+ label="状态信息",
+ interactive=False,
+ lines=1,
+ max_lines=1
+ )
+
+ # 高级选项(可折叠)
+ with gr.Accordion("⚙️ 高级选项", open=False):
+ with gr.Row(equal_height=False):
+ with gr.Column(scale=1, min_width=300):
+ gr.Markdown("#### 可视化参数")
+ frame_filter = gr.Dropdown(
+ choices=["All"], value="All", label="显示帧"
+ )
+ conf_thres = gr.Slider(
+ minimum=0, maximum=100, value=0, step=0.1,
+ label="置信度阈值(百分位)"
+ )
+ show_cam = gr.Checkbox(label="显示相机", value=True)
+ show_mesh = gr.Checkbox(label="显示网格", value=True)
+ filter_black_bg = gr.Checkbox(label="过滤黑色背景", value=False)
+ filter_white_bg = gr.Checkbox(label="过滤白色背景", value=False)
+
+ with gr.Column(scale=1, min_width=300):
+ gr.Markdown("#### 重建参数")
+ apply_mask_checkbox = gr.Checkbox(
+ label="应用深度掩码", value=True
+ )
+
+ gr.Markdown("#### 分割参数")
+ enable_segmentation = gr.Checkbox(
+ label="启用语义分割", value=False
+ )
+
+ text_prompt = gr.Textbox(
+ value=DEFAULT_TEXT_PROMPT,
+ label="检测物体(用 . 分隔)",
+ placeholder="例如: chair . table . sofa",
+ lines=2,
+ max_lines=2
+ )
+
+ with gr.Row():
+ detect_all_btn = gr.Button("🔍 检测所有", size="sm")
+ restore_default_btn = gr.Button("↻ 默认", size="sm")
+
+ # 示例场景(可折叠)
+ with gr.Accordion("🖼️ 示例场景", open=False):
+ scenes = get_scene_info("examples")
+ if scenes:
+ for i in range(0, len(scenes), 4):
+ with gr.Row(equal_height=True):
+ for j in range(4):
+ scene_idx = i + j
+ if scene_idx < len(scenes):
+ scene = scenes[scene_idx]
+ with gr.Column(scale=1, min_width=150):
+ scene_img = gr.Image(
+ value=scene["thumbnail"],
+ height=150,
+ interactive=False,
+ show_label=False,
+ sources=[],
+ container=False
+ )
+ gr.Markdown(
+ f"**{scene['name']}** ({scene['num_images']}张)",
+ elem_classes=["text-center"]
+ )
+ scene_img.select(
+ fn=lambda name=scene["name"]: load_example_scene(name),
+ outputs=[
+ reconstruction_output, segmented_output,
+ target_dir_output, image_gallery, log_output
+ ]
+ )
+
+ # === 事件绑定 ===
+
+ # 分割选项按钮
+ detect_all_btn.click(
+ fn=lambda: COMMON_OBJECTS_PROMPT,
+ outputs=[text_prompt]
+ )
+ restore_default_btn.click(
+ fn=lambda: DEFAULT_TEXT_PROMPT,
+ outputs=[text_prompt]
+ )
+
+ # 上传文件自动更新
+ def update_gallery_on_unified_upload(files_video, files_images, interval):
+ if not files_video and not files_images:
+ return None, None, None, None
+ # Combine both inputs
+ all_files = []
+ if files_video:
+ all_files.append(files_video)
+ if files_images:
+ all_files.extend(files_images)
+ target_dir, image_paths = handle_uploads(all_files, interval)
+ return (
+ None,
+ target_dir,
+ image_paths,
+ "✅ 上传完成,点击「开始重建」进行 3D 处理",
+ )
+
+ input_video.change(
+ fn=update_gallery_on_unified_upload,
+ inputs=[input_video, input_images, s_time_interval],
+ outputs=[segmented_output, target_dir_output, image_gallery, log_output]
+ )
+ input_images.change(
+ fn=update_gallery_on_unified_upload,
+ inputs=[input_video, input_images, s_time_interval],
+ outputs=[segmented_output, target_dir_output, image_gallery, log_output]
+ )
+
+ # 重建按钮
+ submit_btn.click(
+ fn=clear_fields,
+ outputs=[reconstruction_output, segmented_output]
+ ).then(
+ fn=update_log,
+ outputs=[log_output]
+ ).then(
+ fn=gradio_demo,
+ inputs=[
+ target_dir_output, frame_filter, show_cam,
+ filter_black_bg, filter_white_bg, conf_thres,
+ apply_mask_checkbox, show_mesh,
+ enable_segmentation, text_prompt
+ ],
+ outputs=[
+ reconstruction_output, segmented_output, log_output, frame_filter,
+ processed_data_state, depth_map, normal_map, measure_image,
+ measure_text, depth_view_selector, normal_view_selector, measure_view_selector
+ ]
+ ).then(
+ fn=lambda: "False",
+ outputs=[is_example]
+ )
+
+ # 清空按钮
+ clear_btn.add([reconstruction_output, segmented_output, log_output])
+
+ # 可视化参数实时更新
+ for component in [frame_filter, show_cam, conf_thres, show_mesh]:
+ component.change(
+ fn=update_visualization,
+ inputs=[
+ target_dir_output, frame_filter, show_cam, is_example,
+ conf_thres, filter_black_bg, filter_white_bg, show_mesh
+ ],
+ outputs=[reconstruction_output, log_output]
+ )
+
+ # 背景过滤器更新所有视图
+ for bg_filter in [filter_black_bg, filter_white_bg]:
+ bg_filter.change(
+ fn=update_all_views_on_filter_change,
+ inputs=[
+ target_dir_output, filter_black_bg, filter_white_bg, processed_data_state,
+ depth_view_selector, normal_view_selector, measure_view_selector
+ ],
+ outputs=[processed_data_state, depth_map, normal_map, measure_image, measure_points_state]
+ )
+
+ # 深度图导航
+ prev_depth_btn.click(
+ fn=lambda pd, cs: navigate_depth_view(pd, cs, -1),
+ inputs=[processed_data_state, depth_view_selector],
+ outputs=[depth_view_selector, depth_map]
+ )
+ next_depth_btn.click(
+ fn=lambda pd, cs: navigate_depth_view(pd, cs, 1),
+ inputs=[processed_data_state, depth_view_selector],
+ outputs=[depth_view_selector, depth_map]
+ )
+ depth_view_selector.change(
+ fn=lambda pd, sv: update_depth_view(pd, int(sv.split()[1]) - 1) if sv else None,
+ inputs=[processed_data_state, depth_view_selector],
+ outputs=[depth_map]
+ )
+
+ # 法线图导航
+ prev_normal_btn.click(
+ fn=lambda pd, cs: navigate_normal_view(pd, cs, -1),
+ inputs=[processed_data_state, normal_view_selector],
+ outputs=[normal_view_selector, normal_map]
+ )
+ next_normal_btn.click(
+ fn=lambda pd, cs: navigate_normal_view(pd, cs, 1),
+ inputs=[processed_data_state, normal_view_selector],
+ outputs=[normal_view_selector, normal_map]
+ )
+ normal_view_selector.change(
+ fn=lambda pd, sv: update_normal_view(pd, int(sv.split()[1]) - 1) if sv else None,
+ inputs=[processed_data_state, normal_view_selector],
+ outputs=[normal_map]
+ )
+
+ # 测量功能
+ measure_image.select(
+ fn=measure,
+ inputs=[processed_data_state, measure_points_state, measure_view_selector],
+ outputs=[measure_image, measure_points_state, measure_text]
+ )
+ prev_measure_btn.click(
+ fn=lambda pd, cs: navigate_measure_view(pd, cs, -1),
+ inputs=[processed_data_state, measure_view_selector],
+ outputs=[measure_view_selector, measure_image, measure_points_state]
+ )
+ next_measure_btn.click(
+ fn=lambda pd, cs: navigate_measure_view(pd, cs, 1),
+ inputs=[processed_data_state, measure_view_selector],
+ outputs=[measure_view_selector, measure_image, measure_points_state]
+ )
+ measure_view_selector.change(
+ fn=lambda pd, sv: update_measure_view(pd, int(sv.split()[1]) - 1) if sv else (None, []),
+ inputs=[processed_data_state, measure_view_selector],
+ outputs=[measure_image, measure_points_state]
+ )
+
+# 启动信息
+print("\n" + "="*60)
+print("🚀 MapAnything V2 - 3D重建与物体分割")
+print("="*60)
+print("📊 核心技术: 自适应DBSCAN聚类 + 多视图融合")
+print(f"🔧 质量控制: 置信度≥{MIN_DETECTION_CONFIDENCE} | 面积≥{MIN_MASK_AREA}px")
+print(f"🎯 聚类半径: 沙发{DBSCAN_EPS_CONFIG['sofa']}m | 桌子{DBSCAN_EPS_CONFIG['table']}m | 窗户{DBSCAN_EPS_CONFIG['window']}m | 默认{DBSCAN_EPS_CONFIG['default']}m")
+print("="*60 + "\n")
+
+demo.queue(max_size=20).launch(show_error=True, share=True, ssr_mode=False)
diff --git a/configs/calibration_benchmark.yaml b/configs/calibration_benchmark.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78e6714cc164afbe8a39eee0c23c80267ad708a1
--- /dev/null
+++ b/configs/calibration_benchmark.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - machine: aws
+ - model: default
+ - dataset: default
+ - _self_
+
+output_dir: ${hydra:run.dir}
+root_data_dir: ${machine.root_data_dir}
+mapanything_dataset_metadata_dir: ${machine.mapanything_dataset_metadata_dir}
+root_pretrained_checkpoints_dir: ${machine.root_pretrained_checkpoints_dir}
+root_experiments_dir: ${machine.root_experiments_dir}
+root_uniception_pretrained_checkpoints_dir: ${machine.root_uniception_pretrained_checkpoints_dir}
+
+### Benchmarking args
+seed: 0
+# Disable CUDNN Benchmark (Disable for variable resolution & number of view training)
+disable_cudnn_benchmark: true
+# Batch size for inference (Metrics are computed per multi-view set and averaged, not per batch of multi-view sets)
+batch_size: 20
+# Use mixed precision for inference
+amp: 1
+# Floating point type to use for mixed precision
+amp_dtype: "bf16"
diff --git a/configs/dataset/ase_wai/default.yaml b/configs/dataset/ase_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/ase_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/ase_wai/train/default.yaml b/configs/dataset/ase_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2c77c0b49e50d015695f61db5f2ec4fd42fc8ca8
--- /dev/null
+++ b/configs/dataset/ase_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ASEWAI(
+ split='${dataset.ase_wai.train.split}',
+ resolution=${dataset.ase_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.ase_wai.train.principal_point_centered},
+ aug_crop=${dataset.ase_wai.train.aug_crop},
+ transform='${dataset.ase_wai.train.transform}',
+ data_norm_type='${dataset.ase_wai.train.data_norm_type}',
+ ROOT='${dataset.ase_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.ase_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.ase_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.ase_wai.train.variable_num_views},
+ num_views=${dataset.ase_wai.train.num_views},
+ covisibility_thres=${dataset.ase_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/ase
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/ase_wai/val/default.yaml b/configs/dataset/ase_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ee2a92715e80edb49535765455ee30d6f782fb2c
--- /dev/null
+++ b/configs/dataset/ase_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ASEWAI(
+ split='${dataset.ase_wai.val.split}',
+ resolution=${dataset.ase_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.ase_wai.val.principal_point_centered},
+ seed=${dataset.ase_wai.val.seed},
+ transform='${dataset.ase_wai.val.transform}',
+ data_norm_type='${dataset.ase_wai.val.data_norm_type}',
+ ROOT='${dataset.ase_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.ase_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.ase_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.ase_wai.val.variable_num_views},
+ num_views=${dataset.ase_wai.val.num_views},
+ covisibility_thres=${dataset.ase_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_ase}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/ase
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/bedlam_wai/default.yaml b/configs/dataset/bedlam_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/bedlam_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/bedlam_wai/train/default.yaml b/configs/dataset/bedlam_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..11dc8db66e605f9f42a33677c1bcba7236ca3ef7
--- /dev/null
+++ b/configs/dataset/bedlam_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "BedlamWAI(
+ split='${dataset.bedlam_wai.train.split}',
+ resolution=${dataset.bedlam_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.bedlam_wai.train.principal_point_centered},
+ aug_crop=${dataset.bedlam_wai.train.aug_crop},
+ transform='${dataset.bedlam_wai.train.transform}',
+ data_norm_type='${dataset.bedlam_wai.train.data_norm_type}',
+ ROOT='${dataset.bedlam_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.bedlam_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.bedlam_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.bedlam_wai.train.variable_num_views},
+ num_views=${dataset.bedlam_wai.train.num_views},
+ covisibility_thres=${dataset.bedlam_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/bedlam
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/bedlam_wai/val/default.yaml b/configs/dataset/bedlam_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8d1471050f84b32fa1858a4d11ad4dc798c0f002
--- /dev/null
+++ b/configs/dataset/bedlam_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "BedlamWAI(
+ split='${dataset.bedlam_wai.val.split}',
+ resolution=${dataset.bedlam_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.bedlam_wai.val.principal_point_centered},
+ seed=${dataset.bedlam_wai.val.seed},
+ transform='${dataset.bedlam_wai.val.transform}',
+ data_norm_type='${dataset.bedlam_wai.val.data_norm_type}',
+ ROOT='${dataset.bedlam_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.bedlam_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.bedlam_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.bedlam_wai.val.variable_num_views},
+ num_views=${dataset.bedlam_wai.val.num_views},
+ covisibility_thres=${dataset.bedlam_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_bedlam}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/bedlam
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/benchmark_512_eth3d_snpp_tav2.yaml b/configs/dataset/benchmark_512_eth3d_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0bc7a5fd13da943404ada22c02028c3667e1cfc0
--- /dev/null
+++ b/configs/dataset/benchmark_512_eth3d_snpp_tav2.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 2
+
+# Test Resolution
+resolution_test_eth3d: ${dataset.resolution_options.512_1_52_ar}
+resolution_test_scannetpp: ${dataset.resolution_options.512_1_52_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.512_1_00_ar}
+
+# Test Set
+# Sample 10 multi-view sets from each scene
+# ETH3D: 13 scenes
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 130 @ ${dataset.eth3d_wai.test.dataset_str}
+ + 300 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/benchmark_512_snpp_tav2.yaml b/configs/dataset/benchmark_512_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4be5baf2584b8557cef2d80d1e8e41bc9b4689e8
--- /dev/null
+++ b/configs/dataset/benchmark_512_snpp_tav2.yaml
@@ -0,0 +1,17 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 2
+
+# Test Resolution
+resolution_test_scannetpp: ${dataset.resolution_options.512_1_52_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.512_1_00_ar}
+
+# Test Set
+# Sample 10 multi-view sets from each scene
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 300 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/benchmark_518_eth3d_snpp_tav2.yaml b/configs/dataset/benchmark_518_eth3d_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a5251facb27c98ce7625829012e3f289cd11ed12
--- /dev/null
+++ b/configs/dataset/benchmark_518_eth3d_snpp_tav2.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 2
+
+# Test Resolution
+resolution_test_eth3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_test_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+
+# Test Set
+# Sample 10 multi-view sets from each scene
+# ETH3D: 13 scenes
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 130 @ ${dataset.eth3d_wai.test.dataset_str}
+ + 300 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/benchmark_518_snpp_tav2.yaml b/configs/dataset/benchmark_518_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4b15f825631ce0421857902f570c33812c6236f2
--- /dev/null
+++ b/configs/dataset/benchmark_518_snpp_tav2.yaml
@@ -0,0 +1,17 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 2
+
+# Test Resolution
+resolution_test_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+
+# Test Set
+# Sample 10 multi-view sets from each scene
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 300 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 50 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/benchmark_sv_calib_518_many_ar_eth3d_snpp_tav2.yaml b/configs/dataset/benchmark_sv_calib_518_many_ar_eth3d_snpp_tav2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0d6266d248ef2d02b11b462ce922f0dfb8b52c91
--- /dev/null
+++ b/configs/dataset/benchmark_sv_calib_518_many_ar_eth3d_snpp_tav2.yaml
@@ -0,0 +1,20 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 1
+
+# Test Resolution
+resolution_test_eth3d: ${dataset.resolution_options.518_many_ar}
+resolution_test_scannetpp: ${dataset.resolution_options.518_many_ar}
+resolution_test_tav2_wb: ${dataset.resolution_options.518_many_ar}
+
+# Test Set
+# Sample 20 frames from each scene
+# ETH3D: 13 scenes
+# ScanNet++V2: 30 scenes
+# TartanAirV2-WB: 5 scenes
+test_dataset:
+ "+ 260 @ ${dataset.eth3d_wai.test.dataset_str}
+ + 600 @ ${dataset.scannetpp_wai.test.dataset_str}
+ + 100 @ ${dataset.tav2_wb_wai.test.dataset_str}"
diff --git a/configs/dataset/blendedmvs_wai/default.yaml b/configs/dataset/blendedmvs_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/blendedmvs_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/blendedmvs_wai/train/default.yaml b/configs/dataset/blendedmvs_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..39391d2da22bf69b6e8cc3d860c278d812f08944
--- /dev/null
+++ b/configs/dataset/blendedmvs_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "BlendedMVSWAI(
+ split='${dataset.blendedmvs_wai.train.split}',
+ resolution=${dataset.blendedmvs_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.blendedmvs_wai.train.principal_point_centered},
+ aug_crop=${dataset.blendedmvs_wai.train.aug_crop},
+ transform='${dataset.blendedmvs_wai.train.transform}',
+ data_norm_type='${dataset.blendedmvs_wai.train.data_norm_type}',
+ ROOT='${dataset.blendedmvs_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.blendedmvs_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.blendedmvs_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.blendedmvs_wai.train.variable_num_views},
+ num_views=${dataset.blendedmvs_wai.train.num_views},
+ covisibility_thres=${dataset.blendedmvs_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/blendedmvs
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/blendedmvs_wai/val/default.yaml b/configs/dataset/blendedmvs_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bbb876e5c41a4859d7c8f602b8ac37c7142a868b
--- /dev/null
+++ b/configs/dataset/blendedmvs_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "BlendedMVSWAI(
+ split='${dataset.blendedmvs_wai.val.split}',
+ resolution=${dataset.blendedmvs_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.blendedmvs_wai.val.principal_point_centered},
+ seed=${dataset.blendedmvs_wai.val.seed},
+ transform='${dataset.blendedmvs_wai.val.transform}',
+ data_norm_type='${dataset.blendedmvs_wai.val.data_norm_type}',
+ ROOT='${dataset.blendedmvs_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.blendedmvs_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.blendedmvs_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.blendedmvs_wai.val.variable_num_views},
+ num_views=${dataset.blendedmvs_wai.val.num_views},
+ covisibility_thres=${dataset.blendedmvs_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_blendedmvs}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/blendedmvs
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/default.yaml b/configs/dataset/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..84b4954bca4b0ac2127f7eafdc896ae4c46ce1b7
--- /dev/null
+++ b/configs/dataset/default.yaml
@@ -0,0 +1,45 @@
+defaults:
+ - resolution_options: default
+ - ase_wai: default
+ - bedlam_wai: default
+ - blendedmvs_wai: default
+ - dl3dv_wai: default
+ - dtu_wai: default
+ - dynamicreplica_wai: default
+ - eth3d_wai: default
+ - gta_sfm_wai: default
+ - matrixcity_wai: default
+ - megadepth_wai: default
+ - mpsd_wai: default
+ - mvs_synth_wai: default
+ - paralleldomain4d_wai: default
+ - sailvos3d_wai: default
+ - scannetpp_wai: default
+ - spring_wai: default
+ - structured3d_wai: default
+ - tav2_wb_wai: default
+ - unrealstereo4k_wai: default
+ - xrooms_wai: default
+
+# Training Set, For example: BlendedMVS(split='train', resolution=(512, 384), transform=...)
+train_dataset: ???
+# Validation Set
+test_dataset: "[null]"
+# Number of workers for dataloader
+num_workers: 12
+# Default resolution for training
+resolution_train: ???
+# Default resolution for validation
+resolution_val: ???
+# Number of views parameter for multi-view datasets
+num_views: 2
+# Use a centered principal point for all images
+principal_point_centered: false
+# Default config for multi-view datasets
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+val:
+ variable_num_views: false
+test:
+ variable_num_views: false
diff --git a/configs/dataset/dl3dv_wai/default.yaml b/configs/dataset/dl3dv_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/dl3dv_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/dl3dv_wai/train/default.yaml b/configs/dataset/dl3dv_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..af9610c79e25da5e6638043bfa0b8cec5acd4666
--- /dev/null
+++ b/configs/dataset/dl3dv_wai/train/default.yaml
@@ -0,0 +1,28 @@
+dataset_str:
+ "DL3DVWAI(
+ split='${dataset.dl3dv_wai.train.split}',
+ resolution=${dataset.dl3dv_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.dl3dv_wai.train.principal_point_centered},
+ aug_crop=${dataset.dl3dv_wai.train.aug_crop},
+ transform='${dataset.dl3dv_wai.train.transform}',
+ data_norm_type='${dataset.dl3dv_wai.train.data_norm_type}',
+ ROOT='${dataset.dl3dv_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.dl3dv_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.dl3dv_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.dl3dv_wai.train.variable_num_views},
+ num_views=${dataset.dl3dv_wai.train.num_views},
+ covisibility_thres=${dataset.dl3dv_wai.train.covisibility_thres},
+ mvs_confidence_filter_thres=${dataset.dl3dv_wai.train.mvs_confidence_filter_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dl3dv
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
+mvs_confidence_filter_thres: 0.25
diff --git a/configs/dataset/dl3dv_wai/val/default.yaml b/configs/dataset/dl3dv_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e4ba1e12ee503c5e1c867f375d682096dff5fdb
--- /dev/null
+++ b/configs/dataset/dl3dv_wai/val/default.yaml
@@ -0,0 +1,28 @@
+dataset_str:
+ "DL3DVWAI(
+ split='${dataset.dl3dv_wai.val.split}',
+ resolution=${dataset.dl3dv_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.dl3dv_wai.val.principal_point_centered},
+ seed=${dataset.dl3dv_wai.val.seed},
+ transform='${dataset.dl3dv_wai.val.transform}',
+ data_norm_type='${dataset.dl3dv_wai.val.data_norm_type}',
+ ROOT='${dataset.dl3dv_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.dl3dv_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.dl3dv_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.dl3dv_wai.val.variable_num_views},
+ num_views=${dataset.dl3dv_wai.val.num_views},
+ covisibility_thres=${dataset.dl3dv_wai.val.covisibility_thres},
+ mvs_confidence_filter_thres=${dataset.dl3dv_wai.val.mvs_confidence_filter_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_dl3dv}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dl3dv
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
+mvs_confidence_filter_thres: 0.25
diff --git a/configs/dataset/dtu_wai/default.yaml b/configs/dataset/dtu_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b1278dcc74c8a2ee16b87a31ebabca50234ab9fa
--- /dev/null
+++ b/configs/dataset/dtu_wai/default.yaml
@@ -0,0 +1,2 @@
+defaults:
+ - test: default
diff --git a/configs/dataset/dtu_wai/test/default.yaml b/configs/dataset/dtu_wai/test/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7910a3aae6da7cef8a26ca77932b090d998cf8bb
--- /dev/null
+++ b/configs/dataset/dtu_wai/test/default.yaml
@@ -0,0 +1,22 @@
+dataset_str:
+ "DTUWAI(
+ resolution=${dataset.dtu_wai.test.dataset_resolution},
+ principal_point_centered=${dataset.dtu_wai.test.principal_point_centered},
+ seed=${dataset.dtu_wai.test.seed},
+ transform='${dataset.dtu_wai.test.transform}',
+ data_norm_type='${dataset.dtu_wai.test.data_norm_type}',
+ ROOT='${dataset.dtu_wai.test.ROOT}',
+ dataset_metadata_dir='${dataset.dtu_wai.test.dataset_metadata_dir}',
+ variable_num_views=${dataset.dtu_wai.test.variable_num_views},
+ num_views=${dataset.dtu_wai.test.num_views},
+ covisibility_thres=${dataset.dtu_wai.test.covisibility_thres})"
+dataset_resolution: ${dataset.resolution_test_dtu}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dtu
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+variable_num_views: ${dataset.test.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/dynamicreplica_wai/default.yaml b/configs/dataset/dynamicreplica_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/dynamicreplica_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/dynamicreplica_wai/train/default.yaml b/configs/dataset/dynamicreplica_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..aa8b82c5692d494da325bbe7419e9a4d31db7c5f
--- /dev/null
+++ b/configs/dataset/dynamicreplica_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "DynamicReplicaWAI(
+ split='${dataset.dynamicreplica_wai.train.split}',
+ resolution=${dataset.dynamicreplica_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.dynamicreplica_wai.train.principal_point_centered},
+ aug_crop=${dataset.dynamicreplica_wai.train.aug_crop},
+ transform='${dataset.dynamicreplica_wai.train.transform}',
+ data_norm_type='${dataset.dynamicreplica_wai.train.data_norm_type}',
+ ROOT='${dataset.dynamicreplica_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.dynamicreplica_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.dynamicreplica_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.dynamicreplica_wai.train.variable_num_views},
+ num_views=${dataset.dynamicreplica_wai.train.num_views},
+ covisibility_thres=${dataset.dynamicreplica_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dynamicreplica
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/dynamicreplica_wai/val/default.yaml b/configs/dataset/dynamicreplica_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7f50be2705174dad31dc4d1be1f965a4419f36f3
--- /dev/null
+++ b/configs/dataset/dynamicreplica_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "DynamicReplicaWAI(
+ split='${dataset.dynamicreplica_wai.val.split}',
+ resolution=${dataset.dynamicreplica_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.dynamicreplica_wai.val.principal_point_centered},
+ seed=${dataset.dynamicreplica_wai.val.seed},
+ transform='${dataset.dynamicreplica_wai.val.transform}',
+ data_norm_type='${dataset.dynamicreplica_wai.val.data_norm_type}',
+ ROOT='${dataset.dynamicreplica_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.dynamicreplica_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.dynamicreplica_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.dynamicreplica_wai.val.variable_num_views},
+ num_views=${dataset.dynamicreplica_wai.val.num_views},
+ covisibility_thres=${dataset.dynamicreplica_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_dynamicreplica}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/dynamicreplica
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/eth3d_wai/default.yaml b/configs/dataset/eth3d_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b1278dcc74c8a2ee16b87a31ebabca50234ab9fa
--- /dev/null
+++ b/configs/dataset/eth3d_wai/default.yaml
@@ -0,0 +1,2 @@
+defaults:
+ - test: default
diff --git a/configs/dataset/eth3d_wai/test/default.yaml b/configs/dataset/eth3d_wai/test/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0848ef4ab87819185c059567cf66d36dabc152d4
--- /dev/null
+++ b/configs/dataset/eth3d_wai/test/default.yaml
@@ -0,0 +1,22 @@
+dataset_str:
+ "ETH3DWAI(
+ resolution=${dataset.eth3d_wai.test.dataset_resolution},
+ principal_point_centered=${dataset.eth3d_wai.test.principal_point_centered},
+ seed=${dataset.eth3d_wai.test.seed},
+ transform='${dataset.eth3d_wai.test.transform}',
+ data_norm_type='${dataset.eth3d_wai.test.data_norm_type}',
+ ROOT='${dataset.eth3d_wai.test.ROOT}',
+ dataset_metadata_dir='${dataset.eth3d_wai.test.dataset_metadata_dir}',
+ variable_num_views=${dataset.eth3d_wai.test.variable_num_views},
+ num_views=${dataset.eth3d_wai.test.num_views},
+ covisibility_thres=${dataset.eth3d_wai.test.covisibility_thres})"
+dataset_resolution: ${dataset.resolution_test_eth3d}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/eth3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+variable_num_views: ${dataset.test.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.025
diff --git a/configs/dataset/gta_sfm_wai/default.yaml b/configs/dataset/gta_sfm_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/gta_sfm_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/gta_sfm_wai/train/default.yaml b/configs/dataset/gta_sfm_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..971b4f4a78f3207d38f667860becf035269a46a6
--- /dev/null
+++ b/configs/dataset/gta_sfm_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "GTASfMWAI(
+ split='${dataset.gta_sfm_wai.train.split}',
+ resolution=${dataset.gta_sfm_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.gta_sfm_wai.train.principal_point_centered},
+ aug_crop=${dataset.gta_sfm_wai.train.aug_crop},
+ transform='${dataset.gta_sfm_wai.train.transform}',
+ data_norm_type='${dataset.gta_sfm_wai.train.data_norm_type}',
+ ROOT='${dataset.gta_sfm_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.gta_sfm_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.gta_sfm_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.gta_sfm_wai.train.variable_num_views},
+ num_views=${dataset.gta_sfm_wai.train.num_views},
+ covisibility_thres=${dataset.gta_sfm_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/gta_sfm
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/gta_sfm_wai/val/default.yaml b/configs/dataset/gta_sfm_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..430ac9e292dc1059dcccbbab6bdf82b0f46f391e
--- /dev/null
+++ b/configs/dataset/gta_sfm_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "GTASfMWAI(
+ split='${dataset.gta_sfm_wai.val.split}',
+ resolution=${dataset.gta_sfm_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.gta_sfm_wai.val.principal_point_centered},
+ seed=${dataset.gta_sfm_wai.val.seed},
+ transform='${dataset.gta_sfm_wai.val.transform}',
+ data_norm_type='${dataset.gta_sfm_wai.val.data_norm_type}',
+ ROOT='${dataset.gta_sfm_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.gta_sfm_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.gta_sfm_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.gta_sfm_wai.val.variable_num_views},
+ num_views=${dataset.gta_sfm_wai.val.num_views},
+ covisibility_thres=${dataset.gta_sfm_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_gta_sfm}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/gta_sfm
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/matrixcity_wai/default.yaml b/configs/dataset/matrixcity_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/matrixcity_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/matrixcity_wai/train/default.yaml b/configs/dataset/matrixcity_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ca7412ba48e10620b969af8311bb9e511bf5e437
--- /dev/null
+++ b/configs/dataset/matrixcity_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MatrixCityWAI(
+ split='${dataset.matrixcity_wai.train.split}',
+ resolution=${dataset.matrixcity_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.matrixcity_wai.train.principal_point_centered},
+ aug_crop=${dataset.matrixcity_wai.train.aug_crop},
+ transform='${dataset.matrixcity_wai.train.transform}',
+ data_norm_type='${dataset.matrixcity_wai.train.data_norm_type}',
+ ROOT='${dataset.matrixcity_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.matrixcity_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.matrixcity_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.matrixcity_wai.train.variable_num_views},
+ num_views=${dataset.matrixcity_wai.train.num_views},
+ covisibility_thres=${dataset.matrixcity_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/matrixcity
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/matrixcity_wai/val/default.yaml b/configs/dataset/matrixcity_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..64a73059704da9e721b18b949ec3487b436c3607
--- /dev/null
+++ b/configs/dataset/matrixcity_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MatrixCityWAI(
+ split='${dataset.matrixcity_wai.val.split}',
+ resolution=${dataset.matrixcity_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.matrixcity_wai.val.principal_point_centered},
+ seed=${dataset.matrixcity_wai.val.seed},
+ transform='${dataset.matrixcity_wai.val.transform}',
+ data_norm_type='${dataset.matrixcity_wai.val.data_norm_type}',
+ ROOT='${dataset.matrixcity_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.matrixcity_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.matrixcity_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.matrixcity_wai.val.variable_num_views},
+ num_views=${dataset.matrixcity_wai.val.num_views},
+ covisibility_thres=${dataset.matrixcity_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_matrixcity}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/matrixcity
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/megadepth_wai/default.yaml b/configs/dataset/megadepth_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/megadepth_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/megadepth_wai/train/default.yaml b/configs/dataset/megadepth_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..901443908e42340cb2bbdaa2fd4b0614c1748003
--- /dev/null
+++ b/configs/dataset/megadepth_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MegaDepthWAI(
+ split='${dataset.megadepth_wai.train.split}',
+ resolution=${dataset.megadepth_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.megadepth_wai.train.principal_point_centered},
+ aug_crop=${dataset.megadepth_wai.train.aug_crop},
+ transform='${dataset.megadepth_wai.train.transform}',
+ data_norm_type='${dataset.megadepth_wai.train.data_norm_type}',
+ ROOT='${dataset.megadepth_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.megadepth_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.megadepth_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.megadepth_wai.train.variable_num_views},
+ num_views=${dataset.megadepth_wai.train.num_views},
+ covisibility_thres=${dataset.megadepth_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/megadepth
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/megadepth_wai/val/default.yaml b/configs/dataset/megadepth_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b37ff6fc3966d54b40ed40cb66f1b0979598a897
--- /dev/null
+++ b/configs/dataset/megadepth_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MegaDepthWAI(
+ split='${dataset.megadepth_wai.val.split}',
+ resolution=${dataset.megadepth_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.megadepth_wai.val.principal_point_centered},
+ seed=${dataset.megadepth_wai.val.seed},
+ transform='${dataset.megadepth_wai.val.transform}',
+ data_norm_type='${dataset.megadepth_wai.val.data_norm_type}',
+ ROOT='${dataset.megadepth_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.megadepth_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.megadepth_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.megadepth_wai.val.variable_num_views},
+ num_views=${dataset.megadepth_wai.val.num_views},
+ covisibility_thres=${dataset.megadepth_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_megadepth}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/megadepth
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/megatrain_11d_se_518_many_ar_48ipg_64g.yaml b/configs/dataset/megatrain_11d_se_518_many_ar_48ipg_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a5c5f087b72828f98ae5406f871e763145dbd7b0
--- /dev/null
+++ b/configs/dataset/megatrain_11d_se_518_many_ar_48ipg_64g.yaml
@@ -0,0 +1,53 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_dl3dv: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 2_450_000 @ ${dataset.ase_wai.train.dataset_str}
+ + 250_000 @ ${dataset.dl3dv_wai.train.dataset_str}
+ + 12_400 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 1_675_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 3_000 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 36_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 22_600 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 800 @ ${dataset.spring_wai.train.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 200 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_12d_518_many_ar_24ipg_16g.yaml b/configs/dataset/megatrain_12d_518_many_ar_24ipg_16g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e97bbb2c38a81334f7a57eb93854774c4df01b78
--- /dev/null
+++ b/configs/dataset/megatrain_12d_518_many_ar_24ipg_16g.yaml
@@ -0,0 +1,56 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_megadepth: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 58_000 @ ${dataset.ase_wai.train.dataset_str}
+ + 58_000 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 45_000 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 58_000 @ ${dataset.megadepth_wai.train.dataset_str}
+ + 58_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 58_000 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 58_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 58_000 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 58_000 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 2_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 58_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 5_500 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.megadepth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_13d_512_many_ar_24ipg_16g.yaml b/configs/dataset/megatrain_13d_512_many_ar_24ipg_16g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e23dd5cca53addeb5515e213b8918b9de8d20da
--- /dev/null
+++ b/configs/dataset/megatrain_13d_512_many_ar_24ipg_16g.yaml
@@ -0,0 +1,59 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.512_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.512_1_00_ar}
+resolution_val_blendedmvs: ${dataset.resolution_options.512_1_33_ar}
+resolution_val_dl3dv: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_megadepth: ${dataset.resolution_options.512_1_52_ar}
+resolution_val_mpsd: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.512_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.512_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.512_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.512_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.512_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.512_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 52_500 @ ${dataset.ase_wai.train.dataset_str}
+ + 52_500 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 52_500 @ ${dataset.dl3dv_wai.train.dataset_str}
+ + 40_000 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 52_500 @ ${dataset.megadepth_wai.train.dataset_str}
+ + 52_500 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 52_500 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 52_500 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 52_500 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 52_500 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 2_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 52_500 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 5_500 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.megadepth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_13d_518_many_ar_24ipg_16g.yaml b/configs/dataset/megatrain_13d_518_many_ar_24ipg_16g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..187bc543689db308d17000869933bc88e41027db
--- /dev/null
+++ b/configs/dataset/megatrain_13d_518_many_ar_24ipg_16g.yaml
@@ -0,0 +1,59 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_dl3dv: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_megadepth: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 52_500 @ ${dataset.ase_wai.train.dataset_str}
+ + 52_500 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 52_500 @ ${dataset.dl3dv_wai.train.dataset_str}
+ + 40_000 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 52_500 @ ${dataset.megadepth_wai.train.dataset_str}
+ + 52_500 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 52_500 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 52_500 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 52_500 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 52_500 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 2_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 52_500 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 5_500 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.megadepth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_13d_518_many_ar_48ipg_64g.yaml b/configs/dataset/megatrain_13d_518_many_ar_48ipg_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..199214082801e43e36b1f70930a8b3a2d24cfd3b
--- /dev/null
+++ b/configs/dataset/megatrain_13d_518_many_ar_48ipg_64g.yaml
@@ -0,0 +1,59 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_ase: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_dl3dv: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_dynamicreplica: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_megadepth: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_mvs_synth: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_paralleldomain4d: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_sailvos3d: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 420_000 @ ${dataset.ase_wai.train.dataset_str}
+ + 420_000 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 420_000 @ ${dataset.dl3dv_wai.train.dataset_str}
+ + 320_000 @ ${dataset.dynamicreplica_wai.train.dataset_str}
+ + 420_000 @ ${dataset.megadepth_wai.train.dataset_str}
+ + 420_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 420_000 @ ${dataset.mvs_synth_wai.train.dataset_str}
+ + 420_000 @ ${dataset.paralleldomain4d_wai.train.dataset_str}
+ + 420_000 @ ${dataset.sailvos3d_wai.train.dataset_str}
+ + 420_000 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 16_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 420_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 44_000 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.ase_wai.val.dataset_str}
+ + 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dl3dv_wai.val.dataset_str}
+ + 4_000 @ ${dataset.dynamicreplica_wai.val.dataset_str}
+ + 4_000 @ ${dataset.megadepth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mvs_synth_wai.val.dataset_str}
+ + 4_000 @ ${dataset.paralleldomain4d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.sailvos3d_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_6d_518_many_ar_48ipg_64g.yaml b/configs/dataset/megatrain_6d_518_many_ar_48ipg_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..add073ff63121c292c35867704eaf08b237d23ef
--- /dev/null
+++ b/configs/dataset/megatrain_6d_518_many_ar_48ipg_64g.yaml
@@ -0,0 +1,38 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 1_120_000 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 1_120_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 1_120_000 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 44_000 @ ${dataset.spring_wai.train.dataset_str}
+ + 1_120_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 116_000 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/megatrain_6d_518_many_ar_48ipg_8g.yaml b/configs/dataset/megatrain_6d_518_many_ar_48ipg_8g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..31dc47a184bb839be495af040263561eba0f2ba3
--- /dev/null
+++ b/configs/dataset/megatrain_6d_518_many_ar_48ipg_8g.yaml
@@ -0,0 +1,38 @@
+defaults:
+ - default
+
+# Number of views parameter for the multi-view datasets
+num_views: 4
+
+train:
+ # If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2. (On by default for N-view training)
+ variable_num_views: true
+
+# Train Resolution
+resolution_train: ${dataset.resolution_options.518_many_ar}
+
+# Validation Resolution
+resolution_val_blendedmvs: ${dataset.resolution_options.518_1_33_ar}
+resolution_val_mpsd: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_scannetpp: ${dataset.resolution_options.518_1_52_ar}
+resolution_val_spring: ${dataset.resolution_options.518_1_77_ar}
+resolution_val_tav2_wb: ${dataset.resolution_options.518_1_00_ar}
+resolution_val_unrealstereo4k: ${dataset.resolution_options.518_1_77_ar}
+
+# Training Set
+train_dataset:
+ "+ 140_000 @ ${dataset.blendedmvs_wai.train.dataset_str}
+ + 140_000 @ ${dataset.mpsd_wai.train.dataset_str}
+ + 140_000 @ ${dataset.scannetpp_wai.train.dataset_str}
+ + 5_500 @ ${dataset.spring_wai.train.dataset_str}
+ + 140_000 @ ${dataset.tav2_wb_wai.train.dataset_str}
+ + 14_500 @ ${dataset.unrealstereo4k_wai.train.dataset_str}"
+
+# Validation Set
+test_dataset:
+ "+ 4_000 @ ${dataset.blendedmvs_wai.val.dataset_str}
+ + 4_000 @ ${dataset.mpsd_wai.val.dataset_str}
+ + 4_000 @ ${dataset.scannetpp_wai.val.dataset_str}
+ + 500 @ ${dataset.spring_wai.val.dataset_str}
+ + 4_000 @ ${dataset.tav2_wb_wai.val.dataset_str}
+ + 500 @ ${dataset.unrealstereo4k_wai.val.dataset_str}"
diff --git a/configs/dataset/mpsd_wai/default.yaml b/configs/dataset/mpsd_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/mpsd_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/mpsd_wai/train/default.yaml b/configs/dataset/mpsd_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2ad2e3ca2aec9e6b12c01a8e3a4b6ff0a635b95c
--- /dev/null
+++ b/configs/dataset/mpsd_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MPSDWAI(
+ split='${dataset.mpsd_wai.train.split}',
+ resolution=${dataset.mpsd_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.mpsd_wai.train.principal_point_centered},
+ aug_crop=${dataset.mpsd_wai.train.aug_crop},
+ transform='${dataset.mpsd_wai.train.transform}',
+ data_norm_type='${dataset.mpsd_wai.train.data_norm_type}',
+ ROOT='${dataset.mpsd_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.mpsd_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.mpsd_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.mpsd_wai.train.variable_num_views},
+ num_views=${dataset.mpsd_wai.train.num_views},
+ covisibility_thres=${dataset.mpsd_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/mpsd
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.15
diff --git a/configs/dataset/mpsd_wai/val/default.yaml b/configs/dataset/mpsd_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..cec643ba5bf830fe64cb2cbc6ef83ce28913d0ff
--- /dev/null
+++ b/configs/dataset/mpsd_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MPSDWAI(
+ split='${dataset.mpsd_wai.val.split}',
+ resolution=${dataset.mpsd_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.mpsd_wai.val.principal_point_centered},
+ seed=${dataset.mpsd_wai.val.seed},
+ transform='${dataset.mpsd_wai.val.transform}',
+ data_norm_type='${dataset.mpsd_wai.val.data_norm_type}',
+ ROOT='${dataset.mpsd_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.mpsd_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.mpsd_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.mpsd_wai.val.variable_num_views},
+ num_views=${dataset.mpsd_wai.val.num_views},
+ covisibility_thres=${dataset.mpsd_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_mpsd}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/mpsd
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.15
diff --git a/configs/dataset/mvs_synth_wai/default.yaml b/configs/dataset/mvs_synth_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/mvs_synth_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/mvs_synth_wai/train/default.yaml b/configs/dataset/mvs_synth_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3f325cfb8d9e34e8364c3652467bfc97e9ea4b70
--- /dev/null
+++ b/configs/dataset/mvs_synth_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MVSSynthWAI(
+ split='${dataset.mvs_synth_wai.train.split}',
+ resolution=${dataset.mvs_synth_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.mvs_synth_wai.train.principal_point_centered},
+ aug_crop=${dataset.mvs_synth_wai.train.aug_crop},
+ transform='${dataset.mvs_synth_wai.train.transform}',
+ data_norm_type='${dataset.mvs_synth_wai.train.data_norm_type}',
+ ROOT='${dataset.mvs_synth_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.mvs_synth_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.mvs_synth_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.mvs_synth_wai.train.variable_num_views},
+ num_views=${dataset.mvs_synth_wai.train.num_views},
+ covisibility_thres=${dataset.mvs_synth_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/mvs_synth
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/mvs_synth_wai/val/default.yaml b/configs/dataset/mvs_synth_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f84630347c54b059ec5afb3c1aecf9a883769259
--- /dev/null
+++ b/configs/dataset/mvs_synth_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "MVSSynthWAI(
+ split='${dataset.mvs_synth_wai.val.split}',
+ resolution=${dataset.mvs_synth_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.mvs_synth_wai.val.principal_point_centered},
+ seed=${dataset.mvs_synth_wai.val.seed},
+ transform='${dataset.mvs_synth_wai.val.transform}',
+ data_norm_type='${dataset.mvs_synth_wai.val.data_norm_type}',
+ ROOT='${dataset.mvs_synth_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.mvs_synth_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.mvs_synth_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.mvs_synth_wai.val.variable_num_views},
+ num_views=${dataset.mvs_synth_wai.val.num_views},
+ covisibility_thres=${dataset.mvs_synth_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_mvs_synth}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/mvs_synth
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/paralleldomain4d_wai/default.yaml b/configs/dataset/paralleldomain4d_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/paralleldomain4d_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/paralleldomain4d_wai/train/default.yaml b/configs/dataset/paralleldomain4d_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..14d3f6046ec5ccf83f489c5905093e94720101bd
--- /dev/null
+++ b/configs/dataset/paralleldomain4d_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ParallelDomain4DWAI(
+ split='${dataset.paralleldomain4d_wai.train.split}',
+ resolution=${dataset.paralleldomain4d_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.paralleldomain4d_wai.train.principal_point_centered},
+ aug_crop=${dataset.paralleldomain4d_wai.train.aug_crop},
+ transform='${dataset.paralleldomain4d_wai.train.transform}',
+ data_norm_type='${dataset.paralleldomain4d_wai.train.data_norm_type}',
+ ROOT='${dataset.paralleldomain4d_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.paralleldomain4d_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.paralleldomain4d_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.paralleldomain4d_wai.train.variable_num_views},
+ num_views=${dataset.paralleldomain4d_wai.train.num_views},
+ covisibility_thres=${dataset.paralleldomain4d_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/paralleldomain4d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/paralleldomain4d_wai/val/default.yaml b/configs/dataset/paralleldomain4d_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d9409e28d5ba686866286411e820a376fadbb645
--- /dev/null
+++ b/configs/dataset/paralleldomain4d_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ParallelDomain4DWAI(
+ split='${dataset.paralleldomain4d_wai.val.split}',
+ resolution=${dataset.paralleldomain4d_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.paralleldomain4d_wai.val.principal_point_centered},
+ seed=${dataset.paralleldomain4d_wai.val.seed},
+ transform='${dataset.paralleldomain4d_wai.val.transform}',
+ data_norm_type='${dataset.paralleldomain4d_wai.val.data_norm_type}',
+ ROOT='${dataset.paralleldomain4d_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.paralleldomain4d_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.paralleldomain4d_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.paralleldomain4d_wai.val.variable_num_views},
+ num_views=${dataset.paralleldomain4d_wai.val.num_views},
+ covisibility_thres=${dataset.paralleldomain4d_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_paralleldomain4d}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/paralleldomain4d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/resolution_options/default.yaml b/configs/dataset/resolution_options/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7adb12ff5a911bc004bd0e1ab832992708b95a4b
--- /dev/null
+++ b/configs/dataset/resolution_options/default.yaml
@@ -0,0 +1,77 @@
+518_many_ar: '[(518, 518), (518, 392), (518, 336), (518, 294), (518, 252), (518, 168), (392, 518), (336, 518), (294, 518), (252, 518)]'
+518_many_landscape_ar: '[(518, 518), (518, 392), (518, 336), (518, 294), (518, 252), (518, 168)]'
+518_many_non_square_landscape_ar: '[(518, 392), (518, 336), (518, 294), (518, 252), (518, 168)]'
+518_0_50_ar: (252, 518) # 1:2
+518_0_56_ar: (294, 518) # 9:16
+518_0_66_ar: (336, 518) # 2:3
+518_0_75_ar: (392, 518) # 3:4
+518_1_00_ar: (518, 518) # 1:1
+518_1_33_ar: (518, 392) # 4:3
+518_1_52_ar: (518, 336) # 3:2
+518_1_77_ar: (518, 294) # 16:9
+518_2_00_ar: (518, 252) # 2:1
+518_3_20_ar: (518, 168) # 3.2:1
+512_many_ar: '[(512, 512), (512, 384), (512, 336), (512, 288), (512, 256), (512, 160), (384, 512), (336, 512), (288, 512), (256, 512)]'
+512_many_landscape_ar: '[(512, 512), (512, 384), (512, 336), (512, 288), (512, 256), (512, 160)]'
+512_many_non_square_landscape_ar: '[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)]'
+512_0_50_ar: (256, 512)
+512_0_56_ar: (288, 512)
+512_0_66_ar: (336, 512)
+512_0_75_ar: (384, 512)
+512_1_00_ar: (512, 512)
+512_1_33_ar: (512, 384)
+512_1_52_ar: (512, 336)
+512_1_77_ar: (512, 288)
+512_2_00_ar: (512, 256)
+512_3_20_ar: (512, 160)
+504_many_ar: '[(504, 504), (504, 378), (504, 322), (504, 280), (504, 238), (504, 154), (378, 504), (322, 504), (280, 504), (238, 504)]'
+504_many_landscape_ar: '[(504, 504), (504, 378), (504, 322), (504, 280), (504, 238), (504, 154)]'
+504_many_non_square_landscape_ar: '[(504, 378), (504, 322), (504, 280), (504, 238), (504, 154)]'
+504_0_50_ar: (238, 504)
+504_0_56_ar: (280, 504)
+504_0_66_ar: (322, 504)
+504_0_75_ar: (378, 504)
+504_1_00_ar: (504, 504)
+504_1_33_ar: (504, 378)
+504_1_52_ar: (504, 322)
+504_1_77_ar: (504, 280)
+504_2_00_ar: (504, 238)
+504_3_20_ar: (504, 154)
+448_many_ar: '[(448, 448), (448, 336), (448, 294), (448, 252), (448, 224), (448, 140), (336, 448), (294, 448), (252, 448), (224, 448)]'
+448_many_landscape_ar: '[(448, 448), (448, 336), (448, 294), (448, 252), (448, 224), (448, 140)]'
+448_many_non_square_landscape_ar: '[(448, 336), (448, 294), (448, 252), (448, 224), (448, 140)]'
+448_0_50_ar: (224, 448)
+448_0_56_ar: (252, 448)
+448_0_66_ar: (294, 448)
+448_0_75_ar: (336, 448)
+448_1_00_ar: (448, 448)
+448_1_33_ar: (448, 336)
+448_1_52_ar: (448, 294)
+448_1_77_ar: (448, 252)
+448_2_00_ar: (448, 224)
+448_3_20_ar: (448, 140)
+224_many_ar_14ps: '[(224, 224), (224, 168), (224, 154), (224, 126), (224, 112), (224, 70), (168, 224), (154, 224), (126, 224), (112, 224)]'
+224_many_landscape_ar_14ps: '[(224, 224), (224, 168), (224, 154), (224, 126), (224, 112), (224, 70)]'
+224_many_non_square_landscape_ar_14ps: '[(224, 168), (224, 154), (224, 126), (224, 112), (224, 70)]'
+224_0_50_ar_14ps: (112, 224)
+224_0_56_ar_14ps: (126, 224)
+224_0_66_ar_14ps: (154, 224)
+224_0_75_ar_14ps: (168, 224)
+224_1_00_ar: (224, 224)
+224_1_33_ar_14ps: (224, 168)
+224_1_52_ar_14ps: (224, 154)
+224_1_77_ar_14ps: (224, 126)
+224_2_00_ar_14ps: (224, 112)
+224_3_20_ar_14ps: (224, 70)
+224_many_ar_16ps: '[(224, 224), (224, 176), (224, 160), (224, 128), (224, 112), (224, 80), (176, 224), (160, 224), (128, 224), (112, 224)]'
+224_many_landscape_ar_16ps: '[(224, 224), (224, 176), (224, 160), (224, 128), (224, 112), (224, 80)]'
+224_many_non_square_landscape_ar_16ps: '[(224, 176), (224, 160), (224, 128), (224, 112), (224, 80)]'
+224_0_50_ar_16ps: (112, 224)
+224_0_56_ar_16ps: (128, 224)
+224_0_66_ar_16ps: (160, 224)
+224_0_75_ar_16ps: (176, 224)
+224_1_33_ar_16ps: (224, 176)
+224_1_52_ar_16ps: (224, 160)
+224_1_77_ar_16ps: (224, 128)
+224_2_00_ar_16ps: (224, 112)
+224_3_20_ar_16ps: (224, 80)
diff --git a/configs/dataset/sailvos3d_wai/default.yaml b/configs/dataset/sailvos3d_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/sailvos3d_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/sailvos3d_wai/train/default.yaml b/configs/dataset/sailvos3d_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2187deaedc5af41b3a7c32d7371f24b78028f8ac
--- /dev/null
+++ b/configs/dataset/sailvos3d_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "SAILVOS3DWAI(
+ split='${dataset.sailvos3d_wai.train.split}',
+ resolution=${dataset.sailvos3d_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.sailvos3d_wai.train.principal_point_centered},
+ aug_crop=${dataset.sailvos3d_wai.train.aug_crop},
+ transform='${dataset.sailvos3d_wai.train.transform}',
+ data_norm_type='${dataset.sailvos3d_wai.train.data_norm_type}',
+ ROOT='${dataset.sailvos3d_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.sailvos3d_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.sailvos3d_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.sailvos3d_wai.train.variable_num_views},
+ num_views=${dataset.sailvos3d_wai.train.num_views},
+ covisibility_thres=${dataset.sailvos3d_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/sailvos3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/sailvos3d_wai/val/default.yaml b/configs/dataset/sailvos3d_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f5936d324143f03abc8d100061e1b38217b2d35d
--- /dev/null
+++ b/configs/dataset/sailvos3d_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "SAILVOS3DWAI(
+ split='${dataset.sailvos3d_wai.val.split}',
+ resolution=${dataset.sailvos3d_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.sailvos3d_wai.val.principal_point_centered},
+ seed=${dataset.sailvos3d_wai.val.seed},
+ transform='${dataset.sailvos3d_wai.val.transform}',
+ data_norm_type='${dataset.sailvos3d_wai.val.data_norm_type}',
+ ROOT='${dataset.sailvos3d_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.sailvos3d_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.sailvos3d_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.sailvos3d_wai.val.variable_num_views},
+ num_views=${dataset.sailvos3d_wai.val.num_views},
+ covisibility_thres=${dataset.sailvos3d_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_sailvos3d}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/sailvos3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/scannetpp_wai/default.yaml b/configs/dataset/scannetpp_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d760ece911feb086e706395c16cb1eb86d758d79
--- /dev/null
+++ b/configs/dataset/scannetpp_wai/default.yaml
@@ -0,0 +1,4 @@
+defaults:
+ - train: default
+ - val: default
+ - test: default
diff --git a/configs/dataset/scannetpp_wai/test/default.yaml b/configs/dataset/scannetpp_wai/test/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a14c936aac97853cfebfcdc6671247facd77d291
--- /dev/null
+++ b/configs/dataset/scannetpp_wai/test/default.yaml
@@ -0,0 +1,24 @@
+dataset_str:
+ "ScanNetPPWAI(
+ split='${dataset.scannetpp_wai.test.split}',
+ resolution=${dataset.scannetpp_wai.test.dataset_resolution},
+ principal_point_centered=${dataset.scannetpp_wai.test.principal_point_centered},
+ seed=${dataset.scannetpp_wai.test.seed},
+ transform='${dataset.scannetpp_wai.test.transform}',
+ data_norm_type='${dataset.scannetpp_wai.test.data_norm_type}',
+ ROOT='${dataset.scannetpp_wai.test.ROOT}',
+ dataset_metadata_dir='${dataset.scannetpp_wai.test.dataset_metadata_dir}',
+ variable_num_views=${dataset.scannetpp_wai.test.variable_num_views},
+ num_views=${dataset.scannetpp_wai.test.num_views},
+ covisibility_thres=${dataset.scannetpp_wai.test.covisibility_thres})"
+split: 'test'
+dataset_resolution: ${dataset.resolution_test_scannetpp}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/scannetppv2
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+variable_num_views: ${dataset.test.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/scannetpp_wai/train/default.yaml b/configs/dataset/scannetpp_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9cad920ccb90ad492fdff199413fb9c569bd4bac
--- /dev/null
+++ b/configs/dataset/scannetpp_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ScanNetPPWAI(
+ split='${dataset.scannetpp_wai.train.split}',
+ resolution=${dataset.scannetpp_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.scannetpp_wai.train.principal_point_centered},
+ aug_crop=${dataset.scannetpp_wai.train.aug_crop},
+ transform='${dataset.scannetpp_wai.train.transform}',
+ data_norm_type='${dataset.scannetpp_wai.train.data_norm_type}',
+ ROOT='${dataset.scannetpp_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.scannetpp_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.scannetpp_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.scannetpp_wai.train.variable_num_views},
+ num_views=${dataset.scannetpp_wai.train.num_views},
+ covisibility_thres=${dataset.scannetpp_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/scannetppv2
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/scannetpp_wai/val/default.yaml b/configs/dataset/scannetpp_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a829383ecdcf50c2fee841e9dd8be61b2bd4f599
--- /dev/null
+++ b/configs/dataset/scannetpp_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "ScanNetPPWAI(
+ split='${dataset.scannetpp_wai.val.split}',
+ resolution=${dataset.scannetpp_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.scannetpp_wai.val.principal_point_centered},
+ seed=${dataset.scannetpp_wai.val.seed},
+ transform='${dataset.scannetpp_wai.val.transform}',
+ data_norm_type='${dataset.scannetpp_wai.val.data_norm_type}',
+ ROOT='${dataset.scannetpp_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.scannetpp_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.scannetpp_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.scannetpp_wai.val.variable_num_views},
+ num_views=${dataset.scannetpp_wai.val.num_views},
+ covisibility_thres=${dataset.scannetpp_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_scannetpp}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/scannetppv2
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/spring_wai/default.yaml b/configs/dataset/spring_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/spring_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/spring_wai/train/default.yaml b/configs/dataset/spring_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b073ccb8aa70267c06ce127a901b7e3fe47773db
--- /dev/null
+++ b/configs/dataset/spring_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "SpringWAI(
+ split='${dataset.spring_wai.train.split}',
+ resolution=${dataset.spring_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.spring_wai.train.principal_point_centered},
+ aug_crop=${dataset.spring_wai.train.aug_crop},
+ transform='${dataset.spring_wai.train.transform}',
+ data_norm_type='${dataset.spring_wai.train.data_norm_type}',
+ ROOT='${dataset.spring_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.spring_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.spring_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.spring_wai.train.variable_num_views},
+ num_views=${dataset.spring_wai.train.num_views},
+ covisibility_thres=${dataset.spring_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/spring
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/spring_wai/val/default.yaml b/configs/dataset/spring_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..92a28220e58b62cd46e5ed772fb9e4f690f4ad8e
--- /dev/null
+++ b/configs/dataset/spring_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "SpringWAI(
+ split='${dataset.spring_wai.val.split}',
+ resolution=${dataset.spring_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.spring_wai.val.principal_point_centered},
+ seed=${dataset.spring_wai.val.seed},
+ transform='${dataset.spring_wai.val.transform}',
+ data_norm_type='${dataset.spring_wai.val.data_norm_type}',
+ ROOT='${dataset.spring_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.spring_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.spring_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.spring_wai.val.variable_num_views},
+ num_views=${dataset.spring_wai.val.num_views},
+ covisibility_thres=${dataset.spring_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_spring}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/spring
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/structured3d_wai/default.yaml b/configs/dataset/structured3d_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/structured3d_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/structured3d_wai/train/default.yaml b/configs/dataset/structured3d_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8556d92cf9f9e830b35c797778ca63f79bc31a56
--- /dev/null
+++ b/configs/dataset/structured3d_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "Structured3DWAI(
+ split='${dataset.structured3d_wai.train.split}',
+ resolution=${dataset.structured3d_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.structured3d_wai.train.principal_point_centered},
+ aug_crop=${dataset.structured3d_wai.train.aug_crop},
+ transform='${dataset.structured3d_wai.train.transform}',
+ data_norm_type='${dataset.structured3d_wai.train.data_norm_type}',
+ ROOT='${dataset.structured3d_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.structured3d_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.structured3d_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.structured3d_wai.train.variable_num_views},
+ num_views=${dataset.structured3d_wai.train.num_views},
+ covisibility_thres=${dataset.structured3d_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/structured3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/structured3d_wai/val/default.yaml b/configs/dataset/structured3d_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..396a399b95993b4c20e452aa77f8e43d7177205b
--- /dev/null
+++ b/configs/dataset/structured3d_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "Structured3DWAI(
+ split='${dataset.structured3d_wai.val.split}',
+ resolution=${dataset.structured3d_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.structured3d_wai.val.principal_point_centered},
+ seed=${dataset.structured3d_wai.val.seed},
+ transform='${dataset.structured3d_wai.val.transform}',
+ data_norm_type='${dataset.structured3d_wai.val.data_norm_type}',
+ ROOT='${dataset.structured3d_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.structured3d_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.structured3d_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.structured3d_wai.val.variable_num_views},
+ num_views=${dataset.structured3d_wai.val.num_views},
+ covisibility_thres=${dataset.structured3d_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_structured3d}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/structured3d
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/tav2_wb_wai/default.yaml b/configs/dataset/tav2_wb_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d760ece911feb086e706395c16cb1eb86d758d79
--- /dev/null
+++ b/configs/dataset/tav2_wb_wai/default.yaml
@@ -0,0 +1,4 @@
+defaults:
+ - train: default
+ - val: default
+ - test: default
diff --git a/configs/dataset/tav2_wb_wai/test/default.yaml b/configs/dataset/tav2_wb_wai/test/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a616954c15ea1e4ec4dbe4083efb33b5349de2ff
--- /dev/null
+++ b/configs/dataset/tav2_wb_wai/test/default.yaml
@@ -0,0 +1,24 @@
+dataset_str:
+ "TartanAirV2WBWAI(
+ split='${dataset.tav2_wb_wai.test.split}',
+ resolution=${dataset.tav2_wb_wai.test.dataset_resolution},
+ principal_point_centered=${dataset.tav2_wb_wai.test.principal_point_centered},
+ seed=${dataset.tav2_wb_wai.test.seed},
+ transform='${dataset.tav2_wb_wai.test.transform}',
+ data_norm_type='${dataset.tav2_wb_wai.test.data_norm_type}',
+ ROOT='${dataset.tav2_wb_wai.test.ROOT}',
+ dataset_metadata_dir='${dataset.tav2_wb_wai.test.dataset_metadata_dir}',
+ variable_num_views=${dataset.tav2_wb_wai.test.variable_num_views},
+ num_views=${dataset.tav2_wb_wai.test.num_views},
+ covisibility_thres=${dataset.tav2_wb_wai.test.covisibility_thres})"
+split: 'test'
+dataset_resolution: ${dataset.resolution_test_tav2_wb}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/tav2_wb
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+variable_num_views: ${dataset.test.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/tav2_wb_wai/train/default.yaml b/configs/dataset/tav2_wb_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e897aefbcc774792bffc98bb430e8082dd1a163e
--- /dev/null
+++ b/configs/dataset/tav2_wb_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "TartanAirV2WBWAI(
+ split='${dataset.tav2_wb_wai.train.split}',
+ resolution=${dataset.tav2_wb_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.tav2_wb_wai.train.principal_point_centered},
+ aug_crop=${dataset.tav2_wb_wai.train.aug_crop},
+ transform='${dataset.tav2_wb_wai.train.transform}',
+ data_norm_type='${dataset.tav2_wb_wai.train.data_norm_type}',
+ ROOT='${dataset.tav2_wb_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.tav2_wb_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.tav2_wb_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.tav2_wb_wai.train.variable_num_views},
+ num_views=${dataset.tav2_wb_wai.train.num_views},
+ covisibility_thres=${dataset.tav2_wb_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/tav2_wb
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/tav2_wb_wai/val/default.yaml b/configs/dataset/tav2_wb_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..97b04c1a4778079a38dec5631708ecf97a84e496
--- /dev/null
+++ b/configs/dataset/tav2_wb_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "TartanAirV2WBWAI(
+ split='${dataset.tav2_wb_wai.val.split}',
+ resolution=${dataset.tav2_wb_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.tav2_wb_wai.val.principal_point_centered},
+ seed=${dataset.tav2_wb_wai.val.seed},
+ transform='${dataset.tav2_wb_wai.val.transform}',
+ data_norm_type='${dataset.tav2_wb_wai.val.data_norm_type}',
+ ROOT='${dataset.tav2_wb_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.tav2_wb_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.tav2_wb_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.tav2_wb_wai.val.variable_num_views},
+ num_views=${dataset.tav2_wb_wai.val.num_views},
+ covisibility_thres=${dataset.tav2_wb_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_tav2_wb}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/tav2_wb
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/unrealstereo4k_wai/default.yaml b/configs/dataset/unrealstereo4k_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/unrealstereo4k_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/unrealstereo4k_wai/train/default.yaml b/configs/dataset/unrealstereo4k_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..3ecce8dde40fbdfa33da5f3e9d6a2631b9ccb2d4
--- /dev/null
+++ b/configs/dataset/unrealstereo4k_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "UnrealStereo4KWAI(
+ split='${dataset.unrealstereo4k_wai.train.split}',
+ resolution=${dataset.unrealstereo4k_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.unrealstereo4k_wai.train.principal_point_centered},
+ aug_crop=${dataset.unrealstereo4k_wai.train.aug_crop},
+ transform='${dataset.unrealstereo4k_wai.train.transform}',
+ data_norm_type='${dataset.unrealstereo4k_wai.train.data_norm_type}',
+ ROOT='${dataset.unrealstereo4k_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.unrealstereo4k_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.unrealstereo4k_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.unrealstereo4k_wai.train.variable_num_views},
+ num_views=${dataset.unrealstereo4k_wai.train.num_views},
+ covisibility_thres=${dataset.unrealstereo4k_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/unrealstereo4k
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/unrealstereo4k_wai/val/default.yaml b/configs/dataset/unrealstereo4k_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4570290d7259f26925234b208f92243f40c34fb3
--- /dev/null
+++ b/configs/dataset/unrealstereo4k_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "UnrealStereo4KWAI(
+ split='${dataset.unrealstereo4k_wai.val.split}',
+ resolution=${dataset.unrealstereo4k_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.unrealstereo4k_wai.val.principal_point_centered},
+ seed=${dataset.unrealstereo4k_wai.val.seed},
+ transform='${dataset.unrealstereo4k_wai.val.transform}',
+ data_norm_type='${dataset.unrealstereo4k_wai.val.data_norm_type}',
+ ROOT='${dataset.unrealstereo4k_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.unrealstereo4k_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.unrealstereo4k_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.unrealstereo4k_wai.val.variable_num_views},
+ num_views=${dataset.unrealstereo4k_wai.val.num_views},
+ covisibility_thres=${dataset.unrealstereo4k_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_unrealstereo4k}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/unrealstereo4k
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/xrooms_wai/default.yaml b/configs/dataset/xrooms_wai/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..78448df144ecb19398c761f707b5e264dcaaae29
--- /dev/null
+++ b/configs/dataset/xrooms_wai/default.yaml
@@ -0,0 +1,3 @@
+defaults:
+ - train: default
+ - val: default
diff --git a/configs/dataset/xrooms_wai/train/default.yaml b/configs/dataset/xrooms_wai/train/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2a6131e36392f0efe537dbdf3b6767c83d7b9a3b
--- /dev/null
+++ b/configs/dataset/xrooms_wai/train/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "XRoomsWAI(
+ split='${dataset.xrooms_wai.train.split}',
+ resolution=${dataset.xrooms_wai.train.dataset_resolution},
+ principal_point_centered=${dataset.xrooms_wai.train.principal_point_centered},
+ aug_crop=${dataset.xrooms_wai.train.aug_crop},
+ transform='${dataset.xrooms_wai.train.transform}',
+ data_norm_type='${dataset.xrooms_wai.train.data_norm_type}',
+ ROOT='${dataset.xrooms_wai.train.ROOT}',
+ dataset_metadata_dir='${dataset.xrooms_wai.train.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.xrooms_wai.train.overfit_num_sets},
+ variable_num_views=${dataset.xrooms_wai.train.variable_num_views},
+ num_views=${dataset.xrooms_wai.train.num_views},
+ covisibility_thres=${dataset.xrooms_wai.train.covisibility_thres})"
+split: 'train'
+dataset_resolution: ${dataset.resolution_train}
+principal_point_centered: ${dataset.principal_point_centered}
+aug_crop: 16
+transform: 'colorjitter+grayscale+gaublur'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/xrooms
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.train.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dataset/xrooms_wai/val/default.yaml b/configs/dataset/xrooms_wai/val/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..90044d43e2cc0f3f6f67cd6e73e27c4898f88d30
--- /dev/null
+++ b/configs/dataset/xrooms_wai/val/default.yaml
@@ -0,0 +1,26 @@
+dataset_str:
+ "XRoomsWAI(
+ split='${dataset.xrooms_wai.val.split}',
+ resolution=${dataset.xrooms_wai.val.dataset_resolution},
+ principal_point_centered=${dataset.xrooms_wai.val.principal_point_centered},
+ seed=${dataset.xrooms_wai.val.seed},
+ transform='${dataset.xrooms_wai.val.transform}',
+ data_norm_type='${dataset.xrooms_wai.val.data_norm_type}',
+ ROOT='${dataset.xrooms_wai.val.ROOT}',
+ dataset_metadata_dir='${dataset.xrooms_wai.val.dataset_metadata_dir}',
+ overfit_num_sets=${dataset.xrooms_wai.val.overfit_num_sets},
+ variable_num_views=${dataset.xrooms_wai.val.variable_num_views},
+ num_views=${dataset.xrooms_wai.val.num_views},
+ covisibility_thres=${dataset.xrooms_wai.val.covisibility_thres})"
+split: 'val'
+dataset_resolution: ${dataset.resolution_val_xrooms}
+principal_point_centered: ${dataset.principal_point_centered}
+seed: 777
+transform: 'imgnorm'
+data_norm_type: ${model.data_norm_type}
+ROOT: ${root_data_dir}/xrooms
+dataset_metadata_dir: ${mapanything_dataset_metadata_dir}
+overfit_num_sets: null
+variable_num_views: ${dataset.val.variable_num_views}
+num_views: ${dataset.num_views}
+covisibility_thres: 0.25
diff --git a/configs/dense_n_view_benchmark.yaml b/configs/dense_n_view_benchmark.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..53f63f54d13bb2e421c47f2e5e2b3d23c746ff12
--- /dev/null
+++ b/configs/dense_n_view_benchmark.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - machine: aws
+ - model: default
+ - dataset: default
+ - _self_
+
+output_dir: ${hydra:run.dir}
+root_data_dir: ${machine.root_data_dir}
+mapanything_dataset_metadata_dir: ${machine.mapanything_dataset_metadata_dir}
+root_pretrained_checkpoints_dir: ${machine.root_pretrained_checkpoints_dir}
+root_experiments_dir: ${machine.root_experiments_dir}
+root_uniception_pretrained_checkpoints_dir: ${machine.root_uniception_pretrained_checkpoints_dir}
+
+### Benchmarking args
+seed: 0
+# Disable CUDNN Benchmark (Disable for variable resolution & number of view training)
+disable_cudnn_benchmark: true
+# Batch size for inference (Metrics are computed per multi-view set and averaged, not per batch of multi-view sets)
+batch_size: 10
+# Use mixed precision for inference
+amp: 1
+# Floating point type to use for mixed precision
+amp_dtype: "bf16"
diff --git a/configs/distributed/default.yaml b/configs/distributed/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..780d4ea77de057e5c554aa8cf25f3e7618740b2a
--- /dev/null
+++ b/configs/distributed/default.yaml
@@ -0,0 +1,6 @@
+# Distributed Training Params
+# Number of distributed processes
+world_size: 1
+local_rank: -1
+# Url used to set up distributed training
+dist_url: 'env://'
diff --git a/configs/loss/conf_pm_mask_loss.yaml b/configs/loss/conf_pm_mask_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d55324da8ff5c2b946a606de2640bd7f5160ee9
--- /dev/null
+++ b/configs/loss/conf_pm_mask_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfLoss(Regr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', loss_in_log=True), alpha=0.2) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(Regr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', flatten_across_image_only=True, loss_in_log=True), top_n_percent=5, apply_to_real_data_only=True) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/conf_pm_mask_scale_loss.yaml b/configs/loss/conf_pm_mask_scale_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eb186d734746083fa5bb327fa3c67d9a497eccfa
--- /dev/null
+++ b/configs/loss/conf_pm_mask_scale_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfLoss(PointsPlusScaleRegr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', flatten_across_image_only=True, loss_in_log=True), alpha=0.2) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(PointsPlusScaleRegr3D(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', flatten_across_image_only=True, loss_in_log=True), top_n_percent=5, apply_to_real_data_only=True) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/default.yaml b/configs/loss/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e345afed2a90fa4dc0c6d36e51a835f4451f0329
--- /dev/null
+++ b/configs/loss/default.yaml
@@ -0,0 +1,6 @@
+# Training Loss, For example: "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)"
+train_criterion: ""
+# Validation Loss, For example:
+# "Regr3D_ScaleShiftInv(L21, norm_mode='?avg_dis', ambiguous_loss_value=0)" (DUSt3R)
+# "Regr3D(L21, norm_mode='?avg_dis', ambiguous_loss_value=2)" (MASt3R)
+test_criterion: ""
diff --git a/configs/loss/entangled_metric_loss.yaml b/configs/loss/entangled_metric_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b5107d58bfb4f34cd0bb983a6bbe8fcc27a87ee3
--- /dev/null
+++ b/configs/loss/entangled_metric_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='?avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_depth_loss.yaml b/configs/loss/no_depth_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eaa5c251136bb987a6475ec5effef6dc623b7ea2
--- /dev/null
+++ b/configs/loss/no_depth_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, depth_loss_weight=0.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, depth_loss_weight=0.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_log_scaling.yaml b/configs/loss/no_log_scaling.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8e9e6b68279a4b869675fbc11a63c8c564c0b4bb
--- /dev/null
+++ b/configs/loss/no_log_scaling.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=False, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=False, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_points_loss.yaml b/configs/loss/no_points_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7126a4692c90439c1a4a81f056a364a704f1b4b4
--- /dev/null
+++ b/configs/loss/no_points_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=False, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, cam_frame_points_loss_weight=0.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[1]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=False, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, cam_frame_points_loss_weight=0.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[1]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_pose_loss.yaml b/configs/loss/no_pose_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5515eccdce79909e0c6ff40dd8e4ab4356011c5f
--- /dev/null
+++ b/configs/loss/no_pose_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, pose_quats_loss_weight=0.0, pose_trans_loss_weight=0.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, pose_quats_loss_weight=0.0, pose_trans_loss_weight=0.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_ray_dirs_loss.yaml b/configs/loss/no_ray_dirs_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4493b31c269bd4667bec579513a7dcb70868b08e
--- /dev/null
+++ b/configs/loss/no_ray_dirs_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, ray_directions_loss_weight=0.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0, ray_directions_loss_weight=0.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/no_robust_loss.yaml b/configs/loss/no_robust_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..94afee0e092a53707134429d35335b8a9745ee81
--- /dev/null
+++ b/configs/loss/no_robust_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(L2Loss(), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=1.0, gm_loss_weight=1.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.1 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(L2Loss(), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=1.0, gm_loss_weight=1.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.1 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/overall_disentangled_loss.yaml b/configs/loss/overall_disentangled_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..60348d9add9433b451bc1cd59e64d2f1ae773dbf
--- /dev/null
+++ b/configs/loss/overall_disentangled_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfLoss(DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), alpha=0.2, loss_set_indices=[0]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/overall_loss.yaml b/configs/loss/overall_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fcbc8dc28f99c31a7c193bb209d0f4128b7b99a3
--- /dev/null
+++ b/configs/loss/overall_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/overall_loss_weigh_pm_higher.yaml b/configs/loss/overall_loss_weigh_pm_higher.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..23706c64548095601553ebd2eff6919821760d61
--- /dev/null
+++ b/configs/loss/overall_loss_weigh_pm_higher.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryScaleRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, cam_frame_points_loss_weight=0.1, depth_loss_weight=0.1, ray_directions_loss_weight=0.1, pose_quats_loss_weight=0.1, pose_trans_loss_weight=0.1, scale_loss_weight=0.1, world_frame_points_loss_weight=1, normal_loss_weight=0.3, gm_loss_weight=0.3), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.03 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/up_to_scale_loss.yaml b/configs/loss/up_to_scale_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f2cc22d132ce3209db0c8f34d3e625cf734ee861
--- /dev/null
+++ b/configs/loss/up_to_scale_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_along_ray', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2]) + 0.3 * NonAmbiguousMaskLoss(BCELoss())"
diff --git a/configs/loss/vggt_loss.yaml b/configs/loss/vggt_loss.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f42d7de190624a128d71d8c028b2787b48a0b201
--- /dev/null
+++ b/configs/loss/vggt_loss.yaml
@@ -0,0 +1,4 @@
+# Training Loss
+train_criterion: "ConfAndExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_z', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), conf_alpha=0.2, top_n_percent=5, apply_to_real_data_only=True, conf_loss_set_indices=[0], exclude_loss_set_indices=[1, 2])"
+# Validation Loss
+test_criterion: "ExcludeTopNPercentPixelLoss(FactoredGeometryRegr3DPlusNormalGMLoss(RobustRegressionLoss(alpha=0.5, scaling_c=0.05), norm_mode='avg_dis', depth_type_for_loss='depth_z', loss_in_log=True, flatten_across_image_only=True, compute_pairwise_relative_pose_loss=False, compute_world_frame_points_loss=True, apply_normal_and_gm_loss_to_synthetic_data_only=True, normal_loss_weight=3.0, gm_loss_weight=3.0), top_n_percent=5, apply_to_real_data_only=True, loss_set_indices=[0, 1, 2])"
diff --git a/configs/machine/aws.yaml b/configs/machine/aws.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..584639ff57977dcb505b43e866443ac8df155295
--- /dev/null
+++ b/configs/machine/aws.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - default
+
+# Root directory containing all datasets
+root_data_dir: "/fsx/xrtech/data"
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: "/fsx/nkeetha/mapanything_dataset_metadata"
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: "/fsx/nkeetha/mapanything_checkpoints"
+# Root directory to log experiments
+root_experiments_dir: "/fsx/nkeetha/experiments"
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: "/fsx/nkeetha/uniception_checkpoints"
diff --git a/configs/machine/default.yaml b/configs/machine/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d89f6d360340a9eb79a12707ce19b07aba3a083a
--- /dev/null
+++ b/configs/machine/default.yaml
@@ -0,0 +1,10 @@
+# Root directory containing all datasets
+root_data_dir: ???
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: ???
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: ???
+# Root directory to log experiments
+root_experiments_dir: ???
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: ???
diff --git a/configs/machine/psc.yaml b/configs/machine/psc.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1e529be838f2f9838e428bb25244b7b4c94aefa3
--- /dev/null
+++ b/configs/machine/psc.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - default
+
+# Root directory containing all datasets
+root_data_dir: "/ocean/projects/cis220039p/shared/datasets"
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: "/ocean/projects/cis220039p/shared/mapanything_dataset_metadata"
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/nkeetha/code/AnyMap/checkpoints"
+# Root directory to log experiments
+root_experiments_dir: "/ocean/projects/cis220039p/nkeetha/experiments"
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/nkeetha/code/AnyMap/UniCeption/checkpoints"
diff --git a/configs/machine/psc_yuchen.yaml b/configs/machine/psc_yuchen.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d071afa04eadb2cd35b4f41bc12c1e8f820fe27c
--- /dev/null
+++ b/configs/machine/psc_yuchen.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - default
+
+# Root directory containing all datasets
+root_data_dir: "/ocean/projects/cis220039p/shared/datasets"
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: "/ocean/projects/cis220039p/shared/mapanything_dataset_metadata"
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: "/jet/home/yzhang25/AnyMap/checkpoints"
+# Root directory to log experiments
+root_experiments_dir: "/jet/home/yzhang25/AnyMap/outputs"
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: "/ocean/projects/cis220039p/shared/uniception/checkpoints/"
diff --git a/configs/machine/xri_dgx.yaml b/configs/machine/xri_dgx.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ba77beedfdd736d6a6d7c236ff3a1b9033ea8b24
--- /dev/null
+++ b/configs/machine/xri_dgx.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - default
+
+# Root directory containing all datasets
+root_data_dir: "/mnt/xri_mapsresearch/data/nkeetha"
+# Dataset metadata directory
+mapanything_dataset_metadata_dir: "/mnt/xri_mapsresearch/data/nkeetha/mapanything_dataset_metadata"
+# Root directory containing pretrained checkpoints for custom models
+root_pretrained_checkpoints_dir: "/mnt/xri_mapsresearch/code/nkeetha/AnyMap/checkpoints"
+# Root directory to log experiments
+root_experiments_dir: "/mnt/xri_mapsresearch/experiments/nkeetha"
+# Root directory containing UniCeption pretrained checkpoints
+root_uniception_pretrained_checkpoints_dir: "/mnt/xri_mapsresearch/code/nkeetha/AnyMap/UniCeption/checkpoints"
diff --git a/configs/model/anycalib.yaml b/configs/model/anycalib.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..76688024dc56b3cbcef486b09efb1b7a1231b94b
--- /dev/null
+++ b/configs/model/anycalib.yaml
@@ -0,0 +1,11 @@
+# String for model factory
+model_str: "anycalib"
+# Model config
+model_config:
+ name: "anycalib"
+# Image Normalization Type
+data_norm_type: "identity"
+# AnyCalib checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/default.yaml b/configs/model/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0926e7bf93b61f81bc4397ec2a16331d10f756be
--- /dev/null
+++ b/configs/model/default.yaml
@@ -0,0 +1,16 @@
+# String for model factory (Options: "mapanything", "mapanything_ablations", "modular_dust3r", "vggt", "pi3")
+model_str: ???
+# Model config
+model_config:
+ # Path to pretrained model checkpoint
+ pretrained_checkpoint_path: null
+ # Load specific submodules from the checkpoint
+ load_specific_pretrained_submodules: False
+ # List of submodules to load from the checkpoint (if load_specific_pretrained_submodules is True)
+ specific_pretrained_submodules: []
+# Path of a starting checkpoint (to enable backward compatibility with original DUSt3R class)
+pretrained: null
+# Image normalization type
+data_norm_type: ???
+# Torch hub force reload
+torch_hub_force_reload: False
diff --git a/configs/model/dust3r.yaml b/configs/model/dust3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5496bb311883ea3db42f69d9f22a5208b26d28f9
--- /dev/null
+++ b/configs/model/dust3r.yaml
@@ -0,0 +1,23 @@
+# String for model factory
+model_str: "dust3r"
+# Model config
+model_config:
+ name: "dust3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
+ # Scene graph for BA
+ scene_graph: "complete"
+ # Pairwise inference batch size
+ inference_batch_size: 32
+ # Global optim schedule
+ global_optim_schedule: "cosine"
+ # Global optim lr
+ global_optim_lr: 0.01
+ # Number of iterations for global optimization
+ global_optim_niter: 300
+# Image Normalization Type
+data_norm_type: "dust3r"
+# DUSt3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/encoder/croco_512.yaml b/configs/model/encoder/croco_512.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..37d15cb7ac75f83836b7d1fed956a398b61af9df
--- /dev/null
+++ b/configs/model/encoder/croco_512.yaml
@@ -0,0 +1,16 @@
+# UniCeption encoder string used for selecting encoder class (python3 -m uniception.models.encoders.list)
+encoder_str: "croco"
+# Name of the encoder
+name: "croco_512"
+# Data normalization type
+data_norm_type: "croco"
+# Patch embedding class
+patch_embed_cls: "PatchEmbedDust3R"
+# Image size
+img_size: [512, 512] # This parameter has no influence for PatchEmbedDust3R
+# Path to the pretrained encoder checkpoint
+pretrained_checkpoint_path: '${machine.root_uniception_pretrained_checkpoints_dir}/encoders/CroCo_Encoder_224.pth'
+# Override attributes in the pretrained checkpoint
+override_checkpoint_attributes: True
+# Flag to indicate whether model class uses torch hub
+uses_torch_hub: False
diff --git a/configs/model/encoder/croco_512_data_norm_dust3r.yaml b/configs/model/encoder/croco_512_data_norm_dust3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e98d9d3b91af5fe8ab35256ee8f0fe36a49b9890
--- /dev/null
+++ b/configs/model/encoder/croco_512_data_norm_dust3r.yaml
@@ -0,0 +1,16 @@
+# UniCeption encoder string used for selecting encoder class (python3 -m uniception.models.encoders.list)
+encoder_str: "croco"
+# Name of the encoder
+name: "croco_512_img_norm_dust3r"
+# Data normalization type
+data_norm_type: "dust3r"
+# Patch embedding class
+patch_embed_cls: "PatchEmbedDust3R"
+# Image size
+img_size: [512, 512] # This parameter has no influence for PatchEmbedDust3R
+# Path to the pretrained encoder checkpoint
+pretrained_checkpoint_path: '${machine.root_uniception_pretrained_checkpoints_dir}/encoders/CroCo_Encoder_224.pth'
+# Override attributes in the pretrained checkpoint
+override_checkpoint_attributes: True
+# Flag to indicate whether model class uses torch hub
+uses_torch_hub: False
diff --git a/configs/model/encoder/dinov2_large.yaml b/configs/model/encoder/dinov2_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a15ad5aa58cfd7da159b4713d206d54ac0a8f9a7
--- /dev/null
+++ b/configs/model/encoder/dinov2_large.yaml
@@ -0,0 +1,14 @@
+# UniCeption encoder string used for selecting encoder class (python3 -m uniception.models.encoders.list)
+encoder_str: "dinov2"
+# Name of the encoder
+name: "dinov2_large"
+# Data normalization type
+data_norm_type: "dinov2"
+# ViT size
+size: "large"
+# Registers
+with_registers: False
+# Flag to indicate whether model class uses torch hub
+uses_torch_hub: True
+# Flag to indicate whether to use gradient checkpointing for encoder
+gradient_checkpointing: False
diff --git a/configs/model/encoder/radio_v2_5_large.yaml b/configs/model/encoder/radio_v2_5_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..95ce4103aa426bcb3a7be5b48184778d935afac2
--- /dev/null
+++ b/configs/model/encoder/radio_v2_5_large.yaml
@@ -0,0 +1,10 @@
+# UniCeption encoder string used for selecting encoder class (python3 -m uniception.models.encoders.list)
+encoder_str: "radio"
+# Name of the encoder
+name: "radio_v2.5-large"
+# Data normalization type
+data_norm_type: "radio"
+# Model version
+model_version: "radio_v2.5-l"
+# Flag to indicate whether model class uses torch hub
+uses_torch_hub: True
diff --git a/configs/model/info_sharing/aat_ifr_24_layers.yaml b/configs/model/info_sharing/aat_ifr_24_layers.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..34f787a24dc3d3facbdc6e387046f1eafbe2d099
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_24_layers.yaml
@@ -0,0 +1,22 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_24_layers_ifr"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_24_layers_escaling.yaml b/configs/model/info_sharing/aat_ifr_24_layers_escaling.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..63705b14ff02c25d6a58f59791a10daa479f907f
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_24_layers_escaling.yaml
@@ -0,0 +1,24 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_24_layers_ifr"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Scale Entropy in Attention
+ use_entropy_scaling: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_24_layers_no_ref_view.yaml b/configs/model/info_sharing/aat_ifr_24_layers_no_ref_view.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4abf1d168eca3ccd6c7a8ca2717c1aee1c0e927b
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_24_layers_no_ref_view.yaml
@@ -0,0 +1,22 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_24_layers_ifr_no_ref_view"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: False
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_24_layers_w_view_pe.yaml b/configs/model/info_sharing/aat_ifr_24_layers_w_view_pe.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ef2fb6ccae11b740d2ea081181893147d4ee415c
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_24_layers_w_view_pe.yaml
@@ -0,0 +1,26 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_24_layers_ifr_w_view_pe"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
+ # Maximum number of views for positional encoding
+ max_num_views_for_pe: 1000
+ # Use random indices within range (1, max_num_views_for_pe) for positional encoding of non reference views
+ use_rand_idx_pe_for_non_reference_views: True
diff --git a/configs/model/info_sharing/aat_ifr_48_layers.yaml b/configs/model/info_sharing/aat_ifr_48_layers.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7249706a5b5031e71218604df44cd377a90dc37b
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_48_layers.yaml
@@ -0,0 +1,26 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_48_layers_ifr"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 23, 35]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "48_layers"
+ # Depth (this includes both frame-wise and gloabl attention layers)
+ depth: 48
+ # Feature dim (similar to ViT-Large)
+ dim: 1024
+ # Number of heads (similar to ViT-Large)
+ num_heads: 16
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_48_layers_escaling.yaml b/configs/model/info_sharing/aat_ifr_48_layers_escaling.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..fc8ca1da63b0a523ea537d47515974473820ef9e
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_48_layers_escaling.yaml
@@ -0,0 +1,28 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_48_layers_ifr"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 23, 35]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "48_layers"
+ # Depth (this includes both frame-wise and gloabl attention layers)
+ depth: 48
+ # Feature dim (similar to ViT-Large)
+ dim: 1024
+ # Number of heads (similar to ViT-Large)
+ num_heads: 16
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: True
+ # Scale Entropy in Attention
+ use_entropy_scaling: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/aat_ifr_48_layers_no_ref_view.yaml b/configs/model/info_sharing/aat_ifr_48_layers_no_ref_view.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..68456c32359c48f5e626a2ef9ecb0eb08faa2d54
--- /dev/null
+++ b/configs/model/info_sharing/aat_ifr_48_layers_no_ref_view.yaml
@@ -0,0 +1,26 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "alternating_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "aat_48_layers_ifr_no_ref_view"
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 23, 35]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "48_layers"
+ # Depth (this includes both frame-wise and gloabl attention layers)
+ depth: 48
+ # Feature dim (similar to ViT-Large)
+ dim: 1024
+ # Number of heads (similar to ViT-Large)
+ num_heads: 16
+ # Distinguish Reference and Non-Reference Views
+ distinguish_ref_and_non_ref_views: False
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/cat_ifr_dust3r.yaml b/configs/model/info_sharing/cat_ifr_dust3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e52e1d37cefc92d7121bb6a4076198fef8a72ee6
--- /dev/null
+++ b/configs/model/info_sharing/cat_ifr_dust3r.yaml
@@ -0,0 +1,18 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "cross_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: "RoPE100"
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "base_cat_ifr_dust3r"
+ # Number of views
+ num_views: 2
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [5, 8]
+ # Normalize intermediate features
+ norm_intermediate: False
+ # Load CroCo cross-attention transformer for DUSt3R Init
+ pretrained_checkpoint_path: '${machine.root_uniception_pretrained_checkpoints_dir}/info_sharing/cross_attn_transformer/Two_View_Cross_Attention_Transformer_CroCo.pth'
diff --git a/configs/model/info_sharing/gat_ifr_24_layers.yaml b/configs/model/info_sharing/gat_ifr_24_layers.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..60423714da681f71ca2088f101f4b9fa734b254b
--- /dev/null
+++ b/configs/model/info_sharing/gat_ifr_24_layers.yaml
@@ -0,0 +1,24 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "global_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "gat_24_layers_ifr"
+ # Maximum number of views for positional encoding
+ max_num_views: 1000
+ # Use random indices within range (1, max_num_views) for positional encoding of non reference views
+ use_rand_idx_pe_for_non_reference_views: True
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/info_sharing/gat_ifr_24_layers_escaling.yaml b/configs/model/info_sharing/gat_ifr_24_layers_escaling.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6143b5b37adde56a6a742b01a9b0d4fc0824fa35
--- /dev/null
+++ b/configs/model/info_sharing/gat_ifr_24_layers_escaling.yaml
@@ -0,0 +1,26 @@
+# Model type (Options: ["cross_attention", "global_attention", "alternating_attention"])
+model_type: "global_attention"
+# Model class type (Options: ["no_intermediate_features", "intermediate_features"])
+model_return_type: "intermediate_features"
+# Custom positional encoding (Options: ["RoPEfreq"], Callable Function, null)
+custom_positional_encoding: null
+# Module arguments
+module_args:
+ # Name of the info sharing module
+ name: "gat_24_layers_ifr"
+ # Maximum number of views for positional encoding
+ max_num_views: 1000
+ # Use random indices within range (1, max_num_views) for positional encoding of non reference views
+ use_rand_idx_pe_for_non_reference_views: True
+ # Indices of the intermediate features to be shared (indices start from 0)
+ indices: [11, 17]
+ # Normalize intermediate features
+ norm_intermediate: True
+ # Size string
+ size: "24_layers"
+ # Depth
+ depth: 24
+ # Scale Entropy in Attention
+ use_entropy_scaling: True
+ # Flag to indicate whether to use gradient checkpointing
+ gradient_checkpointing: False
diff --git a/configs/model/mapanything.yaml b/configs/model/mapanything.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a69c44ec9cd9d684e2dc0764365d462fcc095030
--- /dev/null
+++ b/configs/model/mapanything.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_24_layers
+ - pred_head: dpt_pose_scale
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything"
+# Model config
+model_config:
+ name: "mapanything"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mapanything_ablations.yaml b/configs/model/mapanything_ablations.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4b8dc0d7a29ca93bdd2fe34660ea5c5a883e7d60
--- /dev/null
+++ b/configs/model/mapanything_ablations.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_24_layers
+ - pred_head: dpt_pose
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything_ablations"
+# Model config
+model_config:
+ name: "mapanything_ablations"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mapanything_inference.yaml b/configs/model/mapanything_inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a081212dc00ee6b2e105254346129e430531f642
--- /dev/null
+++ b/configs/model/mapanything_inference.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_24_layers_escaling
+ - pred_head: dpt_pose_scale
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything"
+# Model config
+model_config:
+ name: "mapanything"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mapanything_large.yaml b/configs/model/mapanything_large.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..371e2fe6513fee6f5a0619aff5fff0c515a4d8d9
--- /dev/null
+++ b/configs/model/mapanything_large.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_48_layers
+ - pred_head: dpt_pose_scale
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything"
+# Model config
+model_config:
+ name: "mapanything"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mapanything_large_inference.yaml b/configs/model/mapanything_large_inference.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2e57bf7ddbcc416d4e637276202e60515535e0e1
--- /dev/null
+++ b/configs/model/mapanything_large_inference.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - default
+ - encoder: dinov2_large
+ - info_sharing: aat_ifr_48_layers_escaling
+ - pred_head: dpt_pose_scale
+ - task: images_only
+
+# String for model factory
+model_str: "mapanything"
+# Model config
+model_config:
+ name: "mapanything"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: ${model.encoder.data_norm_type}
diff --git a/configs/model/mast3r.yaml b/configs/model/mast3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dc315b3a24d0748a2b671a6d29e10dda9d446b3b
--- /dev/null
+++ b/configs/model/mast3r.yaml
@@ -0,0 +1,15 @@
+# String for model factory
+model_str: "mast3r"
+# Model config
+model_config:
+ name: "mast3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
+ # Cache dir
+ cache_dir: "${root_pretrained_checkpoints_dir}/mast3r_cache"
+# Image Normalization Type
+data_norm_type: "dust3r"
+# MASt3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/metric_dust3r.yaml b/configs/model/metric_dust3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..329a3cb9378b163049d898d1e54b778d9d2a4b0f
--- /dev/null
+++ b/configs/model/metric_dust3r.yaml
@@ -0,0 +1,23 @@
+# String for model factory
+model_str: "dust3r"
+# Model config
+model_config:
+ name: "metric_dust3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth"
+ # Scene graph for BA
+ scene_graph: "complete"
+ # Pairwise inference batch size
+ inference_batch_size: 32
+ # Global optim schedule
+ global_optim_schedule: "cosine"
+ # Global optim lr
+ global_optim_lr: 0.01
+ # Number of iterations for global optimization
+ global_optim_niter: 300
+# Image Normalization Type
+data_norm_type: "dust3r"
+# DUSt3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/modular_dust3r_512_dpt.yaml b/configs/model/modular_dust3r_512_dpt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a0ffe75a5af160b0f93e0790bf3788cd1fc1455
--- /dev/null
+++ b/configs/model/modular_dust3r_512_dpt.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+ - encoder: croco_512_data_norm_dust3r
+ - info_sharing: cat_ifr_dust3r
+ - pred_head: dpt
+
+# String for model factory
+model_str: "modular_dust3r"
+# Model config
+model_config:
+ name: "dust3r_512_dpt"
+ encoder_config: ${model.encoder}
+ info_sharing_config: ${model.info_sharing}
+ pred_head_config: ${model.pred_head}
+# Image Normalization Type
+data_norm_type: "dust3r"
diff --git a/configs/model/moge_1.yaml b/configs/model/moge_1.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4db9bcf8e7a5637393a71f5a9ae9bf0003b72d90
--- /dev/null
+++ b/configs/model/moge_1.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "moge"
+# Model config
+model_config:
+ name: "moge-1"
+ # MoGe pre-trained model checkpoint string
+ model_string: "Ruicheng/moge-vitl"
+ # Load custom checkpoint
+ load_custom_ckpt: false
+ # Custom checkpoint path
+ custom_ckpt_path: null
+# Image Normalization Type
+data_norm_type: "identity"
+# MoGe checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/moge_2.yaml b/configs/model/moge_2.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0b048344b8cd34011ae17fdb7d543a5adce1f403
--- /dev/null
+++ b/configs/model/moge_2.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "moge"
+# Model config
+model_config:
+ name: "moge-2"
+ # MoGe pre-trained model checkpoint string
+ model_string: "Ruicheng/moge-2-vitl"
+ # Load custom checkpoint
+ load_custom_ckpt: false
+ # Custom checkpoint path
+ custom_ckpt_path: null
+# Image Normalization Type
+data_norm_type: "identity"
+# MoGe checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/must3r.yaml b/configs/model/must3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..09bc01eadb2853a1aea3ed316bfae3f2a5ad0371
--- /dev/null
+++ b/configs/model/must3r.yaml
@@ -0,0 +1,15 @@
+# String for model factory
+model_str: "must3r"
+# Model config
+model_config:
+ name: "must3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/MUSt3R_512.pth"
+ # Retrieval Checkpoint path
+ retrieval_ckpt_path: "${root_pretrained_checkpoints_dir}/MUSt3R_512_retrieval_trainingfree.pth"
+# Image Normalization Type
+data_norm_type: "dust3r"
+# MASt3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/pi3.yaml b/configs/model/pi3.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..138e1393ffd1996f3412d6d91c775e1769c5c8c6
--- /dev/null
+++ b/configs/model/pi3.yaml
@@ -0,0 +1,13 @@
+# String for model factory
+model_str: "pi3"
+# Model config
+model_config:
+ name: "pi3"
+ # Load pre-trained weights
+ load_pretrained_weights: true
+# Image Normalization Type
+data_norm_type: "identity"
+# Pi3 checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: False
diff --git a/configs/model/pow3r.yaml b/configs/model/pow3r.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e7141fc206bb51d62ff03093409fa221d204bfee
--- /dev/null
+++ b/configs/model/pow3r.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - default
+ - task: images_only
+
+# String for model factory
+model_str: "pow3r"
+# Model config
+model_config:
+ name: "pow3r"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/Pow3R_ViTLarge_BaseDecoder_512_linear.pth"
+ # Geometric input config
+ geometric_input_config: ${model.task}
+# Image Normalization Type
+data_norm_type: "dust3r"
+# Pow3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/pow3r_ba.yaml b/configs/model/pow3r_ba.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..45c4dbca68f54d79f519e7757cd1490599c4839d
--- /dev/null
+++ b/configs/model/pow3r_ba.yaml
@@ -0,0 +1,29 @@
+defaults:
+ - default
+ - task: images_only
+
+# String for model factory
+model_str: "pow3r_ba"
+# Model config
+model_config:
+ name: "pow3r_ba"
+ # Checkpoint path
+ ckpt_path: "${root_pretrained_checkpoints_dir}/Pow3R_ViTLarge_BaseDecoder_512_linear.pth"
+ # Geometric input config
+ geometric_input_config: ${model.task}
+ # Scene graph for BA
+ scene_graph: "complete"
+ # Pairwise inference batch size
+ inference_batch_size: 32
+ # Global optim schedule
+ global_optim_schedule: "cosine"
+ # Global optim lr
+ global_optim_lr: 0.01
+ # Number of iterations for global optimization
+ global_optim_niter: 300
+# Image Normalization Type
+data_norm_type: "dust3r"
+# Pow3R checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask.yaml b/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c98de936666ba0c0f469df99d75935f34f492cc4
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask.yaml
@@ -0,0 +1,22 @@
+# Camera Frame Pointmap + Global Camera Pose (Trans + Quats) + Confidence + Mask
+input_dim: 5
+scene_rep_dim: 3
+type: "campointmap+pose+confidence+mask"
+scene_rep_type: "campointmap+pose"
+dense_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ pointmap_mode: "z_exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5ef92a540e7e2979c73541ed532ee225b0df3c11
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/campointmap_pose_confidence_mask_scale.yaml
@@ -0,0 +1,27 @@
+# Camera Frame Pointmap + Global Camera Pose (Trans + Quats) + Confidence + Mask + Scene-wide Metric Scaling Factor
+input_dim: 5
+scene_rep_dim: 3
+type: "campointmap+pose+confidence+mask"
+scene_rep_type: "campointmap+pose"
+dense_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ pointmap_mode: "z_exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/pointmap_confidence.yaml b/configs/model/pred_head/adaptor_config/pointmap_confidence.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..96b60f9a47407a048c1f4af698f5de693eb1d109
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_confidence.yaml
@@ -0,0 +1,13 @@
+# Pointmap + Confidence
+input_dim: 4
+scene_rep_dim: 3
+type: "pointmap+confidence"
+scene_rep_type: "pointmap"
+init_dict:
+ name: "pointmap+confidence"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/pointmap_confidence_mask.yaml b/configs/model/pred_head/adaptor_config/pointmap_confidence_mask.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..833df59a2a8d941788cbf9ef6e988cd8e0cf5e6b
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_confidence_mask.yaml
@@ -0,0 +1,13 @@
+# Pointmap + Confidence + Mask
+input_dim: 5
+scene_rep_dim: 3
+type: "pointmap+confidence+mask"
+scene_rep_type: "pointmap"
+init_dict:
+ name: "pointmap+confidence+mask"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/pointmap_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/pointmap_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4230d94989652d7b5ee779bebd7d8ac52d8e3cd7
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_confidence_mask_scale.yaml
@@ -0,0 +1,18 @@
+# Pointmap + Confidence + Mask + Scene-wide Metric Scaling Factor
+input_dim: 5
+scene_rep_dim: 3
+type: "pointmap+confidence+mask"
+scene_rep_type: "pointmap"
+init_dict:
+ name: "pointmap+confidence+mask"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "campointmap+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/pointmap_factored_raydirs_depth_pose_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/pointmap_factored_raydirs_depth_pose_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2f93a518ba2322cc8e6915abfa4b1cdf00fb2dbf
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_factored_raydirs_depth_pose_confidence_mask_scale.yaml
@@ -0,0 +1,39 @@
+# Global Pointmaps + Ray Directions on Unit Sphere + Depth along Ray + Global Camera Pose (Trans + Quats) + Confidence + Mask + Global Metric Scaling Factor
+input_dim: 9
+scene_rep_dim: 7
+type: "pointmap+raydirs+depth+pose+confidence+mask"
+scene_rep_type: "pointmap+raydirs+depth+pose"
+dense_pred_init_dict:
+ name: "pointmap+raydirs+depth+pose+confidence+mask+scale"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ ray_directions_mode: "linear"
+ ray_directions_normalize_to_unit_sphere: true
+ ray_directions_normalize_to_unit_image_plane: false
+ ray_directions_vmin: ${special_float:"-inf"}
+ ray_directions_vmax: ${special_float:"inf"}
+ ray_directions_clamp_min_of_z_dir: false
+ ray_directions_z_dir_min: ${special_float:"-inf"}
+ depth_mode: "exp"
+ depth_vmin: 0
+ depth_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
+# Flag to decide what representaion to use for global pointmaps
+use_factored_predictions_for_global_pointmaps: true
diff --git a/configs/model/pred_head/adaptor_config/pointmap_raydirs_depth_pose_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/pointmap_raydirs_depth_pose_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..bfb2408fd39ab13bd619b3b503e29f9900433a61
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/pointmap_raydirs_depth_pose_confidence_mask_scale.yaml
@@ -0,0 +1,39 @@
+# Global Pointmaps + Ray Directions on Unit Sphere + Depth along Ray + Global Camera Pose (Trans + Quats) + Confidence + Mask + Global Metric Scaling Factor
+input_dim: 9
+scene_rep_dim: 7
+type: "pointmap+raydirs+depth+pose+confidence+mask"
+scene_rep_type: "pointmap+raydirs+depth+pose"
+dense_pred_init_dict:
+ name: "pointmap+raydirs+depth+pose+confidence+mask+scale"
+ pointmap_mode: "exp"
+ pointmap_vmin: ${special_float:"-inf"}
+ pointmap_vmax: ${special_float:"inf"}
+ ray_directions_mode: "linear"
+ ray_directions_normalize_to_unit_sphere: true
+ ray_directions_normalize_to_unit_image_plane: false
+ ray_directions_vmin: ${special_float:"-inf"}
+ ray_directions_vmax: ${special_float:"inf"}
+ ray_directions_clamp_min_of_z_dir: false
+ ray_directions_z_dir_min: ${special_float:"-inf"}
+ depth_mode: "exp"
+ depth_vmin: 0
+ depth_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
+# Flag to decide what representaion to use for global pointmaps
+use_factored_predictions_for_global_pointmaps: false
diff --git a/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask.yaml b/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ee83ea0b976e3d173302d167b51404446b950b11
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask.yaml
@@ -0,0 +1,29 @@
+# Ray Directions on Unit Sphere + Depth along Ray + Global Camera Pose (Trans + Quats) + Confidence + Mask
+input_dim: 6
+scene_rep_dim: 4
+type: "raydirs+depth+pose+confidence+mask"
+scene_rep_type: "raydirs+depth+pose"
+dense_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask"
+ ray_directions_mode: "linear"
+ ray_directions_normalize_to_unit_sphere: true
+ ray_directions_normalize_to_unit_image_plane: false
+ ray_directions_vmin: ${special_float:"-inf"}
+ ray_directions_vmax: ${special_float:"inf"}
+ ray_directions_clamp_min_of_z_dir: false
+ ray_directions_z_dir_min: ${special_float:"-inf"}
+ depth_mode: "exp"
+ depth_vmin: 0
+ depth_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask_scale.yaml b/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4a939ad9eeed4131990941976bde8124bfbeba15
--- /dev/null
+++ b/configs/model/pred_head/adaptor_config/raydirs_depth_pose_confidence_mask_scale.yaml
@@ -0,0 +1,34 @@
+# Ray Directions on Unit Sphere + Depth along Ray + Global Camera Pose (Trans + Quats) + Confidence + Mask + Global Metric Scaling Factor
+input_dim: 6
+scene_rep_dim: 4
+type: "raydirs+depth+pose+confidence+mask"
+scene_rep_type: "raydirs+depth+pose"
+dense_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ ray_directions_mode: "linear"
+ ray_directions_normalize_to_unit_sphere: true
+ ray_directions_normalize_to_unit_image_plane: false
+ ray_directions_vmin: ${special_float:"-inf"}
+ ray_directions_vmax: ${special_float:"inf"}
+ ray_directions_clamp_min_of_z_dir: false
+ ray_directions_z_dir_min: ${special_float:"-inf"}
+ depth_mode: "exp"
+ depth_vmin: 0
+ depth_vmax: ${special_float:"inf"}
+ confidence_type: "exp"
+ confidence_vmin: 1
+ confidence_vmax: ${special_float:"inf"}
+pose_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ cam_trans_mode: "linear"
+ cam_trans_vmin: ${special_float:"-inf"}
+ cam_trans_vmax: ${special_float:"inf"}
+ quaternions_mode: "linear"
+ quaternions_normalize: true
+ quaternions_vmin: ${special_float:"-inf"}
+ quaternions_vmax: ${special_float:"inf"}
+scale_pred_init_dict:
+ name: "raydirs+depth+pose+confidence+mask+scale"
+ mode: "exp"
+ vmin: 1e-08
+ vmax: ${special_float:"inf"}
diff --git a/configs/model/pred_head/dpt.yaml b/configs/model/pred_head/dpt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d05fe0b2695236470388b522d0ed70fcf603b1d
--- /dev/null
+++ b/configs/model/pred_head/dpt.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - adaptor_config: pointmap_confidence
+
+type: "dpt"
+feature_head:
+ feature_dim: 256
+ hooks: [0, 1, 2, 3]
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+regressor_head:
+ output_dim: ${model.pred_head.adaptor_config.input_dim}
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+adaptor_type: ${model.pred_head.adaptor_config.type}
+adaptor: ${model.pred_head.adaptor_config.init_dict}
+# Flag to indicate whether to use gradient checkpointing
+gradient_checkpointing: False
diff --git a/configs/model/pred_head/dpt_pose.yaml b/configs/model/pred_head/dpt_pose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..24d50ae5dbfa1ada961b0afde5422ce9dac97a8b
--- /dev/null
+++ b/configs/model/pred_head/dpt_pose.yaml
@@ -0,0 +1,19 @@
+defaults:
+ - adaptor_config: raydirs_depth_pose_confidence_mask
+
+type: "dpt+pose"
+feature_head:
+ feature_dim: 256
+ hooks: [0, 1, 2, 3]
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+regressor_head:
+ output_dim: ${model.pred_head.adaptor_config.input_dim}
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+pose_head:
+ num_resconv_block: 2
+ rot_representation_dim: 4
+adaptor_type: ${model.pred_head.adaptor_config.type}
+dpt_adaptor: ${model.pred_head.adaptor_config.dense_pred_init_dict}
+pose_adaptor: ${model.pred_head.adaptor_config.pose_pred_init_dict}
+# Flag to indicate whether to use gradient checkpointing
+gradient_checkpointing: False
diff --git a/configs/model/pred_head/dpt_pose_scale.yaml b/configs/model/pred_head/dpt_pose_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..068dc181712e4d92f2e1e1e088f363ed698c2582
--- /dev/null
+++ b/configs/model/pred_head/dpt_pose_scale.yaml
@@ -0,0 +1,22 @@
+defaults:
+ - adaptor_config: raydirs_depth_pose_confidence_mask_scale
+
+type: "dpt+pose"
+feature_head:
+ feature_dim: 256
+ hooks: [0, 1, 2, 3]
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+regressor_head:
+ output_dim: ${model.pred_head.adaptor_config.input_dim}
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+pose_head:
+ num_resconv_block: 2
+ rot_representation_dim: 4
+scale_head:
+ output_dim: 1
+adaptor_type: ${model.pred_head.adaptor_config.type}
+dpt_adaptor: ${model.pred_head.adaptor_config.dense_pred_init_dict}
+pose_adaptor: ${model.pred_head.adaptor_config.pose_pred_init_dict}
+scale_adaptor: ${model.pred_head.adaptor_config.scale_pred_init_dict}
+# Flag to indicate whether to use gradient checkpointing
+gradient_checkpointing: False
diff --git a/configs/model/pred_head/dpt_scale.yaml b/configs/model/pred_head/dpt_scale.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..286eb589259cd07c6ff0d2c24e5a58a6d1c748c5
--- /dev/null
+++ b/configs/model/pred_head/dpt_scale.yaml
@@ -0,0 +1,18 @@
+defaults:
+ - adaptor_config: pointmap_confidence_mask_scale
+
+type: "dpt"
+feature_head:
+ feature_dim: 256
+ hooks: [0, 1, 2, 3]
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+regressor_head:
+ output_dim: ${model.pred_head.adaptor_config.input_dim}
+ checkpoint_gradient: ${model.pred_head.gradient_checkpointing}
+scale_head:
+ output_dim: 1
+adaptor_type: ${model.pred_head.adaptor_config.type}
+adaptor: ${model.pred_head.adaptor_config.init_dict}
+scale_adaptor: ${model.pred_head.adaptor_config.scale_pred_init_dict}
+# Flag to indicate whether to use gradient checkpointing
+gradient_checkpointing: False
diff --git a/configs/model/task/aug_training.yaml b/configs/model/task/aug_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1dfae15dcd58a86c6cef932a31ca6d74ba760aa4
--- /dev/null
+++ b/configs/model/task/aug_training.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 0.9
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0.05
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 0.5
+# Probability of Geometric Inputs with Depths
+depth_prob: 0.5
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0.5
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0.5
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale quantities for the input metric high quality gt depth
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0.05
+# Probability for skipping input of the metric scale quantities for the input metric pose
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0.05
diff --git a/configs/model/task/calibrated_sfm.yaml b/configs/model/task/calibrated_sfm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8d24c0033e82e7db23c7c6c40428d5c295380c8b
--- /dev/null
+++ b/configs/model/task/calibrated_sfm.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/default.yaml b/configs/model/task/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8fdc9a98f4931a014dda25939aa44eadbdd614ed
--- /dev/null
+++ b/configs/model/task/default.yaml
@@ -0,0 +1,27 @@
+# Ray Directions Encoder Config
+ray_dirs_encoder_config:
+ name: "ray_dirs_encoder"
+ in_chans: 3
+ encoder_str: "dense_rep_encoder"
+ apply_pe: false
+# Depth Encoder Config
+depth_encoder_config:
+ name: "depth_encoder"
+ in_chans: 1
+ encoder_str: "dense_rep_encoder"
+ apply_pe: false
+# Cam Rotation (Quats) Encoder Config
+cam_rot_encoder_config:
+ name: "cam_rot_quats_encoder"
+ in_chans: 4
+ encoder_str: "global_rep_encoder"
+# Cam Translation Encoder Config
+cam_trans_encoder_config:
+ name: "cam_trans_encoder"
+ in_chans: 3
+ encoder_str: "global_rep_encoder"
+# Scale Encoder Config
+scale_encoder_config:
+ name: "scale_encoder"
+ in_chans: 1
+ encoder_str: "global_rep_encoder"
diff --git a/configs/model/task/depth_completion.yaml b/configs/model/task/depth_completion.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0461a2c7fce649a4b1fe4e55b80955375197a88e
--- /dev/null
+++ b/configs/model/task/depth_completion.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 1
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/images_only.yaml b/configs/model/task/images_only.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..99607dbd43ffc1687a665cfc32fa67e2d5b5908d
--- /dev/null
+++ b/configs/model/task/images_only.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 0
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 1
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 0
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/mvs.yaml b/configs/model/task/mvs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8cd2850ba2cad9d925a268639d09a5b19158f323
--- /dev/null
+++ b/configs/model/task/mvs.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/mvs_non_metric.yaml b/configs/model/task/mvs_non_metric.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ba80d5b4694e1c2798191d4949dea22688187185
--- /dev/null
+++ b/configs/model/task/mvs_non_metric.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/mvs_training.yaml b/configs/model/task/mvs_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..eb6b4db744cb16eb7ea351082ffde5a6edd434cb
--- /dev/null
+++ b/configs/model/task/mvs_training.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale quantities for the input metric high quality gt depth
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale quantities for the input metric pose
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0.05
diff --git a/configs/model/task/non_metric_poses_metric_depth.yaml b/configs/model/task/non_metric_poses_metric_depth.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1bc224e89b0df9c44de0676bb4a085c1276abc4a
--- /dev/null
+++ b/configs/model/task/non_metric_poses_metric_depth.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/non_metric_poses_metric_depth_sparse.yaml b/configs/model/task/non_metric_poses_metric_depth_sparse.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9c92ed95b830a78cf0224dacbb139fcc14e6b2f8
--- /dev/null
+++ b/configs/model/task/non_metric_poses_metric_depth_sparse.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 1
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/non_metric_poses_non_metric_depth.yaml b/configs/model/task/non_metric_poses_non_metric_depth.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..9a5e9902c59379e6e9c8e9b2384a363f58e510bf
--- /dev/null
+++ b/configs/model/task/non_metric_poses_non_metric_depth.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 1
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/non_metric_poses_non_metric_depth_sparse.yaml b/configs/model/task/non_metric_poses_non_metric_depth_sparse.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..5ee97881825a37347ad476f6c35b46639de7e296
--- /dev/null
+++ b/configs/model/task/non_metric_poses_non_metric_depth_sparse.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 1
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 1
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/pass_through.yaml b/configs/model/task/pass_through.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..97aa6245875bbd587019d73c16edf169e489d425
--- /dev/null
+++ b/configs/model/task/pass_through.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/posed_sfm.yaml b/configs/model/task/posed_sfm.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7c86724a0d08256b91c74ecc625ed46f2a700467
--- /dev/null
+++ b/configs/model/task/posed_sfm.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 0
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/posed_sfm_non_metric.yaml b/configs/model/task/posed_sfm_non_metric.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..62b540e366656503f984ed7053c453ad7bd1386c
--- /dev/null
+++ b/configs/model/task/posed_sfm_non_metric.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 0
+# Probability of Geometric Inputs with Depths
+depth_prob: 0
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 1
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 1
diff --git a/configs/model/task/registration.yaml b/configs/model/task/registration.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e1f7fc206dc81b91aa27f99d304f4386ffb705bb
--- /dev/null
+++ b/configs/model/task/registration.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/registration_sparse.yaml b/configs/model/task/registration_sparse.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..548d43b39ace8f8f70b81e98cc9fd46144a43bc6
--- /dev/null
+++ b/configs/model/task/registration_sparse.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 1
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale norm factor for the input metric high quality gt depth
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0
+# Probability for skipping input of the metric scale norm factor for the input metric pose
+# If 0, the metric scale norm factor will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/task/registration_training.yaml b/configs/model/task/registration_training.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0676132925de37e46560550a917a83a777662beb
--- /dev/null
+++ b/configs/model/task/registration_training.yaml
@@ -0,0 +1,23 @@
+defaults:
+ - default
+
+# Overall Probability of Geometric Inputs
+overall_prob: 1
+# Dropout Probability of Geometric Inputs (for each sample across batch size and number of views)
+dropout_prob: 0
+# Probability of Geometric Inputs with Ray Directions
+ray_dirs_prob: 1
+# Probability of Geometric Inputs with Depths
+depth_prob: 1
+# Probability of Geometric Inputs with Camera Poses
+cam_prob: 0
+# Probability of sparsely sampling the high quality gt depth
+sparse_depth_prob: 0.5
+# Percentage of the valid depth to remove if the probability of using sparse depth is greater than 0 (Range: [0, 1])
+sparsification_removal_percent: 0.9
+# Probability for skipping input of the metric scale quantities for the input metric high quality gt depth
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+depth_scale_norm_all_prob: 0.05
+# Probability for skipping input of the metric scale quantities for the input metric pose
+# If 0, the metric scale quantities will be provided as input to the model for all the metric scale conditionings
+pose_scale_norm_all_prob: 0
diff --git a/configs/model/vggt.yaml b/configs/model/vggt.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1911b3979b708da810046b5bdd59fa6addb25368
--- /dev/null
+++ b/configs/model/vggt.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "vggt"
+# Model config
+model_config:
+ name: "vggt"
+ # Load pre-trained weights
+ load_pretrained_weights: true
+ # Load custom checkpoint
+ load_custom_ckpt: false
+ # Custom checkpoint path
+ custom_ckpt_path: null
+# Image Normalization Type
+data_norm_type: "identity"
+# VGGT checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/vggt_commercial.yaml b/configs/model/vggt_commercial.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2a5dea00b52828f553daa3502f9474ed378d1192
--- /dev/null
+++ b/configs/model/vggt_commercial.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "vggt"
+# Model config
+model_config:
+ name: "vggt"
+ # Load pre-trained weights
+ load_pretrained_weights: true
+ # Load custom checkpoint
+ load_custom_ckpt: true
+ # Custom checkpoint path
+ custom_ckpt_path: "${root_pretrained_checkpoints_dir}/vggt_1B_commercial.pt"
+# Image Normalization Type
+data_norm_type: "identity"
+# VGGT checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/model/vggt_non_pretrained.yaml b/configs/model/vggt_non_pretrained.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a67a996aa754c7fe3ec53fa66bdad45fdc3db736
--- /dev/null
+++ b/configs/model/vggt_non_pretrained.yaml
@@ -0,0 +1,17 @@
+# String for model factory
+model_str: "vggt"
+# Model config
+model_config:
+ name: "vggt"
+ # Load pre-trained weights
+ load_pretrained_weights: false
+ # Load custom checkpoint
+ load_custom_ckpt: false
+ # Custom checkpoint path
+ custom_ckpt_path: null
+# Image Normalization Type
+data_norm_type: "identity"
+# VGGT checkpoint is already loaded in the inference wrapper
+pretrained: null
+# Torch hub force reload
+torch_hub_force_reload: false
diff --git a/configs/rmvd_benchmark.yaml b/configs/rmvd_benchmark.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b4042f433e9eabcf9fde4221d81c443495af8a64
--- /dev/null
+++ b/configs/rmvd_benchmark.yaml
@@ -0,0 +1,34 @@
+defaults:
+ - machine: aws
+ - model: default
+ - dataset: default
+ - _self_
+
+# Path Settings
+output_dir: ${hydra:run.dir}
+root_data_dir: ${machine.root_data_dir}
+mapanything_dataset_metadata_dir: ${machine.mapanything_dataset_metadata_dir}
+root_pretrained_checkpoints_dir: ${machine.root_pretrained_checkpoints_dir}
+root_experiments_dir: ${machine.root_experiments_dir}
+root_uniception_pretrained_checkpoints_dir: ${machine.root_uniception_pretrained_checkpoints_dir}
+
+### Benchmarking args
+seed: 0
+# Disable CUDNN Benchmark (Disable for variable resolution & number of view training)
+disable_cudnn_benchmark: true
+# Batch size for inference (Metrics are computed per multi-view set and averaged, not per batch of multi-view sets)
+batch_size: 10
+# Use mixed precision for inference
+amp: 1
+# Floating point type to use for mixed precision
+amp_dtype: "bf16"
+# Choose from eth3d, kitti, scannet
+eval_dataset: eth3d
+# Choose from img, img+intrinsics, img+intrinsics+pose
+evaluation_conditioning: img
+# Choose from "median", "none"
+evaluation_alignment: median
+# Choose from "multi_view", "single_view"
+evaluation_views: multi_view
+# Resolution to inference the selected model.
+evaluation_resolution: ${dataset.resolution_options.518_1_33_ar}
diff --git a/configs/train.yaml b/configs/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a8a4ad870c3a0fb88eb8671b95fd77be254c0931
--- /dev/null
+++ b/configs/train.yaml
@@ -0,0 +1,15 @@
+defaults:
+ - machine: aws
+ - model: default
+ - dataset: default
+ - loss: default
+ - train_params: default
+ - distributed: default
+ - _self_
+
+output_dir: ${hydra:run.dir}
+root_data_dir: ${machine.root_data_dir}
+mapanything_dataset_metadata_dir: ${machine.mapanything_dataset_metadata_dir}
+root_pretrained_checkpoints_dir: ${machine.root_pretrained_checkpoints_dir}
+root_experiments_dir: ${machine.root_experiments_dir}
+root_uniception_pretrained_checkpoints_dir: ${machine.root_uniception_pretrained_checkpoints_dir}
diff --git a/configs/train_params/default.yaml b/configs/train_params/default.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d059c9d5ee6d4d0feb69ef6a5d3cef2014a037e1
--- /dev/null
+++ b/configs/train_params/default.yaml
@@ -0,0 +1,41 @@
+# Random Seed
+seed: 0
+# Maximum number of images per GPU (changes based on available GPU memory)
+max_num_of_imgs_per_gpu: 48
+# Accumulate gradient iterations (for increasing the effective batch size under memory constraints)
+accum_iter: 1
+# Maximum number of epochs for the scheduler
+epochs: 100
+## Default Optimizer parameters
+# Learning rate (absolute lr)
+lr: 0.0001
+# Lower lr bound for cyclic schedulers that hit 0
+min_lr: 1e-06
+# Epochs to warmup LR
+warmup_epochs: 10
+# Weight decay
+weight_decay: 0.05
+# LR schedule type
+schedule_type: "linear_warmup_half_cycle_cosine_decay"
+# Warn if model params are not in the below submodule_configs
+warn_not_in_submodule: False
+# Optimizer parameters specific to submodules
+submodule_configs: {}
+# Use Automatic Mixed Precision for pretraining
+amp: 1
+# Floating point type to use for mixed precision training
+amp_dtype: "bf16"
+# Disable CUDNN Benchmark (Disable for variable resolution & number of view training)
+disable_cudnn_benchmark: true
+# Freeze the validation samples across all epochs
+freeze_val_samples_across_all_epochs: true
+# Test loss evaluation frequency
+eval_freq: 1
+# Frequency (number of epochs) to save checkpoint in checkpoint-last.pth
+save_freq: 1
+# Frequency (number of epochs) to save checkpoint in checkpoint-%d.pth
+keep_freq: 10
+# Frequence (number of iterations) to print infos while training (includes tensorboard logging)
+print_freq: 20
+# Resume Training from last checkpoint
+resume: True
diff --git a/configs/train_params/finetune_with_lower_encoder_lr.yaml b/configs/train_params/finetune_with_lower_encoder_lr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1dda66b8c03670e8cdbee34aa0f1f0f0fb44d29a
--- /dev/null
+++ b/configs/train_params/finetune_with_lower_encoder_lr.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+
+# Use 20x lower lr for finetuning
+lr: 5e-06
+min_lr: 5e-08
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 2.5e-07
+ min_lr: 2.5e-09
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/configs/train_params/finetune_with_lower_encoder_lr_64g.yaml b/configs/train_params/finetune_with_lower_encoder_lr_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f593dd3a43c6808bd65d5d34b8c0c5ad3253d28e
--- /dev/null
+++ b/configs/train_params/finetune_with_lower_encoder_lr_64g.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+
+# Use 20x lower lr for finetuning
+lr: 1e-05
+min_lr: 1e-07
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 5e-07
+ min_lr: 5e-09
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/configs/train_params/freeze_encoder.yaml b/configs/train_params/freeze_encoder.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..25e7508544f98564678bd863db491de856e010e4
--- /dev/null
+++ b/configs/train_params/freeze_encoder.yaml
@@ -0,0 +1,8 @@
+defaults:
+ - default
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 0
diff --git a/configs/train_params/lower_encoder_lr.yaml b/configs/train_params/lower_encoder_lr.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4a1ae06dae85f31174b5103c056eb4d3a163e2a8
--- /dev/null
+++ b/configs/train_params/lower_encoder_lr.yaml
@@ -0,0 +1,12 @@
+defaults:
+ - default
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 5e-06
+ min_lr: 5e-08
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/configs/train_params/lower_encoder_lr_64g.yaml b/configs/train_params/lower_encoder_lr_64g.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4364232c4d65b6117a0ad25adace93f2444af002
--- /dev/null
+++ b/configs/train_params/lower_encoder_lr_64g.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+
+# Use 2x higher lr for 8x higher effective batch size
+lr: 2e-04
+min_lr: 2e-07
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # Encoder
+ encoder:
+ lr: 1e-05
+ min_lr: 1e-08
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/configs/train_params/vggt_finetune.yaml b/configs/train_params/vggt_finetune.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d130d4f12ad4d8e2de2165bfeb058addd0b668a8
--- /dev/null
+++ b/configs/train_params/vggt_finetune.yaml
@@ -0,0 +1,16 @@
+defaults:
+ - default
+
+# Use 10x lower lr for finetuning
+lr: 1e-05
+min_lr: 1e-07
+
+# Optimizer parameters specific to submodules
+submodule_configs:
+ # DINOv2
+ model.aggregator.patch_embed:
+ lr: 5e-07
+ min_lr: 5e-09
+ warmup_epochs: ${train_params.warmup_epochs}
+ weight_decay: ${train_params.weight_decay}
+ schedule_type: ${train_params.schedule_type}
diff --git a/mapanything/__init__.py b/mapanything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/datasets/__init__.py b/mapanything/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca5b413385e67e07ae394402431ce78a3b480300
--- /dev/null
+++ b/mapanything/datasets/__init__.py
@@ -0,0 +1,178 @@
+"""
+MapAnything Datasets
+"""
+
+import torch
+
+from mapanything.datasets.wai.ase import ASEWAI # noqa
+from mapanything.datasets.wai.bedlam import BedlamWAI # noqa
+from mapanything.datasets.wai.blendedmvs import BlendedMVSWAI # noqa
+from mapanything.datasets.wai.dl3dv import DL3DVWAI # noqa
+from mapanything.datasets.wai.dtu import DTUWAI # noqa
+from mapanything.datasets.wai.dynamicreplica import DynamicReplicaWAI # noqa
+from mapanything.datasets.wai.eth3d import ETH3DWAI # noqa
+from mapanything.datasets.wai.gta_sfm import GTASfMWAI # noqa
+from mapanything.datasets.wai.matrixcity import MatrixCityWAI # noqa
+from mapanything.datasets.wai.megadepth import MegaDepthWAI # noqa
+from mapanything.datasets.wai.mpsd import MPSDWAI # noqa
+from mapanything.datasets.wai.mvs_synth import MVSSynthWAI # noqa
+from mapanything.datasets.wai.paralleldomain4d import ParallelDomain4DWAI # noqa
+from mapanything.datasets.wai.sailvos3d import SAILVOS3DWAI # noqa
+from mapanything.datasets.wai.scannetpp import ScanNetPPWAI # noqa
+from mapanything.datasets.wai.spring import SpringWAI # noqa
+from mapanything.datasets.wai.structured3d import Structured3DWAI # noqa
+from mapanything.datasets.wai.tav2_wb import TartanAirV2WBWAI # noqa
+from mapanything.datasets.wai.unrealstereo4k import UnrealStereo4KWAI # noqa
+from mapanything.datasets.wai.xrooms import XRoomsWAI # noqa
+from mapanything.utils.train_tools import get_rank, get_world_size
+
+
+def get_test_data_loader(
+ dataset, batch_size, num_workers=8, shuffle=False, drop_last=False, pin_mem=True
+):
+ "Get simple PyTorch dataloader corresponding to the testing dataset"
+ # PyTorch dataset
+ if isinstance(dataset, str):
+ dataset = eval(dataset)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ if torch.distributed.is_initialized():
+ sampler = torch.utils.data.DistributedSampler(
+ dataset,
+ num_replicas=world_size,
+ rank=rank,
+ shuffle=shuffle,
+ drop_last=drop_last,
+ )
+ elif shuffle:
+ sampler = torch.utils.data.RandomSampler(dataset)
+ else:
+ sampler = torch.utils.data.SequentialSampler(dataset)
+
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ sampler=sampler,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=pin_mem,
+ drop_last=drop_last,
+ )
+
+ return data_loader
+
+
+def get_test_many_ar_data_loader(
+ dataset, batch_size, num_workers=8, drop_last=False, pin_mem=True
+):
+ "Get PyTorch dataloader corresponding to the testing dataset that supports many aspect ratios"
+ # PyTorch dataset
+ if isinstance(dataset, str):
+ dataset = eval(dataset)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ # Get BatchedMultiFeatureRandomSampler
+ sampler = dataset.make_sampler(
+ batch_size,
+ shuffle=True,
+ world_size=world_size,
+ rank=rank,
+ drop_last=drop_last,
+ use_dynamic_sampler=False,
+ )
+
+ # Init the data laoder
+ data_loader = torch.utils.data.DataLoader(
+ dataset,
+ sampler=sampler,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_memory=pin_mem,
+ drop_last=drop_last,
+ )
+
+ return data_loader
+
+
+class DynamicBatchDatasetWrapper:
+ """
+ Wrapper dataset that handles DynamicBatchedMultiFeatureRandomSampler output.
+
+ The dynamic sampler returns batches (lists of tuples) instead of individual samples.
+ This wrapper ensures that the underlying dataset's __getitem__ method gets called
+ with individual tuples as expected.
+ """
+
+ def __init__(self, dataset):
+ self.dataset = dataset
+
+ def __getitem__(self, batch_indices):
+ """
+ Handle batch of indices from DynamicBatchedMultiFeatureRandomSampler.
+
+ Args:
+ batch_indices: List of tuples like [(sample_idx, feat_idx_1, feat_idx_2, ...), ...]
+
+ Returns:
+ List of samples from the underlying dataset
+ """
+ if isinstance(batch_indices, (list, tuple)) and len(batch_indices) > 0:
+ # If it's a batch (list of tuples), process each item
+ if isinstance(batch_indices[0], (list, tuple)):
+ return [self.dataset[idx] for idx in batch_indices]
+ else:
+ # Single tuple, call dataset directly
+ return self.dataset[batch_indices]
+ else:
+ # Fallback for single index
+ return self.dataset[batch_indices]
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def __getattr__(self, name):
+ # Delegate all other attributes to the wrapped dataset
+ return getattr(self.dataset, name)
+
+
+def get_train_data_loader(
+ dataset,
+ max_num_of_imgs_per_gpu,
+ num_workers=8,
+ shuffle=True,
+ drop_last=True,
+ pin_mem=True,
+):
+ "Dynamic PyTorch dataloader corresponding to the training dataset"
+ # PyTorch dataset
+ if isinstance(dataset, str):
+ dataset = eval(dataset)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ # Get DynamicBatchedMultiFeatureRandomSampler
+ batch_sampler = dataset.make_sampler(
+ shuffle=shuffle,
+ world_size=world_size,
+ rank=rank,
+ drop_last=drop_last,
+ max_num_of_images_per_gpu=max_num_of_imgs_per_gpu,
+ use_dynamic_sampler=True,
+ )
+
+ # Wrap the dataset to handle batch format from dynamic sampler
+ wrapped_dataset = DynamicBatchDatasetWrapper(dataset)
+
+ # Init the dynamic data loader
+ data_loader = torch.utils.data.DataLoader(
+ wrapped_dataset,
+ batch_sampler=batch_sampler,
+ num_workers=num_workers,
+ pin_memory=pin_mem,
+ )
+
+ return data_loader
diff --git a/mapanything/datasets/base/__init__.py b/mapanything/datasets/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/datasets/base/base_dataset.py b/mapanything/datasets/base/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a184d200d98b181f68a358ae8dea73c3e3b072f
--- /dev/null
+++ b/mapanything/datasets/base/base_dataset.py
@@ -0,0 +1,692 @@
+"""
+Base class for MapAnything datasets.
+"""
+
+from typing import List, Tuple, Union
+
+import numpy as np
+import PIL
+import torch
+import torchvision.transforms as tvf
+from scipy.spatial.transform import Rotation
+
+from mapanything.datasets.base.easy_dataset import EasyDataset
+from mapanything.utils.cropping import (
+ bbox_from_intrinsics_in_out,
+ camera_matrix_of_crop,
+ crop_image_and_other_optional_info,
+ rescale_image_and_other_optional_info,
+)
+from mapanything.utils.geometry import (
+ depthmap_to_camera_coordinates,
+ get_absolute_pointmaps_and_rays_info,
+)
+from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
+
+
+class BaseDataset(EasyDataset):
+ """
+ Define all basic options.
+
+ Usage:
+ class MyDataset(BaseDataset):
+ def _get_views(self, idx):
+ views = []
+ views.append(dict(img=, ...))
+ return views
+ """
+
+ def __init__(
+ self,
+ num_views: int,
+ variable_num_views: bool = False,
+ split: str = None,
+ covisibility_thres: float = None,
+ resolution: Union[int, Tuple[int, int], List[Tuple[int, int]]] = None,
+ principal_point_centered: bool = False,
+ transform: str = None,
+ data_norm_type: str = None,
+ aug_crop: int = 0,
+ seed: int = None,
+ max_num_retries: int = 5,
+ ):
+ """
+ PyTorch dataset for multi-view images sampled from scenes, where the images form a single connected component.
+
+ Args:
+ num_views (int): Number of views.
+ variable_num_views (bool): If True, the number of views can vary from batch to batch. The maximum number of views is num_views and minimum is 2.
+ On by default for N-view train dataloader (hydra config).
+ split (str): 'train', 'val', 'test', etc.
+ covisibility_thres (float): Covisibility (%) threshold to determine if another image is a neighbor or not
+ resolution (int or tuple or list of tuples): Resolution of the images
+ principal_point_centered (bool): If True, the principal point is centered in the image.
+ transform (str): Transform to apply to the images. Options:
+ - 'colorjitter+grayscale+gaublur':
+ tvf.Compose([
+ tvf.RandomApply([tvf.ColorJittter(0.3, 0.4, 0.2, 0.1)], p=0.75),
+ tvf.RandomGrayscale(p=0.05),
+ tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05),
+ ]) after ImgNorm
+ - 'colorjitter': tvf.ColorJittter(0.5, 0.5, 0.5, 0.1) after ImgNorm
+ - 'imgnorm': ImgNorm only
+ data_norm_type (str): Image normalization type.
+ For options, see UniCeption image normalization dict.
+ aug_crop (int): Augment crop. If int greater than 0, indicates the number of pixels to increase in target resolution.
+ seed (int): Seed for the random number generator.
+ max_num_retries (int): Maximum number of retries for loading a different sample from the dataset, if provided idx fails.
+ """
+ self.num_views = num_views
+ self.variable_num_views = variable_num_views
+ self.num_views_min = 2
+ self.split = split
+ self.covisibility_thres = covisibility_thres
+ self._set_resolutions(resolution)
+ self.principal_point_centered = principal_point_centered
+
+ # Update the number of views if necessary and make it a list if variable_num_views is True
+ if self.variable_num_views and self.num_views > self.num_views_min:
+ self.num_views = list(range(self.num_views_min, self.num_views + 1))
+
+ # Initialize the image normalization type
+ if data_norm_type in IMAGE_NORMALIZATION_DICT.keys():
+ self.data_norm_type = data_norm_type
+ image_norm = IMAGE_NORMALIZATION_DICT[data_norm_type]
+ ImgNorm = tvf.Compose(
+ [
+ tvf.ToTensor(),
+ tvf.Normalize(mean=image_norm.mean, std=image_norm.std),
+ ]
+ )
+ elif data_norm_type == "identity":
+ self.data_norm_type = data_norm_type
+ ImgNorm = tvf.Compose([tvf.ToTensor()])
+ else:
+ raise ValueError(
+ f"Unknown data_norm_type: {data_norm_type}. Available options: identity or {list(IMAGE_NORMALIZATION_DICT.keys())}"
+ )
+
+ # Initialize torchvision transforms
+ if transform == "imgnorm":
+ self.transform = ImgNorm
+ elif transform == "colorjitter":
+ self.transform = tvf.Compose([tvf.ColorJitter(0.5, 0.5, 0.5, 0.1), ImgNorm])
+ elif transform == "colorjitter+grayscale+gaublur":
+ self.transform = tvf.Compose(
+ [
+ tvf.RandomApply([tvf.ColorJitter(0.3, 0.4, 0.2, 0.1)], p=0.75),
+ tvf.RandomGrayscale(p=0.05),
+ tvf.RandomApply([tvf.GaussianBlur(5, sigma=(0.1, 1.0))], p=0.05),
+ ImgNorm,
+ ]
+ )
+ else:
+ raise ValueError(
+ 'Unknown transform. Available options: "imgnorm", "colorjitter", "colorjitter+grayscale+gaublur"'
+ )
+
+ # Initialize the augmentation parameters
+ self.aug_crop = aug_crop
+
+ # Initialize the seed for the random number generator
+ self.seed = seed
+ self._seed_offset = 0
+
+ # Initialize the maximum number of retries for loading a different sample from the dataset, if the first idx fails
+ self.max_num_retries = max_num_retries
+
+ # Initialize the dataset type flags
+ self.is_metric_scale = False # by default a dataset is not metric scale, subclasses can overwrite this
+ self.is_synthetic = False # by default a dataset is not synthetic, subclasses can overwrite this
+
+ def _load_data(self):
+ self.scenes = []
+ self.num_of_scenes = len(self.scenes)
+
+ def __len__(self):
+ "Length of the dataset is determined by the number of scenes in the dataset split"
+ return self.num_of_scenes
+
+ def get_stats(self):
+ "Get the number of scenes in the dataset split"
+ return f"{self.num_of_scenes} scenes"
+
+ def __repr__(self):
+ resolutions_str = "[" + ";".join(f"{w}x{h}" for w, h in self._resolutions) + "]"
+ return (
+ f"""{type(self).__name__}({self.get_stats()},
+ {self.num_views=}
+ {self.split=},
+ {self.seed=},
+ resolutions={resolutions_str},
+ {self.transform=})""".replace("self.", "")
+ .replace("\n", "")
+ .replace(" ", "")
+ )
+
+ def _get_views(self, idx, num_views_to_sample, resolution):
+ raise NotImplementedError()
+
+ def _set_seed_offset(self, idx):
+ """
+ Set the seed offset. This is directly added to self.seed when setting the random seed.
+ """
+ self._seed_offset = idx
+
+ def _set_resolutions(self, resolutions):
+ assert resolutions is not None, "undefined resolution"
+
+ if isinstance(resolutions, int):
+ resolutions = [resolutions]
+ elif isinstance(resolutions, tuple):
+ resolutions = [resolutions]
+ elif isinstance(resolutions, list):
+ assert all(isinstance(res, tuple) for res in resolutions), (
+ f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints"
+ )
+ else:
+ raise ValueError(
+ f"Bad type for {resolutions=}, should be int or tuple of ints or list of tuples of ints"
+ )
+
+ self._resolutions = []
+ for resolution in resolutions:
+ if isinstance(resolution, int):
+ width = height = resolution
+ else:
+ width, height = resolution
+ assert isinstance(width, int), (
+ f"Bad type for {width=} {type(width)=}, should be int"
+ )
+ assert isinstance(height, int), (
+ f"Bad type for {height=} {type(height)=}, should be int"
+ )
+ self._resolutions.append((width, height))
+
+ def _crop_resize_if_necessary(
+ self,
+ image,
+ resolution,
+ depthmap,
+ intrinsics,
+ additional_quantities=None,
+ ):
+ """
+ Process an image by downsampling and cropping as needed to match the target resolution.
+
+ This method performs the following operations:
+ 1. Converts the image to PIL.Image if necessary
+ 2. Crops the image centered on the principal point if requested
+ 3. Downsamples the image using high-quality Lanczos filtering
+ 4. Performs final cropping to match the target resolution
+
+ Args:
+ image (numpy.ndarray or PIL.Image.Image): Input image to be processed
+ resolution (tuple): Target resolution as (width, height)
+ depthmap (numpy.ndarray): Depth map corresponding to the image
+ intrinsics (numpy.ndarray): Camera intrinsics matrix (3x3)
+ additional_quantities (dict, optional): Additional image-related data to be processed
+ alongside the main image with nearest interpolation. Defaults to None.
+
+ Returns:
+ tuple: Processed image, depthmap, and updated intrinsics matrix.
+ If additional_quantities is provided, it returns those as well.
+ """
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+
+ # Cropping centered on the principal point if necessary
+ if self.principal_point_centered:
+ W, H = image.size
+ cx, cy = intrinsics[:2, 2].round().astype(int)
+ if cx < 0 or cx >= W or cy < 0 or cy >= H:
+ # Skip centered cropping if principal point is outside image bounds
+ pass
+ else:
+ min_margin_x = min(cx, W - cx)
+ min_margin_y = min(cy, H - cy)
+ left, top = cx - min_margin_x, cy - min_margin_y
+ right, bottom = cx + min_margin_x, cy + min_margin_y
+ crop_bbox = (left, top, right, bottom)
+ # Only perform the centered crop if the crop_bbox is larger than the target resolution
+ crop_width = right - left
+ crop_height = bottom - top
+ if crop_width > resolution[0] and crop_height > resolution[1]:
+ image, depthmap, intrinsics, additional_quantities = (
+ crop_image_and_other_optional_info(
+ image=image,
+ crop_bbox=crop_bbox,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities=additional_quantities,
+ )
+ )
+
+ # Get the target resolution for re-scaling
+ target_rescale_resolution = np.array(resolution)
+ if self.aug_crop > 1:
+ target_rescale_resolution += self._rng.integers(0, self.aug_crop)
+
+ # High-quality Lanczos down-scaling if necessary
+ image, depthmap, intrinsics, additional_quantities = (
+ rescale_image_and_other_optional_info(
+ image=image,
+ output_resolution=target_rescale_resolution,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities_to_be_resized_with_nearest=additional_quantities,
+ )
+ )
+
+ # Actual cropping (if necessary)
+ new_intrinsics = camera_matrix_of_crop(
+ input_camera_matrix=intrinsics,
+ input_resolution=image.size,
+ output_resolution=resolution,
+ offset_factor=0.5,
+ )
+ crop_bbox = bbox_from_intrinsics_in_out(
+ input_camera_matrix=intrinsics,
+ output_camera_matrix=new_intrinsics,
+ output_resolution=resolution,
+ )
+ image, depthmap, new_intrinsics, additional_quantities = (
+ crop_image_and_other_optional_info(
+ image=image,
+ crop_bbox=crop_bbox,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities=additional_quantities,
+ )
+ )
+
+ # Return the output
+ if additional_quantities is not None:
+ return image, depthmap, new_intrinsics, additional_quantities
+ else:
+ return image, depthmap, new_intrinsics
+
+ def _random_walk_sampling(
+ self,
+ scene_pairwise_covisibility,
+ num_of_samples,
+ max_retries=4,
+ use_bidirectional_covis=True,
+ ):
+ """
+ Randomly samples S indices from an N x N covisbility matrix by forming adjacency edges such that the resulting subgraph (given by the indices) is connected.
+ If the current node has no new unvisited neighbors, backtracking occurs.
+ Retries with different starting indices if the desired number of samples is not reached, excluding previously visited components.
+
+ Args:
+ scene_pairwise_covisibility : np.ndarray (mmap)
+ N x N covisibility matrix for the scene, where N is the number of views in the scene.
+ num_of_samples : int
+ The desired number of nodes to sample (num_of_samples < N).
+ max_retries : int
+ The maximum number of retries with different starting indices.
+ use_bidirectional_covis : bool
+ Whether to compute bidirectional covisibility by averaging row and column values.
+ If False, uses only row access (faster for large memory-mapped arrays).
+ Defaults to True.
+
+ Returns:
+ np.ndarray
+ An array of sampled indices forming a connected subgraph.
+ """
+ excluded_nodes = set()
+ best_walk = [] # To keep track of the best walk found
+ for _ in range(max_retries):
+ visited = set()
+ walk = [] # List to store the random walk sampling order
+ stack = [] # Stack for backtracking
+
+ # Choose a random starting index that is not in the excluded set
+ all_nodes = set(range(len(scene_pairwise_covisibility)))
+ available_nodes = list(all_nodes - excluded_nodes)
+ if not available_nodes:
+ break # No more nodes to try
+ start = self._rng.choice(available_nodes)
+ walk.append(start)
+ visited.add(start)
+ stack.append(start)
+
+ # Continue until we have sampled S indices or all expandable nodes are exhausted
+ while len(walk) < num_of_samples and stack:
+ current = stack[-1]
+ # Get the pairwise covisibility for the current node
+ if use_bidirectional_covis:
+ # Use bidirectional covisibility (slower for large memory-mapped arrays)
+ pairwise_covisibility = (
+ scene_pairwise_covisibility[current, :]
+ + scene_pairwise_covisibility[:, current].T
+ ) / 2
+ else:
+ # Use only row access (faster for large memory-mapped arrays)
+ pairwise_covisibility = scene_pairwise_covisibility[current, :]
+ # Normalize the covisibility using self covisibility
+ pairwise_covisibility = pairwise_covisibility / (
+ pairwise_covisibility[current] + 1e-8
+ )
+ # Assign overlap score of zero to self-pairs
+ pairwise_covisibility[current] = 0
+ # Threshold the covisibility to get adjacency list for the current node
+ adjacency_list_for_current = (
+ pairwise_covisibility > self.covisibility_thres
+ ).astype(int)
+ adjacency_list_for_current = np.flatnonzero(adjacency_list_for_current)
+ # Get all unvisited neighbors
+ candidates = [
+ idx for idx in adjacency_list_for_current if idx not in visited
+ ] # Remove visited nodes
+ if candidates:
+ # Randomly select one of the unvisited overlapping neighbors
+ next_node = self._rng.choice(candidates)
+ walk.append(next_node)
+ visited.add(next_node)
+ stack.append(next_node)
+ else:
+ # If no unvisited neighbor is available, backtrack
+ stack.pop()
+
+ # Update the best walk if the current walk is larger
+ if len(walk) > len(best_walk):
+ best_walk = walk
+
+ # If we have enough samples, return the result
+ if len(walk) >= num_of_samples:
+ return np.array(walk)
+
+ # Add all visited nodes to the excluded set
+ excluded_nodes.update(visited)
+
+ # If all retries are exhausted and we still don't have enough samples, return the best walk found
+ return np.array(best_walk)
+
+ def _sample_view_indices(
+ self,
+ num_views_to_sample,
+ num_views_in_scene,
+ scene_pairwise_covisibility,
+ use_bidirectional_covis=True,
+ ):
+ """
+ Sample view indices from a scene based on the adjacency list and the number of views to sample.
+
+ Args:
+ num_views_to_sample (int): Number of views to sample.
+ num_views_in_scene (int): Total number of views available in the scene.
+ scene_pairwise_covisibility (np.ndarray): N x N covisibility matrix for the scene, where N is the number of views in the scene.
+ use_bidirectional_covis (bool): Whether to compute bidirectional covisibility by averaging row and column values.
+ If False, uses only row access (faster for large memory-mapped arrays).
+
+ Returns:
+ numpy.ndarray: Array of sampled view indices.
+ """
+ if num_views_to_sample == num_views_in_scene:
+ # Select all views in the scene
+ view_indices = self._rng.permutation(num_views_in_scene)
+ elif num_views_to_sample > num_views_in_scene:
+ # Select all views in the scene and repeat them to get the desired number of views
+ view_indices = self._rng.choice(
+ num_views_in_scene, size=num_views_to_sample, replace=True
+ )
+ else:
+ # Select a subset of single component connected views in the scene using random walk sampling
+ view_indices = self._random_walk_sampling(
+ scene_pairwise_covisibility,
+ num_views_to_sample,
+ use_bidirectional_covis=use_bidirectional_covis,
+ )
+ # If the required num of views can't be obtained even with 4 retries, repeat existing indices to get the desired number of views
+ if len(view_indices) < num_views_to_sample:
+ view_indices = self._rng.choice(
+ view_indices, size=num_views_to_sample, replace=True
+ )
+
+ return view_indices
+
+ def _getitem_fn(self, idx):
+ if isinstance(idx, tuple):
+ # The idx is a tuple if specifying the aspect-ratio or/and the number of views
+ if isinstance(self.num_views, int):
+ idx, ar_idx = idx
+ else:
+ idx, ar_idx, num_views_to_sample_idx = idx
+ else:
+ assert len(self._resolutions) == 1
+ assert isinstance(self.num_views, int)
+ ar_idx = 0
+
+ # Setup the rng
+ if self.seed: # reseed for each _getitem_fn
+ # Leads to deterministic sampling where repeating self.seed and self._seed_offset yields the same multi-view set again
+ # Scenes will be repeated if size of dataset is artificially increased using "N @" or "N *"
+ # When scenes are repeated, self._seed_offset is increased to ensure new multi-view sets
+ # This is useful for evaluation if the number of dataset scenes is < N, yet we want unique multi-view sets each iter
+ self._rng = np.random.default_rng(seed=self.seed + self._seed_offset + idx)
+ elif not hasattr(self, "_rng"):
+ seed = torch.initial_seed() # this is different for each dataloader process
+ self._rng = np.random.default_rng(seed=seed)
+
+ # Get the views for the given index and check that the number of views is correct
+ resolution = self._resolutions[ar_idx]
+ if isinstance(self.num_views, int):
+ num_views_to_sample = self.num_views
+ else:
+ num_views_to_sample = self.num_views[num_views_to_sample_idx]
+ views = self._get_views(idx, num_views_to_sample, resolution)
+ if isinstance(self.num_views, int):
+ assert len(views) == self.num_views
+ else:
+ assert len(views) in self.num_views
+
+ for v, view in enumerate(views):
+ # Store the index and other metadata
+ view["idx"] = (idx, ar_idx, v)
+ view["is_metric_scale"] = self.is_metric_scale
+ view["is_synthetic"] = self.is_synthetic
+
+ # Check the depth, intrinsics, and pose data (also other data if present)
+ assert "camera_intrinsics" in view
+ assert "camera_pose" in view
+ assert np.isfinite(view["camera_pose"]).all(), (
+ f"NaN or infinite values in camera pose for view {view_name(view)}"
+ )
+ assert np.isfinite(view["depthmap"]).all(), (
+ f"NaN or infinite values in depthmap for view {view_name(view)}"
+ )
+ assert "valid_mask" not in view
+ assert "pts3d" not in view, (
+ f"pts3d should not be there, they will be computed afterwards based on intrinsics+depthmap for view {view_name(view)}"
+ )
+ if "prior_depth_z" in view:
+ assert np.isfinite(view["prior_depth_z"]).all(), (
+ f"NaN or infinite values in prior_depth_z for view {view_name(view)}"
+ )
+ if "non_ambiguous_mask" in view:
+ assert np.isfinite(view["non_ambiguous_mask"]).all(), (
+ f"NaN or infinite values in non_ambiguous_mask for view {view_name(view)}"
+ )
+
+ # Encode the image
+ width, height = view["img"].size
+ view["true_shape"] = np.int32((height, width))
+ view["img"] = self.transform(view["img"])
+ view["data_norm_type"] = self.data_norm_type
+
+ # Compute the pointmaps, raymap and depth along ray
+ (
+ pts3d,
+ valid_mask,
+ ray_origins_world,
+ ray_directions_world,
+ depth_along_ray,
+ ray_directions_cam,
+ pts3d_cam,
+ ) = get_absolute_pointmaps_and_rays_info(**view)
+ view["pts3d"] = pts3d
+ view["valid_mask"] = valid_mask & np.isfinite(pts3d).all(axis=-1)
+ view["depth_along_ray"] = depth_along_ray
+ view["ray_directions_cam"] = ray_directions_cam
+ view["pts3d_cam"] = pts3d_cam
+
+ # Compute the prior depth along ray if present
+ if "prior_depth_z" in view:
+ prior_pts3d, _ = depthmap_to_camera_coordinates(
+ view["prior_depth_z"], view["camera_intrinsics"]
+ )
+ view["prior_depth_along_ray"] = np.linalg.norm(prior_pts3d, axis=-1)
+ view["prior_depth_along_ray"] = view["prior_depth_along_ray"][..., None]
+ del view["prior_depth_z"]
+
+ # Convert ambiguous mask dtype to match valid mask dtype
+ if "non_ambiguous_mask" in view:
+ view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype(
+ view["valid_mask"].dtype
+ )
+ else:
+ ambiguous_mask = view["depthmap"] < 0
+ view["non_ambiguous_mask"] = ~ambiguous_mask
+ view["non_ambiguous_mask"] = view["non_ambiguous_mask"].astype(
+ view["valid_mask"].dtype
+ )
+
+ # Check all datatypes
+ for key, val in view.items():
+ res, err_msg = is_good_type(val)
+ assert res, f"{err_msg} with {key}={val} for view {view_name(view)}"
+
+ # Check shapes
+ assert view["depthmap"].shape == view["img"].shape[1:]
+ assert view["depthmap"].shape == view["pts3d"].shape[:2]
+ assert view["depthmap"].shape == view["valid_mask"].shape
+ assert view["depthmap"].shape == view["depth_along_ray"].shape[:2]
+ assert view["depthmap"].shape == view["ray_directions_cam"].shape[:2]
+ assert view["depthmap"].shape == view["pts3d_cam"].shape[:2]
+ if "prior_depth_along_ray" in view:
+ assert view["depthmap"].shape == view["prior_depth_along_ray"].shape[:2]
+ if "non_ambiguous_mask" in view:
+ assert view["depthmap"].shape == view["non_ambiguous_mask"].shape
+
+ # Expand the last dimennsion of the depthmap
+ view["depthmap"] = view["depthmap"][..., None]
+
+ # Append RNG state to the views, this allows to check whether the RNG is in the same state each time
+ view["rng"] = int.from_bytes(self._rng.bytes(4), "big")
+
+ # Compute and store the quaternions and translation for the camera poses
+ # Notation is (x, y, z, w) for quaternions
+ # This also ensures that the camera poses have a positive determinant (right-handed coordinate system)
+ view["camera_pose_quats"] = (
+ Rotation.from_matrix(view["camera_pose"][:3, :3])
+ .as_quat()
+ .astype(view["camera_pose"].dtype)
+ )
+ view["camera_pose_trans"] = view["camera_pose"][:3, 3].astype(
+ view["camera_pose"].dtype
+ )
+
+ # Check the pointmaps, rays, depth along ray, and camera pose quaternions and translation to ensure they are finite
+ assert np.isfinite(view["pts3d"]).all(), (
+ f"NaN in pts3d for view {view_name(view)}"
+ )
+ assert np.isfinite(view["valid_mask"]).all(), (
+ f"NaN in valid_mask for view {view_name(view)}"
+ )
+ assert np.isfinite(view["depth_along_ray"]).all(), (
+ f"NaN in depth_along_ray for view {view_name(view)}"
+ )
+ assert np.isfinite(view["ray_directions_cam"]).all(), (
+ f"NaN in ray_directions_cam for view {view_name(view)}"
+ )
+ assert np.isfinite(view["pts3d_cam"]).all(), (
+ f"NaN in pts3d_cam for view {view_name(view)}"
+ )
+ assert np.isfinite(view["camera_pose_quats"]).all(), (
+ f"NaN in camera_pose_quats for view {view_name(view)}"
+ )
+ assert np.isfinite(view["camera_pose_trans"]).all(), (
+ f"NaN in camera_pose_trans for view {view_name(view)}"
+ )
+ if "prior_depth_along_ray" in view:
+ assert np.isfinite(view["prior_depth_along_ray"]).all(), (
+ f"NaN in prior_depth_along_ray for view {view_name(view)}"
+ )
+
+ return views
+
+ def __getitem__(self, idx):
+ if self.max_num_retries == 0:
+ return self._getitem_fn(idx)
+
+ num_retries = 0
+ while num_retries <= self.max_num_retries:
+ try:
+ return self._getitem_fn(idx)
+ except Exception as e:
+ scene_idx = idx[0] if isinstance(idx, tuple) else idx
+ print(
+ f"Error in {type(self).__name__}.__getitem__ for scene_idx={scene_idx}: {e}"
+ )
+
+ if num_retries >= self.max_num_retries:
+ print(
+ f"Max retries ({self.max_num_retries}) reached, raising the exception"
+ )
+ raise e
+
+ # Retry with a different scene index
+ num_retries += 1
+ if isinstance(idx, tuple):
+ # The scene index is the first element of the tuple
+ idx_list = list(idx)
+ idx_list[0] = np.random.randint(0, len(self))
+ idx = tuple(idx_list)
+ else:
+ # The scene index is idx
+ idx = np.random.randint(0, len(self))
+ scene_idx = idx[0] if isinstance(idx, tuple) else idx
+ print(
+ f"Retrying with scene_idx={scene_idx} ({num_retries} of {self.max_num_retries})"
+ )
+
+
+def is_good_type(v):
+ """
+ Check if a value has an acceptable data type for processing in the dataset.
+
+ Args:
+ v: The value to check.
+
+ Returns:
+ tuple: A tuple containing:
+ - bool: True if the type is acceptable, False otherwise.
+ - str or None: Error message if the type is not acceptable, None otherwise.
+ """
+ if isinstance(v, (str, int, tuple)):
+ return True, None
+ if v.dtype not in (np.float32, torch.float32, bool, np.int32, np.int64, np.uint8):
+ return False, f"bad {v.dtype=}"
+ return True, None
+
+
+def view_name(view, batch_index=None):
+ """
+ Generate a string identifier for a view based on its dataset, label, and instance.
+
+ Args:
+ view (dict): Dictionary containing view information with 'dataset', 'label', and 'instance' keys.
+ batch_index (int, optional): Index to select from batched data. Defaults to None.
+
+ Returns:
+ str: A formatted string in the form "dataset/label/instance".
+ """
+
+ def sel(x):
+ return x[batch_index] if batch_index not in (None, slice(None)) else x
+
+ db = sel(view["dataset"])
+ label = sel(view["label"])
+ instance = sel(view["instance"])
+ return f"{db}/{label}/{instance}"
diff --git a/mapanything/datasets/base/batched_sampler.py b/mapanything/datasets/base/batched_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..322a66bfadd30b88ea6d14b1f8241dc570263846
--- /dev/null
+++ b/mapanything/datasets/base/batched_sampler.py
@@ -0,0 +1,426 @@
+"""
+Utilities for random sampling under a single or multiple constraints
+
+References: DUSt3R
+"""
+
+import numpy as np
+import torch
+
+
+def round_by(total, multiple, up=False):
+ """
+ Round a number to the nearest multiple of another number.
+
+ Args:
+ total (int): The number to round
+ multiple (int): The multiple to round to
+ up (bool, optional): Whether to round up. Defaults to False.
+
+ Returns:
+ int: The rounded number
+ """
+ if up:
+ total = total + multiple - 1
+ return (total // multiple) * multiple
+
+
+class BatchedRandomSampler:
+ """
+ Random sampling under a constraint: each sample in the batch has the same feature,
+ which is chosen randomly from a known pool of 'features' for each batch.
+
+ For instance, the 'feature' could be the image aspect-ratio.
+
+ The index returned is a tuple (sample_idx, feat_idx).
+ This sampler ensures that each series of `batch_size` indices has the same `feat_idx`.
+ """
+
+ def __init__(
+ self, dataset, batch_size, pool_size, world_size=1, rank=0, drop_last=True
+ ):
+ """
+ Args:
+ dataset: Dataset to sample from
+ batch_size: Number of samples per batch
+ pool_size: Integer representing the size of feature pool
+ world_size: Number of distributed processes
+ rank: Rank of the current process
+ drop_last: Whether to drop the last incomplete batch
+ """
+ self.batch_size = batch_size
+ self.pool_size = pool_size
+
+ self.len_dataset = N = len(dataset)
+ self.total_size = round_by(N, batch_size * world_size) if drop_last else N
+ assert world_size == 1 or drop_last, (
+ "must drop the last batch in distributed mode"
+ )
+
+ # Distributed sampler
+ self.world_size = world_size
+ self.rank = rank
+ self.epoch = None
+
+ def __len__(self):
+ """
+ Get the length of the sampler.
+
+ Returns:
+ int: The number of samples in the sampler for the current process
+ """
+ return self.total_size // self.world_size
+
+ def set_epoch(self, epoch):
+ """
+ Set the epoch for this sampler.
+
+ This should be called before each epoch to ensure proper shuffling of the data.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ self.epoch = epoch
+
+ def __iter__(self):
+ """
+ Iterator over the indices.
+
+ This method generates random indices for each batch, ensuring that all samples
+ within a batch have the same feature index for the given feature pool.
+
+ Yields:
+ tuple: A tuple containing (sample_idx, feat_idx)
+ """
+ # Prepare RNG
+ if self.epoch is None:
+ assert self.world_size == 1 and self.rank == 0, (
+ "use set_epoch() if distributed mode is used"
+ )
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ else:
+ seed = self.epoch + 777
+ rng = np.random.default_rng(seed=seed)
+
+ # Random indices (will restart from 0 if not drop_last)
+ sample_idxs = np.arange(self.total_size)
+ rng.shuffle(sample_idxs)
+
+ # Random feat_idxs (same across each batch)
+ n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
+ feat_idxs = rng.integers(self.pool_size, size=n_batches)
+ feat_idxs = np.broadcast_to(feat_idxs[:, None], (n_batches, self.batch_size))
+ feat_idxs = feat_idxs.ravel()[: self.total_size]
+
+ # Put them together
+ idxs = np.c_[sample_idxs, feat_idxs] # shape = (total_size, 2)
+
+ # Distributed sampler: we select a subset of batches
+ # Make sure the slice for each node is aligned with batch_size
+ size_per_proc = self.batch_size * (
+ (self.total_size + self.world_size * self.batch_size - 1)
+ // (self.world_size * self.batch_size)
+ )
+ idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
+
+ yield from (tuple(idx) for idx in idxs)
+
+
+class BatchedMultiFeatureRandomSampler:
+ """
+ Random sampling under multiple constraints: each sample in the batch has the same features,
+ which are chosen randomly from known pools of 'features' for each batch.
+
+ For instance, the 'features' could be the image aspect-ratio and scene type.
+
+ The index returned is a tuple (sample_idx, feat_idx_1, feat_idx_2, ...).
+ This sampler ensures that each series of `batch_size` indices has the same feature indices.
+ """
+
+ def __init__(
+ self, dataset, batch_size, pool_sizes, world_size=1, rank=0, drop_last=True
+ ):
+ """
+ Args:
+ dataset: Dataset to sample from
+ batch_size: Number of samples per batch
+ pool_sizes: List of integers representing the size of each feature pool
+ world_size: Number of distributed processes
+ rank: Rank of the current process
+ drop_last: Whether to drop the last incomplete batch
+ """
+ self.batch_size = batch_size
+ self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes]
+
+ self.len_dataset = N = len(dataset)
+ self.total_size = round_by(N, batch_size * world_size) if drop_last else N
+ assert world_size == 1 or drop_last, (
+ "must drop the last batch in distributed mode"
+ )
+
+ # Distributed sampler
+ self.world_size = world_size
+ self.rank = rank
+ self.epoch = None
+
+ def __len__(self):
+ """
+ Get the length of the sampler.
+
+ Returns:
+ int: The number of samples in the sampler for the current process
+ """
+ return self.total_size // self.world_size
+
+ def set_epoch(self, epoch):
+ """
+ Set the epoch for this sampler.
+
+ This should be called before each epoch to ensure proper shuffling of the data.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ self.epoch = epoch
+
+ def __iter__(self):
+ """
+ Iterator over the indices.
+
+ This method generates random indices for each batch, ensuring that all samples
+ within a batch have the same feature indices for multiple features.
+
+ Yields:
+ tuple: A tuple containing (sample_idx, feat_idx_1, feat_idx_2, ...)
+ """
+ # Prepare RNG
+ if self.epoch is None:
+ assert self.world_size == 1 and self.rank == 0, (
+ "use set_epoch() if distributed mode is used"
+ )
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ else:
+ seed = self.epoch + 777
+ rng = np.random.default_rng(seed=seed)
+
+ # Random indices (will restart from 0 if not drop_last)
+ sample_idxs = np.arange(self.total_size)
+ rng.shuffle(sample_idxs)
+
+ # Random feat_idxs (same across each batch)
+ n_batches = (self.total_size + self.batch_size - 1) // self.batch_size
+
+ # Generate feature indices for each feature pool
+ all_feat_idxs = []
+ for pool_size in self.pool_sizes:
+ feat_idxs = rng.integers(pool_size, size=n_batches)
+ feat_idxs = np.broadcast_to(
+ feat_idxs[:, None], (n_batches, self.batch_size)
+ )
+ feat_idxs = feat_idxs.ravel()[: self.total_size]
+ all_feat_idxs.append(feat_idxs)
+
+ # Put them together
+ idxs = np.column_stack(
+ [sample_idxs] + all_feat_idxs
+ ) # shape = (total_size, 1 + len(pool_sizes))
+
+ # Distributed sampler: we select a subset of batches
+ # Make sure the slice for each node is aligned with batch_size
+ size_per_proc = self.batch_size * (
+ (self.total_size + self.world_size * self.batch_size - 1)
+ // (self.world_size * self.batch_size)
+ )
+ idxs = idxs[self.rank * size_per_proc : (self.rank + 1) * size_per_proc]
+
+ yield from (tuple(idx) for idx in idxs)
+
+
+class DynamicBatchedMultiFeatureRandomSampler:
+ """
+ Random sampling under multiple constraints with dynamic batch size:
+ each sample in the batch has the same features, which are chosen randomly
+ from known pools of 'features' for each batch.
+
+ The batch size is dynamically determined based on a specified feature index,
+ using a direct mapping from feature values to batch sizes.
+
+ For instance, if one of the features is the number of images in a multi-view set,
+ you can specify different batch sizes for different numbers of images to optimize
+ GPU memory usage. This is achieved by using the feature_to_batch_size_map parameter
+ to directly specify what batch size to use for each feature value.
+
+ The returned index is a list of tuples [(sample_idx, feat_idx_1, feat_idx_2, ...), ...].
+ """
+
+ def __init__(
+ self,
+ dataset,
+ pool_sizes,
+ scaling_feature_idx=0,
+ feature_to_batch_size_map=None,
+ world_size=1,
+ rank=0,
+ drop_last=True,
+ ):
+ """
+ Args:
+ dataset: Dataset to sample from
+ pool_sizes: List of integers representing the size of each feature pool
+ scaling_feature_idx: Index of the feature to use for determining batch size (0-based index into pool_sizes)
+ feature_to_batch_size_map: Optional function or dict that maps feature values directly to batch sizes.
+ For example, if the feature represents number of views, this maps number of views
+ to appropriate batch size that can fit in GPU memory.
+ If None, uses a default batch size of 1 for all feature values.
+ world_size: Number of distributed processes
+ rank: Rank of the current process
+ drop_last: Whether to drop the last incomplete batch
+ """
+ self.pool_sizes = pool_sizes if isinstance(pool_sizes, list) else [pool_sizes]
+ self.scaling_feature_idx = scaling_feature_idx
+
+ # Ensure scaling_feature_idx is valid
+ if scaling_feature_idx < 0 or scaling_feature_idx >= len(self.pool_sizes):
+ raise ValueError(
+ f"scaling_feature_idx must be between 0 and {len(self.pool_sizes) - 1}"
+ )
+
+ # Set up mapping from feature values to batch sizes
+ self.feature_to_batch_size_map = feature_to_batch_size_map
+ if self.feature_to_batch_size_map is None:
+ # Default: batch size of 1 for all feature values
+ self.feature_to_batch_size_map = {
+ i: 1 for i in range(self.pool_sizes[scaling_feature_idx])
+ }
+
+ self.len_dataset = N = len(dataset)
+
+ # We don't know the exact batch size yet, so we use a large number for total_size
+ # This will be adjusted during iteration
+ self.total_size = N
+
+ # Distributed sampler
+ self.world_size = world_size
+ self.rank = rank
+ self.epoch = None
+ self.drop_last = drop_last
+
+ def __len__(self):
+ """
+ Get the approximate length of the sampler.
+
+ Since batch size varies, this is an estimate based on the largest batch size
+ in the mapping, which provides a lower bound on the number of batches.
+
+ Returns:
+ int: The estimated minimum number of samples in the sampler for the current process
+ """
+ # Find the largest batch size in the mapping
+ if callable(self.feature_to_batch_size_map):
+ # If it's a function, sample some values to find the maximum
+ batch_sizes = [
+ self.feature_to_batch_size_map(i)
+ for i in range(self.pool_sizes[self.scaling_feature_idx])
+ ]
+ max_batch_size = max(batch_sizes)
+ else:
+ # If it's a dict or similar, find the maximum directly
+ max_batch_size = max(self.feature_to_batch_size_map.values())
+
+ # Ensure minimum batch size of 1
+ max_batch_size = max(1, max_batch_size)
+
+ # Estimate total batches using the largest batch size
+ # This gives a lower bound on the number of batches
+ total_batches = self.total_size // max_batch_size
+ if not self.drop_last and self.total_size % max_batch_size > 0:
+ total_batches += 1
+
+ # Distribute among processes
+ return total_batches // self.world_size
+
+ def set_epoch(self, epoch):
+ """
+ Set the epoch for this sampler.
+
+ This should be called before each epoch to ensure proper shuffling of the data.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ self.epoch = epoch
+
+ def __iter__(self):
+ """
+ Iterator over the indices with dynamic batch sizes.
+
+ This method generates random indices for each batch, ensuring that all samples
+ within a batch have the same feature indices for multiple features.
+ The batch size is determined directly from the feature_to_batch_size_map.
+
+ The iterator enforces the length returned by __len__() by stopping after
+ exactly that many batches have been yielded for this process.
+
+ Yields:
+ list of tuples: A batch of tuples, each containing (sample_idx, feat_idx_1, feat_idx_2, ...)
+ """
+ # Prepare RNG
+ if self.epoch is None:
+ assert self.world_size == 1 and self.rank == 0, (
+ "use set_epoch() if distributed mode is used"
+ )
+ seed = int(torch.empty((), dtype=torch.int64).random_().item())
+ else:
+ seed = self.epoch + 777
+ rng = np.random.default_rng(seed=seed)
+
+ # Random indices for the entire dataset
+ sample_idxs = np.arange(self.total_size)
+ rng.shuffle(sample_idxs)
+
+ # Get the target number of batches for this process (enforce strict length)
+ target_batches_for_process = len(self)
+ batches_yielded_for_process = 0
+
+ # Process indices in batches with dynamic sizing
+ idx = 0
+ batch_idx = 0 # Track batch index for even distribution
+ while idx < len(sample_idxs) and (
+ batches_yielded_for_process < target_batches_for_process
+ ):
+ # Randomly select feature indices for this batch
+ feat_idxs = [rng.integers(pool_size) for pool_size in self.pool_sizes]
+
+ # Get the scaling feature value
+ scaling_feat = feat_idxs[self.scaling_feature_idx]
+
+ # Get the batch size directly from the mapping
+ if callable(self.feature_to_batch_size_map):
+ batch_size = self.feature_to_batch_size_map(scaling_feat)
+ else:
+ batch_size = self.feature_to_batch_size_map.get(scaling_feat, 1)
+
+ # Ensure minimum batch size of 1
+ batch_size = max(1, batch_size)
+
+ # Ensure we don't go beyond available samples
+ remaining = len(sample_idxs) - idx
+ if remaining < batch_size:
+ if self.drop_last:
+ break
+ batch_size = remaining
+
+ # Create batch with consistent feature indices
+ batch = []
+ for i in range(batch_size):
+ if idx + i < len(sample_idxs):
+ sample_idx = sample_idxs[idx + i]
+ batch.append(tuple([sample_idx] + feat_idxs))
+
+ # Distribute batches among processes in round-robin fashion
+ if len(batch) > 0 and (batch_idx % self.world_size == self.rank):
+ yield batch
+ batches_yielded_for_process += 1
+
+ batch_idx += 1 # Increment batch index
+ idx += batch_size
diff --git a/mapanything/datasets/base/easy_dataset.py b/mapanything/datasets/base/easy_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac429bf808762e4f83f04af50f562e885a3dd61
--- /dev/null
+++ b/mapanything/datasets/base/easy_dataset.py
@@ -0,0 +1,473 @@
+"""
+Base dataset class that enables easy resizing and combining
+
+References: DUSt3R
+"""
+
+import numpy as np
+
+from mapanything.datasets.base.batched_sampler import (
+ BatchedMultiFeatureRandomSampler,
+ DynamicBatchedMultiFeatureRandomSampler,
+)
+
+
+class EasyDataset:
+ """
+ Dataset that can be easily resized and combined.
+
+ Examples:
+ ---------
+ 2 * dataset ==> Duplicate each element 2x
+
+ 10 @ dataset ==> Set the size to 10 (random sampling, duplicates if necessary)
+
+ Dataset1 + Dataset2 ==> Concatenate datasets
+ """
+
+ def __add__(self, other):
+ """
+ Concatenate this dataset with another dataset.
+
+ Args:
+ other (EasyDataset): Another dataset to concatenate with this one
+
+ Returns:
+ CatDataset: A new dataset that is the concatenation of this dataset and the other
+ """
+ return CatDataset([self, other])
+
+ def __rmul__(self, factor):
+ """
+ Multiply the dataset by a factor, duplicating each element.
+
+ Args:
+ factor (int): Number of times to duplicate each element
+
+ Returns:
+ MulDataset: A new dataset with each element duplicated 'factor' times
+ """
+ return MulDataset(factor, self)
+
+ def __rmatmul__(self, factor):
+ """
+ Resize the dataset to a specific size using random sampling.
+
+ Args:
+ factor (int): The new size of the dataset
+
+ Returns:
+ ResizedDataset: A new dataset with the specified size
+ """
+ return ResizedDataset(factor, self)
+
+ def set_epoch(self, epoch):
+ """
+ Set the current epoch for all constituent datasets.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ pass # nothing to do by default
+
+ def make_sampler(
+ self,
+ batch_size=None,
+ shuffle=True,
+ world_size=1,
+ rank=0,
+ drop_last=True,
+ max_num_of_images_per_gpu=None,
+ use_dynamic_sampler=True,
+ ):
+ """
+ Create a sampler for this dataset.
+
+ Args:
+ batch_size (int, optional): Number of samples per batch (used for non-dynamic sampler). Defaults to None.
+ shuffle (bool, optional): Whether to shuffle the dataset. Defaults to True.
+ world_size (int, optional): Number of distributed processes. Defaults to 1.
+ rank (int, optional): Rank of the current process. Defaults to 0.
+ drop_last (bool, optional): Whether to drop the last incomplete batch. Defaults to True.
+ max_num_of_images_per_gpu (int, optional): Maximum number of images per GPU for dynamic batching. Defaults to None.
+ use_dynamic_sampler (bool, optional): Whether to use the dynamic sampler. Defaults to True.
+
+ Returns:
+ DynamicBatchedMultiFeatureRandomSampler or BatchedMultiFeatureRandomSampler: A sampler for this dataset
+
+ Raises:
+ NotImplementedError: If shuffle is False
+ ValueError: If num_views has an invalid type or required parameters are missing
+ """
+ if not (shuffle):
+ raise NotImplementedError() # cannot deal yet
+
+ if isinstance(self.num_views, int):
+ num_of_aspect_ratios = len(self._resolutions)
+ feature_pool_sizes = [num_of_aspect_ratios]
+ scaling_feature_idx = 0 # Use aspect ratio as scaling feature
+ elif isinstance(self.num_views, list):
+ num_of_aspect_ratios = len(self._resolutions)
+ num_of_num_views = len(self.num_views)
+ feature_pool_sizes = [num_of_aspect_ratios, num_of_num_views]
+ scaling_feature_idx = 1 # Use num_views as scaling feature
+ else:
+ raise ValueError(
+ f"Bad type for {self.num_views=}, should be int or list of ints"
+ )
+
+ if use_dynamic_sampler:
+ if max_num_of_images_per_gpu is None:
+ raise ValueError(
+ "max_num_of_images_per_gpu must be provided when using dynamic sampler"
+ )
+
+ # Create feature-to-batch-size mapping
+ if isinstance(self.num_views, list):
+ # Map num_views_idx to batch size: max(1, max_num_of_images_per_gpu // (num_views_idx + dataset.num_views_min))
+ feature_to_batch_size_map = {}
+ for num_views_idx, num_views in enumerate(self.num_views):
+ batch_size_for_multi_view_sets = max(
+ 1, max_num_of_images_per_gpu // num_views
+ )
+ feature_to_batch_size_map[num_views_idx] = (
+ batch_size_for_multi_view_sets
+ )
+ else:
+ # For fixed num_views, use a simple mapping
+ feature_to_batch_size_map = {
+ 0: max(1, max_num_of_images_per_gpu // self.num_views)
+ }
+
+ return DynamicBatchedMultiFeatureRandomSampler(
+ self,
+ pool_sizes=feature_pool_sizes,
+ scaling_feature_idx=scaling_feature_idx,
+ feature_to_batch_size_map=feature_to_batch_size_map,
+ world_size=world_size,
+ rank=rank,
+ drop_last=drop_last,
+ )
+ else:
+ if batch_size is None:
+ raise ValueError(
+ "batch_size must be provided when not using dynamic sampler"
+ )
+
+ return BatchedMultiFeatureRandomSampler(
+ self,
+ batch_size,
+ feature_pool_sizes,
+ world_size=world_size,
+ rank=rank,
+ drop_last=drop_last,
+ )
+
+
+class MulDataset(EasyDataset):
+ """Artifically augmenting the size of a dataset."""
+
+ multiplicator: int
+
+ def __init__(self, multiplicator, dataset):
+ """
+ Initialize a dataset that artificially augments the size of another dataset.
+
+ Args:
+ multiplicator (int): Factor by which to multiply the dataset size
+ dataset (EasyDataset): The dataset to augment
+ """
+ assert isinstance(multiplicator, int) and multiplicator > 0
+ self.multiplicator = multiplicator
+ self.dataset = dataset
+
+ def __len__(self):
+ """
+ Get the length of the dataset.
+
+ Returns:
+ int: The number of samples in the dataset
+ """
+ return self.multiplicator * len(self.dataset)
+
+ def __repr__(self):
+ """
+ Get a string representation of the dataset.
+
+ Returns:
+ str: String representation showing the multiplication factor and the original dataset
+ """
+ return f"{self.multiplicator}*{repr(self.dataset)}"
+
+ def __getitem__(self, idx):
+ """
+ Get an item from the dataset.
+
+ Args:
+ idx: Index or tuple of indices to retrieve
+
+ Returns:
+ The item at the specified index from the original dataset
+ """
+ if isinstance(idx, tuple):
+ other = idx[1:]
+ idx = idx[0]
+ new_idx = (idx // self.multiplicator, *other)
+ return self.dataset[new_idx]
+ else:
+ return self.dataset[idx // self.multiplicator]
+
+ @property
+ def _resolutions(self):
+ """
+ Get the resolutions of the dataset.
+
+ Returns:
+ The resolutions from the original dataset
+ """
+ return self.dataset._resolutions
+
+ @property
+ def num_views(self):
+ """
+ Get the number of views used for the dataset.
+
+ Returns:
+ int or list: The number of views parameter from the original dataset
+ """
+ return self.dataset.num_views
+
+
+class ResizedDataset(EasyDataset):
+ """Artifically changing the size of a dataset."""
+
+ new_size: int
+
+ def __init__(self, new_size, dataset):
+ """
+ Initialize a dataset with an artificially changed size.
+
+ Args:
+ new_size (int): The new size of the dataset
+ dataset (EasyDataset): The original dataset
+ """
+ assert isinstance(new_size, int) and new_size > 0
+ self.new_size = new_size
+ self.dataset = dataset
+
+ def __len__(self):
+ """
+ Get the length of the dataset.
+
+ Returns:
+ int: The new size of the dataset
+ """
+ return self.new_size
+
+ def __repr__(self):
+ """
+ Get a string representation of the dataset.
+
+ Returns:
+ str: String representation showing the new size and the original dataset
+ """
+ size_str = str(self.new_size)
+ for i in range((len(size_str) - 1) // 3):
+ sep = -4 * i - 3
+ size_str = size_str[:sep] + "_" + size_str[sep:]
+ return f"{size_str} @ {repr(self.dataset)}"
+
+ def set_epoch(self, epoch):
+ """
+ Set the current epoch and generate a new random mapping of indices.
+
+ This method must be called before using __getitem__.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ # This random shuffle only depends on the epoch
+ rng = np.random.default_rng(seed=epoch + 777)
+
+ # Shuffle all indices
+ perm = rng.permutation(len(self.dataset))
+
+ # Calculate how many repetitions we need
+ num_repetitions = 1 + (len(self) - 1) // len(self.dataset)
+
+ # Rotary extension until target size is met
+ shuffled_idxs = np.concatenate([perm] * num_repetitions)
+ self._idxs_mapping = shuffled_idxs[: self.new_size]
+
+ # Generate the seed offset for each repetition
+ # This is needed to ensure we see unique samples when we repeat a scene
+ seed_offset_per_repetition = [
+ np.full(len(self.dataset), i) for i in range(num_repetitions)
+ ]
+ seed_offset_idxs = np.concatenate(seed_offset_per_repetition)
+ self._idxs_seed_offset = seed_offset_idxs[: self.new_size]
+
+ assert len(self._idxs_mapping) == self.new_size
+ assert len(self._idxs_seed_offset) == self.new_size
+
+ def __getitem__(self, idx):
+ """
+ Get an item from the dataset.
+
+ Args:
+ idx: Index or tuple of indices to retrieve
+
+ Returns:
+ The item at the mapped index from the original dataset
+
+ Raises:
+ AssertionError: If set_epoch has not been called
+ """
+ assert hasattr(self, "_idxs_mapping"), (
+ "You need to call dataset.set_epoch() to use ResizedDataset.__getitem__()"
+ )
+ if isinstance(idx, tuple):
+ other = idx[1:]
+ idx = idx[0]
+ self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
+ new_idx = (self._idxs_mapping[idx], *other)
+ return self.dataset[new_idx]
+ else:
+ self.dataset._set_seed_offset(self._idxs_seed_offset[idx])
+ return self.dataset[self._idxs_mapping[idx]]
+
+ @property
+ def _resolutions(self):
+ """
+ Get the resolutions of the dataset.
+
+ Returns:
+ The resolutions from the original dataset
+ """
+ return self.dataset._resolutions
+
+ @property
+ def num_views(self):
+ """
+ Get the number of views used for the dataset.
+
+ Returns:
+ int or list: The number of views parameter from the original dataset
+ """
+ return self.dataset.num_views
+
+
+class CatDataset(EasyDataset):
+ """Concatenation of several datasets"""
+
+ def __init__(self, datasets):
+ """
+ Initialize a dataset that is a concatenation of several datasets.
+
+ Args:
+ datasets (list): List of EasyDataset instances to concatenate
+ """
+ for dataset in datasets:
+ assert isinstance(dataset, EasyDataset)
+ self.datasets = datasets
+ self._cum_sizes = np.cumsum([len(dataset) for dataset in datasets])
+
+ def __len__(self):
+ """
+ Get the length of the concatenated dataset.
+
+ Returns:
+ int: Total number of samples across all datasets
+ """
+ return self._cum_sizes[-1]
+
+ def __repr__(self):
+ """
+ Get a string representation of the concatenated dataset.
+
+ Returns:
+ str: String representation showing all concatenated datasets joined by '+'
+ """
+ # Remove uselessly long transform
+ return " + ".join(
+ repr(dataset).replace(
+ ",transform=Compose( ToTensor() Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)))",
+ "",
+ )
+ for dataset in self.datasets
+ )
+
+ def set_epoch(self, epoch):
+ """
+ Set the current epoch for all constituent datasets.
+
+ Args:
+ epoch (int): The current epoch number
+ """
+ for dataset in self.datasets:
+ dataset.set_epoch(epoch)
+
+ def __getitem__(self, idx):
+ """
+ Get an item from the concatenated dataset.
+
+ Args:
+ idx: Index or tuple of indices to retrieve
+
+ Returns:
+ The item at the specified index from the appropriate constituent dataset
+
+ Raises:
+ IndexError: If the index is out of range
+ """
+ other = None
+ if isinstance(idx, tuple):
+ other = idx[1:]
+ idx = idx[0]
+
+ if not (0 <= idx < len(self)):
+ raise IndexError()
+
+ db_idx = np.searchsorted(self._cum_sizes, idx, "right")
+ dataset = self.datasets[db_idx]
+ new_idx = idx - (self._cum_sizes[db_idx - 1] if db_idx > 0 else 0)
+
+ if other is not None:
+ new_idx = (new_idx, *other)
+ return dataset[new_idx]
+
+ @property
+ def _resolutions(self):
+ """
+ Get the resolutions of the dataset.
+
+ Returns:
+ The resolutions from the first dataset (all datasets must have the same resolutions)
+
+ Raises:
+ AssertionError: If datasets have different resolutions
+ """
+ resolutions = self.datasets[0]._resolutions
+ for dataset in self.datasets[1:]:
+ assert tuple(dataset._resolutions) == tuple(resolutions), (
+ "All datasets must have the same resolutions"
+ )
+ return resolutions
+
+ @property
+ def num_views(self):
+ """
+ Get the number of views used for the dataset.
+
+ Returns:
+ int or list: The number of views parameter from the first dataset
+
+ Raises:
+ AssertionError: If datasets have different num_views
+ """
+ num_views = self.datasets[0].num_views
+ for dataset in self.datasets[1:]:
+ assert dataset.num_views == num_views, (
+ "All datasets must have the same num_views and variable_num_views parameters"
+ )
+ return num_views
diff --git a/mapanything/datasets/utils/__init__.py b/mapanything/datasets/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/datasets/utils/data_splits.py b/mapanything/datasets/utils/data_splits.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a69ec6df5a98cd21e15c9de1791be49903d8383
--- /dev/null
+++ b/mapanything/datasets/utils/data_splits.py
@@ -0,0 +1,1741 @@
+"""
+Modules containing dataset split information
+"""
+
+
+class BlendedMVSSplits:
+ """
+ This class contains the information about the BlendedMVS dataset splits.
+ """
+
+ def __init__(self):
+ """
+ The splits are generated using the following logic:
+ # Get all seqls and seqhs using self.blendedmvs_info.all_sequences
+ all_sequences = self.blendedmvs_info.all_sequences
+ all_seqls = [int(seq[8:], 16) for seq in all_sequences]
+ all_seqhs = [int(seq[:8], 16) for seq in all_sequences]
+ # Split the seqls (& corresponding seqhs) using the DUSt3R train/val split logic
+ if split is None:
+ selection = slice(None)
+ elif split in ["train", "overfit"]:
+ # select 90% of all scenes
+ selection = [(seql % 10) > 0 for seql in all_seqls]
+ elif split == "val":
+ # select 10% of all scenes
+ selection = [(seql % 10) == 0 for seql in all_seqls]
+ else:
+ raise ValueError(f"Unknown split {split}, must be None, train, val or overfit")
+ # Filter sequences based on the selection
+ selected_seqls = [seql for seql, sel in zip(all_seqls, selection) if sel]
+ selected_seqhs = [seqh for seqh, sel in zip(all_seqhs, selection) if sel]
+ # Put them back into sequence names f"{seqh:08x}{seql:016x}"
+ sequence_names = [f"{seqh:08x}{seql:016x}" for seqh, seql in zip(selected_seqhs, selected_seqls)]
+ # Remove invalid sequence names which don't exist in self.blendedmvs_info.sequences
+ valid_sequences = set(self.blendedmvs_info.sequences)
+ valid_sequence_names = [name for name in sequence_names if name in valid_sequences]
+ """
+ # All the 502 sequences in the dataset (totals to 115k images)
+ self.all_scenes = [
+ "000000000000000000000000",
+ "00000000000000000000000a",
+ "00000000000000000000000b",
+ "00000000000000000000000c",
+ "00000000000000000000000d",
+ "00000000000000000000000e",
+ "00000000000000000000000f",
+ "000000000000000000000001",
+ "00000000000000000000001a",
+ "00000000000000000000001b",
+ "00000000000000000000001d",
+ "000000000000000000000002",
+ "000000000000000000000003",
+ "000000000000000000000004",
+ "000000000000000000000005",
+ "5a2a95f032a1c655cfe3de62",
+ "5a2af22b32a1c655cfe46013",
+ "5a2ba6de32a1c655cfe51b79",
+ "5a3b9731e24cd76dad1a5f1b",
+ "5a3ca9cb270f0e3f14d0eddb",
+ "5a3cb4e4270f0e3f14d12f43",
+ "5a03e732454a8a7ec672776c",
+ "5a3f4aba5889373fbbc5d3b5",
+ "5a4a38dad38c8a075495b5d2",
+ "5a5a1e48d62c7a12d5d00e47",
+ "5a6b1c418d100c2f8fdc4411",
+ "5a6feeb54a7fbc3f874f9db7",
+ "5a7cb1d6fe5c0d6fb53e64fb",
+ "5a7d3db14989e929563eb153",
+ "5a8aa0fab18050187cbe060e",
+ "5a9e5df65baeef72b4a021cd",
+ "5a48ba95c7dab83a7d7b44ed",
+ "5a48c4e9c7dab83a7d7b5cc7",
+ "5a48d4b2c7dab83a7d7b9851",
+ "5a69c47d0d5d0a7f3b2e9752",
+ "5a77b46b318efe6c6736e68a",
+ "5a355c271b63f53d5970f362",
+ "5a489fb1c7dab83a7d7b1070",
+ "5a533e8034d7582116e34209",
+ "5a562fc7425d0f5186314725",
+ "5a572fd9fc597b0478a81d14",
+ "5a588a8193ac3d233f77fbca",
+ "5a618c72784780334bc1972d",
+ "5a752d42acc41e2423f17674",
+ "5a969eea91dfc339a9a3ad2c",
+ "5a8315f624b8e938486e0bd8",
+ "5a57542f333d180827dfc132",
+ "5a0271884e62597cdee0d0eb",
+ "5a6400933d809f1d8200af15",
+ "5a6464143d809f1d8208c43c",
+ "5a563183425d0f5186314855",
+ "5aa0f9d7a9efce63548c69a1",
+ "5aa0f478a9efce63548c1cb4",
+ "5aa7db90bfdd572271e95246",
+ "5aa235f64a17b335eeaf9609",
+ "5aa515e613d42d091d29d300",
+ "5aa1196ea9efce63548ed649",
+ "5aaadd4cbc13235570d178a7",
+ "5ab6af12ac4291329b1072ab",
+ "5ab7e00aac4291329b15864d",
+ "5ab8b8e029f5351f7f2ccf59",
+ "5ab74bf2ac4291329b11e879",
+ "5ab85f1dac4291329b17cb50",
+ "5ab8713ba3799a1d138bd69a",
+ "5abc2506b53b042ead637d86",
+ "5acc7459a7853c4b5ebbef59",
+ "5acf8ca0f3d8a750097e4b15",
+ "5adc6bd52430a05ecb2ffb85",
+ "5ae2e9c5fe405c5076abc6b2",
+ "5af02e904c8216544b4ab5a2",
+ "5af28cea59bc705737003253",
+ "5af545d0559359053d25dcf5",
+ "5afacb69ab00705d0cefdd5b",
+ "5b2c67b5e0878c381608b8d8",
+ "5b3b2b9e8d46a939f933fdc0",
+ "5b3b353d8d46a939f93524b9",
+ "5b6e716d67b396324c2d77cb",
+ "5b6eff8b67b396324c5b2672",
+ "5b7a3890fc8fcf6781e2593a",
+ "5b21e18c58e2823a67a10dd8",
+ "5b60fa0c764f146feef84df0",
+ "5b69cc0cb44b61786eb959bf",
+ "5b78e57afc8fcf6781d0c3ba",
+ "5b192eb2170cf166458ff886",
+ "5b558a928bbfb62204e77ba2",
+ "5b864d850d072a699b32f4ae",
+ "5b908d3dc6ab78485f3d24a9",
+ "5b950c71608de421b1e7318f",
+ "5b4933abf2b5f44e95de482a",
+ "5b08286b2775267d5b0634ba",
+ "5b37189a35304b6f75e7583e",
+ "5b271079e0878c3816dacca4",
+ "5b22269758e2823a67a3bd03",
+ "5b62647143840965efc0dbde",
+ "5ba19a8a360c7c30c1c169df",
+ "5ba75d79d76ffa2c86cf2f05",
+ "5bb7a08aea1cfa39f1a947ab",
+ "5bb8a49aea1cfa39f1aa7f75",
+ "5bbb6eb2ea1cfa39f1af7e0c",
+ "5bc5f0e896b66a2cd8f9bd36",
+ "5bccd6beca24970bce448134",
+ "5bce7ac9ca24970bce4934b6",
+ "5bcf979a6d5f586b95c258cd",
+ "5bd43b4ba6b28b1ee86b92dd",
+ "5be3a5fb8cfdd56947f6b67c",
+ "5be3ae47f44e235bdbbc9771",
+ "5be4ab93870d330ff2dce134",
+ "5be47bf9b18881428d8fbc1d",
+ "5be883a4f98cee15019d5b83",
+ "5bea87f4abd34c35e1860ab5",
+ "5beb6e66abd34c35e18e66b9",
+ "5bf3a82cd439231948877aed",
+ "5bf7d63575c26f32dbf7413b",
+ "5bf17c0fd439231948355385",
+ "5bf26cbbd43923194854b270",
+ "5bf03590d4392319481971dc",
+ "5bf18642c50e6f7f8bdbd492",
+ "5bf21799d43923194842c001",
+ "5bfc9d5aec61ca1dd69132a2",
+ "5bfd0f32ec61ca1dd69dc77b",
+ "5bfe5ae0fe0ea555e6a969ca",
+ "5bff3c5cfe0ea555e6bcbf3a",
+ "5c0d13b795da9479e12e2ee9",
+ "5c1af2e2bee9a723c963d019",
+ "5c1b1500bee9a723c96c3e78",
+ "5c1dbf200843bc542d8ef8c4",
+ "5c1f33f1d33e1f2e4aa6dda4",
+ "5c2b3ed5e611832e8aed46bf",
+ "5c20ca3a0843bc542d94e3e2",
+ "5c062d84a96e33018ff6f0a6",
+ "5c189f2326173c3a09ed7ef3",
+ "5c1892f726173c3a09ea9aeb",
+ "5c34300a73a8df509add216d",
+ "5c34529873a8df509ae57b58",
+ "000000000000000000000006",
+ "000000000000000000000007",
+ "000000000000000000000008",
+ "000000000000000000000009",
+ "000000000000000000000010",
+ "000000000000000000000011",
+ "000000000000000000000012",
+ "000000000000000000000015",
+ "000000000000000000000016",
+ "000000000000000000000017",
+ "000000000000000000000018",
+ "000000000000000000000019",
+ "56d73ba74bd29b8c35abade2",
+ "56f34064e296120e10484dc4",
+ "57a4a7bb6b9272286e26dc18",
+ "57f8d9bbe73f6760f10e916a",
+ "58a0a2f33d0b4542479a11b1",
+ "58a0dd1a3d0b4542479a28f3",
+ "58a1a7914a4d262a170b1101",
+ "58a1bc804a4d262a170b2f01",
+ "58a1d9d14a4d262a170b58fe",
+ "58a01dea38486e3c98475871",
+ "58a1f5d74a4d262a170b65fc",
+ "58a2a09e156b87103d3d668c",
+ "58a2d9c3156b87103d3da90f",
+ "58a3ccb0156b87103d3e4332",
+ "58a3f2f8156b87103d3e5838",
+ "58a3f6c0156b87103d3e5971",
+ "58a3fc95156b87103d3e5d9b",
+ "58a07ce53d0b45424799fdde",
+ "58a07f233d0b45424799ffe7",
+ "58a44df2156b87103d3ee239",
+ "58a164f73d0b4542479a7a8e",
+ "58a0365e38486e3c984783eb",
+ "58a439cf156b87103d3ec885",
+ "58a464aa156b87103d3eec04",
+ "58a4452f156b87103d3ed55b",
+ "58a160983d0b4542479a7347",
+ "58a186444a4d262a170ae3ae",
+ "58a285424a4d262a170baf3e",
+ "58a41819156b87103d3e92a5",
+ "58a44463156b87103d3ed45e",
+ "58a47552156b87103d3f00a4",
+ "58c4bb4f4a69c55606122be4",
+ "58c6451e4a69c556061894f1",
+ "58ca7014affdfd07c70a95ce",
+ "58cf4771d0f5fb221defe6da",
+ "58d36897f387231e6c929903",
+ "58eaf1513353456af3a1682a",
+ "58f7f7299f5b5647873cb110",
+ "58f73e7c9f5b56478738929f",
+ "59a8f851597729752c31e7e0",
+ "59a452bf9b460239aa5d1c72",
+ "59a9619a825418241fb88191",
+ "59acd2f4b891807f439c8992",
+ "59bf97fe7e7b31545da34439",
+ "59c1c3e2fd6e3d4ead9f1013",
+ "59d2657f82ca7774b1ec081d",
+ "59da1fb88a126011d0394ae9",
+ "59e75a2ca9e91f2c5526005d",
+ "59e864b2a9e91f2c5529325f",
+ "59ecfd02e225f6492d20fcc9",
+ "59f37f74b45be2233001ba18",
+ "59f70ab1e5c5d366af29bf3e",
+ "59f87d0bfa6280566fb38c9a",
+ "59f363a8b45be22330016cad",
+ "564a27b26d07883f460d8ab0",
+ "565fb1dead14d4154dae2b94",
+ "567a0fb0a825d2fb79ac9a20",
+ "569b92eb826bcba945ca002b",
+ "576fefa017ce5a16397e87fd",
+ "584a7333fe3cb463906c9fe6",
+ "584aa8e9fe3cb463906cc7d0",
+ "584ad76bfe3cb463906ce6dc",
+ "584af003fe3cb463906d0e9b",
+ "584b9a747072670e72bfc49d",
+ "584b671f7072670e72bfaaf8",
+ "584b81747072670e72bfbbfd",
+ "584ba35f7072670e72bfca4d",
+ "584ba5977072670e72bfcc2d",
+ "584bc53c7072670e72bfe85f",
+ "584bc3997072670e72bfe58d",
+ "584bc4407072670e72bfe665",
+ "584bd5587072670e72bffe39",
+ "584bdadf7072670e72c0005c",
+ "584be5ed7072670e72c007b3",
+ "584c9ad27072670e72c060c5",
+ "584c9cc67072670e72c063a1",
+ "584c58b77072670e72c03990",
+ "584cea557072670e72c07fb4",
+ "584d19d47072670e72c0c6c0",
+ "584dfe467072670e72c1665a",
+ "584e875c7072670e72c1ec94",
+ "584e05667072670e72c17167",
+ "584f94e87072670e72c2d3f7",
+ "584fdffd7072670e72c32dc7",
+ "584fe07f7072670e72c32e59",
+ "585a2a71b338a62ad50138dc",
+ "585a206ab338a62ad501298f",
+ "585a217cb338a62ad5012b38",
+ "585b34afb338a62ad501e836",
+ "585bb25fc49c8507c3ce7812",
+ "585bbe55c49c8507c3ce81cd",
+ "585d6c8a2a57cc11d4920a1e",
+ "585e54c72a57cc11d492f71a",
+ "585e34302a57cc11d492be30",
+ "585ee0632a57cc11d4933608",
+ "585f9661712e2761468dabca",
+ "585ffe9a712e2761468df643",
+ "586a37ec9d1b5e34c28184fc",
+ "586a515a9d1b5e34c281b431",
+ "586a94939d1b5e34c2823b5d",
+ "586abc689d1b5e34c2826360",
+ "586b0e219d1b5e34c2828862",
+ "586b3db89d1b5e34c282cd52",
+ "586b4c459d1b5e34c282e66d",
+ "586b7d7d9d1b5e34c283359e",
+ "586b8f149d1b5e34c283497c",
+ "586b8f629d1b5e34c28349d6",
+ "586c4c4d9d1b5e34c28391a1",
+ "586c5b5b9d1b5e34c2839a5b",
+ "586c9fdf9d1b5e34c283b657",
+ "586c48329d1b5e34c2838e80",
+ "586caab99d1b5e34c283c213",
+ "586cd0779d1b5e34c28403a7",
+ "586d6d249d1b5e34c284b80e",
+ "586d8a029d1b5e34c284c948",
+ "586d55af9d1b5e34c284a999",
+ "586d07869d1b5e34c2842e5b",
+ "586d27489d1b5e34c28453af",
+ "586df9849d1b5e34c28506de",
+ "586e279c9d1b5e34c2852180",
+ "587bc5ec2366dd5d06e262c1",
+ "587c1abf2366dd5d06e28901",
+ "587c03f12366dd5d06e27722",
+ "587c19da2366dd5d06e2877b",
+ "587c31b92366dd5d06e2a9dc",
+ "587c87d02366dd5d06e2f989",
+ "587c97a52366dd5d06e30a96",
+ "587c45192366dd5d06e2c0eb",
+ "587cec702366dd5d06e37862",
+ "587cef0a2366dd5d06e379e3",
+ "587db5872366dd5d06e3e0af",
+ "587e2b1d2366dd5d06e41af0",
+ "587e2ea62366dd5d06e41f2e",
+ "587e5cb52366dd5d06e4486e",
+ "587eb1822366dd5d06e45f29",
+ "587f365d2366dd5d06e4906e",
+ "588a9c5fec4d5a1c088ec350",
+ "588a34cfec4d5a1c088ea8d1",
+ "588ab5bdec4d5a1c088ed60f",
+ "588aff9d90414422fbe7885a",
+ "588b20d290414422fbe79f40",
+ "588c08d590414422fbe8200b",
+ "588c203d90414422fbe8319e",
+ "588c989a90414422fbe86d96",
+ "588ca09d90414422fbe871a1",
+ "588cce2190414422fbe88520",
+ "588cd5ef90414422fbe8875c",
+ "588cf0ad90414422fbe8a20f",
+ "588e0d8c90414422fbe8f8b2",
+ "588e01c490414422fbe8ee2a",
+ "588e35e690414422fbe90a53",
+ "588f017e90414422fbe9b74b",
+ "588f095190414422fbe9c1ee",
+ "589aca717dc3d323d55671c4",
+ "589af2c97dc3d323d55691e8",
+ "589b49ea7dc3d323d556d9b4",
+ "589b04287dc3d323d556a185",
+ "589bf6a57dc3d323d55743ab",
+ "589c3c497dc3d323d5578468",
+ "589c3c577dc3d323d5578480",
+ "589c300f7dc3d323d5577926",
+ "589c24527dc3d323d5577126",
+ "589c35457dc3d323d5577d8d",
+ "589ca6a6b896147a1b73aff7",
+ "589d1e1fb896147a1b73ee5b",
+ "589d5c58b896147a1b742256",
+ "589d95538fa2cf375df3317b",
+ "589df0ffb504a864ad63521a",
+ "589ea316b504a864ad639a2b",
+ "589ec97cb504a864ad63adc3",
+ "589f214338486e3c9846f123",
+ "589fdfe738486e3c984736cf",
+ "590c2d70336bb52a190be886",
+ "590f91851225725be9e25d4e",
+ "591a467a6109e14d4f09b776",
+ "591cf3033162411cf9047f37",
+ "591ea44850991c70dc99a207",
+ "599aa591d5b41f366fed0d58",
+ "5643df56138263b51db1b5f3",
+ "5644bdac138263b51db9f669",
+ "5692a4c2adafac1f14201821",
+ "5850d4f97072670e72c425d6",
+ "5854c405804be105852330fe",
+ "5855a4fc804be1058523bd75",
+ "5856ac15804be105852419d8",
+ "5856ae8b804be10585241bae",
+ "5856b460804be10585242059",
+ "5857aa5ab338a62ad5ff4dbe",
+ "5857acf8b338a62ad5ff5107",
+ "5858db6cb338a62ad500103b",
+ "5858dbcab338a62ad5001081",
+ "5859d84fb338a62ad500e5cf",
+ "5861d8ea712e2761468f3cb3",
+ "5863edf8712e27614690cce0",
+ "5864a935712e2761469111b4",
+ "5864b076712e27614691197e",
+ "5864da88712e276146913d8b",
+ "5865f4a8712e27614691e39b",
+ "5867a434833dfe3f7b88edaf",
+ "5868cd15833dfe3f7b89bfa3",
+ "5880b3692366dd5d06e5d534",
+ "5880e3422366dd5d06e5ff8e",
+ "5880f0ef2366dd5d06e6166e",
+ "5881d2bfb6844814c136a119",
+ "5881f11d8ce2c2754d0714c3",
+ "5881fee18ce2c2754d0723f8",
+ "5882cda2b116682b4adebd25",
+ "5882d58fb116682b4adec7db",
+ "5884c256932ba84fbed70bf5",
+ "5884cc13932ba84fbed71ec4",
+ "5885bc5296fa095e0671a7f0",
+ "5886d14cb791366d617a362c",
+ "5888becfc02346100f4b0b21",
+ "5888e408c02346100f4b1a29",
+ "5889da66ec4d5a1c088e5187",
+ "5889e344ec4d5a1c088e59be",
+ "5889e754ec4d5a1c088e60ba",
+ "5890c16b90414422fbeb0262",
+ "5891d8ae9a8c0314c5cd30ab",
+ "5891d0479a8c0314c5cd2abd",
+ "5891ecf19a8c0314c5cd490a",
+ "5892c0cd9a8c0314c5cdc977",
+ "5894ab309a8c0314c5cee57d",
+ "5895a6a89a8c0314c5cfca7c",
+ "5895b8c29a8c0314c5cfd051",
+ "5895d38f9a8c0314c5cfe50c",
+ "5895f2329a8c0314c5d00117",
+ "5896bb989a8c0314c5d086b6",
+ "5896ebf39a8c0314c5d0a8c4",
+ "5898b1bac9dccc22987b7f74",
+ "5898b6ffc9dccc22987b8a03",
+ "5898b31cc9dccc22987b82ec",
+ "5898bbaac9dccc22987b8eba",
+ "5899cfa6b76d7a3780a4cb64",
+ "5899e5dcb76d7a3780a4ecc1",
+ "5947b62af1b45630bd0c2a02",
+ "57102be2877e1421026358af",
+ "57153d4031bb9900425bde85",
+ "57177cd7fb8d93461afc4527",
+ "58497cdf97b73e0b090c4273",
+ "58500b007072670e72c35588",
+ "58510bf97072670e72c46ddf",
+ "58522bd56789802282f2ecb3",
+ "58524a2e0e7012308944bcf3",
+ "58524a080e7012308944bcbf",
+ "58524c1d0e7012308944bfda",
+ "58524f170e7012308944c200",
+ "58529a4e0e70123089454c6f",
+ "58551bdf804be1058523556d",
+ "58568c9a804be10585240b03",
+ "58574b35804be105852455fd",
+ "58577c60b338a62ad5ff1564",
+ "58592d69b338a62ad5007a74",
+ "58598db2b338a62ad500bc38",
+ "58625f42712e2761468fb44c",
+ "58651bcc712e2761469166dc",
+ "58660e79712e27614691fe3d",
+ "58669aad712e27614692834c",
+ "58669c02712e27614692851a",
+ "58676c36833dfe3f7b88b7f2",
+ "58678b2d833dfe3f7b88e244",
+ "58790c82ce911104a3467c88",
+ "58800b0b2366dd5d06e5312d",
+ "58805eac2366dd5d06e56460",
+ "58806e422366dd5d06e57bb6",
+ "58831d060db9bf59bf8ab98b",
+ "58851ebb932ba84fbed7abad",
+ "58871dc3b791366d617a55ff",
+ "58873cabb791366d617a65a7",
+ "58873d44b791366d617a65dd",
+ "58888b3dc02346100f4af665",
+ "58897f62c02346100f4b8ee6",
+ "58933bac9a8c0314c5ce3508",
+ "58938e6d9a8c0314c5ce726f",
+ "58951cb49a8c0314c5cf4d5e",
+ "58970fd09a8c0314c5d0e383",
+ "58977ef09a8c0314c5d17b26",
+ "59056e6760bb961de55f3501",
+ "59071f2e5a6dbd3af4130f98",
+ "59102c811225725be9e64149",
+ "59338e76772c3e6384afbb15",
+ "59350ca084b7f26bf5ce6eb8",
+ "59397e493a87372f2c9e882b",
+ "59521e0b9096412211c2aa9d",
+ "59817e4a1bd4b175e7038d19",
+ "567884f58d2828b95e3c8eba",
+ "585559d9804be10585238ddf",
+ "585834cdb338a62ad5ffab4d",
+ "586082d8712e2761468e2877",
+ "586133c2712e2761468ecfe3",
+ "586281d2712e2761468fcaa2",
+ "586316e5712e276146903c4d",
+ "586326ad712e276146904571",
+ "586375c9712e276146907429",
+ "586389c9712e276146908da6",
+ "586496fa712e2761469108e7",
+ "586669c6712e27614692597a",
+ "586913a49d1b5e34c2808b02",
+ "586922da9d1b5e34c2809ff3",
+ "588185d8dfb7a15588a114a3",
+ "588305ed0db9bf59bf8a8c80",
+ "588315c60db9bf59bf8aa928",
+ "588332ee0db9bf59bf8ae9c3",
+ "588457b8932ba84fbed69942",
+ "588519d5932ba84fbed7a04a",
+ "588824d1b791366d617adeef",
+ "588857f6c02346100f4ac09f",
+ "589145ef90414422fbeb2e08",
+ "589433fa9a8c0314c5ce9656",
+ "589765d39a8c0314c5d16b12",
+ "5851165f7072670e72c4860d",
+ "5859341ab338a62ad500848d",
+ "5862388b712e2761468f84aa",
+ "5863915b712e276146909135",
+ "5866445b712e27614692383e",
+ "5866500d712e2761469240fd",
+ "5867785a833dfe3f7b88c764",
+ "5867969c833dfe3f7b88e8bc",
+ "5868040c833dfe3f7b8934f7",
+ "5880675a2366dd5d06e570ca",
+ "5882372c8ce2c2754d076af0",
+ "5883535e932ba84fbed5ad07",
+ "5888358cb791366d617af69d",
+ "5890330d90414422fbeaa0cb",
+ "5897076e9a8c0314c5d0d31b",
+ "5940564ec2d9527ab869f7e2",
+ "5947719bf1b45630bd096665",
+ "5948194ff1b45630bd0f47e3",
+ "5950206a41b158666ac50506",
+ "5983012d1bd4b175e70c985a",
+ "58586810b338a62ad5ffc20c",
+ "58592046b338a62ad5006b33",
+ "58592854b338a62ad500750a",
+ "58596531b338a62ad500aace",
+ "58818685dfb7a15588a11626",
+ "58829563f42b1d3ee3ec835f",
+ "58894345c02346100f4b51ca",
+ "585289980e7012308945276a",
+ "585369770e7012308945c709",
+ "585373640e7012308945cab9",
+ "588230658ce2c2754d076728",
+ "589388059a8c0314c5ce718b",
+ "595979485ec6a95e86a58c8d",
+ "5841206219d291325678ca90",
+ "58563650804be1058523da55",
+ "58564084804be1058523e116",
+ "58636467712e27614690661f",
+ "58647495712e27614690f36d",
+ "58654563712e276146918643",
+ "58664251712e276146923738",
+ "588084032366dd5d06e59e82",
+ "588159582366dd5d06e66877",
+ "5890279190414422fbea9734",
+ "5890523090414422fbeab3f0",
+ "5890641690414422fbeabbe7",
+ "585203546789802282f2aaf5",
+ ]
+
+ # Final sequences to be used after filtering (some of the sequences have incorrect/low quality depth)
+ # Generally water bodies like lakes have incorrect depth
+ # Filtered out sequences:
+ # "5692a4c2adafac1f14201821" # Incorrect Depth
+ # "5864a935712e2761469111b4" # Noisy Depth and artifacts near horizon
+ # "59f87d0bfa6280566fb38c9a" # Object-centric, noise with background and sometimes in front of object
+ # "58a44463156b87103d3ed45e" # Very noisy depth in background
+ # "5c2b3ed5e611832e8aed46bf" # Depth occluded by artifacts
+ # "5bf03590d4392319481971dc" # Depth occluded by artifacts
+ # "00000000000000000000001a" # Largely incomplete depth
+ # "00000000000000000000000c" # Imprecise depth for buildings
+ # "000000000000000000000000" # Incorrect depth for planar terrain
+ self.scenes = [
+ "00000000000000000000000a",
+ "00000000000000000000000b",
+ "00000000000000000000000d",
+ "00000000000000000000000e",
+ "00000000000000000000000f",
+ "000000000000000000000001",
+ "00000000000000000000001b",
+ "00000000000000000000001d",
+ "000000000000000000000002",
+ "000000000000000000000003",
+ "000000000000000000000004",
+ "000000000000000000000005",
+ "5a2a95f032a1c655cfe3de62",
+ "5a2af22b32a1c655cfe46013",
+ "5a2ba6de32a1c655cfe51b79",
+ "5a3b9731e24cd76dad1a5f1b",
+ "5a3ca9cb270f0e3f14d0eddb",
+ "5a3cb4e4270f0e3f14d12f43",
+ "5a03e732454a8a7ec672776c",
+ "5a3f4aba5889373fbbc5d3b5",
+ "5a4a38dad38c8a075495b5d2",
+ "5a5a1e48d62c7a12d5d00e47",
+ "5a6b1c418d100c2f8fdc4411",
+ "5a6feeb54a7fbc3f874f9db7",
+ "5a7cb1d6fe5c0d6fb53e64fb",
+ "5a7d3db14989e929563eb153",
+ "5a8aa0fab18050187cbe060e",
+ "5a9e5df65baeef72b4a021cd",
+ "5a48ba95c7dab83a7d7b44ed",
+ "5a48c4e9c7dab83a7d7b5cc7",
+ "5a48d4b2c7dab83a7d7b9851",
+ "5a69c47d0d5d0a7f3b2e9752",
+ "5a77b46b318efe6c6736e68a",
+ "5a355c271b63f53d5970f362",
+ "5a489fb1c7dab83a7d7b1070",
+ "5a533e8034d7582116e34209",
+ "5a562fc7425d0f5186314725",
+ "5a572fd9fc597b0478a81d14",
+ "5a588a8193ac3d233f77fbca",
+ "5a618c72784780334bc1972d",
+ "5a752d42acc41e2423f17674",
+ "5a969eea91dfc339a9a3ad2c",
+ "5a8315f624b8e938486e0bd8",
+ "5a57542f333d180827dfc132",
+ "5a0271884e62597cdee0d0eb",
+ "5a6400933d809f1d8200af15",
+ "5a6464143d809f1d8208c43c",
+ "5a563183425d0f5186314855",
+ "5aa0f9d7a9efce63548c69a1",
+ "5aa0f478a9efce63548c1cb4",
+ "5aa7db90bfdd572271e95246",
+ "5aa235f64a17b335eeaf9609",
+ "5aa515e613d42d091d29d300",
+ "5aa1196ea9efce63548ed649",
+ "5aaadd4cbc13235570d178a7",
+ "5ab6af12ac4291329b1072ab",
+ "5ab7e00aac4291329b15864d",
+ "5ab8b8e029f5351f7f2ccf59",
+ "5ab74bf2ac4291329b11e879",
+ "5ab85f1dac4291329b17cb50",
+ "5ab8713ba3799a1d138bd69a",
+ "5abc2506b53b042ead637d86",
+ "5acc7459a7853c4b5ebbef59",
+ "5acf8ca0f3d8a750097e4b15",
+ "5adc6bd52430a05ecb2ffb85",
+ "5ae2e9c5fe405c5076abc6b2",
+ "5af02e904c8216544b4ab5a2",
+ "5af28cea59bc705737003253",
+ "5af545d0559359053d25dcf5",
+ "5afacb69ab00705d0cefdd5b",
+ "5b2c67b5e0878c381608b8d8",
+ "5b3b2b9e8d46a939f933fdc0",
+ "5b3b353d8d46a939f93524b9",
+ "5b6e716d67b396324c2d77cb",
+ "5b6eff8b67b396324c5b2672",
+ "5b7a3890fc8fcf6781e2593a",
+ "5b21e18c58e2823a67a10dd8",
+ "5b60fa0c764f146feef84df0",
+ "5b69cc0cb44b61786eb959bf",
+ "5b78e57afc8fcf6781d0c3ba",
+ "5b192eb2170cf166458ff886",
+ "5b558a928bbfb62204e77ba2",
+ "5b864d850d072a699b32f4ae",
+ "5b908d3dc6ab78485f3d24a9",
+ "5b950c71608de421b1e7318f",
+ "5b4933abf2b5f44e95de482a",
+ "5b08286b2775267d5b0634ba",
+ "5b37189a35304b6f75e7583e",
+ "5b271079e0878c3816dacca4",
+ "5b22269758e2823a67a3bd03",
+ "5b62647143840965efc0dbde",
+ "5ba19a8a360c7c30c1c169df",
+ "5ba75d79d76ffa2c86cf2f05",
+ "5bb7a08aea1cfa39f1a947ab",
+ "5bb8a49aea1cfa39f1aa7f75",
+ "5bbb6eb2ea1cfa39f1af7e0c",
+ "5bc5f0e896b66a2cd8f9bd36",
+ "5bccd6beca24970bce448134",
+ "5bce7ac9ca24970bce4934b6",
+ "5bcf979a6d5f586b95c258cd",
+ "5bd43b4ba6b28b1ee86b92dd",
+ "5be3a5fb8cfdd56947f6b67c",
+ "5be3ae47f44e235bdbbc9771",
+ "5be4ab93870d330ff2dce134",
+ "5be47bf9b18881428d8fbc1d",
+ "5be883a4f98cee15019d5b83",
+ "5bea87f4abd34c35e1860ab5",
+ "5beb6e66abd34c35e18e66b9",
+ "5bf3a82cd439231948877aed",
+ "5bf7d63575c26f32dbf7413b",
+ "5bf17c0fd439231948355385",
+ "5bf26cbbd43923194854b270",
+ "5bf18642c50e6f7f8bdbd492",
+ "5bf21799d43923194842c001",
+ "5bfc9d5aec61ca1dd69132a2",
+ "5bfd0f32ec61ca1dd69dc77b",
+ "5bfe5ae0fe0ea555e6a969ca",
+ "5bff3c5cfe0ea555e6bcbf3a",
+ "5c0d13b795da9479e12e2ee9",
+ "5c1af2e2bee9a723c963d019",
+ "5c1b1500bee9a723c96c3e78",
+ "5c1dbf200843bc542d8ef8c4",
+ "5c1f33f1d33e1f2e4aa6dda4",
+ "5c20ca3a0843bc542d94e3e2",
+ "5c062d84a96e33018ff6f0a6",
+ "5c189f2326173c3a09ed7ef3",
+ "5c1892f726173c3a09ea9aeb",
+ "5c34300a73a8df509add216d",
+ "5c34529873a8df509ae57b58",
+ "000000000000000000000006",
+ "000000000000000000000007",
+ "000000000000000000000008",
+ "000000000000000000000009",
+ "000000000000000000000010",
+ "000000000000000000000011",
+ "000000000000000000000012",
+ "000000000000000000000015",
+ "000000000000000000000016",
+ "000000000000000000000017",
+ "000000000000000000000018",
+ "000000000000000000000019",
+ "56d73ba74bd29b8c35abade2",
+ "56f34064e296120e10484dc4",
+ "57a4a7bb6b9272286e26dc18",
+ "57f8d9bbe73f6760f10e916a",
+ "58a0a2f33d0b4542479a11b1",
+ "58a0dd1a3d0b4542479a28f3",
+ "58a1a7914a4d262a170b1101",
+ "58a1bc804a4d262a170b2f01",
+ "58a1d9d14a4d262a170b58fe",
+ "58a01dea38486e3c98475871",
+ "58a1f5d74a4d262a170b65fc",
+ "58a2a09e156b87103d3d668c",
+ "58a2d9c3156b87103d3da90f",
+ "58a3ccb0156b87103d3e4332",
+ "58a3f2f8156b87103d3e5838",
+ "58a3f6c0156b87103d3e5971",
+ "58a3fc95156b87103d3e5d9b",
+ "58a07ce53d0b45424799fdde",
+ "58a07f233d0b45424799ffe7",
+ "58a44df2156b87103d3ee239",
+ "58a164f73d0b4542479a7a8e",
+ "58a0365e38486e3c984783eb",
+ "58a439cf156b87103d3ec885",
+ "58a464aa156b87103d3eec04",
+ "58a4452f156b87103d3ed55b",
+ "58a160983d0b4542479a7347",
+ "58a186444a4d262a170ae3ae",
+ "58a285424a4d262a170baf3e",
+ "58a41819156b87103d3e92a5",
+ "58a47552156b87103d3f00a4",
+ "58c4bb4f4a69c55606122be4",
+ "58c6451e4a69c556061894f1",
+ "58ca7014affdfd07c70a95ce",
+ "58cf4771d0f5fb221defe6da",
+ "58d36897f387231e6c929903",
+ "58eaf1513353456af3a1682a",
+ "58f7f7299f5b5647873cb110",
+ "58f73e7c9f5b56478738929f",
+ "59a8f851597729752c31e7e0",
+ "59a452bf9b460239aa5d1c72",
+ "59a9619a825418241fb88191",
+ "59acd2f4b891807f439c8992",
+ "59bf97fe7e7b31545da34439",
+ "59c1c3e2fd6e3d4ead9f1013",
+ "59d2657f82ca7774b1ec081d",
+ "59da1fb88a126011d0394ae9",
+ "59e75a2ca9e91f2c5526005d",
+ "59e864b2a9e91f2c5529325f",
+ "59ecfd02e225f6492d20fcc9",
+ "59f37f74b45be2233001ba18",
+ "59f70ab1e5c5d366af29bf3e",
+ "59f363a8b45be22330016cad",
+ "564a27b26d07883f460d8ab0",
+ "565fb1dead14d4154dae2b94",
+ "567a0fb0a825d2fb79ac9a20",
+ "569b92eb826bcba945ca002b",
+ "576fefa017ce5a16397e87fd",
+ "584a7333fe3cb463906c9fe6",
+ "584aa8e9fe3cb463906cc7d0",
+ "584ad76bfe3cb463906ce6dc",
+ "584af003fe3cb463906d0e9b",
+ "584b9a747072670e72bfc49d",
+ "584b671f7072670e72bfaaf8",
+ "584b81747072670e72bfbbfd",
+ "584ba35f7072670e72bfca4d",
+ "584ba5977072670e72bfcc2d",
+ "584bc53c7072670e72bfe85f",
+ "584bc3997072670e72bfe58d",
+ "584bc4407072670e72bfe665",
+ "584bd5587072670e72bffe39",
+ "584bdadf7072670e72c0005c",
+ "584be5ed7072670e72c007b3",
+ "584c9ad27072670e72c060c5",
+ "584c9cc67072670e72c063a1",
+ "584c58b77072670e72c03990",
+ "584cea557072670e72c07fb4",
+ "584d19d47072670e72c0c6c0",
+ "584dfe467072670e72c1665a",
+ "584e875c7072670e72c1ec94",
+ "584e05667072670e72c17167",
+ "584f94e87072670e72c2d3f7",
+ "584fdffd7072670e72c32dc7",
+ "584fe07f7072670e72c32e59",
+ "585a2a71b338a62ad50138dc",
+ "585a206ab338a62ad501298f",
+ "585a217cb338a62ad5012b38",
+ "585b34afb338a62ad501e836",
+ "585bb25fc49c8507c3ce7812",
+ "585bbe55c49c8507c3ce81cd",
+ "585d6c8a2a57cc11d4920a1e",
+ "585e54c72a57cc11d492f71a",
+ "585e34302a57cc11d492be30",
+ "585ee0632a57cc11d4933608",
+ "585f9661712e2761468dabca",
+ "585ffe9a712e2761468df643",
+ "586a37ec9d1b5e34c28184fc",
+ "586a515a9d1b5e34c281b431",
+ "586a94939d1b5e34c2823b5d",
+ "586abc689d1b5e34c2826360",
+ "586b0e219d1b5e34c2828862",
+ "586b3db89d1b5e34c282cd52",
+ "586b4c459d1b5e34c282e66d",
+ "586b7d7d9d1b5e34c283359e",
+ "586b8f149d1b5e34c283497c",
+ "586b8f629d1b5e34c28349d6",
+ "586c4c4d9d1b5e34c28391a1",
+ "586c5b5b9d1b5e34c2839a5b",
+ "586c9fdf9d1b5e34c283b657",
+ "586c48329d1b5e34c2838e80",
+ "586caab99d1b5e34c283c213",
+ "586cd0779d1b5e34c28403a7",
+ "586d6d249d1b5e34c284b80e",
+ "586d8a029d1b5e34c284c948",
+ "586d55af9d1b5e34c284a999",
+ "586d07869d1b5e34c2842e5b",
+ "586d27489d1b5e34c28453af",
+ "586df9849d1b5e34c28506de",
+ "586e279c9d1b5e34c2852180",
+ "587bc5ec2366dd5d06e262c1",
+ "587c1abf2366dd5d06e28901",
+ "587c03f12366dd5d06e27722",
+ "587c19da2366dd5d06e2877b",
+ "587c31b92366dd5d06e2a9dc",
+ "587c87d02366dd5d06e2f989",
+ "587c97a52366dd5d06e30a96",
+ "587c45192366dd5d06e2c0eb",
+ "587cec702366dd5d06e37862",
+ "587cef0a2366dd5d06e379e3",
+ "587db5872366dd5d06e3e0af",
+ "587e2b1d2366dd5d06e41af0",
+ "587e2ea62366dd5d06e41f2e",
+ "587e5cb52366dd5d06e4486e",
+ "587eb1822366dd5d06e45f29",
+ "587f365d2366dd5d06e4906e",
+ "588a9c5fec4d5a1c088ec350",
+ "588a34cfec4d5a1c088ea8d1",
+ "588ab5bdec4d5a1c088ed60f",
+ "588aff9d90414422fbe7885a",
+ "588b20d290414422fbe79f40",
+ "588c08d590414422fbe8200b",
+ "588c203d90414422fbe8319e",
+ "588c989a90414422fbe86d96",
+ "588ca09d90414422fbe871a1",
+ "588cce2190414422fbe88520",
+ "588cd5ef90414422fbe8875c",
+ "588cf0ad90414422fbe8a20f",
+ "588e0d8c90414422fbe8f8b2",
+ "588e01c490414422fbe8ee2a",
+ "588e35e690414422fbe90a53",
+ "588f017e90414422fbe9b74b",
+ "588f095190414422fbe9c1ee",
+ "589aca717dc3d323d55671c4",
+ "589af2c97dc3d323d55691e8",
+ "589b49ea7dc3d323d556d9b4",
+ "589b04287dc3d323d556a185",
+ "589bf6a57dc3d323d55743ab",
+ "589c3c497dc3d323d5578468",
+ "589c3c577dc3d323d5578480",
+ "589c300f7dc3d323d5577926",
+ "589c24527dc3d323d5577126",
+ "589c35457dc3d323d5577d8d",
+ "589ca6a6b896147a1b73aff7",
+ "589d1e1fb896147a1b73ee5b",
+ "589d5c58b896147a1b742256",
+ "589d95538fa2cf375df3317b",
+ "589df0ffb504a864ad63521a",
+ "589ea316b504a864ad639a2b",
+ "589ec97cb504a864ad63adc3",
+ "589f214338486e3c9846f123",
+ "589fdfe738486e3c984736cf",
+ "590c2d70336bb52a190be886",
+ "590f91851225725be9e25d4e",
+ "591a467a6109e14d4f09b776",
+ "591cf3033162411cf9047f37",
+ "591ea44850991c70dc99a207",
+ "599aa591d5b41f366fed0d58",
+ "5643df56138263b51db1b5f3",
+ "5644bdac138263b51db9f669",
+ "5850d4f97072670e72c425d6",
+ "5854c405804be105852330fe",
+ "5855a4fc804be1058523bd75",
+ "5856ac15804be105852419d8",
+ "5856ae8b804be10585241bae",
+ "5856b460804be10585242059",
+ "5857aa5ab338a62ad5ff4dbe",
+ "5857acf8b338a62ad5ff5107",
+ "5858db6cb338a62ad500103b",
+ "5858dbcab338a62ad5001081",
+ "5859d84fb338a62ad500e5cf",
+ "5861d8ea712e2761468f3cb3",
+ "5863edf8712e27614690cce0",
+ "5864b076712e27614691197e",
+ "5864da88712e276146913d8b",
+ "5865f4a8712e27614691e39b",
+ "5867a434833dfe3f7b88edaf",
+ "5868cd15833dfe3f7b89bfa3",
+ "5880b3692366dd5d06e5d534",
+ "5880e3422366dd5d06e5ff8e",
+ "5880f0ef2366dd5d06e6166e",
+ "5881d2bfb6844814c136a119",
+ "5881f11d8ce2c2754d0714c3",
+ "5881fee18ce2c2754d0723f8",
+ "5882cda2b116682b4adebd25",
+ "5882d58fb116682b4adec7db",
+ "5884c256932ba84fbed70bf5",
+ "5884cc13932ba84fbed71ec4",
+ "5885bc5296fa095e0671a7f0",
+ "5886d14cb791366d617a362c",
+ "5888becfc02346100f4b0b21",
+ "5888e408c02346100f4b1a29",
+ "5889da66ec4d5a1c088e5187",
+ "5889e344ec4d5a1c088e59be",
+ "5889e754ec4d5a1c088e60ba",
+ "5890c16b90414422fbeb0262",
+ "5891d8ae9a8c0314c5cd30ab",
+ "5891d0479a8c0314c5cd2abd",
+ "5891ecf19a8c0314c5cd490a",
+ "5892c0cd9a8c0314c5cdc977",
+ "5894ab309a8c0314c5cee57d",
+ "5895a6a89a8c0314c5cfca7c",
+ "5895b8c29a8c0314c5cfd051",
+ "5895d38f9a8c0314c5cfe50c",
+ "5895f2329a8c0314c5d00117",
+ "5896bb989a8c0314c5d086b6",
+ "5896ebf39a8c0314c5d0a8c4",
+ "5898b1bac9dccc22987b7f74",
+ "5898b6ffc9dccc22987b8a03",
+ "5898b31cc9dccc22987b82ec",
+ "5898bbaac9dccc22987b8eba",
+ "5899cfa6b76d7a3780a4cb64",
+ "5899e5dcb76d7a3780a4ecc1",
+ "5947b62af1b45630bd0c2a02",
+ "57102be2877e1421026358af",
+ "57153d4031bb9900425bde85",
+ "57177cd7fb8d93461afc4527",
+ "58497cdf97b73e0b090c4273",
+ "58500b007072670e72c35588",
+ "58510bf97072670e72c46ddf",
+ "58522bd56789802282f2ecb3",
+ "58524a2e0e7012308944bcf3",
+ "58524a080e7012308944bcbf",
+ "58524c1d0e7012308944bfda",
+ "58524f170e7012308944c200",
+ "58529a4e0e70123089454c6f",
+ "58551bdf804be1058523556d",
+ "58568c9a804be10585240b03",
+ "58574b35804be105852455fd",
+ "58577c60b338a62ad5ff1564",
+ "58592d69b338a62ad5007a74",
+ "58598db2b338a62ad500bc38",
+ "58625f42712e2761468fb44c",
+ "58651bcc712e2761469166dc",
+ "58660e79712e27614691fe3d",
+ "58669aad712e27614692834c",
+ "58669c02712e27614692851a",
+ "58676c36833dfe3f7b88b7f2",
+ "58678b2d833dfe3f7b88e244",
+ "58790c82ce911104a3467c88",
+ "58800b0b2366dd5d06e5312d",
+ "58805eac2366dd5d06e56460",
+ "58806e422366dd5d06e57bb6",
+ "58831d060db9bf59bf8ab98b",
+ "58851ebb932ba84fbed7abad",
+ "58871dc3b791366d617a55ff",
+ "58873cabb791366d617a65a7",
+ "58873d44b791366d617a65dd",
+ "58888b3dc02346100f4af665",
+ "58897f62c02346100f4b8ee6",
+ "58933bac9a8c0314c5ce3508",
+ "58938e6d9a8c0314c5ce726f",
+ "58951cb49a8c0314c5cf4d5e",
+ "58970fd09a8c0314c5d0e383",
+ "58977ef09a8c0314c5d17b26",
+ "59056e6760bb961de55f3501",
+ "59071f2e5a6dbd3af4130f98",
+ "59102c811225725be9e64149",
+ "59338e76772c3e6384afbb15",
+ "59350ca084b7f26bf5ce6eb8",
+ "59397e493a87372f2c9e882b",
+ "59521e0b9096412211c2aa9d",
+ "59817e4a1bd4b175e7038d19",
+ "567884f58d2828b95e3c8eba",
+ "585559d9804be10585238ddf",
+ "585834cdb338a62ad5ffab4d",
+ "586082d8712e2761468e2877",
+ "586133c2712e2761468ecfe3",
+ "586281d2712e2761468fcaa2",
+ "586316e5712e276146903c4d",
+ "586326ad712e276146904571",
+ "586375c9712e276146907429",
+ "586389c9712e276146908da6",
+ "586496fa712e2761469108e7",
+ "586669c6712e27614692597a",
+ "586913a49d1b5e34c2808b02",
+ "586922da9d1b5e34c2809ff3",
+ "588185d8dfb7a15588a114a3",
+ "588305ed0db9bf59bf8a8c80",
+ "588315c60db9bf59bf8aa928",
+ "588332ee0db9bf59bf8ae9c3",
+ "588457b8932ba84fbed69942",
+ "588519d5932ba84fbed7a04a",
+ "588824d1b791366d617adeef",
+ "588857f6c02346100f4ac09f",
+ "589145ef90414422fbeb2e08",
+ "589433fa9a8c0314c5ce9656",
+ "589765d39a8c0314c5d16b12",
+ "5851165f7072670e72c4860d",
+ "5859341ab338a62ad500848d",
+ "5862388b712e2761468f84aa",
+ "5863915b712e276146909135",
+ "5866445b712e27614692383e",
+ "5866500d712e2761469240fd",
+ "5867785a833dfe3f7b88c764",
+ "5867969c833dfe3f7b88e8bc",
+ "5868040c833dfe3f7b8934f7",
+ "5880675a2366dd5d06e570ca",
+ "5882372c8ce2c2754d076af0",
+ "5883535e932ba84fbed5ad07",
+ "5888358cb791366d617af69d",
+ "5890330d90414422fbeaa0cb",
+ "5897076e9a8c0314c5d0d31b",
+ "5940564ec2d9527ab869f7e2",
+ "5947719bf1b45630bd096665",
+ "5948194ff1b45630bd0f47e3",
+ "5950206a41b158666ac50506",
+ "5983012d1bd4b175e70c985a",
+ "58586810b338a62ad5ffc20c",
+ "58592046b338a62ad5006b33",
+ "58592854b338a62ad500750a",
+ "58596531b338a62ad500aace",
+ "58818685dfb7a15588a11626",
+ "58829563f42b1d3ee3ec835f",
+ "58894345c02346100f4b51ca",
+ "585289980e7012308945276a",
+ "585369770e7012308945c709",
+ "585373640e7012308945cab9",
+ "588230658ce2c2754d076728",
+ "589388059a8c0314c5ce718b",
+ "595979485ec6a95e86a58c8d",
+ "5841206219d291325678ca90",
+ "58563650804be1058523da55",
+ "58564084804be1058523e116",
+ "58636467712e27614690661f",
+ "58647495712e27614690f36d",
+ "58654563712e276146918643",
+ "58664251712e276146923738",
+ "588084032366dd5d06e59e82",
+ "588159582366dd5d06e66877",
+ "5890279190414422fbea9734",
+ "5890523090414422fbeab3f0",
+ "5890641690414422fbeabbe7",
+ "585203546789802282f2aaf5",
+ ]
+
+ # Train set sequences after filtering
+ self.train_split_scenes = [
+ "00000000000000000000000b",
+ "00000000000000000000000d",
+ "00000000000000000000000e",
+ "00000000000000000000000f",
+ "000000000000000000000001",
+ "00000000000000000000001b",
+ "00000000000000000000001d",
+ "000000000000000000000002",
+ "000000000000000000000003",
+ "000000000000000000000004",
+ "000000000000000000000005",
+ "5a2a95f032a1c655cfe3de62",
+ "5a2af22b32a1c655cfe46013",
+ "5a2ba6de32a1c655cfe51b79",
+ "5a3b9731e24cd76dad1a5f1b",
+ "5a3ca9cb270f0e3f14d0eddb",
+ "5a3cb4e4270f0e3f14d12f43",
+ "5a03e732454a8a7ec672776c",
+ "5a3f4aba5889373fbbc5d3b5",
+ "5a5a1e48d62c7a12d5d00e47",
+ "5a6b1c418d100c2f8fdc4411",
+ "5a6feeb54a7fbc3f874f9db7",
+ "5a7cb1d6fe5c0d6fb53e64fb",
+ "5a7d3db14989e929563eb153",
+ "5a8aa0fab18050187cbe060e",
+ "5a9e5df65baeef72b4a021cd",
+ "5a48ba95c7dab83a7d7b44ed",
+ "5a48c4e9c7dab83a7d7b5cc7",
+ "5a48d4b2c7dab83a7d7b9851",
+ "5a69c47d0d5d0a7f3b2e9752",
+ "5a77b46b318efe6c6736e68a",
+ "5a355c271b63f53d5970f362",
+ "5a533e8034d7582116e34209",
+ "5a562fc7425d0f5186314725",
+ "5a618c72784780334bc1972d",
+ "5a752d42acc41e2423f17674",
+ "5a969eea91dfc339a9a3ad2c",
+ "5a8315f624b8e938486e0bd8",
+ "5a57542f333d180827dfc132",
+ "5a0271884e62597cdee0d0eb",
+ "5a6400933d809f1d8200af15",
+ "5a6464143d809f1d8208c43c",
+ "5a563183425d0f5186314855",
+ "5aa0f9d7a9efce63548c69a1",
+ "5aa7db90bfdd572271e95246",
+ "5aa235f64a17b335eeaf9609",
+ "5aa515e613d42d091d29d300",
+ "5aa1196ea9efce63548ed649",
+ "5aaadd4cbc13235570d178a7",
+ "5ab6af12ac4291329b1072ab",
+ "5ab7e00aac4291329b15864d",
+ "5ab8b8e029f5351f7f2ccf59",
+ "5ab74bf2ac4291329b11e879",
+ "5ab85f1dac4291329b17cb50",
+ "5ab8713ba3799a1d138bd69a",
+ "5abc2506b53b042ead637d86",
+ "5acc7459a7853c4b5ebbef59",
+ "5acf8ca0f3d8a750097e4b15",
+ "5adc6bd52430a05ecb2ffb85",
+ "5af02e904c8216544b4ab5a2",
+ "5af28cea59bc705737003253",
+ "5af545d0559359053d25dcf5",
+ "5afacb69ab00705d0cefdd5b",
+ "5b3b2b9e8d46a939f933fdc0",
+ "5b3b353d8d46a939f93524b9",
+ "5b6e716d67b396324c2d77cb",
+ "5b6eff8b67b396324c5b2672",
+ "5b7a3890fc8fcf6781e2593a",
+ "5b60fa0c764f146feef84df0",
+ "5b69cc0cb44b61786eb959bf",
+ "5b78e57afc8fcf6781d0c3ba",
+ "5b192eb2170cf166458ff886",
+ "5b558a928bbfb62204e77ba2",
+ "5b908d3dc6ab78485f3d24a9",
+ "5b950c71608de421b1e7318f",
+ "5b08286b2775267d5b0634ba",
+ "5b271079e0878c3816dacca4",
+ "5b22269758e2823a67a3bd03",
+ "5b62647143840965efc0dbde",
+ "5ba19a8a360c7c30c1c169df",
+ "5ba75d79d76ffa2c86cf2f05",
+ "5bb7a08aea1cfa39f1a947ab",
+ "5bb8a49aea1cfa39f1aa7f75",
+ "5bbb6eb2ea1cfa39f1af7e0c",
+ "5bce7ac9ca24970bce4934b6",
+ "5bcf979a6d5f586b95c258cd",
+ "5bd43b4ba6b28b1ee86b92dd",
+ "5be3a5fb8cfdd56947f6b67c",
+ "5be3ae47f44e235bdbbc9771",
+ "5be4ab93870d330ff2dce134",
+ "5be47bf9b18881428d8fbc1d",
+ "5be883a4f98cee15019d5b83",
+ "5bea87f4abd34c35e1860ab5",
+ "5beb6e66abd34c35e18e66b9",
+ "5bf3a82cd439231948877aed",
+ "5bf7d63575c26f32dbf7413b",
+ "5bf17c0fd439231948355385",
+ "5bf21799d43923194842c001",
+ "5bfd0f32ec61ca1dd69dc77b",
+ "5bfe5ae0fe0ea555e6a969ca",
+ "5c0d13b795da9479e12e2ee9",
+ "5c1af2e2bee9a723c963d019",
+ "5c1b1500bee9a723c96c3e78",
+ "5c1dbf200843bc542d8ef8c4",
+ "5c20ca3a0843bc542d94e3e2",
+ "5c062d84a96e33018ff6f0a6",
+ "5c189f2326173c3a09ed7ef3",
+ "5c1892f726173c3a09ea9aeb",
+ "5c34300a73a8df509add216d",
+ "000000000000000000000006",
+ "000000000000000000000007",
+ "000000000000000000000008",
+ "000000000000000000000009",
+ "000000000000000000000010",
+ "000000000000000000000011",
+ "000000000000000000000012",
+ "000000000000000000000015",
+ "000000000000000000000016",
+ "000000000000000000000017",
+ "000000000000000000000018",
+ "000000000000000000000019",
+ "56d73ba74bd29b8c35abade2",
+ "56f34064e296120e10484dc4",
+ "57a4a7bb6b9272286e26dc18",
+ "57f8d9bbe73f6760f10e916a",
+ "58a0a2f33d0b4542479a11b1",
+ "58a0dd1a3d0b4542479a28f3",
+ "58a1a7914a4d262a170b1101",
+ "58a1bc804a4d262a170b2f01",
+ "58a1d9d14a4d262a170b58fe",
+ "58a01dea38486e3c98475871",
+ "58a1f5d74a4d262a170b65fc",
+ "58a2a09e156b87103d3d668c",
+ "58a2d9c3156b87103d3da90f",
+ "58a3ccb0156b87103d3e4332",
+ "58a3f2f8156b87103d3e5838",
+ "58a3f6c0156b87103d3e5971",
+ "58a3fc95156b87103d3e5d9b",
+ "58a07ce53d0b45424799fdde",
+ "58a07f233d0b45424799ffe7",
+ "58a44df2156b87103d3ee239",
+ "58a164f73d0b4542479a7a8e",
+ "58a0365e38486e3c984783eb",
+ "58a439cf156b87103d3ec885",
+ "58a464aa156b87103d3eec04",
+ "58a4452f156b87103d3ed55b",
+ "58a160983d0b4542479a7347",
+ "58a285424a4d262a170baf3e",
+ "58a41819156b87103d3e92a5",
+ "58a47552156b87103d3f00a4",
+ "58c4bb4f4a69c55606122be4",
+ "58c6451e4a69c556061894f1",
+ "58ca7014affdfd07c70a95ce",
+ "58cf4771d0f5fb221defe6da",
+ "58d36897f387231e6c929903",
+ "58eaf1513353456af3a1682a",
+ "58f73e7c9f5b56478738929f",
+ "59a8f851597729752c31e7e0",
+ "59a452bf9b460239aa5d1c72",
+ "59a9619a825418241fb88191",
+ "59bf97fe7e7b31545da34439",
+ "59c1c3e2fd6e3d4ead9f1013",
+ "59d2657f82ca7774b1ec081d",
+ "59da1fb88a126011d0394ae9",
+ "59e75a2ca9e91f2c5526005d",
+ "59e864b2a9e91f2c5529325f",
+ "59ecfd02e225f6492d20fcc9",
+ "59f37f74b45be2233001ba18",
+ "59f70ab1e5c5d366af29bf3e",
+ "59f363a8b45be22330016cad",
+ "564a27b26d07883f460d8ab0",
+ "565fb1dead14d4154dae2b94",
+ "569b92eb826bcba945ca002b",
+ "576fefa017ce5a16397e87fd",
+ "584a7333fe3cb463906c9fe6",
+ "584aa8e9fe3cb463906cc7d0",
+ "584af003fe3cb463906d0e9b",
+ "584b9a747072670e72bfc49d",
+ "584b671f7072670e72bfaaf8",
+ "584b81747072670e72bfbbfd",
+ "584ba35f7072670e72bfca4d",
+ "584ba5977072670e72bfcc2d",
+ "584bc53c7072670e72bfe85f",
+ "584bc3997072670e72bfe58d",
+ "584bc4407072670e72bfe665",
+ "584bd5587072670e72bffe39",
+ "584bdadf7072670e72c0005c",
+ "584be5ed7072670e72c007b3",
+ "584c9ad27072670e72c060c5",
+ "584c9cc67072670e72c063a1",
+ "584cea557072670e72c07fb4",
+ "584d19d47072670e72c0c6c0",
+ "584dfe467072670e72c1665a",
+ "584e875c7072670e72c1ec94",
+ "584e05667072670e72c17167",
+ "584f94e87072670e72c2d3f7",
+ "584fdffd7072670e72c32dc7",
+ "584fe07f7072670e72c32e59",
+ "585a2a71b338a62ad50138dc",
+ "585a206ab338a62ad501298f",
+ "585a217cb338a62ad5012b38",
+ "585b34afb338a62ad501e836",
+ "585bb25fc49c8507c3ce7812",
+ "585bbe55c49c8507c3ce81cd",
+ "585d6c8a2a57cc11d4920a1e",
+ "585e54c72a57cc11d492f71a",
+ "585e34302a57cc11d492be30",
+ "585ee0632a57cc11d4933608",
+ "585f9661712e2761468dabca",
+ "585ffe9a712e2761468df643",
+ "586a37ec9d1b5e34c28184fc",
+ "586a515a9d1b5e34c281b431",
+ "586a94939d1b5e34c2823b5d",
+ "586abc689d1b5e34c2826360",
+ "586b0e219d1b5e34c2828862",
+ "586b3db89d1b5e34c282cd52",
+ "586b4c459d1b5e34c282e66d",
+ "586b7d7d9d1b5e34c283359e",
+ "586b8f149d1b5e34c283497c",
+ "586b8f629d1b5e34c28349d6",
+ "586c4c4d9d1b5e34c28391a1",
+ "586c5b5b9d1b5e34c2839a5b",
+ "586c9fdf9d1b5e34c283b657",
+ "586caab99d1b5e34c283c213",
+ "586cd0779d1b5e34c28403a7",
+ "586d6d249d1b5e34c284b80e",
+ "586d8a029d1b5e34c284c948",
+ "586d55af9d1b5e34c284a999",
+ "586d07869d1b5e34c2842e5b",
+ "586d27489d1b5e34c28453af",
+ "586e279c9d1b5e34c2852180",
+ "587bc5ec2366dd5d06e262c1",
+ "587c1abf2366dd5d06e28901",
+ "587c03f12366dd5d06e27722",
+ "587c19da2366dd5d06e2877b",
+ "587c31b92366dd5d06e2a9dc",
+ "587c87d02366dd5d06e2f989",
+ "587c97a52366dd5d06e30a96",
+ "587c45192366dd5d06e2c0eb",
+ "587cec702366dd5d06e37862",
+ "587cef0a2366dd5d06e379e3",
+ "587db5872366dd5d06e3e0af",
+ "587e2b1d2366dd5d06e41af0",
+ "587e2ea62366dd5d06e41f2e",
+ "587e5cb52366dd5d06e4486e",
+ "587eb1822366dd5d06e45f29",
+ "587f365d2366dd5d06e4906e",
+ "588a9c5fec4d5a1c088ec350",
+ "588a34cfec4d5a1c088ea8d1",
+ "588ab5bdec4d5a1c088ed60f",
+ "588aff9d90414422fbe7885a",
+ "588b20d290414422fbe79f40",
+ "588c08d590414422fbe8200b",
+ "588c203d90414422fbe8319e",
+ "588c989a90414422fbe86d96",
+ "588ca09d90414422fbe871a1",
+ "588cce2190414422fbe88520",
+ "588cd5ef90414422fbe8875c",
+ "588cf0ad90414422fbe8a20f",
+ "588e01c490414422fbe8ee2a",
+ "588e35e690414422fbe90a53",
+ "588f017e90414422fbe9b74b",
+ "588f095190414422fbe9c1ee",
+ "589aca717dc3d323d55671c4",
+ "589af2c97dc3d323d55691e8",
+ "589b49ea7dc3d323d556d9b4",
+ "589b04287dc3d323d556a185",
+ "589bf6a57dc3d323d55743ab",
+ "589c3c497dc3d323d5578468",
+ "589c3c577dc3d323d5578480",
+ "589c24527dc3d323d5577126",
+ "589c35457dc3d323d5577d8d",
+ "589ca6a6b896147a1b73aff7",
+ "589d1e1fb896147a1b73ee5b",
+ "589d5c58b896147a1b742256",
+ "589d95538fa2cf375df3317b",
+ "589df0ffb504a864ad63521a",
+ "589ea316b504a864ad639a2b",
+ "589ec97cb504a864ad63adc3",
+ "589f214338486e3c9846f123",
+ "589fdfe738486e3c984736cf",
+ "590c2d70336bb52a190be886",
+ "591a467a6109e14d4f09b776",
+ "591cf3033162411cf9047f37",
+ "591ea44850991c70dc99a207",
+ "599aa591d5b41f366fed0d58",
+ "5643df56138263b51db1b5f3",
+ "5644bdac138263b51db9f669",
+ "5850d4f97072670e72c425d6",
+ "5854c405804be105852330fe",
+ "5855a4fc804be1058523bd75",
+ "5856ac15804be105852419d8",
+ "5856ae8b804be10585241bae",
+ "5856b460804be10585242059",
+ "5857aa5ab338a62ad5ff4dbe",
+ "5857acf8b338a62ad5ff5107",
+ "5858db6cb338a62ad500103b",
+ "5858dbcab338a62ad5001081",
+ "5859d84fb338a62ad500e5cf",
+ "5861d8ea712e2761468f3cb3",
+ "5863edf8712e27614690cce0",
+ "5864b076712e27614691197e",
+ "5864da88712e276146913d8b",
+ "5865f4a8712e27614691e39b",
+ "5867a434833dfe3f7b88edaf",
+ "5868cd15833dfe3f7b89bfa3",
+ "5880b3692366dd5d06e5d534",
+ "5880e3422366dd5d06e5ff8e",
+ "5880f0ef2366dd5d06e6166e",
+ "5881d2bfb6844814c136a119",
+ "5881f11d8ce2c2754d0714c3",
+ "5881fee18ce2c2754d0723f8",
+ "5882cda2b116682b4adebd25",
+ "5882d58fb116682b4adec7db",
+ "5884c256932ba84fbed70bf5",
+ "5884cc13932ba84fbed71ec4",
+ "5885bc5296fa095e0671a7f0",
+ "5886d14cb791366d617a362c",
+ "5888becfc02346100f4b0b21",
+ "5888e408c02346100f4b1a29",
+ "5889da66ec4d5a1c088e5187",
+ "5889e754ec4d5a1c088e60ba",
+ "5890c16b90414422fbeb0262",
+ "5891d8ae9a8c0314c5cd30ab",
+ "5891d0479a8c0314c5cd2abd",
+ "5891ecf19a8c0314c5cd490a",
+ "5892c0cd9a8c0314c5cdc977",
+ "5894ab309a8c0314c5cee57d",
+ "5895a6a89a8c0314c5cfca7c",
+ "5895b8c29a8c0314c5cfd051",
+ "5895d38f9a8c0314c5cfe50c",
+ "5895f2329a8c0314c5d00117",
+ "5896bb989a8c0314c5d086b6",
+ "5896ebf39a8c0314c5d0a8c4",
+ "5898b1bac9dccc22987b7f74",
+ "5898b6ffc9dccc22987b8a03",
+ "5898bbaac9dccc22987b8eba",
+ "5899cfa6b76d7a3780a4cb64",
+ "5899e5dcb76d7a3780a4ecc1",
+ "57102be2877e1421026358af",
+ "57153d4031bb9900425bde85",
+ "57177cd7fb8d93461afc4527",
+ "58497cdf97b73e0b090c4273",
+ "58500b007072670e72c35588",
+ "58510bf97072670e72c46ddf",
+ "58522bd56789802282f2ecb3",
+ "58524a2e0e7012308944bcf3",
+ "58524a080e7012308944bcbf",
+ "58524c1d0e7012308944bfda",
+ "58524f170e7012308944c200",
+ "58529a4e0e70123089454c6f",
+ "58551bdf804be1058523556d",
+ "58568c9a804be10585240b03",
+ "58574b35804be105852455fd",
+ "58577c60b338a62ad5ff1564",
+ "58592d69b338a62ad5007a74",
+ "58625f42712e2761468fb44c",
+ "58651bcc712e2761469166dc",
+ "58660e79712e27614691fe3d",
+ "58669aad712e27614692834c",
+ "58676c36833dfe3f7b88b7f2",
+ "58678b2d833dfe3f7b88e244",
+ "58800b0b2366dd5d06e5312d",
+ "58805eac2366dd5d06e56460",
+ "58806e422366dd5d06e57bb6",
+ "58831d060db9bf59bf8ab98b",
+ "58851ebb932ba84fbed7abad",
+ "58871dc3b791366d617a55ff",
+ "58873cabb791366d617a65a7",
+ "58873d44b791366d617a65dd",
+ "58888b3dc02346100f4af665",
+ "58933bac9a8c0314c5ce3508",
+ "58938e6d9a8c0314c5ce726f",
+ "58951cb49a8c0314c5cf4d5e",
+ "58970fd09a8c0314c5d0e383",
+ "58977ef09a8c0314c5d17b26",
+ "59056e6760bb961de55f3501",
+ "59071f2e5a6dbd3af4130f98",
+ "59102c811225725be9e64149",
+ "59338e76772c3e6384afbb15",
+ "59350ca084b7f26bf5ce6eb8",
+ "59397e493a87372f2c9e882b",
+ "59521e0b9096412211c2aa9d",
+ "59817e4a1bd4b175e7038d19",
+ "567884f58d2828b95e3c8eba",
+ "585559d9804be10585238ddf",
+ "585834cdb338a62ad5ffab4d",
+ "586082d8712e2761468e2877",
+ "586133c2712e2761468ecfe3",
+ "586281d2712e2761468fcaa2",
+ "586316e5712e276146903c4d",
+ "586326ad712e276146904571",
+ "586375c9712e276146907429",
+ "586389c9712e276146908da6",
+ "586496fa712e2761469108e7",
+ "586669c6712e27614692597a",
+ "586913a49d1b5e34c2808b02",
+ "586922da9d1b5e34c2809ff3",
+ "588185d8dfb7a15588a114a3",
+ "588315c60db9bf59bf8aa928",
+ "588332ee0db9bf59bf8ae9c3",
+ "588519d5932ba84fbed7a04a",
+ "588824d1b791366d617adeef",
+ "588857f6c02346100f4ac09f",
+ "589145ef90414422fbeb2e08",
+ "589433fa9a8c0314c5ce9656",
+ "589765d39a8c0314c5d16b12",
+ "5851165f7072670e72c4860d",
+ "5859341ab338a62ad500848d",
+ "5863915b712e276146909135",
+ "5866445b712e27614692383e",
+ "5866500d712e2761469240fd",
+ "5867785a833dfe3f7b88c764",
+ "5867969c833dfe3f7b88e8bc",
+ "5868040c833dfe3f7b8934f7",
+ "5882372c8ce2c2754d076af0",
+ "5883535e932ba84fbed5ad07",
+ "5888358cb791366d617af69d",
+ "5890330d90414422fbeaa0cb",
+ "5897076e9a8c0314c5d0d31b",
+ "5940564ec2d9527ab869f7e2",
+ "5947719bf1b45630bd096665",
+ "5948194ff1b45630bd0f47e3",
+ "5950206a41b158666ac50506",
+ "5983012d1bd4b175e70c985a",
+ "58586810b338a62ad5ffc20c",
+ "58592046b338a62ad5006b33",
+ "58592854b338a62ad500750a",
+ "58596531b338a62ad500aace",
+ "58818685dfb7a15588a11626",
+ "58829563f42b1d3ee3ec835f",
+ "58894345c02346100f4b51ca",
+ "585289980e7012308945276a",
+ "585369770e7012308945c709",
+ "585373640e7012308945cab9",
+ "588230658ce2c2754d076728",
+ "589388059a8c0314c5ce718b",
+ "595979485ec6a95e86a58c8d",
+ "5841206219d291325678ca90",
+ "58563650804be1058523da55",
+ "58564084804be1058523e116",
+ "58636467712e27614690661f",
+ "58647495712e27614690f36d",
+ "58654563712e276146918643",
+ "58664251712e276146923738",
+ "588084032366dd5d06e59e82",
+ "588159582366dd5d06e66877",
+ "5890279190414422fbea9734",
+ "5890641690414422fbeabbe7",
+ "585203546789802282f2aaf5",
+ ]
+
+ # Validation set sequences after filtering
+ self.val_split_scenes = [
+ "00000000000000000000000a",
+ "5a4a38dad38c8a075495b5d2",
+ "5a489fb1c7dab83a7d7b1070",
+ "5a572fd9fc597b0478a81d14",
+ "5a588a8193ac3d233f77fbca",
+ "5aa0f478a9efce63548c1cb4",
+ "5ae2e9c5fe405c5076abc6b2",
+ "5b2c67b5e0878c381608b8d8",
+ "5b21e18c58e2823a67a10dd8",
+ "5b864d850d072a699b32f4ae",
+ "5b4933abf2b5f44e95de482a",
+ "5b37189a35304b6f75e7583e",
+ "5bc5f0e896b66a2cd8f9bd36",
+ "5bccd6beca24970bce448134",
+ "5bf26cbbd43923194854b270",
+ "5bf18642c50e6f7f8bdbd492",
+ "5bfc9d5aec61ca1dd69132a2",
+ "5bff3c5cfe0ea555e6bcbf3a",
+ "5c1f33f1d33e1f2e4aa6dda4",
+ "5c34529873a8df509ae57b58",
+ "58a186444a4d262a170ae3ae",
+ "58f7f7299f5b5647873cb110",
+ "59acd2f4b891807f439c8992",
+ "567a0fb0a825d2fb79ac9a20",
+ "584ad76bfe3cb463906ce6dc",
+ "584c58b77072670e72c03990",
+ "586c48329d1b5e34c2838e80",
+ "586df9849d1b5e34c28506de",
+ "588e0d8c90414422fbe8f8b2",
+ "589c300f7dc3d323d5577926",
+ "590f91851225725be9e25d4e",
+ "5889e344ec4d5a1c088e59be",
+ "5898b31cc9dccc22987b82ec",
+ "5947b62af1b45630bd0c2a02",
+ "58598db2b338a62ad500bc38",
+ "58669c02712e27614692851a",
+ "58790c82ce911104a3467c88",
+ "58897f62c02346100f4b8ee6",
+ "588305ed0db9bf59bf8a8c80",
+ "588457b8932ba84fbed69942",
+ "5862388b712e2761468f84aa",
+ "5880675a2366dd5d06e570ca",
+ "5890523090414422fbeab3f0",
+ ]
+
+
+class TartanAirV2Splits:
+ """
+ This class contains the information about the splits of the TartanAir V2 dataset.
+ """
+
+ def __init__(self):
+ """
+ Splits of environments with unique geometry selected based on TartanVO & UFM splits.
+ """
+ # Apart from the below 2 splits, all other TAv2 scenes are in the train split
+ # Val split
+ self.val_split_scenes = ["EndofTheWorld", "HongKong", "WesternDesertTown"]
+
+ # Test split
+ self.test_split_scenes = [
+ "DesertGasStation",
+ "OldScandinavia",
+ "PolarSciFi",
+ "Sewerage",
+ "Supermarket",
+ ]
+
+
+class MegaDepthSplits:
+ """
+ This class contains the information about the splits of the MegaDepth dataset.
+ """
+
+ def __init__(self):
+ """
+ Validation split is based on scenes used in DUSt3R.
+ """
+ self.val_split_scenes = ["0015_0", "0015_1", "0022_0"]
+
+
+class SpringSplits:
+ """
+ This class contains the information about the splits of the Spring dataset.
+ """
+
+ def __init__(self):
+ self.val_split_scenes = ["0013", "0023", "0037"]
+
+
+class MPSDSplits:
+ """
+ This class contains the information about the splits of the MPSD dataset.
+ """
+
+ def __init__(self):
+ """
+ Train & Validation split numpy files containing folder names are generated during preprocessing of MPSD dataset.
+ Load the numpy files to get the list of scenes in the train & validation split.
+ A 95% (Train) & 5% (Validation) split is used.
+ """
+ self.train_split_scenes = "load_numpy_file_with_train_scenes"
+ self.val_split_scenes = "load_numpy_file_with_val_scenes"
+
+
+class ScanNetPPSplits:
+ """
+ This class contains the information about the splits of the ScanNetPP dataset.
+ """
+
+ def __init__(self):
+ """
+ Validation & Test split only contains scenes from ScanNet++V2 to prevent data leak with other methods such as DUSt3R during benchmarking.
+
+ Following logic was used to generate the splits:
+ # Select 80%, 10%, 10% of the scenes for train, val, test respectively from ScanNet++ V2 (~300 scene subset; excluding V1 scenes)
+ snpp_v2_test_scenes = np.random.choice(
+ snpp_v2_processed_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False
+ )
+ remaining_scenes = [scene for scene in snpp_v2_processed_scenes if scene not in snpp_v2_test_scenes]
+ snpp_v2_val_scenes = np.random.choice(
+ remaining_scenes, size=int(0.1 * len(snpp_v2_processed_scenes)), replace=False
+ )
+ snpp_v2_train_scenes = [
+ scene for scene in remaining_scenes if scene not in snpp_v2_val_scenes and scene not in snpp_v2_test_scenes
+ ]
+ """
+ # Validation Scenes
+ self.val_split_scenes = [
+ "1c7a683c92",
+ "2a1b555966",
+ "3a43c7b8d2",
+ "4aef651da7",
+ "06bc6d1b24",
+ "7f22d5ef1b",
+ "7f77abce34",
+ "8ea517a2fc",
+ "29c7afafed",
+ "41eb967018",
+ "77b40ce601",
+ "086f09d6e3",
+ "307e3262f1",
+ "639f2c4d5a",
+ "894dbd41f1",
+ "898a7dfd0c",
+ "2779f8f9e2",
+ "151178afd7",
+ "182932a4f3",
+ "635852d56e",
+ "9906136b57",
+ "af112b8903",
+ "b0f057c684",
+ "b37177e6c8",
+ "b119249da7",
+ "be8367fcbe",
+ "c8fc01c453",
+ "e1fb8626c8",
+ "e2caaaf5b5",
+ "fe3fc057a1",
+ ]
+
+ # Test Scenes
+ self.test_split_scenes = [
+ "0e900bcc5c",
+ "0eba3981c9",
+ "1cbb105c6a",
+ "3c8d535d49",
+ "5d902f1593",
+ "6bd39ac392",
+ "6c14d5fd01",
+ "7c31a42404",
+ "9bfbc75700",
+ "13b4efaf62",
+ "062e5a23a6",
+ "95b9971d01",
+ "246fe09e98",
+ "637a27d04b",
+ "725b8f0cba",
+ "413085a827",
+ "696317583f",
+ "a4c043ac48",
+ "a9e4791c7e",
+ "b0b004c40f",
+ "c3bc5e82c5",
+ "c31ebd4b22",
+ "cba701332a",
+ "cc5ea8026c",
+ "cec8312f4e",
+ "e3b3b0d0c7",
+ "e667e09fe6",
+ "eaa6c90310",
+ "f9397af4cb",
+ "fb893ffaf3",
+ ]
+
+
+class DL3DV10KSplits:
+ """
+ This class contains the information about the splits of the DL3DV-10K dataset.
+ We use the official benchmark split as the val split.
+ """
+
+ def __init__(self):
+ """
+ Validation split is based on DL3DV-Benchmark.
+ """
+ self.val_split_scenes = [
+ "load https://huggingface.co/datasets/DL3DV/DL3DV-Benchmark/raw/main/benchmark-meta.csv \
+ & https://raw.githubusercontent.com/DL3DV-10K/Dataset/main/cache/DL3DV-valid.csv"
+ ]
+
+
+class DTUSplits:
+ """
+ This class contains the information about the splits of the DTU dataset.
+ """
+
+ def __init__(self):
+ """
+ All scenes are in the test split.
+ """
+ self.test_split_scenes = "all"
+
+
+class ETH3DSplits:
+ """
+ This class contains the information about the splits of the ETH3D dataset.
+ """
+
+ def __init__(self):
+ """
+ All scenes are in the test split.
+ """
+ self.test_split_scenes = "all"
diff --git a/mapanything/datasets/wai/__init__.py b/mapanything/datasets/wai/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/datasets/wai/ase.py b/mapanything/datasets/wai/ase.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3bf49ffe6fc61323265178a583c3585d53824b2
--- /dev/null
+++ b/mapanything/datasets/wai/ase.py
@@ -0,0 +1,289 @@
+"""
+ASE Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class ASEWAI(BaseDataset):
+ """
+ ASE dataset containing large diversity of synthetic indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"ase_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="ASE",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/ase", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = ASEWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 518),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = ASEWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 518),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "ASE_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/bedlam.py b/mapanything/datasets/wai/bedlam.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4daa56470f73d8e67f9b6a4c5ba3fd822ef9ff1
--- /dev/null
+++ b/mapanything/datasets/wai/bedlam.py
@@ -0,0 +1,309 @@
+"""
+Bedlam Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class BedlamWAI(BaseDataset):
+ """
+ Bedlam dataset containing diverse synthetic scenes with humans.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"bedlam_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ # Bedlam scenes have very large number of images
+ # Thus, we use unidirectional covis for faster access
+ view_indices = self._sample_view_indices(
+ num_views_to_sample,
+ num_views_in_scene,
+ pairwise_covisibility,
+ use_bidirectional_covis=False,
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (see through window or horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="Bedlam",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/bedlam", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = BedlamWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = BedlamWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "Bedlam_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/blendedmvs.py b/mapanything/datasets/wai/blendedmvs.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ef278842807b667f478ba4b1f013bb702173632
--- /dev/null
+++ b/mapanything/datasets/wai/blendedmvs.py
@@ -0,0 +1,308 @@
+"""
+BlendedMVS Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class BlendedMVSWAI(BaseDataset):
+ """
+ BlendedMVS dataset containing object-centric and birds-eye-view scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = False
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"blendedmvs_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="BlendedMVS",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/blendedmvs", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = BlendedMVSWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = BlendedMVSWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 392),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "BlendedMVS_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/dl3dv.py b/mapanything/datasets/wai/dl3dv.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab9bf341a231e77d6a3dde7b7c3e203638b223a7
--- /dev/null
+++ b/mapanything/datasets/wai/dl3dv.py
@@ -0,0 +1,351 @@
+"""
+DL3DV Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.cropping import (
+ rescale_image_and_other_optional_info,
+ resize_with_nearest_interpolation_to_match_aspect_ratio,
+)
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class DL3DVWAI(BaseDataset):
+ """
+ DL3DV dataset containing over 10k in-the-wild and indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ mvs_confidence_filter_thres: float = 0.25,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ mvs_confidence_filter_thres: Confidence threshold to filter MVS depth. Defaults to 0.25.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self.mvs_confidence_filter_thres = mvs_confidence_filter_thres
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = False
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"dl3dv_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0_mvsa_based"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=[
+ "image",
+ "pred_depth/mvsanywhere",
+ "pred_mask/moge2",
+ "depth_confidence/mvsanywhere",
+ ],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["pred_depth/mvsanywhere"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the dimensions of the original image
+ img_h, img_w = image.shape[:2]
+
+ # Resize depth to match image aspect ratio while ensuring that depth resolution doesn't increase
+ depthmap, target_depth_h, target_depth_w = (
+ resize_with_nearest_interpolation_to_match_aspect_ratio(
+ input_data=depthmap, img_h=img_h, img_w=img_w
+ )
+ )
+
+ # Now resize the image and update intrinsics to match the resized depth
+ image, _, intrinsics, _ = rescale_image_and_other_optional_info(
+ image=image,
+ output_resolution=(target_depth_w, target_depth_h),
+ depthmap=None,
+ camera_intrinsics=intrinsics,
+ )
+ image = np.array(image)
+
+ # Get the depth confidence map and mask out the MVS depth
+ confidence_map = view_data["depth_confidence/mvsanywhere"].numpy()
+ confidence_mask = (
+ confidence_map > self.mvs_confidence_filter_thres
+ ).astype(int)
+ confidence_mask = cv2.resize(
+ confidence_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ depthmap = np.where(confidence_mask, depthmap, 0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="DL3DV",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dl3dv", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = DL3DVWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ mvs_confidence_filter_thres=0.25,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = DL3DVWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # mvs_confidence_filter_thres=0.25,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "DL3DV_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/dtu.py b/mapanything/datasets/wai/dtu.py
new file mode 100644
index 0000000000000000000000000000000000000000..fde4a666752c36a2668913c010d42fb93d62dd3b
--- /dev/null
+++ b/mapanything/datasets/wai/dtu.py
@@ -0,0 +1,272 @@
+"""
+DTU Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class DTUWAI(BaseDataset):
+ """
+ DTU dataset containing high-quality multi-view stereo object scans.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = "test"
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = False
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"dtu_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="DTU",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/dtu", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = DTUWAI(
+ num_views=args.num_of_views,
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ seed=777,
+ transform="imgnorm",
+ data_norm_type="dinov2",
+ )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "DTU_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/dynamicreplica.py b/mapanything/datasets/wai/dynamicreplica.py
new file mode 100644
index 0000000000000000000000000000000000000000..20d1ce109631f394c6dd5469f2439b0de5e52e97
--- /dev/null
+++ b/mapanything/datasets/wai/dynamicreplica.py
@@ -0,0 +1,292 @@
+"""
+Dynamic Replica Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class DynamicReplicaWAI(BaseDataset):
+ """
+ Dynamic Replica dataset containing synthetic scenes with humans and animals.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"dynamicreplica_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="DynamicReplica",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/dynamicreplica", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = DynamicReplicaWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = DynamicReplicaWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "DynamicReplica_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/eth3d.py b/mapanything/datasets/wai/eth3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a8cebe347aeb18031379f31fff4f6581f9802b
--- /dev/null
+++ b/mapanything/datasets/wai/eth3d.py
@@ -0,0 +1,272 @@
+"""
+ETH3D Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class ETH3DWAI(BaseDataset):
+ """
+ ETH3D dataset containing high-quality outdoor and indoor scans of the ETH Zurich campus.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = "test"
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"eth3d_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="ETH3D",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/eth3d", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = ETH3DWAI(
+ num_views=args.num_of_views,
+ covisibility_thres=0.025,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 336),
+ seed=777,
+ transform="imgnorm",
+ data_norm_type="dinov2",
+ )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "ETH3D_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/gta_sfm.py b/mapanything/datasets/wai/gta_sfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..06122b8f305dd1a63dfb7f992914fb334aa63f56
--- /dev/null
+++ b/mapanything/datasets/wai/gta_sfm.py
@@ -0,0 +1,303 @@
+"""
+GTA SfM Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class GTASfMWAI(BaseDataset):
+ """
+ GTA SfM dataset containing large diversity of synthetic in-the-wild scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"gta_sfm_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="GTASfM",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/gta_sfm", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = GTASfMWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = GTASfMWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 392),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "GTASfM_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/matrixcity.py b/mapanything/datasets/wai/matrixcity.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff1b9262d81c5e79b11cc9720bb3750b89c475c5
--- /dev/null
+++ b/mapanything/datasets/wai/matrixcity.py
@@ -0,0 +1,307 @@
+"""
+Matrix City Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class MatrixCityWAI(BaseDataset):
+ """
+ Matrix City dataset containing large scale aerial & street-view urban synthetic scenes.
+ Depth maps are antialiased and there are floaters at all object boundaries due to interpolation.
+ https://github.com/city-super/MatrixCity/issues/4#issuecomment-3027961575
+ Normal based edge masking doesn't fix this issue completely.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"matrixcity_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="MatrixCity",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/matrixcity", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = MatrixCityWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = MatrixCityWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "MatrixCity_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/megadepth.py b/mapanything/datasets/wai/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..deb87f45872231ad7088fa781a22fd72bed8a084
--- /dev/null
+++ b/mapanything/datasets/wai/megadepth.py
@@ -0,0 +1,309 @@
+"""
+MegaDepth Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class MegaDepthWAI(BaseDataset):
+ """
+ MegaDepth dataset containing outdoor phototourism and in-the-wild scenes.
+ Also includes Tanks & Temples scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = False
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"megadepth_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="MegaDepth",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/megadepth", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = MegaDepthWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 336),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = MegaDepthWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 336),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "MegaDepth_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/mpsd.py b/mapanything/datasets/wai/mpsd.py
new file mode 100644
index 0000000000000000000000000000000000000000..efba1a506fdd9a770d4e94d0ebacd5724ee77ff4
--- /dev/null
+++ b/mapanything/datasets/wai/mpsd.py
@@ -0,0 +1,306 @@
+"""
+MPSD Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class MPSDWAI(BaseDataset):
+ """
+ MPSD dataset containing outdoor planet scale metric reconstructions.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"mpsd_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="MPSD",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-rd", "--root_dir", default="/fsx/xrtech/data/mpsd", type=str)
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = MPSDWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.15,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = MPSDWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.15,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 392),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "MPSD_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/mvs_synth.py b/mapanything/datasets/wai/mvs_synth.py
new file mode 100644
index 0000000000000000000000000000000000000000..ceb6e9c8aa5d430a968626ed67a2abf6cb6ae6b6
--- /dev/null
+++ b/mapanything/datasets/wai/mvs_synth.py
@@ -0,0 +1,303 @@
+"""
+MVS Synth Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class MVSSynthWAI(BaseDataset):
+ """
+ MVS Synth dataset containing large diversity of synthetic in-the-wild scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"mvs_synth_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="MVSSynth",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/mvs_synth", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = MVSSynthWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = MVSSynthWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "MVSSynth_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/paralleldomain4d.py b/mapanything/datasets/wai/paralleldomain4d.py
new file mode 100644
index 0000000000000000000000000000000000000000..091a53e83ffa742347f075993afaa63fa3b194be
--- /dev/null
+++ b/mapanything/datasets/wai/paralleldomain4d.py
@@ -0,0 +1,304 @@
+"""
+Parallel Domain 4D Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class ParallelDomain4DWAI(BaseDataset):
+ """
+ Parallel Domain 4D dataset containing large diversity of synthetic AV scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"paralleldomain4d_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="ParallelDomain4D",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/paralleldomain4d", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = ParallelDomain4DWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 392),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = ParallelDomain4DWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 392),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "ParallelDomain4D_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/sailvos3d.py b/mapanything/datasets/wai/sailvos3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..eff48a6b3a2b5190b4e61386d58ad056bca504b4
--- /dev/null
+++ b/mapanything/datasets/wai/sailvos3d.py
@@ -0,0 +1,303 @@
+"""
+SAIL-VOS 3D Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class SAILVOS3DWAI(BaseDataset):
+ """
+ SAIL-VOS 3D dataset containing large diversity of synthetic in-the-wild cut scenes from GTA.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"sailvos3d_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="SAILVOS3D",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/sailvos3d", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = SAILVOS3DWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 336),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = SAILVOS3DWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 336),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "SAILVOS3D_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/scannetpp.py b/mapanything/datasets/wai/scannetpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ecde6dfe22c08162cbe9ffb1e53c0315187c2da
--- /dev/null
+++ b/mapanything/datasets/wai/scannetpp.py
@@ -0,0 +1,302 @@
+"""
+ScanNet++V2 Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class ScanNetPPWAI(BaseDataset):
+ """
+ ScanNet++V2 dataset containing large diversity of indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = False
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"scannetppv2_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "rendered_depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["rendered_depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="ScanNetPP",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/scannetppv2", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = ScanNetPPWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 336),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = ScanNetPPWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 336),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ # dataset = ScanNetPPWAI(
+ # num_views=args.num_of_views,
+ # split="test",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 336),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "ScanNetPP_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/spring.py b/mapanything/datasets/wai/spring.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bbfd375eb03ab662e838505e5159dbf64d8eef2
--- /dev/null
+++ b/mapanything/datasets/wai/spring.py
@@ -0,0 +1,311 @@
+"""
+Spring Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class SpringWAI(BaseDataset):
+ """
+ Spring dataset containing high-quality large-scale in-the-wild scenes with unique animated objects.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"spring_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ ) # Assumes only npy file in directory is covisbility map
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth", "skymask", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Get the sky mask and mask out GT depth
+ sky_mask = view_data["skymask"].numpy().astype(int)
+ depthmap = np.where(sky_mask, 0, depthmap)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="Spring",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/spring", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = SpringWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = SpringWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "Spring_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/structured3d.py b/mapanything/datasets/wai/structured3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8b54723714b4f3da530a236684b45cc8e6f7bf7
--- /dev/null
+++ b/mapanything/datasets/wai/structured3d.py
@@ -0,0 +1,292 @@
+"""
+Structured3D Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class Structured3DWAI(BaseDataset):
+ """
+ Structured3D dataset containing large diversity of synthetic multi-room indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"structured3d_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="Structured3D",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/structured3d", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = Structured3DWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = Structured3DWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "Structured3D_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/tav2_wb.py b/mapanything/datasets/wai/tav2_wb.py
new file mode 100644
index 0000000000000000000000000000000000000000..f19bc5c00de55b7b16eb53252260519448443dee
--- /dev/null
+++ b/mapanything/datasets/wai/tav2_wb.py
@@ -0,0 +1,323 @@
+"""
+TartanAirV2-WB Dataset using WAI format data.
+"""
+
+import os
+
+import cv2
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class TartanAirV2WBWAI(BaseDataset):
+ """
+ TartanAirV2-WB dataset containing vastly-sized in-the-wild synthetic scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"tav2_wb_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth", "pred_mask/moge2"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Mask out the outlier depth caused due to transparent windows in TartanAirV2
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Get the non_ambiguous_mask and ensure it matches image resolution
+ non_ambiguous_mask = view_data["pred_mask/moge2"].numpy().astype(int)
+ non_ambiguous_mask = cv2.resize(
+ non_ambiguous_mask,
+ (image.shape[1], image.shape[0]),
+ interpolation=cv2.INTER_NEAREST,
+ )
+
+ # Mask out the GT depth using the non_ambiguous_mask
+ depthmap = np.where(non_ambiguous_mask, depthmap, 0)
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="TartanAirV2WB",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/tav2_wb", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = TartanAirV2WBWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 518),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = TartanAirV2WBWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 518),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ # dataset = TartanAirV2WBWAI(
+ # num_views=args.num_of_views,
+ # split="test",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 518),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "TartanAirV2WB_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/unrealstereo4k.py b/mapanything/datasets/wai/unrealstereo4k.py
new file mode 100644
index 0000000000000000000000000000000000000000..99b9d32a094b416509261450bc11058c634001a3
--- /dev/null
+++ b/mapanything/datasets/wai/unrealstereo4k.py
@@ -0,0 +1,304 @@
+"""
+UnrealStereo4K Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class UnrealStereo4KWAI(BaseDataset):
+ """
+ UnrealStereo4K dataset containing synthetic in-the-wild scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"unrealstereo4k_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = image[:, :, :3] # RGBA to RGB
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Get the non ambiguous mask (zero depth pixels are sky or ambiguous)
+ non_ambiguous_mask = (depthmap > 0).astype(int)
+
+ # Mask out the outlier depth (horizon depth)
+ percentile_depth = np.percentile(depthmap, 95)
+ depthmap[depthmap > percentile_depth] = 0
+
+ # Resize the data to match the desired resolution
+ additional_quantities_to_resize = [non_ambiguous_mask]
+ image, depthmap, intrinsics, additional_quantities_to_resize = (
+ self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=additional_quantities_to_resize,
+ )
+ )
+ non_ambiguous_mask = additional_quantities_to_resize[0]
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ non_ambiguous_mask=non_ambiguous_mask,
+ dataset="UnrealStereo4K",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/unrealstereo4k", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = UnrealStereo4KWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 294),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = UnrealStereo4KWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 294),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "UnrealStereo4K_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=len(dataset), replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/datasets/wai/xrooms.py b/mapanything/datasets/wai/xrooms.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8e31b8c59c02507dc1aa90b5b81d4883bf18a35
--- /dev/null
+++ b/mapanything/datasets/wai/xrooms.py
@@ -0,0 +1,300 @@
+"""
+XRooms Dataset using WAI format data.
+"""
+
+import os
+
+import numpy as np
+
+from mapanything.datasets.base.base_dataset import BaseDataset
+from mapanything.utils.wai.core import load_data, load_frame
+
+
+class XRoomsWAI(BaseDataset):
+ """
+ XRooms dataset containing large diversity of synthetic re-lightable indoor scenes.
+ """
+
+ def __init__(
+ self,
+ *args,
+ ROOT,
+ dataset_metadata_dir,
+ split,
+ overfit_num_sets=None,
+ sample_specific_scene: bool = False,
+ specific_scene_name: str = None,
+ **kwargs,
+ ):
+ """
+ Initialize the dataset attributes.
+ Args:
+ ROOT: Root directory of the dataset.
+ dataset_metadata_dir: Path to the dataset metadata directory.
+ split: Dataset split (train, val, test).
+ overfit_num_sets: If None, use all sets. Else, the dataset will be truncated to this number of sets.
+ sample_specific_scene: Whether to sample a specific scene from the dataset.
+ specific_scene_name: Name of the specific scene to sample.
+ """
+ # Initialize the dataset attributes
+ super().__init__(*args, **kwargs)
+ self.ROOT = ROOT
+ self.dataset_metadata_dir = dataset_metadata_dir
+ self.split = split
+ self.overfit_num_sets = overfit_num_sets
+ self.sample_specific_scene = sample_specific_scene
+ self.specific_scene_name = specific_scene_name
+ self._load_data()
+
+ # Define the dataset type flags
+ self.is_metric_scale = True
+ self.is_synthetic = True
+
+ def _load_data(self):
+ "Load the precomputed dataset metadata"
+ # Load the dataset metadata corresponding to the split
+ split_metadata_path = os.path.join(
+ self.dataset_metadata_dir,
+ self.split,
+ f"xrooms_scene_list_{self.split}.npy",
+ )
+ split_scene_list = np.load(split_metadata_path, allow_pickle=True)
+
+ # Get the list of all scenes
+ if not self.sample_specific_scene:
+ self.scenes = list(split_scene_list)
+ else:
+ self.scenes = [self.specific_scene_name]
+ self.num_of_scenes = len(self.scenes)
+
+ def _get_views(self, sampled_idx, num_views_to_sample, resolution):
+ # Get the scene name of the sampled index
+ scene_index = sampled_idx
+ scene_name = self.scenes[scene_index]
+
+ # Get the metadata corresponding to the scene
+ scene_root = os.path.join(self.ROOT, scene_name)
+ scene_meta = load_data(
+ os.path.join(scene_root, "scene_meta.json"), "scene_meta"
+ )
+ scene_file_names = list(scene_meta["frame_names"].keys())
+ num_views_in_scene = len(scene_file_names)
+
+ # Load the scene pairwise covisibility mmap
+ covisibility_version_key = "v0"
+ covisibility_map_dir = os.path.join(
+ scene_root, "covisibility", covisibility_version_key
+ )
+ # Assumes only npy file in directory is covisbility map
+ covisibility_map_name = next(
+ f for f in os.listdir(covisibility_map_dir) if f.endswith(".npy")
+ )
+ covisibility_map_path = os.path.join(
+ scene_root, "covisibility", covisibility_version_key, covisibility_map_name
+ )
+ pairwise_covisibility = load_data(covisibility_map_path, "mmap")
+
+ ### HOTFIX HACK for incompatible covisibility in a few scenes
+ ### TODO: Re-mine covisibility on errorenous scenes
+ if len(pairwise_covisibility) == num_views_in_scene:
+ # Get the indices of the N views in the scene
+ view_indices = self._sample_view_indices(
+ num_views_to_sample, num_views_in_scene, pairwise_covisibility
+ )
+ else:
+ # Get a random view index
+ view_indices = self._rng.choice(num_views_in_scene, size=1, replace=False)
+ # Repeat the view index to get the desired number of views
+ view_indices = np.repeat(view_indices, num_views_to_sample)
+ ### END HOTFIX HACK
+
+ # Get the views corresponding to the selected view indices
+ views = []
+ for view_index in view_indices:
+ # Load the data corresponding to the view
+ view_file_name = scene_file_names[view_index]
+ view_data = load_frame(
+ scene_root,
+ view_file_name,
+ modalities=["image", "depth"],
+ scene_meta=scene_meta,
+ )
+
+ # Convert necessary data to numpy
+ image = view_data["image"].permute(1, 2, 0).numpy()
+ image = (image * 255).astype(np.uint8)
+ depthmap = view_data["depth"].numpy().astype(np.float32)
+ intrinsics = view_data["intrinsics"].numpy().astype(np.float32)
+ c2w_pose = view_data["extrinsics"].numpy().astype(np.float32)
+
+ # Ensure that the depthmap has all valid values
+ depthmap = np.nan_to_num(depthmap, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # Resize the data to match the desired resolution
+ image, depthmap, intrinsics = self._crop_resize_if_necessary(
+ image=image,
+ resolution=resolution,
+ depthmap=depthmap,
+ intrinsics=intrinsics,
+ additional_quantities=None,
+ )
+
+ # Append the view dictionary to the list of views
+ views.append(
+ dict(
+ img=image,
+ depthmap=depthmap,
+ camera_pose=c2w_pose, # cam2world
+ camera_intrinsics=intrinsics,
+ dataset="XRooms",
+ label=scene_name,
+ instance=os.path.join("images", str(view_file_name)),
+ )
+ )
+
+ return views
+
+
+def get_parser():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-rd", "--root_dir", default="/fsx/xrtech/data/xrooms", type=str
+ )
+ parser.add_argument(
+ "-dmd",
+ "--dataset_metadata_dir",
+ default="/fsx/nkeetha/mapanything_dataset_metadata",
+ type=str,
+ )
+ parser.add_argument(
+ "-nv",
+ "--num_of_views",
+ default=2,
+ type=int,
+ )
+ parser.add_argument("--viz", action="store_true")
+
+ return parser
+
+
+if __name__ == "__main__":
+ import rerun as rr
+ from tqdm import tqdm
+
+ from mapanything.datasets.base.base_dataset import view_name
+ from mapanything.utils.image import rgb
+ from mapanything.utils.viz import script_add_rerun_args
+
+ parser = get_parser()
+ script_add_rerun_args(
+ parser
+ ) # Options: --headless, --connect, --serve, --addr, --save, --stdout
+ args = parser.parse_args()
+
+ dataset = XRoomsWAI(
+ num_views=args.num_of_views,
+ split="train",
+ covisibility_thres=0.25,
+ ROOT=args.root_dir,
+ dataset_metadata_dir=args.dataset_metadata_dir,
+ resolution=(518, 518),
+ aug_crop=16,
+ transform="colorjitter+grayscale+gaublur",
+ data_norm_type="dinov2",
+ )
+ # dataset = XRoomsWAI(
+ # num_views=args.num_of_views,
+ # split="val",
+ # covisibility_thres=0.25,
+ # ROOT=args.root_dir,
+ # dataset_metadata_dir=args.dataset_metadata_dir,
+ # resolution=(518, 518),
+ # seed=777,
+ # transform="imgnorm",
+ # data_norm_type="dinov2",
+ # )
+ print(dataset.get_stats())
+
+ if args.viz:
+ rr.script_setup(args, "XRooms_Dataloader")
+ rr.set_time("stable_time", sequence=0)
+ rr.log("world", rr.ViewCoordinates.RDF, static=True)
+
+ sampled_indices = np.random.choice(len(dataset), size=10, replace=False)
+
+ for num, idx in enumerate(tqdm(sampled_indices)):
+ views = dataset[idx]
+ assert len(views) == args.num_of_views
+ sample_name = f"{idx}"
+ for view_idx in range(args.num_of_views):
+ sample_name += f" {view_name(views[view_idx])}"
+ print(sample_name)
+ for view_idx in range(args.num_of_views):
+ image = rgb(
+ views[view_idx]["img"], norm_type=views[view_idx]["data_norm_type"]
+ )
+ depthmap = views[view_idx]["depthmap"]
+ pose = views[view_idx]["camera_pose"]
+ intrinsics = views[view_idx]["camera_intrinsics"]
+ pts3d = views[view_idx]["pts3d"]
+ valid_mask = views[view_idx]["valid_mask"]
+ if "non_ambiguous_mask" in views[view_idx]:
+ non_ambiguous_mask = views[view_idx]["non_ambiguous_mask"]
+ else:
+ non_ambiguous_mask = None
+ if "prior_depth_along_ray" in views[view_idx]:
+ prior_depth_along_ray = views[view_idx]["prior_depth_along_ray"]
+ else:
+ prior_depth_along_ray = None
+ if args.viz:
+ rr.set_time("stable_time", sequence=num)
+ base_name = f"world/view_{view_idx}"
+ pts_name = f"world/view_{view_idx}_pointcloud"
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if prior_depth_along_ray is not None:
+ rr.log(
+ f"prior_depth_along_ray_{view_idx}",
+ rr.DepthImage(prior_depth_along_ray),
+ )
+ if non_ambiguous_mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/non_ambiguous_mask",
+ rr.SegmentationImage(non_ambiguous_mask.astype(int)),
+ )
+ # Log points in 3D
+ filtered_pts = pts3d[valid_mask]
+ filtered_pts_col = image[valid_mask]
+ rr.log(
+ pts_name,
+ rr.Points3D(
+ positions=filtered_pts.reshape(-1, 3),
+ colors=filtered_pts_col.reshape(-1, 3),
+ ),
+ )
diff --git a/mapanything/models/__init__.py b/mapanything/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b9de65a1cb8ff5cb15fab527bc5a7194896f8a
--- /dev/null
+++ b/mapanything/models/__init__.py
@@ -0,0 +1,185 @@
+"""
+Model Factory for MapAnything
+"""
+
+import importlib.util
+import logging
+import warnings
+
+import numpy as np
+from omegaconf import DictConfig, OmegaConf
+
+# Core models that are always available
+from mapanything.models.mapanything import (
+ MapAnything,
+ MapAnythingAblations,
+ ModularDUSt3R,
+)
+
+# Suppress DINOv2 warnings
+logging.getLogger("dinov2").setLevel(logging.WARNING)
+warnings.filterwarnings("ignore", message="xFormers is available", category=UserWarning)
+warnings.filterwarnings(
+ "ignore", message="xFormers is not available", category=UserWarning
+)
+
+
+def resolve_special_float(value):
+ if value == "inf":
+ return np.inf
+ elif value == "-inf":
+ return -np.inf
+ else:
+ raise ValueError(f"Unknown special float value: {value}")
+
+
+def init_model(
+ model_str: str, model_config: DictConfig, torch_hub_force_reload: bool = False
+):
+ """
+ Initialize a model using OmegaConf configuration.
+
+ Args:
+ model_str (str): Name of the model class to create.
+ model_config (DictConfig): OmegaConf model configuration.
+ torch_hub_force_reload (bool): Whether to force reload relevant parts of the model from torch hub.
+ """
+ if not OmegaConf.has_resolver("special_float"):
+ OmegaConf.register_new_resolver("special_float", resolve_special_float)
+ model_dict = OmegaConf.to_container(model_config, resolve=True)
+ model = model_factory(
+ model_str, torch_hub_force_reload=torch_hub_force_reload, **model_dict
+ )
+
+ return model
+
+
+# Define model configurations with import paths
+MODEL_CONFIGS = {
+ # Core models
+ "mapanything": {
+ "class": MapAnything,
+ },
+ "mapanything_ablations": {
+ "class": MapAnythingAblations,
+ },
+ "modular_dust3r": {
+ "class": ModularDUSt3R,
+ },
+ # External models
+ "anycalib": {
+ "module": "mapanything.models.external.anycalib",
+ "class_name": "AnyCalibWrapper",
+ },
+ "dust3r": {
+ "module": "mapanything.models.external.dust3r",
+ "class_name": "DUSt3RBAWrapper",
+ },
+ "mast3r": {
+ "module": "mapanything.models.external.mast3r",
+ "class_name": "MASt3RSGAWrapper",
+ },
+ "moge": {
+ "module": "mapanything.models.external.moge",
+ "class_name": "MoGeWrapper",
+ },
+ "must3r": {
+ "module": "mapanything.models.external.must3r",
+ "class_name": "MUSt3RWrapper",
+ },
+ "pi3": {
+ "module": "mapanything.models.external.pi3",
+ "class_name": "Pi3Wrapper",
+ },
+ "pow3r": {
+ "module": "mapanything.models.external.pow3r",
+ "class_name": "Pow3RWrapper",
+ },
+ "pow3r_ba": {
+ "module": "mapanything.models.external.pow3r",
+ "class_name": "Pow3RBAWrapper",
+ },
+ "vggt": {
+ "module": "mapanything.models.external.vggt",
+ "class_name": "VGGTWrapper",
+ },
+ # Add other model classes here
+}
+
+
+def check_module_exists(module_path):
+ """
+ Check if a module can be imported without actually importing it.
+
+ Args:
+ module_path (str): The path to the module to check.
+
+ Returns:
+ bool: True if the module can be imported, False otherwise.
+ """
+ return importlib.util.find_spec(module_path) is not None
+
+
+def model_factory(model_str: str, **kwargs):
+ """
+ Model factory for MapAnything.
+
+ Args:
+ model_str (str): Name of the model to create.
+ **kwargs: Additional keyword arguments to pass to the model constructor.
+
+ Returns:
+ nn.Module: An instance of the specified model.
+ """
+ if model_str not in MODEL_CONFIGS:
+ raise ValueError(
+ f"Unknown model: {model_str}. Valid options are: {', '.join(MODEL_CONFIGS.keys())}"
+ )
+
+ model_config = MODEL_CONFIGS[model_str]
+
+ # Handle core models directly
+ if "class" in model_config:
+ model_class = model_config["class"]
+ # Handle external models with dynamic imports
+ elif "module" in model_config:
+ module_path = model_config["module"]
+ class_name = model_config["class_name"]
+
+ # Check if the module can be imported
+ if not check_module_exists(module_path):
+ raise ImportError(
+ f"Model '{model_str}' requires module '{module_path}' which is not installed. "
+ f"Please install the corresponding submodule or package."
+ )
+
+ # Dynamically import the module and get the class
+ try:
+ module = importlib.import_module(module_path)
+ model_class = getattr(module, class_name)
+ except (ImportError, AttributeError) as e:
+ raise ImportError(
+ f"Failed to import {class_name} from {module_path}: {str(e)}"
+ )
+ else:
+ raise ValueError(f"Invalid model configuration for {model_str}")
+
+ print(f"Initializing {model_class} with kwargs: {kwargs}")
+ if model_str != "org_dust3r":
+ return model_class(**kwargs)
+ else:
+ eval_str = kwargs.get("model_eval_str", None)
+ return eval(eval_str)
+
+
+def get_available_models() -> list:
+ """
+ Get a list of available models in MapAnything.
+
+ Returns:
+ list: A list of available model names.
+ """
+ return list(MODEL_CONFIGS.keys())
+
+
+__all__ = ["model_factory", "get_available_models"]
diff --git a/mapanything/models/mapanything/__init__.py b/mapanything/models/mapanything/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..13e6352b2a6e6a5374d5e8a079ef65fc503816e8
--- /dev/null
+++ b/mapanything/models/mapanything/__init__.py
@@ -0,0 +1,9 @@
+from mapanything.models.mapanything.ablations import MapAnythingAblations
+from mapanything.models.mapanything.model import MapAnything
+from mapanything.models.mapanything.modular_dust3r import ModularDUSt3R
+
+__all__ = [
+ "MapAnything",
+ "MapAnythingAblations",
+ "ModularDUSt3R",
+]
diff --git a/mapanything/models/mapanything/ablations.py b/mapanything/models/mapanything/ablations.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bf4cb633d30b09815157d17087cc3eca93b39a4
--- /dev/null
+++ b/mapanything/models/mapanything/ablations.py
@@ -0,0 +1,1655 @@
+"""
+MapAnything Ablation model classes defined using UniCeption modules.
+"""
+
+from functools import partial
+from typing import Callable, Dict, Type, Union
+
+import torch
+import torch.nn as nn
+
+from mapanything.utils.geometry import (
+ apply_log_to_norm,
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ normalize_depth_using_non_zero_pixels,
+ normalize_pose_translations,
+ transform_pose_using_quats_and_trans_2_to_1,
+)
+from uniception.models.encoders import (
+ encoder_factory,
+ EncoderGlobalRepInput,
+ ViTEncoderInput,
+ ViTEncoderNonImageInput,
+)
+from uniception.models.info_sharing.alternating_attention_transformer import (
+ MultiViewAlternatingAttentionTransformer,
+ MultiViewAlternatingAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.base import MultiViewTransformerInput
+from uniception.models.info_sharing.cross_attention_transformer import (
+ MultiViewCrossAttentionTransformer,
+ MultiViewCrossAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.global_attention_transformer import (
+ MultiViewGlobalAttentionTransformer,
+ MultiViewGlobalAttentionTransformerIFR,
+)
+from uniception.models.libs.croco.pos_embed import RoPE2D
+from uniception.models.prediction_heads.adaptors import (
+ CamTranslationPlusQuatsAdaptor,
+ PointMapAdaptor,
+ PointMapPlusRayDirectionsPlusDepthAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor,
+ PointMapWithConfidenceAdaptor,
+ PointMapWithConfidenceAndMaskAdaptor,
+ PointMapWithMaskAdaptor,
+ RayDirectionsPlusDepthAdaptor,
+ RayDirectionsPlusDepthWithConfidenceAdaptor,
+ RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor,
+ RayDirectionsPlusDepthWithMaskAdaptor,
+ RayMapPlusDepthAdaptor,
+ RayMapPlusDepthWithConfidenceAdaptor,
+ RayMapPlusDepthWithConfidenceAndMaskAdaptor,
+ RayMapPlusDepthWithMaskAdaptor,
+)
+from uniception.models.prediction_heads.base import (
+ AdaptorInput,
+ PredictionHeadInput,
+ PredictionHeadLayeredInput,
+)
+from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
+from uniception.models.prediction_heads.linear import LinearFeature
+from uniception.models.prediction_heads.pose_head import PoseHead
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+class MapAnythingAblations(nn.Module):
+ "Modular MapAnything Multi-View model class with no scale token."
+
+ def __init__(
+ self,
+ name: str,
+ encoder_config: Dict,
+ info_sharing_config: Dict,
+ pred_head_config: Dict,
+ geometric_input_config: Dict,
+ fusion_norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(
+ nn.LayerNorm, eps=1e-6
+ ),
+ pretrained_checkpoint_path: str = None,
+ load_specific_pretrained_submodules: bool = False,
+ specific_pretrained_submodules: list = [],
+ torch_hub_force_reload: bool = False,
+ ):
+ """
+ Multi-view model containing an image encoder followed by a multi-view attention transformer and respective downstream heads.
+ The goal is to output scene representation directly in view 0's frame.
+
+ Args:
+ name (str): Name of the model.
+ encoder_config (Dict): Configuration for the encoder.
+ info_sharing_config (Dict): Configuration for the multi-view attention transformer.
+ pred_head_config (Dict): Configuration for the prediction heads.
+ pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
+ load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False)
+ specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: [])
+ torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False)
+ """
+ super().__init__()
+
+ # Initalize the attributes
+ self.name = name
+ self.encoder_config = encoder_config
+ self.info_sharing_config = info_sharing_config
+ self.pred_head_config = pred_head_config
+ self.geometric_input_config = geometric_input_config
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
+ self.load_specific_pretrained_submodules = load_specific_pretrained_submodules
+ self.specific_pretrained_submodules = specific_pretrained_submodules
+ self.torch_hub_force_reload = torch_hub_force_reload
+ self.class_init_args = {
+ "name": self.name,
+ "encoder_config": self.encoder_config,
+ "info_sharing_config": self.info_sharing_config,
+ "pred_head_config": self.pred_head_config,
+ "geometric_input_config": self.geometric_input_config,
+ "pretrained_checkpoint_path": self.pretrained_checkpoint_path,
+ "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules,
+ "specific_pretrained_submodules": self.specific_pretrained_submodules,
+ "torch_hub_force_reload": self.torch_hub_force_reload,
+ }
+
+ # Get relevant parameters from the configs
+ self.info_sharing_type = info_sharing_config["model_type"]
+ self.info_sharing_return_type = info_sharing_config["model_return_type"]
+ self.pred_head_type = pred_head_config["type"]
+
+ # Initialize image encoder
+ if self.encoder_config["uses_torch_hub"]:
+ self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
+ # Create a copy of the config before deleting the key to preserve it for serialization
+ encoder_config_copy = self.encoder_config.copy()
+ del encoder_config_copy["uses_torch_hub"]
+ self.encoder = encoder_factory(**encoder_config_copy)
+
+ # Initialize the encoder for ray directions
+ ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
+ ray_dirs_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ ray_dirs_encoder_config["patch_size"] = self.encoder.patch_size
+ self.ray_dirs_encoder = encoder_factory(**ray_dirs_encoder_config)
+
+ # Initialize the encoder for depth (normalized per view and values after normalization are scaled logarithmically)
+ depth_encoder_config = self.geometric_input_config["depth_encoder_config"]
+ depth_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ depth_encoder_config["patch_size"] = self.encoder.patch_size
+ self.depth_encoder = encoder_factory(**depth_encoder_config)
+
+ # Initialize the encoder for log scale factor of depth
+ depth_scale_encoder_config = self.geometric_input_config["scale_encoder_config"]
+ depth_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.depth_scale_encoder = encoder_factory(**depth_scale_encoder_config)
+
+ # Initialize the encoder for camera rotation
+ cam_rot_encoder_config = self.geometric_input_config["cam_rot_encoder_config"]
+ cam_rot_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_rot_encoder = encoder_factory(**cam_rot_encoder_config)
+
+ # Initialize the encoder for camera translation (normalized across all provided camera translations)
+ cam_trans_encoder_config = self.geometric_input_config[
+ "cam_trans_encoder_config"
+ ]
+ cam_trans_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_trans_encoder = encoder_factory(**cam_trans_encoder_config)
+
+ # Initialize the encoder for log scale factor of camera translation
+ cam_trans_scale_encoder_config = self.geometric_input_config[
+ "scale_encoder_config"
+ ]
+ cam_trans_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_trans_scale_encoder = encoder_factory(**cam_trans_scale_encoder_config)
+
+ # Initialize the fusion norm layer
+ self.fusion_norm_layer = fusion_norm_layer(self.encoder.enc_embed_dim)
+
+ # Initialize the info sharing module (Multi-View Transformer)
+ self._initialize_info_sharing(info_sharing_config)
+
+ # Initialize the prediction heads
+ self._initialize_prediction_heads(pred_head_config)
+
+ # Initialize the final adaptors
+ self._initialize_adaptors(pred_head_config)
+
+ # Load pretrained weights
+ self._load_pretrained_weights()
+
+ def _initialize_info_sharing(self, info_sharing_config):
+ """
+ Initialize the information sharing module based on the configuration.
+
+ This method sets up the custom positional encoding if specified and initializes
+ the appropriate multi-view transformer based on the configuration type.
+
+ Args:
+ info_sharing_config (Dict): Configuration for the multi-view attention transformer.
+ Should contain 'custom_positional_encoding', 'model_type', and 'model_return_type'.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If invalid configuration options are provided.
+ """
+ # Initialize Custom Positional Encoding if required
+ custom_positional_encoding = info_sharing_config["custom_positional_encoding"]
+ if custom_positional_encoding is not None:
+ if isinstance(custom_positional_encoding, str):
+ print(
+ f"Using custom positional encoding for multi-view attention transformer: {custom_positional_encoding}"
+ )
+ if custom_positional_encoding.startswith("RoPE"):
+ rope_freq = float(custom_positional_encoding[len("RoPE") :])
+ print(f"RoPE frequency: {rope_freq}")
+ self.custom_positional_encoding = RoPE2D(freq=rope_freq)
+ else:
+ raise ValueError(
+ f"Invalid custom_positional_encoding: {custom_positional_encoding}."
+ )
+ elif isinstance(custom_positional_encoding, Callable):
+ print(
+ "Using callable function as custom positional encoding for multi-view attention transformer."
+ )
+ self.custom_positional_encoding = custom_positional_encoding
+ else:
+ self.custom_positional_encoding = None
+
+ # Add dependecies to info_sharing_config
+ info_sharing_config["module_args"]["input_embed_dim"] = (
+ self.encoder.enc_embed_dim
+ )
+ info_sharing_config["module_args"]["custom_positional_encoding"] = (
+ self.custom_positional_encoding
+ )
+
+ # Initialize Multi-View Transformer
+ if self.info_sharing_return_type == "no_intermediate_features":
+ # Returns only normalized last layer features
+ # Intialize multi-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ elif self.info_sharing_return_type == "intermediate_features":
+ # Returns intermediate features and normalized last layer features
+ # Initialize mulit-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ # Assess if the DPT needs to use encoder features
+ if len(self.info_sharing.indices) == 2:
+ self.use_encoder_features_for_dpt = True
+ elif len(self.info_sharing.indices) == 3:
+ self.use_encoder_features_for_dpt = False
+ else:
+ raise ValueError(
+ "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices."
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']"
+ )
+
+ def _initialize_prediction_heads(self, pred_head_config):
+ """
+ Initialize the prediction heads based on the prediction head configuration.
+
+ This method configures and initializes the appropriate prediction heads based on the
+ specified prediction head type (linear, DPT, or DPT+pose). It sets up the necessary
+ dependencies and creates the required model components.
+
+ Args:
+ pred_head_config (Dict): Configuration for the prediction heads.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If an invalid pred_head_type is provided.
+ """
+ # Add dependencies to prediction head config
+ pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size
+ if self.pred_head_type == "linear":
+ pred_head_config["feature_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ elif "dpt" in self.pred_head_type:
+ # Add dependencies for DPT & Regressor head
+ if self.use_encoder_features_for_dpt:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.encoder.enc_embed_dim
+ ] + [self.info_sharing.dim] * 3
+ else:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.info_sharing.dim
+ ] * 4
+ pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[
+ "feature_head"
+ ]["feature_dim"]
+ # Add dependencies for Pose head if required
+ if "pose" in self.pred_head_type:
+ pred_head_config["pose_head"]["patch_size"] = self.encoder.patch_size
+ pred_head_config["pose_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ # Initialize Prediction Heads
+ if self.pred_head_type == "linear":
+ # Initialize Dense Prediction Head for all views
+ self.dense_head = LinearFeature(**pred_head_config["feature_head"])
+ elif "dpt" in self.pred_head_type:
+ # Initialze Dense Predction Head for all views
+ self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"])
+ self.dpt_regressor_head = DPTRegressionProcessor(
+ **pred_head_config["regressor_head"]
+ )
+ self.dense_head = nn.Sequential(
+ self.dpt_feature_head, self.dpt_regressor_head
+ )
+ # Initialize Pose Head for all views if required
+ if "pose" in self.pred_head_type:
+ self.pose_head = PoseHead(**pred_head_config["pose_head"])
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ def _initialize_adaptors(self, pred_head_config):
+ """
+ Initialize the adaptors based on the prediction head configuration.
+
+ This method sets up the appropriate adaptors for different scene representation types,
+ such as pointmaps, ray maps with depth, or ray directions with depth and pose.
+
+ Args:
+ pred_head_config (Dict): Configuration for the prediction heads including adaptor type.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If an invalid adaptor_type is provided.
+ AssertionError: If ray directions + depth + pose is used with an incompatible head type.
+ """
+ if pred_head_config["adaptor_type"] == "pointmap":
+ self.dense_adaptor = PointMapAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap"
+ elif pred_head_config["adaptor_type"] == "pointmap+confidence":
+ self.dense_adaptor = PointMapWithConfidenceAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "pointmap+confidence"
+ elif pred_head_config["adaptor_type"] == "pointmap+mask":
+ self.dense_adaptor = PointMapWithMaskAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap+mask"
+ elif pred_head_config["adaptor_type"] == "pointmap+confidence+mask":
+ self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "pointmap+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "raymap+depth":
+ self.dense_adaptor = RayMapPlusDepthAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "raymap+depth"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+confidence":
+ self.dense_adaptor = RayMapPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+confidence"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+mask":
+ self.dense_adaptor = RayMapPlusDepthWithMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+mask"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+confidence+mask":
+ self.dense_adaptor = RayMapPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+mask"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapAdaptor(**pred_head_config["dpt_adaptor"])
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+mask"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose"
+ elif (
+ pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+confidence"
+ ):
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = (
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+mask"
+ elif (
+ pred_head_config["adaptor_type"]
+ == "pointmap+raydirs+depth+pose+confidence+mask"
+ ):
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = (
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence+mask"
+ else:
+ raise ValueError(
+ f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. \
+ Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \
+ 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \
+ 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \
+ 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']"
+ )
+
+ def _load_pretrained_weights(self):
+ """
+ Load pretrained weights from a checkpoint file.
+
+ If load_specific_pretrained_submodules is True, only loads weights for the specified submodules.
+ Otherwise, loads all weights from the checkpoint.
+
+ Returns:
+ None
+ """
+ if self.pretrained_checkpoint_path is not None:
+ if not self.load_specific_pretrained_submodules:
+ print(
+ f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ print(self.load_state_dict(ckpt["model"]))
+ else:
+ print(
+ f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} for specific submodules: {self.specific_pretrained_submodules} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ filtered_ckpt = {}
+ for ckpt_key, ckpt_value in ckpt["model"].items():
+ for submodule in self.specific_pretrained_submodules:
+ if ckpt_key.startswith(submodule):
+ filtered_ckpt[ckpt_key] = ckpt_value
+ print(self.load_state_dict(filtered_ckpt, strict=False))
+
+ def _encode_n_views(self, views):
+ """
+ Encode all the input views (batch of images) in a single forward pass.
+ Assumes all the input views have the same image shape, batch size, and data normalization type.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+
+ Returns:
+ List[torch.Tensor]: A list containing the encoded features for all N views.
+ """
+ num_views = len(views)
+ data_norm_type = views[0]["data_norm_type"][0]
+ imgs_list = [view["img"] for view in views]
+ all_imgs_across_views = torch.cat(imgs_list, dim=0)
+ encoder_input = ViTEncoderInput(
+ image=all_imgs_across_views, data_norm_type=data_norm_type
+ )
+ encoder_output = self.encoder(encoder_input)
+ all_encoder_features_across_views = encoder_output.features.chunk(
+ num_views, dim=0
+ )
+
+ return all_encoder_features_across_views
+
+ def _compute_pose_quats_and_trans_for_across_views_in_ref_view(
+ self,
+ views,
+ num_views,
+ device,
+ dtype,
+ batch_size_per_view,
+ per_sample_cam_input_mask,
+ ):
+ """
+ Compute the pose quats and trans for all the views in the frame of the reference view 0.
+ Returns identity pose for views where the camera input mask is False or the pose is not provided.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ device (torch.device): Device to use for the computation.
+ dtype (torch.dtype): Data type to use for the computation.
+ per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4)
+ torch.Tensor: A tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3)
+ torch.Tensor: A tensor containing the per sample camera input mask.
+ """
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ pose_quats_non_ref_views = []
+ pose_trans_non_ref_views = []
+ pose_quats_ref_view_0 = []
+ pose_trans_ref_view_0 = []
+ for view_idx in range(num_views):
+ per_sample_cam_input_mask_for_curr_view = per_sample_cam_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ]
+ if (
+ "camera_pose_quats" in views[view_idx]
+ and "camera_pose_trans" in views[view_idx]
+ and per_sample_cam_input_mask_for_curr_view.any()
+ ):
+ # Get the camera pose quats and trans for the current view
+ cam_pose_quats = views[view_idx]["camera_pose_quats"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ cam_pose_trans = views[view_idx]["camera_pose_trans"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ # Append to the list
+ pose_quats_non_ref_views.append(cam_pose_quats)
+ pose_trans_non_ref_views.append(cam_pose_trans)
+ # Get the camera pose quats and trans for the reference view 0
+ cam_pose_quats = views[0]["camera_pose_quats"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ cam_pose_trans = views[0]["camera_pose_trans"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ # Append to the list
+ pose_quats_ref_view_0.append(cam_pose_quats)
+ pose_trans_ref_view_0.append(cam_pose_trans)
+ else:
+ per_sample_cam_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+
+ # Initialize the pose quats and trans for all views as identity
+ pose_quats_across_views = torch.tensor(
+ [0.0, 0.0, 0.0, 1.0], dtype=dtype, device=device
+ ).repeat(batch_size_per_view * num_views, 1) # (q_x, q_y, q_z, q_w)
+ pose_trans_across_views = torch.zeros(
+ (batch_size_per_view * num_views, 3), dtype=dtype, device=device
+ )
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ if len(pose_quats_non_ref_views) > 0:
+ # Stack the pose quats and trans for all the non-reference views and reference view 0
+ pose_quats_non_ref_views = torch.cat(pose_quats_non_ref_views, dim=0)
+ pose_trans_non_ref_views = torch.cat(pose_trans_non_ref_views, dim=0)
+ pose_quats_ref_view_0 = torch.cat(pose_quats_ref_view_0, dim=0)
+ pose_trans_ref_view_0 = torch.cat(pose_trans_ref_view_0, dim=0)
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ (
+ pose_quats_non_ref_views_in_ref_view_0,
+ pose_trans_non_ref_views_in_ref_view_0,
+ ) = transform_pose_using_quats_and_trans_2_to_1(
+ pose_quats_ref_view_0,
+ pose_trans_ref_view_0,
+ pose_quats_non_ref_views,
+ pose_trans_non_ref_views,
+ )
+
+ # Update the pose quats and trans for all the non-reference views
+ pose_quats_across_views[per_sample_cam_input_mask] = (
+ pose_quats_non_ref_views_in_ref_view_0.to(dtype=dtype)
+ )
+ pose_trans_across_views[per_sample_cam_input_mask] = (
+ pose_trans_non_ref_views_in_ref_view_0.to(dtype=dtype)
+ )
+
+ return (
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ )
+
+ def _encode_and_fuse_ray_dirs(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_ray_dirs_input_mask,
+ ):
+ """
+ Encode the ray directions for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ per_sample_ray_dirs_input_mask (torch.Tensor): Tensor containing the per sample ray direction input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Get the height and width of the images
+ _, _, height, width = views[0]["img"].shape
+
+ # Get the ray directions for all the views where info is provided and the ray direction input mask is True
+ ray_dirs_list = []
+ for view_idx in range(num_views):
+ per_sample_ray_dirs_input_mask_for_curr_view = (
+ per_sample_ray_dirs_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ]
+ )
+ ray_dirs_for_curr_view = torch.zeros(
+ (batch_size_per_view, height, width, 3),
+ dtype=all_encoder_features_across_views.dtype,
+ device=all_encoder_features_across_views.device,
+ )
+ if (
+ "ray_directions_cam" in views[view_idx]
+ and per_sample_ray_dirs_input_mask_for_curr_view.any()
+ ):
+ ray_dirs_for_curr_view[per_sample_ray_dirs_input_mask_for_curr_view] = (
+ views[view_idx]["ray_directions_cam"][
+ per_sample_ray_dirs_input_mask_for_curr_view
+ ]
+ )
+ else:
+ per_sample_ray_dirs_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+ ray_dirs_list.append(ray_dirs_for_curr_view)
+
+ # Stack the ray directions for all the views and permute to (B * V, C, H, W)
+ ray_dirs = torch.cat(ray_dirs_list, dim=0) # (B * V, H, W, 3)
+ ray_dirs = ray_dirs.permute(0, 3, 1, 2).contiguous() # (B * V, 3, H, W)
+
+ # Encode the ray directions
+ ray_dirs_features_across_views = self.ray_dirs_encoder(
+ ViTEncoderNonImageInput(data=ray_dirs)
+ ).features
+
+ # Fuse the ray direction features with the other encoder features (zero out the features where the ray direction input mask is False)
+ ray_dirs_features_across_views = (
+ ray_dirs_features_across_views
+ * per_sample_ray_dirs_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ )
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views + ray_dirs_features_across_views
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_depths(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_depth_input_mask,
+ ):
+ """
+ Encode the z depths for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ per_sample_depth_input_mask (torch.Tensor): Tensor containing the per sample depth input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Get the device and height and width of the images
+ device = all_encoder_features_across_views.device
+ _, _, height, width = views[0]["img"].shape
+
+ # Decide to use randomly sampled sparse depth or dense depth
+ if torch.rand(1) < self.geometric_input_config["sparse_depth_prob"]:
+ use_sparse_depth = True
+ else:
+ use_sparse_depth = False
+
+ # Get the depths for all the views
+ depth_list = []
+ depth_norm_factors_list = []
+ metric_scale_depth_mask_list = []
+ for view_idx in range(num_views):
+ # Get the input mask for current view
+ per_sample_depth_input_mask_for_curr_view = per_sample_depth_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ]
+ depth_for_curr_view = torch.zeros(
+ (batch_size_per_view, height, width, 1),
+ dtype=all_encoder_features_across_views.dtype,
+ device=device,
+ )
+ depth_norm_factor_for_curr_view = torch.zeros(
+ (batch_size_per_view),
+ dtype=all_encoder_features_across_views.dtype,
+ device=device,
+ )
+ metric_scale_mask_for_curr_view = torch.zeros(
+ (batch_size_per_view),
+ dtype=torch.bool,
+ device=device,
+ )
+ if (
+ "depth_along_ray" in views[view_idx]
+ ) and per_sample_depth_input_mask_for_curr_view.any():
+ # Get depth for current view
+ depth_for_curr_view_input = views[view_idx]["depth_along_ray"][
+ per_sample_depth_input_mask_for_curr_view
+ ]
+ # Get the metric scale mask
+ if "is_metric_scale" in views[view_idx]:
+ metric_scale_mask = views[view_idx]["is_metric_scale"][
+ per_sample_depth_input_mask_for_curr_view
+ ]
+ else:
+ metric_scale_mask = torch.zeros(
+ depth_for_curr_view_input.shape[0],
+ dtype=torch.bool,
+ device=device,
+ )
+ # Turn off indication of metric scale samples based on the depth_scale_norm_all_prob
+ depth_scale_norm_all_mask = (
+ torch.rand(metric_scale_mask.shape[0])
+ < self.geometric_input_config["depth_scale_norm_all_prob"]
+ )
+ if depth_scale_norm_all_mask.any():
+ metric_scale_mask[depth_scale_norm_all_mask] = False
+ # Assign the metric scale mask to the respective indices
+ metric_scale_mask_for_curr_view[
+ per_sample_depth_input_mask_for_curr_view
+ ] = metric_scale_mask
+ # Sparsely sample the depth if required
+ if use_sparse_depth:
+ # Create a mask of ones
+ sparsification_mask = torch.ones_like(
+ depth_for_curr_view_input, device=device
+ )
+ # Create a mask for valid pixels (depth > 0)
+ valid_pixel_mask = depth_for_curr_view_input > 0
+ # Calculate the number of valid pixels
+ num_valid_pixels = valid_pixel_mask.sum().item()
+ # Calculate the number of valid pixels to set to zero
+ num_to_zero = int(
+ num_valid_pixels
+ * self.geometric_input_config["sparsification_removal_percent"]
+ )
+ if num_to_zero > 0:
+ # Get the indices of valid pixels
+ valid_indices = valid_pixel_mask.nonzero(as_tuple=True)
+ # Randomly select indices to zero out
+ indices_to_zero = torch.randperm(num_valid_pixels)[:num_to_zero]
+ # Set selected valid indices to zero in the mask
+ sparsification_mask[
+ valid_indices[0][indices_to_zero],
+ valid_indices[1][indices_to_zero],
+ valid_indices[2][indices_to_zero],
+ valid_indices[3][indices_to_zero],
+ ] = 0
+ # Apply the mask on the depth
+ depth_for_curr_view_input = (
+ depth_for_curr_view_input * sparsification_mask
+ )
+ # Normalize the depth
+ scaled_depth_for_curr_view_input, depth_norm_factor = (
+ normalize_depth_using_non_zero_pixels(
+ depth_for_curr_view_input, return_norm_factor=True
+ )
+ )
+ # Assign the depth and depth norm factor to the respective indices
+ depth_for_curr_view[per_sample_depth_input_mask_for_curr_view] = (
+ scaled_depth_for_curr_view_input
+ )
+ depth_norm_factor_for_curr_view[
+ per_sample_depth_input_mask_for_curr_view
+ ] = depth_norm_factor
+ else:
+ per_sample_depth_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+ # Append the depths, depth norm factor and metric scale mask for the current view
+ depth_list.append(depth_for_curr_view)
+ depth_norm_factors_list.append(depth_norm_factor_for_curr_view)
+ metric_scale_depth_mask_list.append(metric_scale_mask_for_curr_view)
+
+ # Stack the depths for all the views and permute to (B * V, C, H, W)
+ depths = torch.cat(depth_list, dim=0) # (B * V, H, W, 1)
+ depths = apply_log_to_norm(
+ depths
+ ) # Scale logarithimically (norm is computed along last dim)
+ depths = depths.permute(0, 3, 1, 2).contiguous() # (B * V, 1, H, W)
+ # Encode the depths using the depth encoder
+ depth_features_across_views = self.depth_encoder(
+ ViTEncoderNonImageInput(data=depths)
+ ).features
+ # Zero out the depth features where the depth input mask is False
+ depth_features_across_views = (
+ depth_features_across_views
+ * per_sample_depth_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ )
+
+ # Stack the depth norm factors for all the views
+ depth_norm_factors = torch.cat(depth_norm_factors_list, dim=0) # (B * V, )
+ # Encode the depth norm factors using the log scale encoder for depth
+ log_depth_norm_factors = torch.log(depth_norm_factors + 1e-8) # (B * V, )
+ depth_scale_features_across_views = self.depth_scale_encoder(
+ EncoderGlobalRepInput(data=log_depth_norm_factors.unsqueeze(-1))
+ ).features
+ # Zero out the depth scale features where the depth input mask is False
+ depth_scale_features_across_views = (
+ depth_scale_features_across_views
+ * per_sample_depth_input_mask.unsqueeze(-1)
+ )
+ # Stack the metric scale mask for all the views
+ metric_scale_depth_mask = torch.cat(
+ metric_scale_depth_mask_list, dim=0
+ ) # (B * V, )
+ # Zero out the depth scale features where the metric scale mask is False
+ # Scale encoding is only provided for metric scale samples
+ depth_scale_features_across_views = (
+ depth_scale_features_across_views * metric_scale_depth_mask.unsqueeze(-1)
+ )
+
+ # Fuse the depth features & depth scale features with the other encoder features
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views
+ + depth_features_across_views
+ + depth_scale_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_cam_quats_and_trans(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ ):
+ """
+ Encode the camera quats and trans for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ pose_quats_across_views (torch.Tensor): Tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4)
+ pose_trans_across_views (torch.Tensor): Tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3)
+ per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Encode the pose quats
+ pose_quats_features_across_views = self.cam_rot_encoder(
+ EncoderGlobalRepInput(data=pose_quats_across_views)
+ ).features
+ # Zero out the pose quat features where the camera input mask is False
+ pose_quats_features_across_views = (
+ pose_quats_features_across_views * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+
+ # Get the metric scale mask for all samples
+ device = all_encoder_features_across_views.device
+ metric_scale_pose_trans_mask = torch.zeros(
+ (batch_size_per_view * num_views), dtype=torch.bool, device=device
+ )
+ for view_idx in range(num_views):
+ if "is_metric_scale" in views[view_idx]:
+ # Get the metric scale mask for the input pose priors
+ metric_scale_mask = views[view_idx]["is_metric_scale"]
+ else:
+ metric_scale_mask = torch.zeros(
+ batch_size_per_view, dtype=torch.bool, device=device
+ )
+ metric_scale_pose_trans_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ] = metric_scale_mask
+
+ # Turn off indication of metric scale samples based on the pose_scale_norm_all_prob
+ pose_norm_all_mask = (
+ torch.rand(batch_size_per_view * num_views)
+ < self.geometric_input_config["pose_scale_norm_all_prob"]
+ )
+ if pose_norm_all_mask.any():
+ metric_scale_pose_trans_mask[pose_norm_all_mask] = False
+
+ # Get the scale norm factor for all the samples and scale the pose translations
+ pose_trans_across_views = torch.split(
+ pose_trans_across_views, batch_size_per_view, dim=0
+ ) # Split into num_views chunks
+ pose_trans_across_views = torch.stack(
+ pose_trans_across_views, dim=1
+ ) # Stack the views along a new dimension (batch_size_per_view, num_views, 3)
+ scaled_pose_trans_across_views, pose_trans_norm_factors = (
+ normalize_pose_translations(
+ pose_trans_across_views, return_norm_factor=True
+ )
+ )
+
+ # Resize the pose translation back to (batch_size_per_view * num_views, 3) and extend the norm factor to (batch_size_per_view * num_views, 1)
+ scaled_pose_trans_across_views = scaled_pose_trans_across_views.unbind(
+ dim=1
+ ) # Convert back to list of views, where each view has batch_size_per_view tensor
+ scaled_pose_trans_across_views = torch.cat(
+ scaled_pose_trans_across_views, dim=0
+ ) # Concatenate back to (batch_size_per_view * num_views, 3)
+ pose_trans_norm_factors_across_views = pose_trans_norm_factors.unsqueeze(
+ -1
+ ).repeat(num_views, 1) # (B, ) -> (B * V, 1)
+
+ # Encode the pose trans
+ pose_trans_features_across_views = self.cam_trans_encoder(
+ EncoderGlobalRepInput(data=scaled_pose_trans_across_views)
+ ).features
+ # Zero out the pose trans features where the camera input mask is False
+ pose_trans_features_across_views = (
+ pose_trans_features_across_views * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+
+ # Encode the pose translation norm factors using the log scale encoder for pose trans
+ log_pose_trans_norm_factors_across_views = torch.log(
+ pose_trans_norm_factors_across_views + 1e-8
+ )
+ pose_trans_scale_features_across_views = self.cam_trans_scale_encoder(
+ EncoderGlobalRepInput(data=log_pose_trans_norm_factors_across_views)
+ ).features
+ # Zero out the pose trans scale features where the camera input mask is False
+ pose_trans_scale_features_across_views = (
+ pose_trans_scale_features_across_views
+ * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+ # Zero out the pose trans scale features where the metric scale mask is False
+ # Scale encoding is only provided for metric scale samples
+ pose_trans_scale_features_across_views = (
+ pose_trans_scale_features_across_views
+ * metric_scale_pose_trans_mask.unsqueeze(-1)
+ )
+
+ # Fuse the pose quat features, pose trans features, pose trans scale features and pose trans type PE features with the other encoder features
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views
+ + pose_quats_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ + pose_trans_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ + pose_trans_scale_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_optional_geometric_inputs(
+ self, views, all_encoder_features_across_views_list
+ ):
+ """
+ Encode all the input optional geometric modalities and fuses it with the image encoder features in a single forward pass.
+ Assumes all the input views have the same shape and batch size.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ all_encoder_features_across_views (List[torch.Tensor]): List of tensors containing the encoded image features for all N views.
+
+ Returns:
+ List[torch.Tensor]: A list containing the encoded features for all N views.
+ """
+ num_views = len(views)
+ batch_size_per_view, _, _, _ = views[0]["img"].shape
+ device = all_encoder_features_across_views_list[0].device
+ dtype = all_encoder_features_across_views_list[0].dtype
+ all_encoder_features_across_views = torch.cat(
+ all_encoder_features_across_views_list, dim=0
+ )
+
+ # Get the overall input mask for all the views
+ overall_geometric_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["overall_prob"]
+ )
+ overall_geometric_input_mask = overall_geometric_input_mask.repeat(num_views)
+
+ # Get the per sample input mask after dropout
+ # Per sample input mask is in view-major order so that index v*B + b in each mask corresponds to sample b of view v: (B * V)
+ per_sample_geometric_input_mask = torch.rand(
+ batch_size_per_view * num_views, device=device
+ ) < (1 - self.geometric_input_config["dropout_prob"])
+ per_sample_geometric_input_mask = (
+ per_sample_geometric_input_mask & overall_geometric_input_mask
+ )
+
+ # Get the ray direction input mask
+ per_sample_ray_dirs_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["ray_dirs_prob"]
+ )
+ per_sample_ray_dirs_input_mask = per_sample_ray_dirs_input_mask.repeat(
+ num_views
+ )
+ per_sample_ray_dirs_input_mask = (
+ per_sample_ray_dirs_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Get the depth input mask
+ per_sample_depth_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["depth_prob"]
+ )
+ per_sample_depth_input_mask = per_sample_depth_input_mask.repeat(num_views)
+ per_sample_depth_input_mask = (
+ per_sample_depth_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Get the camera input mask
+ per_sample_cam_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["cam_prob"]
+ )
+ per_sample_cam_input_mask = per_sample_cam_input_mask.repeat(num_views)
+ per_sample_cam_input_mask = (
+ per_sample_cam_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ # Returned pose quats and trans represent identity pose for views/samples where the camera input mask is False
+ pose_quats_across_views, pose_trans_across_views, per_sample_cam_input_mask = (
+ self._compute_pose_quats_and_trans_for_across_views_in_ref_view(
+ views,
+ num_views,
+ device,
+ dtype,
+ batch_size_per_view,
+ per_sample_cam_input_mask,
+ )
+ )
+
+ # Encode the ray directions and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_ray_dirs(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_ray_dirs_input_mask,
+ )
+
+ # Encode the depths and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_depths(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_depth_input_mask,
+ )
+
+ # Encode the cam quat and trans and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_cam_quats_and_trans(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ )
+
+ # Normalize the fused features (permute -> normalize -> permute)
+ all_encoder_features_across_views = all_encoder_features_across_views.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ all_encoder_features_across_views = self.fusion_norm_layer(
+ all_encoder_features_across_views
+ )
+ all_encoder_features_across_views = all_encoder_features_across_views.permute(
+ 0, 3, 1, 2
+ ).contiguous()
+
+ # Split the batched views into individual views
+ fused_all_encoder_features_across_views = (
+ all_encoder_features_across_views.chunk(num_views, dim=0)
+ )
+
+ return fused_all_encoder_features_across_views
+
+ def forward(self, views):
+ """
+ Forward pass performing the following operations:
+ 1. Encodes the N input views (images).
+ 2. Encodes the optional geometric inputs (ray directions, depths, camera rotations, camera translations).
+ 3. Fuses the encoded features from the N input views and the optional geometric inputs using addition and normalization.
+ 4. Information sharing between the encoded features using a multi-view attention transformer.
+ 5. Passes the final features through the prediction heads.
+ 6. Returns the processed final outputs for N views.
+
+ Assumption:
+ - All the input views have the same image shape.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W). Input images must be normalized based on the data norm type of image encoder.
+ "data_norm_type" (list): [model.encoder.data_norm_type]
+
+ Returns:
+ List[dict]: A list containing the final outputs for all N views.
+ """
+ # Get input shape of the images, number of views, and batch size per view
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ img_shape = (int(height), int(width))
+ num_views = len(views)
+
+ # Run the encoder on all the input views
+ all_encoder_features_across_views = self._encode_n_views(views)
+
+ # Encode the optional geometric inputs and fuse with the encoded features from the N input views
+ # Use high precision to prevent NaN values after layer norm in dense representation encoder (due to high variance in last dim of features)
+ with torch.autocast("cuda", enabled=False):
+ all_encoder_features_across_views = (
+ self._encode_and_fuse_optional_geometric_inputs(
+ views, all_encoder_features_across_views
+ )
+ )
+
+ # Combine all images into view-centric representation
+ info_sharing_input = MultiViewTransformerInput(
+ features=all_encoder_features_across_views
+ )
+ if self.info_sharing_return_type == "no_intermediate_features":
+ final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
+ elif self.info_sharing_return_type == "intermediate_features":
+ (
+ final_info_sharing_multi_view_feat,
+ intermediate_info_sharing_multi_view_feat,
+ ) = self.info_sharing(info_sharing_input)
+
+ if self.pred_head_type == "linear":
+ # Stack the features for all views
+ dense_head_inputs = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
+ # Get the list of features for all views
+ dense_head_inputs_list = []
+ if self.use_encoder_features_for_dpt:
+ # Stack all the image encoder features for all views
+ stacked_encoder_features = torch.cat(
+ all_encoder_features_across_views, dim=0
+ )
+ dense_head_inputs_list.append(stacked_encoder_features)
+ # Stack the first intermediate features for all views
+ stacked_intermediate_features_1 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[0].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_1)
+ # Stack the second intermediate features for all views
+ stacked_intermediate_features_2 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[1].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_2)
+ # Stack the last layer features for all views
+ stacked_final_features = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_final_features)
+ else:
+ # Stack the first intermediate features for all views
+ stacked_intermediate_features_1 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[0].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_1)
+ # Stack the second intermediate features for all views
+ stacked_intermediate_features_2 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[1].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_2)
+ # Stack the third intermediate features for all views
+ stacked_intermediate_features_3 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[2].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_3)
+ # Stack the last layer
+ stacked_final_features = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_final_features)
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ # Downstream task prediction
+ with torch.autocast("cuda", enabled=False):
+ # Run Prediction Heads & Post-Process Outputs
+ if self.pred_head_type == "linear":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadInput(last_feature=dense_head_inputs)
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ elif self.pred_head_type == "dpt":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadLayeredInput(
+ list_features=dense_head_inputs_list,
+ target_output_shape=img_shape,
+ )
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ elif self.pred_head_type == "dpt+pose":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadLayeredInput(
+ list_features=dense_head_inputs_list,
+ target_output_shape=img_shape,
+ )
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ pose_head_outputs = self.pose_head(
+ PredictionHeadInput(last_feature=dense_head_inputs_list[-1])
+ )
+ pose_final_outputs = self.pose_adaptor(
+ AdaptorInput(
+ adaptor_feature=pose_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ # Prepare the final scene representation for all views
+ if self.scene_rep_type in [
+ "pointmap",
+ "pointmap+confidence",
+ "pointmap+mask",
+ "pointmap+confidence+mask",
+ ]:
+ output_pts3d = dense_final_outputs.value
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_pts3d = output_pts3d.permute(0, 2, 3, 1).contiguous()
+ # Split the predicted pointmaps back to their respective views
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append({"pts3d": output_pts3d_per_view[i]})
+ elif self.scene_rep_type in [
+ "raymap+depth",
+ "raymap+depth+confidence",
+ "raymap+depth+mask",
+ "raymap+depth+confidence+mask",
+ ]:
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_scene_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted ray origins, directions, and depths along rays
+ output_ray_origins, output_ray_directions, output_depth_along_ray = (
+ output_scene_rep.split([3, 3, 1], dim=-1)
+ )
+ # Get the predicted pointmaps
+ output_pts3d = (
+ output_ray_origins + output_ray_directions * output_depth_along_ray
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_origins_per_view = output_ray_origins.chunk(num_views, dim=0)
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i],
+ "ray_origins": output_ray_origins_per_view[i],
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i],
+ }
+ )
+ elif self.scene_rep_type in [
+ "raydirs+depth+pose",
+ "raydirs+depth+pose+confidence",
+ "raydirs+depth+pose+mask",
+ "raydirs+depth+pose+confidence+mask",
+ ]:
+ # Reshape output dense rep to (B * V, H, W, C)
+ output_dense_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted ray directions and depths along rays
+ output_ray_directions, output_depth_along_ray = output_dense_rep.split(
+ [3, 1], dim=-1
+ )
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the predicted pointmaps in world frame and camera frame
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ output_pts3d_cam = output_ray_directions * output_depth_along_ray
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i],
+ "pts3d_cam": output_pts3d_cam_per_view[i],
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i],
+ "cam_trans": output_cam_translations_per_view[i],
+ "cam_quats": output_cam_quats_per_view[i],
+ }
+ )
+ elif self.scene_rep_type in [
+ "campointmap+pose",
+ "campointmap+pose+confidence",
+ "campointmap+pose+mask",
+ "campointmap+pose+confidence+mask",
+ ]:
+ # Get the predicted camera frame pointmaps
+ output_pts3d_cam = dense_final_outputs.value
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_pts3d_cam = output_pts3d_cam.permute(0, 2, 3, 1).contiguous()
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the ray directions and depths along rays
+ output_depth_along_ray = torch.norm(
+ output_pts3d_cam, dim=-1, keepdim=True
+ )
+ output_ray_directions = output_pts3d_cam / output_depth_along_ray
+ # Get the predicted pointmaps in world frame
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i],
+ "pts3d_cam": output_pts3d_cam_per_view[i],
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i],
+ "cam_trans": output_cam_translations_per_view[i],
+ "cam_quats": output_cam_quats_per_view[i],
+ }
+ )
+ elif self.scene_rep_type in [
+ "pointmap+raydirs+depth+pose",
+ "pointmap+raydirs+depth+pose+confidence",
+ "pointmap+raydirs+depth+pose+mask",
+ "pointmap+raydirs+depth+pose+confidence+mask",
+ ]:
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_dense_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted pointmaps, ray directions and depths along rays
+ output_pts3d, output_ray_directions, output_depth_along_ray = (
+ output_dense_rep.split([3, 3, 1], dim=-1)
+ )
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the predicted pointmaps in camera frame
+ output_pts3d_cam = output_ray_directions * output_depth_along_ray
+ # Replace the predicted world-frame pointmaps if required
+ if self.pred_head_config["adaptor_config"][
+ "use_factored_predictions_for_global_pointmaps"
+ ]:
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i],
+ "pts3d_cam": output_pts3d_cam_per_view[i],
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i],
+ "cam_trans": output_cam_translations_per_view[i],
+ "cam_quats": output_cam_quats_per_view[i],
+ }
+ )
+ else:
+ raise ValueError(
+ f"Invalid scene_rep_type: {self.scene_rep_type}. \
+ Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \
+ 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \
+ 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \
+ 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']"
+ )
+
+ # Get the output confidences for all views (if available) and add them to the result
+ if "confidence" in self.scene_rep_type:
+ output_confidences = dense_final_outputs.confidence
+ # Reshape confidences to (B * V, H, W)
+ output_confidences = (
+ output_confidences.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ # Split the predicted confidences back to their respective views
+ output_confidences_per_view = output_confidences.chunk(num_views, dim=0)
+ # Add the confidences to the result
+ for i in range(num_views):
+ res[i]["conf"] = output_confidences_per_view[i]
+
+ # Get the output masks (and logits) for all views (if available) and add them to the result
+ if "mask" in self.scene_rep_type:
+ # Get the output masks
+ output_masks = dense_final_outputs.mask
+ # Reshape masks to (B * V, H, W)
+ output_masks = output_masks.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ # Threshold the masks at 0.5 to get binary masks (0: ambiguous/invalid, 1: non-ambiguous/valid)
+ output_masks = output_masks > 0.5
+ # Split the predicted masks back to their respective views
+ output_masks_per_view = output_masks.chunk(num_views, dim=0)
+ # Get the output mask logits (for loss)
+ output_mask_logits = dense_final_outputs.logits
+ # Reshape mask logits to (B * V, H, W)
+ output_mask_logits = (
+ output_mask_logits.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ # Split the predicted mask logits back to their respective views
+ output_mask_logits_per_view = output_mask_logits.chunk(num_views, dim=0)
+ # Add the masks and logits to the result
+ for i in range(num_views):
+ res[i]["non_ambiguous_mask"] = output_masks_per_view[i]
+ res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i]
+
+ return res
diff --git a/mapanything/models/mapanything/model.py b/mapanything/models/mapanything/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de75cf5be57a05fa97b41f877264113f27e11c3
--- /dev/null
+++ b/mapanything/models/mapanything/model.py
@@ -0,0 +1,2107 @@
+"""
+MapAnything model class defined using UniCeption modules.
+"""
+
+import warnings
+from functools import partial
+from typing import Any, Callable, Dict, List, Tuple, Type, Union
+
+import torch
+import torch.nn as nn
+from huggingface_hub import PyTorchModelHubMixin
+
+from mapanything.utils.geometry import (
+ apply_log_to_norm,
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ normalize_depth_using_non_zero_pixels,
+ normalize_pose_translations,
+ transform_pose_using_quats_and_trans_2_to_1,
+)
+from mapanything.utils.inference import (
+ postprocess_model_outputs_for_inference,
+ preprocess_input_views_for_inference,
+ validate_input_views_for_inference,
+)
+from uniception.models.encoders import (
+ encoder_factory,
+ EncoderGlobalRepInput,
+ ViTEncoderInput,
+ ViTEncoderNonImageInput,
+)
+from uniception.models.info_sharing.alternating_attention_transformer import (
+ MultiViewAlternatingAttentionTransformer,
+ MultiViewAlternatingAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.base import MultiViewTransformerInput
+from uniception.models.info_sharing.cross_attention_transformer import (
+ MultiViewCrossAttentionTransformer,
+ MultiViewCrossAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.global_attention_transformer import (
+ MultiViewGlobalAttentionTransformer,
+ MultiViewGlobalAttentionTransformerIFR,
+)
+from uniception.models.prediction_heads.adaptors import (
+ CamTranslationPlusQuatsAdaptor,
+ PointMapAdaptor,
+ PointMapPlusRayDirectionsPlusDepthAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor,
+ PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor,
+ PointMapWithConfidenceAdaptor,
+ PointMapWithConfidenceAndMaskAdaptor,
+ PointMapWithMaskAdaptor,
+ RayDirectionsPlusDepthAdaptor,
+ RayDirectionsPlusDepthWithConfidenceAdaptor,
+ RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor,
+ RayDirectionsPlusDepthWithMaskAdaptor,
+ RayMapPlusDepthAdaptor,
+ RayMapPlusDepthWithConfidenceAdaptor,
+ RayMapPlusDepthWithConfidenceAndMaskAdaptor,
+ RayMapPlusDepthWithMaskAdaptor,
+ ScaleAdaptor,
+)
+from uniception.models.prediction_heads.base import (
+ AdaptorInput,
+ PredictionHeadInput,
+ PredictionHeadLayeredInput,
+ PredictionHeadTokenInput,
+)
+from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
+from uniception.models.prediction_heads.linear import LinearFeature
+from uniception.models.prediction_heads.mlp_head import MLPHead
+from uniception.models.prediction_heads.pose_head import PoseHead
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+class MapAnything(nn.Module, PyTorchModelHubMixin):
+ "Modular MapAnything model class that supports input of images & optional geometric modalities (multiple reconstruction tasks)."
+
+ def __init__(
+ self,
+ name: str,
+ encoder_config: Dict,
+ info_sharing_config: Dict,
+ pred_head_config: Dict,
+ geometric_input_config: Dict,
+ fusion_norm_layer: Union[Type[nn.Module], Callable[..., nn.Module]] = partial(
+ nn.LayerNorm, eps=1e-6
+ ),
+ pretrained_checkpoint_path: str = None,
+ load_specific_pretrained_submodules: bool = False,
+ specific_pretrained_submodules: list = None,
+ torch_hub_force_reload: bool = False,
+ ):
+ """
+ Multi-view model containing an image encoder fused with optional geometric modalities followed by a multi-view attention transformer and respective downstream heads.
+ The goal is to output scene representation.
+ The multi-view attention transformer also takes as input a scale token to predict the metric scaling factor for the predicted scene representation.
+
+ Args:
+ name (str): Name of the model.
+ encoder_config (Dict): Configuration for the encoder.
+ info_sharing_config (Dict): Configuration for the multi-view attention transformer.
+ pred_head_config (Dict): Configuration for the prediction heads.
+ geometric_input_config (Dict): Configuration for the input of optional geometric modalities.
+ fusion_norm_layer (Union[Type[nn.Module], Callable[..., nn.Module]]): Normalization layer to use after fusion (addition) of encoder and geometric modalities. (default: partial(nn.LayerNorm, eps=1e-6))
+ pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
+ load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False)
+ specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: None)
+ torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False)
+ """
+ super().__init__()
+
+ # Initalize the attributes
+ self.name = name
+ self.encoder_config = encoder_config
+ self.info_sharing_config = info_sharing_config
+ self.pred_head_config = pred_head_config
+ self.geometric_input_config = geometric_input_config
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
+ self.load_specific_pretrained_submodules = load_specific_pretrained_submodules
+ self.specific_pretrained_submodules = specific_pretrained_submodules
+ self.torch_hub_force_reload = torch_hub_force_reload
+ self.class_init_args = {
+ "name": self.name,
+ "encoder_config": self.encoder_config,
+ "info_sharing_config": self.info_sharing_config,
+ "pred_head_config": self.pred_head_config,
+ "geometric_input_config": self.geometric_input_config,
+ "pretrained_checkpoint_path": self.pretrained_checkpoint_path,
+ "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules,
+ "specific_pretrained_submodules": self.specific_pretrained_submodules,
+ "torch_hub_force_reload": self.torch_hub_force_reload,
+ }
+
+ # Get relevant parameters from the configs
+ self.info_sharing_type = info_sharing_config["model_type"]
+ self.info_sharing_return_type = info_sharing_config["model_return_type"]
+ self.pred_head_type = pred_head_config["type"]
+
+ # Initialize image encoder
+ if self.encoder_config["uses_torch_hub"]:
+ self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
+ # Create a copy of the config before deleting the key to preserve it for serialization
+ encoder_config_copy = self.encoder_config.copy()
+ del encoder_config_copy["uses_torch_hub"]
+ self.encoder = encoder_factory(**encoder_config_copy)
+
+ # Initialize the encoder for ray directions
+ ray_dirs_encoder_config = self.geometric_input_config["ray_dirs_encoder_config"]
+ ray_dirs_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ ray_dirs_encoder_config["patch_size"] = self.encoder.patch_size
+ self.ray_dirs_encoder = encoder_factory(**ray_dirs_encoder_config)
+
+ # Initialize the encoder for depth (normalized per view and values after normalization are scaled logarithmically)
+ depth_encoder_config = self.geometric_input_config["depth_encoder_config"]
+ depth_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ depth_encoder_config["patch_size"] = self.encoder.patch_size
+ self.depth_encoder = encoder_factory(**depth_encoder_config)
+
+ # Initialize the encoder for log scale factor of depth
+ depth_scale_encoder_config = self.geometric_input_config["scale_encoder_config"]
+ depth_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.depth_scale_encoder = encoder_factory(**depth_scale_encoder_config)
+
+ # Initialize the encoder for camera rotation
+ cam_rot_encoder_config = self.geometric_input_config["cam_rot_encoder_config"]
+ cam_rot_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_rot_encoder = encoder_factory(**cam_rot_encoder_config)
+
+ # Initialize the encoder for camera translation (normalized across all provided camera translations)
+ cam_trans_encoder_config = self.geometric_input_config[
+ "cam_trans_encoder_config"
+ ]
+ cam_trans_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_trans_encoder = encoder_factory(**cam_trans_encoder_config)
+
+ # Initialize the encoder for log scale factor of camera translation
+ cam_trans_scale_encoder_config = self.geometric_input_config[
+ "scale_encoder_config"
+ ]
+ cam_trans_scale_encoder_config["enc_embed_dim"] = self.encoder.enc_embed_dim
+ self.cam_trans_scale_encoder = encoder_factory(**cam_trans_scale_encoder_config)
+
+ # Initialize the fusion norm layer
+ self.fusion_norm_layer = fusion_norm_layer(self.encoder.enc_embed_dim)
+
+ # Initialize the Scale Token
+ # Used to scale the final scene predictions to metric scale
+ # During inference extended to (B, C, T), where T is the number of tokens (i.e., 1)
+ self.scale_token = nn.Parameter(torch.zeros(self.encoder.enc_embed_dim))
+ torch.nn.init.trunc_normal_(self.scale_token, std=0.02)
+
+ # Initialize the info sharing module (multi-view transformer)
+ self._initialize_info_sharing(info_sharing_config)
+
+ # Initialize the prediction heads
+ self._initialize_prediction_heads(pred_head_config)
+
+ # Initialize the final adaptors
+ self._initialize_adaptors(pred_head_config)
+
+ # Load pretrained weights
+ self._load_pretrained_weights()
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ return next(self.parameters()).dtype
+
+ def _initialize_info_sharing(self, info_sharing_config):
+ """
+ Initialize the information sharing module based on the configuration.
+
+ This method sets up the custom positional encoding if specified and initializes
+ the appropriate multi-view transformer based on the configuration type.
+
+ Args:
+ info_sharing_config (Dict): Configuration for the multi-view attention transformer.
+ Should contain 'custom_positional_encoding', 'model_type', and 'model_return_type'.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If invalid configuration options are provided.
+ """
+ # Initialize Custom Positional Encoding if required
+ custom_positional_encoding = info_sharing_config["custom_positional_encoding"]
+ if custom_positional_encoding is not None:
+ if isinstance(custom_positional_encoding, str):
+ print(
+ f"Using custom positional encoding for multi-view attention transformer: {custom_positional_encoding}"
+ )
+ raise ValueError(
+ f"Invalid custom_positional_encoding: {custom_positional_encoding}. None implemented."
+ )
+ elif isinstance(custom_positional_encoding, Callable):
+ print(
+ "Using callable function as custom positional encoding for multi-view attention transformer."
+ )
+ self.custom_positional_encoding = custom_positional_encoding
+ else:
+ self.custom_positional_encoding = None
+
+ # Add dependecies to info_sharing_config
+ info_sharing_config["module_args"]["input_embed_dim"] = (
+ self.encoder.enc_embed_dim
+ )
+ info_sharing_config["module_args"]["custom_positional_encoding"] = (
+ self.custom_positional_encoding
+ )
+
+ # Initialize Multi-View Transformer
+ if self.info_sharing_return_type == "no_intermediate_features":
+ # Returns only normalized last layer features
+ # Intialize multi-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ elif self.info_sharing_return_type == "intermediate_features":
+ # Returns intermediate features and normalized last layer features
+ # Initialize mulit-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ # Assess if the DPT needs to use encoder features
+ if len(self.info_sharing.indices) == 2:
+ self.use_encoder_features_for_dpt = True
+ elif len(self.info_sharing.indices) == 3:
+ self.use_encoder_features_for_dpt = False
+ else:
+ raise ValueError(
+ "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices."
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']"
+ )
+
+ def _initialize_prediction_heads(self, pred_head_config):
+ """
+ Initialize the prediction heads based on the prediction head configuration.
+
+ This method configures and initializes the appropriate prediction heads based on the
+ specified prediction head type (linear, DPT, or DPT+pose). It sets up the necessary
+ dependencies and creates the required model components.
+
+ Args:
+ pred_head_config (Dict): Configuration for the prediction heads.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If an invalid pred_head_type is provided.
+ """
+ # Add dependencies to prediction head config
+ pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size
+ if self.pred_head_type == "linear":
+ pred_head_config["feature_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ elif "dpt" in self.pred_head_type:
+ # Add dependencies for DPT & Regressor head
+ if self.use_encoder_features_for_dpt:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.encoder.enc_embed_dim
+ ] + [self.info_sharing.dim] * 3
+ else:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.info_sharing.dim
+ ] * 4
+ pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[
+ "feature_head"
+ ]["feature_dim"]
+ # Add dependencies for Pose head if required
+ if "pose" in self.pred_head_type:
+ pred_head_config["pose_head"]["patch_size"] = self.encoder.patch_size
+ pred_head_config["pose_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+ pred_head_config["scale_head"]["input_feature_dim"] = self.info_sharing.dim
+
+ # Initialize Prediction Heads
+ if self.pred_head_type == "linear":
+ # Initialize Dense Prediction Head for all views
+ self.dense_head = LinearFeature(**pred_head_config["feature_head"])
+ elif "dpt" in self.pred_head_type:
+ # Initialze Dense Predction Head for all views
+ self.dpt_feature_head = DPTFeature(**pred_head_config["feature_head"])
+ self.dpt_regressor_head = DPTRegressionProcessor(
+ **pred_head_config["regressor_head"]
+ )
+ self.dense_head = nn.Sequential(
+ self.dpt_feature_head, self.dpt_regressor_head
+ )
+ # Initialize Pose Head for all views if required
+ if "pose" in self.pred_head_type:
+ self.pose_head = PoseHead(**pred_head_config["pose_head"])
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+ self.scale_head = MLPHead(**pred_head_config["scale_head"])
+
+ def _initialize_adaptors(self, pred_head_config):
+ """
+ Initialize the adaptors based on the prediction head configuration.
+
+ This method sets up the appropriate adaptors for different scene representation types,
+ such as pointmaps, ray maps with depth, or ray directions with depth and pose.
+
+ Args:
+ pred_head_config (Dict): Configuration for the prediction heads including adaptor type.
+
+ Returns:
+ None
+
+ Raises:
+ ValueError: If an invalid adaptor_type is provided.
+ AssertionError: If ray directions + depth + pose is used with an incompatible head type.
+ """
+ if pred_head_config["adaptor_type"] == "pointmap":
+ self.dense_adaptor = PointMapAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap"
+ elif pred_head_config["adaptor_type"] == "pointmap+confidence":
+ self.dense_adaptor = PointMapWithConfidenceAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "pointmap+confidence"
+ elif pred_head_config["adaptor_type"] == "pointmap+mask":
+ self.dense_adaptor = PointMapWithMaskAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap+mask"
+ elif pred_head_config["adaptor_type"] == "pointmap+confidence+mask":
+ self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "pointmap+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "raymap+depth":
+ self.dense_adaptor = RayMapPlusDepthAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "raymap+depth"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+confidence":
+ self.dense_adaptor = RayMapPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+confidence"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+mask":
+ self.dense_adaptor = RayMapPlusDepthWithMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+mask"
+ elif pred_head_config["adaptor_type"] == "raymap+depth+confidence+mask":
+ self.dense_adaptor = RayMapPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["adaptor"]
+ )
+ self.scene_rep_type = "raymap+depth+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+mask"
+ elif pred_head_config["adaptor_type"] == "raydirs+depth+pose+confidence+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = RayDirectionsPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "raydirs+depth+pose+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapAdaptor(**pred_head_config["dpt_adaptor"])
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+mask"
+ elif pred_head_config["adaptor_type"] == "campointmap+pose+confidence+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Camera pointmap + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "campointmap+pose+confidence+mask"
+ elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose":
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose"
+ elif (
+ pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+confidence"
+ ):
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = (
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence"
+ elif pred_head_config["adaptor_type"] == "pointmap+raydirs+depth+pose+mask":
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = PointMapPlusRayDirectionsPlusDepthWithMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+mask"
+ elif (
+ pred_head_config["adaptor_type"]
+ == "pointmap+raydirs+depth+pose+confidence+mask"
+ ):
+ assert self.pred_head_type == "dpt+pose", (
+ "Pointmap + ray directions + depth + pose can only be used as scene representation with dpt + pose head."
+ )
+ self.dense_adaptor = (
+ PointMapPlusRayDirectionsPlusDepthWithConfidenceAndMaskAdaptor(
+ **pred_head_config["dpt_adaptor"]
+ )
+ )
+ self.pose_adaptor = CamTranslationPlusQuatsAdaptor(
+ **pred_head_config["pose_adaptor"]
+ )
+ self.scene_rep_type = "pointmap+raydirs+depth+pose+confidence+mask"
+ else:
+ raise ValueError(
+ f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. \
+ Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \
+ 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \
+ 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \
+ 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']"
+ )
+ self.scale_adaptor = ScaleAdaptor(**pred_head_config["scale_adaptor"])
+
+ def _load_pretrained_weights(self):
+ """
+ Load pretrained weights from a checkpoint file.
+
+ If load_specific_pretrained_submodules is True, only loads weights for the specified submodules.
+ Otherwise, loads all weights from the checkpoint.
+
+ Returns:
+ None
+ """
+ if self.pretrained_checkpoint_path is not None:
+ if not self.load_specific_pretrained_submodules:
+ print(
+ f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ print(self.load_state_dict(ckpt["model"]))
+ else:
+ print(
+ f"Loading pretrained MapAnything weights from {self.pretrained_checkpoint_path} for specific submodules: {self.specific_pretrained_submodules} ..."
+ )
+ assert self.pred_head_type is not None, (
+ "Specific submodules to load cannot be None."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ filtered_ckpt = {}
+ for ckpt_key, ckpt_value in ckpt["model"].items():
+ for submodule in self.specific_pretrained_submodules:
+ if ckpt_key.startswith(submodule):
+ filtered_ckpt[ckpt_key] = ckpt_value
+ print(self.load_state_dict(filtered_ckpt, strict=False))
+
+ def _encode_n_views(self, views):
+ """
+ Encode all the input views (batch of images) in a single forward pass.
+ Assumes all the input views have the same image shape, batch size, and data normalization type.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+
+ Returns:
+ List[torch.Tensor]: A list containing the encoded features for all N views.
+ """
+ num_views = len(views)
+ data_norm_type = views[0]["data_norm_type"][0]
+ imgs_list = [view["img"] for view in views]
+ all_imgs_across_views = torch.cat(imgs_list, dim=0)
+ encoder_input = ViTEncoderInput(
+ image=all_imgs_across_views, data_norm_type=data_norm_type
+ )
+ encoder_output = self.encoder(encoder_input)
+ all_encoder_features_across_views = encoder_output.features.chunk(
+ num_views, dim=0
+ )
+
+ return all_encoder_features_across_views
+
+ def _compute_pose_quats_and_trans_for_across_views_in_ref_view(
+ self,
+ views,
+ num_views,
+ device,
+ dtype,
+ batch_size_per_view,
+ per_sample_cam_input_mask,
+ ):
+ """
+ Compute the pose quats and trans for all the views in the frame of the reference view 0.
+ Returns identity pose for views where the camera input mask is False or the pose is not provided.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ device (torch.device): Device to use for the computation.
+ dtype (torch.dtype): Data type to use for the computation.
+ per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4)
+ torch.Tensor: A tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3)
+ torch.Tensor: A tensor containing the per sample camera input mask.
+ """
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ pose_quats_non_ref_views = []
+ pose_trans_non_ref_views = []
+ pose_quats_ref_view_0 = []
+ pose_trans_ref_view_0 = []
+ for view_idx in range(num_views):
+ per_sample_cam_input_mask_for_curr_view = per_sample_cam_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ]
+ if (
+ "camera_pose_quats" in views[view_idx]
+ and "camera_pose_trans" in views[view_idx]
+ and per_sample_cam_input_mask_for_curr_view.any()
+ ):
+ # Get the camera pose quats and trans for the current view
+ cam_pose_quats = views[view_idx]["camera_pose_quats"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ cam_pose_trans = views[view_idx]["camera_pose_trans"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ # Append to the list
+ pose_quats_non_ref_views.append(cam_pose_quats)
+ pose_trans_non_ref_views.append(cam_pose_trans)
+ # Get the camera pose quats and trans for the reference view 0
+ cam_pose_quats = views[0]["camera_pose_quats"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ cam_pose_trans = views[0]["camera_pose_trans"][
+ per_sample_cam_input_mask_for_curr_view
+ ]
+ # Append to the list
+ pose_quats_ref_view_0.append(cam_pose_quats)
+ pose_trans_ref_view_0.append(cam_pose_trans)
+ else:
+ per_sample_cam_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+
+ # Initialize the pose quats and trans for all views as identity
+ pose_quats_across_views = torch.tensor(
+ [0.0, 0.0, 0.0, 1.0], dtype=dtype, device=device
+ ).repeat(batch_size_per_view * num_views, 1) # (q_x, q_y, q_z, q_w)
+ pose_trans_across_views = torch.zeros(
+ (batch_size_per_view * num_views, 3), dtype=dtype, device=device
+ )
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ if len(pose_quats_non_ref_views) > 0:
+ # Stack the pose quats and trans for all the non-reference views and reference view 0
+ pose_quats_non_ref_views = torch.cat(pose_quats_non_ref_views, dim=0)
+ pose_trans_non_ref_views = torch.cat(pose_trans_non_ref_views, dim=0)
+ pose_quats_ref_view_0 = torch.cat(pose_quats_ref_view_0, dim=0)
+ pose_trans_ref_view_0 = torch.cat(pose_trans_ref_view_0, dim=0)
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ (
+ pose_quats_non_ref_views_in_ref_view_0,
+ pose_trans_non_ref_views_in_ref_view_0,
+ ) = transform_pose_using_quats_and_trans_2_to_1(
+ pose_quats_ref_view_0,
+ pose_trans_ref_view_0,
+ pose_quats_non_ref_views,
+ pose_trans_non_ref_views,
+ )
+
+ # Update the pose quats and trans for all the non-reference views
+ pose_quats_across_views[per_sample_cam_input_mask] = (
+ pose_quats_non_ref_views_in_ref_view_0.to(dtype=dtype)
+ )
+ pose_trans_across_views[per_sample_cam_input_mask] = (
+ pose_trans_non_ref_views_in_ref_view_0.to(dtype=dtype)
+ )
+
+ return (
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ )
+
+ def _encode_and_fuse_ray_dirs(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_ray_dirs_input_mask,
+ ):
+ """
+ Encode the ray directions for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ per_sample_ray_dirs_input_mask (torch.Tensor): Tensor containing the per sample ray direction input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Get the height and width of the images
+ _, _, height, width = views[0]["img"].shape
+
+ # Get the ray directions for all the views where info is provided and the ray direction input mask is True
+ ray_dirs_list = []
+ for view_idx in range(num_views):
+ per_sample_ray_dirs_input_mask_for_curr_view = (
+ per_sample_ray_dirs_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ]
+ )
+ ray_dirs_for_curr_view = torch.zeros(
+ (batch_size_per_view, height, width, 3),
+ dtype=all_encoder_features_across_views.dtype,
+ device=all_encoder_features_across_views.device,
+ )
+ if (
+ "ray_directions_cam" in views[view_idx]
+ and per_sample_ray_dirs_input_mask_for_curr_view.any()
+ ):
+ ray_dirs_for_curr_view[per_sample_ray_dirs_input_mask_for_curr_view] = (
+ views[view_idx]["ray_directions_cam"][
+ per_sample_ray_dirs_input_mask_for_curr_view
+ ]
+ )
+ else:
+ per_sample_ray_dirs_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+ ray_dirs_list.append(ray_dirs_for_curr_view)
+
+ # Stack the ray directions for all the views and permute to (B * V, C, H, W)
+ ray_dirs = torch.cat(ray_dirs_list, dim=0) # (B * V, H, W, 3)
+ ray_dirs = ray_dirs.permute(0, 3, 1, 2).contiguous() # (B * V, 3, H, W)
+
+ # Encode the ray directions
+ ray_dirs_features_across_views = self.ray_dirs_encoder(
+ ViTEncoderNonImageInput(data=ray_dirs)
+ ).features
+
+ # Fuse the ray direction features with the other encoder features (zero out the features where the ray direction input mask is False)
+ ray_dirs_features_across_views = (
+ ray_dirs_features_across_views
+ * per_sample_ray_dirs_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ )
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views + ray_dirs_features_across_views
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_depths(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_depth_input_mask,
+ ):
+ """
+ Encode the z depths for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ per_sample_depth_input_mask (torch.Tensor): Tensor containing the per sample depth input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Get the device and height and width of the images
+ device = all_encoder_features_across_views.device
+ _, _, height, width = views[0]["img"].shape
+
+ # Decide to use randomly sampled sparse depth or dense depth
+ if torch.rand(1) < self.geometric_input_config["sparse_depth_prob"]:
+ use_sparse_depth = True
+ else:
+ use_sparse_depth = False
+
+ # Get the depths for all the views
+ depth_list = []
+ depth_norm_factors_list = []
+ metric_scale_depth_mask_list = []
+ for view_idx in range(num_views):
+ # Get the input mask for current view
+ per_sample_depth_input_mask_for_curr_view = per_sample_depth_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ]
+ depth_for_curr_view = torch.zeros(
+ (batch_size_per_view, height, width, 1),
+ dtype=all_encoder_features_across_views.dtype,
+ device=device,
+ )
+ depth_norm_factor_for_curr_view = torch.zeros(
+ (batch_size_per_view),
+ dtype=all_encoder_features_across_views.dtype,
+ device=device,
+ )
+ metric_scale_mask_for_curr_view = torch.zeros(
+ (batch_size_per_view),
+ dtype=torch.bool,
+ device=device,
+ )
+ if (
+ "depth_along_ray" in views[view_idx]
+ ) and per_sample_depth_input_mask_for_curr_view.any():
+ # Get depth for current view
+ depth_for_curr_view_input = views[view_idx]["depth_along_ray"][
+ per_sample_depth_input_mask_for_curr_view
+ ]
+ # Get the metric scale mask
+ if "is_metric_scale" in views[view_idx]:
+ metric_scale_mask = views[view_idx]["is_metric_scale"][
+ per_sample_depth_input_mask_for_curr_view
+ ]
+ else:
+ metric_scale_mask = torch.zeros(
+ depth_for_curr_view_input.shape[0],
+ dtype=torch.bool,
+ device=device,
+ )
+ # Turn off indication of metric scale samples based on the depth_scale_norm_all_prob
+ depth_scale_norm_all_mask = (
+ torch.rand(metric_scale_mask.shape[0])
+ < self.geometric_input_config["depth_scale_norm_all_prob"]
+ )
+ if depth_scale_norm_all_mask.any():
+ metric_scale_mask[depth_scale_norm_all_mask] = False
+ # Assign the metric scale mask to the respective indices
+ metric_scale_mask_for_curr_view[
+ per_sample_depth_input_mask_for_curr_view
+ ] = metric_scale_mask
+ # Sparsely sample the depth if required
+ if use_sparse_depth:
+ # Create a mask of ones
+ sparsification_mask = torch.ones_like(
+ depth_for_curr_view_input, device=device
+ )
+ # Create a mask for valid pixels (depth > 0)
+ valid_pixel_mask = depth_for_curr_view_input > 0
+ # Calculate the number of valid pixels
+ num_valid_pixels = valid_pixel_mask.sum().item()
+ # Calculate the number of valid pixels to set to zero
+ num_to_zero = int(
+ num_valid_pixels
+ * self.geometric_input_config["sparsification_removal_percent"]
+ )
+ if num_to_zero > 0:
+ # Get the indices of valid pixels
+ valid_indices = valid_pixel_mask.nonzero(as_tuple=True)
+ # Randomly select indices to zero out
+ indices_to_zero = torch.randperm(num_valid_pixels)[:num_to_zero]
+ # Set selected valid indices to zero in the mask
+ sparsification_mask[
+ valid_indices[0][indices_to_zero],
+ valid_indices[1][indices_to_zero],
+ valid_indices[2][indices_to_zero],
+ valid_indices[3][indices_to_zero],
+ ] = 0
+ # Apply the mask on the depth
+ depth_for_curr_view_input = (
+ depth_for_curr_view_input * sparsification_mask
+ )
+ # Normalize the depth
+ scaled_depth_for_curr_view_input, depth_norm_factor = (
+ normalize_depth_using_non_zero_pixels(
+ depth_for_curr_view_input, return_norm_factor=True
+ )
+ )
+ # Assign the depth and depth norm factor to the respective indices
+ depth_for_curr_view[per_sample_depth_input_mask_for_curr_view] = (
+ scaled_depth_for_curr_view_input
+ )
+ depth_norm_factor_for_curr_view[
+ per_sample_depth_input_mask_for_curr_view
+ ] = depth_norm_factor
+ else:
+ per_sample_depth_input_mask[
+ view_idx * batch_size_per_view : (view_idx + 1)
+ * batch_size_per_view
+ ] = False
+ # Append the depths, depth norm factor and metric scale mask for the current view
+ depth_list.append(depth_for_curr_view)
+ depth_norm_factors_list.append(depth_norm_factor_for_curr_view)
+ metric_scale_depth_mask_list.append(metric_scale_mask_for_curr_view)
+
+ # Stack the depths for all the views and permute to (B * V, C, H, W)
+ depths = torch.cat(depth_list, dim=0) # (B * V, H, W, 1)
+ depths = apply_log_to_norm(
+ depths
+ ) # Scale logarithimically (norm is computed along last dim)
+ depths = depths.permute(0, 3, 1, 2).contiguous() # (B * V, 1, H, W)
+ # Encode the depths using the depth encoder
+ depth_features_across_views = self.depth_encoder(
+ ViTEncoderNonImageInput(data=depths)
+ ).features
+ # Zero out the depth features where the depth input mask is False
+ depth_features_across_views = (
+ depth_features_across_views
+ * per_sample_depth_input_mask.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
+ )
+
+ # Stack the depth norm factors for all the views
+ depth_norm_factors = torch.cat(depth_norm_factors_list, dim=0) # (B * V, )
+ # Encode the depth norm factors using the log scale encoder for depth
+ log_depth_norm_factors = torch.log(depth_norm_factors + 1e-8) # (B * V, )
+ depth_scale_features_across_views = self.depth_scale_encoder(
+ EncoderGlobalRepInput(data=log_depth_norm_factors.unsqueeze(-1))
+ ).features
+ # Zero out the depth scale features where the depth input mask is False
+ depth_scale_features_across_views = (
+ depth_scale_features_across_views
+ * per_sample_depth_input_mask.unsqueeze(-1)
+ )
+ # Stack the metric scale mask for all the views
+ metric_scale_depth_mask = torch.cat(
+ metric_scale_depth_mask_list, dim=0
+ ) # (B * V, )
+ # Zero out the depth scale features where the metric scale mask is False
+ # Scale encoding is only provided for metric scale samples
+ depth_scale_features_across_views = (
+ depth_scale_features_across_views * metric_scale_depth_mask.unsqueeze(-1)
+ )
+
+ # Fuse the depth features & depth scale features with the other encoder features
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views
+ + depth_features_across_views
+ + depth_scale_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_cam_quats_and_trans(
+ self,
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ ):
+ """
+ Encode the camera quats and trans for all the views and fuse it with the other encoder features in a single forward pass.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ num_views (int): Number of views.
+ batch_size_per_view (int): Batch size per view.
+ all_encoder_features_across_views (torch.Tensor): Tensor containing the encoded features for all N views.
+ pose_quats_across_views (torch.Tensor): Tensor containing the pose quats for all the views in the frame of the reference view 0. (batch_size_per_view * view, 4)
+ pose_trans_across_views (torch.Tensor): Tensor containing the pose trans for all the views in the frame of the reference view 0. (batch_size_per_view * view, 3)
+ per_sample_cam_input_mask (torch.Tensor): Tensor containing the per sample camera input mask.
+
+ Returns:
+ torch.Tensor: A tensor containing the encoded features for all the views.
+ """
+ # Encode the pose quats
+ pose_quats_features_across_views = self.cam_rot_encoder(
+ EncoderGlobalRepInput(data=pose_quats_across_views)
+ ).features
+ # Zero out the pose quat features where the camera input mask is False
+ pose_quats_features_across_views = (
+ pose_quats_features_across_views * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+
+ # Get the metric scale mask for all samples
+ device = all_encoder_features_across_views.device
+ metric_scale_pose_trans_mask = torch.zeros(
+ (batch_size_per_view * num_views), dtype=torch.bool, device=device
+ )
+ for view_idx in range(num_views):
+ if "is_metric_scale" in views[view_idx]:
+ # Get the metric scale mask for the input pose priors
+ metric_scale_mask = views[view_idx]["is_metric_scale"]
+ else:
+ metric_scale_mask = torch.zeros(
+ batch_size_per_view, dtype=torch.bool, device=device
+ )
+ metric_scale_pose_trans_mask[
+ view_idx * batch_size_per_view : (view_idx + 1) * batch_size_per_view
+ ] = metric_scale_mask
+
+ # Turn off indication of metric scale samples based on the pose_scale_norm_all_prob
+ pose_norm_all_mask = (
+ torch.rand(batch_size_per_view * num_views)
+ < self.geometric_input_config["pose_scale_norm_all_prob"]
+ )
+ if pose_norm_all_mask.any():
+ metric_scale_pose_trans_mask[pose_norm_all_mask] = False
+
+ # Get the scale norm factor for all the samples and scale the pose translations
+ pose_trans_across_views = torch.split(
+ pose_trans_across_views, batch_size_per_view, dim=0
+ ) # Split into num_views chunks
+ pose_trans_across_views = torch.stack(
+ pose_trans_across_views, dim=1
+ ) # Stack the views along a new dimension (batch_size_per_view, num_views, 3)
+ scaled_pose_trans_across_views, pose_trans_norm_factors = (
+ normalize_pose_translations(
+ pose_trans_across_views, return_norm_factor=True
+ )
+ )
+
+ # Resize the pose translation back to (batch_size_per_view * num_views, 3) and extend the norm factor to (batch_size_per_view * num_views, 1)
+ scaled_pose_trans_across_views = scaled_pose_trans_across_views.unbind(
+ dim=1
+ ) # Convert back to list of views, where each view has batch_size_per_view tensor
+ scaled_pose_trans_across_views = torch.cat(
+ scaled_pose_trans_across_views, dim=0
+ ) # Concatenate back to (batch_size_per_view * num_views, 3)
+ pose_trans_norm_factors_across_views = pose_trans_norm_factors.unsqueeze(
+ -1
+ ).repeat(num_views, 1) # (B, ) -> (B * V, 1)
+
+ # Encode the pose trans
+ pose_trans_features_across_views = self.cam_trans_encoder(
+ EncoderGlobalRepInput(data=scaled_pose_trans_across_views)
+ ).features
+ # Zero out the pose trans features where the camera input mask is False
+ pose_trans_features_across_views = (
+ pose_trans_features_across_views * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+
+ # Encode the pose translation norm factors using the log scale encoder for pose trans
+ log_pose_trans_norm_factors_across_views = torch.log(
+ pose_trans_norm_factors_across_views + 1e-8
+ )
+ pose_trans_scale_features_across_views = self.cam_trans_scale_encoder(
+ EncoderGlobalRepInput(data=log_pose_trans_norm_factors_across_views)
+ ).features
+ # Zero out the pose trans scale features where the camera input mask is False
+ pose_trans_scale_features_across_views = (
+ pose_trans_scale_features_across_views
+ * per_sample_cam_input_mask.unsqueeze(-1)
+ )
+ # Zero out the pose trans scale features where the metric scale mask is False
+ # Scale encoding is only provided for metric scale samples
+ pose_trans_scale_features_across_views = (
+ pose_trans_scale_features_across_views
+ * metric_scale_pose_trans_mask.unsqueeze(-1)
+ )
+
+ # Fuse the pose quat features, pose trans features, pose trans scale features and pose trans type PE features with the other encoder features
+ all_encoder_features_across_views = (
+ all_encoder_features_across_views
+ + pose_quats_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ + pose_trans_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ + pose_trans_scale_features_across_views.unsqueeze(-1).unsqueeze(-1)
+ )
+
+ return all_encoder_features_across_views
+
+ def _encode_and_fuse_optional_geometric_inputs(
+ self, views, all_encoder_features_across_views_list
+ ):
+ """
+ Encode all the input optional geometric modalities and fuses it with the image encoder features in a single forward pass.
+ Assumes all the input views have the same shape and batch size.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ all_encoder_features_across_views (List[torch.Tensor]): List of tensors containing the encoded image features for all N views.
+
+ Returns:
+ List[torch.Tensor]: A list containing the encoded features for all N views.
+ """
+ num_views = len(views)
+ batch_size_per_view, _, _, _ = views[0]["img"].shape
+ device = all_encoder_features_across_views_list[0].device
+ dtype = all_encoder_features_across_views_list[0].dtype
+ all_encoder_features_across_views = torch.cat(
+ all_encoder_features_across_views_list, dim=0
+ )
+
+ # Get the overall input mask for all the views
+ overall_geometric_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["overall_prob"]
+ )
+ overall_geometric_input_mask = overall_geometric_input_mask.repeat(num_views)
+
+ # Get the per sample input mask after dropout
+ # Per sample input mask is in view-major order so that index v*B + b in each mask corresponds to sample b of view v: (B * V)
+ per_sample_geometric_input_mask = torch.rand(
+ batch_size_per_view * num_views, device=device
+ ) < (1 - self.geometric_input_config["dropout_prob"])
+ per_sample_geometric_input_mask = (
+ per_sample_geometric_input_mask & overall_geometric_input_mask
+ )
+
+ # Get the ray direction input mask
+ per_sample_ray_dirs_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["ray_dirs_prob"]
+ )
+ per_sample_ray_dirs_input_mask = per_sample_ray_dirs_input_mask.repeat(
+ num_views
+ )
+ per_sample_ray_dirs_input_mask = (
+ per_sample_ray_dirs_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Get the depth input mask
+ per_sample_depth_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["depth_prob"]
+ )
+ per_sample_depth_input_mask = per_sample_depth_input_mask.repeat(num_views)
+ per_sample_depth_input_mask = (
+ per_sample_depth_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Get the camera input mask
+ per_sample_cam_input_mask = (
+ torch.rand(batch_size_per_view, device=device)
+ < self.geometric_input_config["cam_prob"]
+ )
+ per_sample_cam_input_mask = per_sample_cam_input_mask.repeat(num_views)
+ per_sample_cam_input_mask = (
+ per_sample_cam_input_mask & per_sample_geometric_input_mask
+ )
+
+ # Compute the pose quats and trans for all the non-reference views in the frame of the reference view 0
+ # Returned pose quats and trans represent identity pose for views/samples where the camera input mask is False
+ pose_quats_across_views, pose_trans_across_views, per_sample_cam_input_mask = (
+ self._compute_pose_quats_and_trans_for_across_views_in_ref_view(
+ views,
+ num_views,
+ device,
+ dtype,
+ batch_size_per_view,
+ per_sample_cam_input_mask,
+ )
+ )
+
+ # Encode the ray directions and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_ray_dirs(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_ray_dirs_input_mask,
+ )
+
+ # Encode the depths and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_depths(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ per_sample_depth_input_mask,
+ )
+
+ # Encode the cam quat and trans and fuse with the image encoder features
+ all_encoder_features_across_views = self._encode_and_fuse_cam_quats_and_trans(
+ views,
+ num_views,
+ batch_size_per_view,
+ all_encoder_features_across_views,
+ pose_quats_across_views,
+ pose_trans_across_views,
+ per_sample_cam_input_mask,
+ )
+
+ # Normalize the fused features (permute -> normalize -> permute)
+ all_encoder_features_across_views = all_encoder_features_across_views.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ all_encoder_features_across_views = self.fusion_norm_layer(
+ all_encoder_features_across_views
+ )
+ all_encoder_features_across_views = all_encoder_features_across_views.permute(
+ 0, 3, 1, 2
+ ).contiguous()
+
+ # Split the batched views into individual views
+ fused_all_encoder_features_across_views = (
+ all_encoder_features_across_views.chunk(num_views, dim=0)
+ )
+
+ return fused_all_encoder_features_across_views
+
+ def _compute_adaptive_minibatch_size(
+ self,
+ memory_safety_factor: float = 0.95,
+ ) -> int:
+ """
+ Compute adaptive minibatch size based on available PyTorch memory.
+
+ Args:
+ memory_safety_factor: Safety factor to avoid OOM (0.95 = use 95% of available memory)
+
+ Returns:
+ Computed minibatch size
+ """
+ device = self.device
+
+ if device.type == "cuda":
+ # Get available GPU memory
+ torch.cuda.empty_cache()
+ available_memory = torch.cuda.mem_get_info()[0] # Free memory in bytes
+ usable_memory = (
+ available_memory * memory_safety_factor
+ ) # Use safety factor to avoid OOM
+ else:
+ # For non-CUDA devices, use conservative default
+ print(
+ "Non-CUDA device detected. Using conservative default minibatch size of 1 for memory efficient dense prediction head inference."
+ )
+ return 1
+
+ # Determine minibatch size based on available memory
+ max_estimated_memory_per_sample = (
+ 680 * 1024 * 1024
+ ) # 680 MB per sample (upper bound profiling using a 518 x 518 input)
+ computed_minibatch_size = int(usable_memory / max_estimated_memory_per_sample)
+ if computed_minibatch_size < 1:
+ computed_minibatch_size = 1
+
+ return computed_minibatch_size
+
+ def downstream_dense_head(
+ self,
+ dense_head_inputs: Union[torch.Tensor, List[torch.Tensor]],
+ img_shape: Tuple[int, int],
+ ):
+ """
+ Run the downstream dense prediction head
+ """
+ if self.pred_head_type == "linear":
+ dense_head_outputs = self.dense_head(
+ PredictionHeadInput(last_feature=dense_head_inputs)
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
+ dense_head_outputs = self.dense_head(
+ PredictionHeadLayeredInput(
+ list_features=dense_head_inputs,
+ target_output_shape=img_shape,
+ )
+ )
+ dense_final_outputs = self.dense_adaptor(
+ AdaptorInput(
+ adaptor_feature=dense_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ return dense_final_outputs
+
+ def downstream_head(
+ self,
+ dense_head_inputs: Union[torch.Tensor, List[torch.Tensor]],
+ scale_head_inputs: torch.Tensor,
+ img_shape: Tuple[int, int],
+ memory_efficient_inference: bool = False,
+ ):
+ """
+ Run Prediction Heads & Post-Process Outputs
+ """
+ # Get device
+ device = self.device
+
+ # Use mini-batch inference to run the dense prediction head (the memory bottleneck)
+ # This saves memory and is slower than running the dense prediction head in one go
+ if memory_efficient_inference:
+ # Obtain the batch size of the dense head inputs
+ if self.pred_head_type == "linear":
+ batch_size = dense_head_inputs.shape[0]
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
+ batch_size = dense_head_inputs[0].shape[0]
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ # Compute the mini batch size and number of mini batches adaptively based on available memory
+ minibatch = self._compute_adaptive_minibatch_size()
+ num_batches = (batch_size + minibatch - 1) // minibatch
+
+ # Run prediction for each mini-batch
+ dense_final_outputs_list = []
+ pose_final_outputs_list = [] if self.pred_head_type == "dpt+pose" else None
+ for batch_idx in range(num_batches):
+ start_idx = batch_idx * minibatch
+ end_idx = min((batch_idx + 1) * minibatch, batch_size)
+
+ # Get the inputs for the current mini-batch
+ if self.pred_head_type == "linear":
+ dense_head_inputs_batch = dense_head_inputs[start_idx:end_idx]
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
+ dense_head_inputs_batch = [
+ x[start_idx:end_idx] for x in dense_head_inputs
+ ]
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ # Dense prediction (mini-batched)
+ dense_final_outputs_batch = self.downstream_dense_head(
+ dense_head_inputs_batch, img_shape
+ )
+ dense_final_outputs_list.append(dense_final_outputs_batch)
+
+ # Pose prediction (mini-batched)
+ if self.pred_head_type == "dpt+pose":
+ pose_head_inputs_batch = dense_head_inputs[-1][start_idx:end_idx]
+ pose_head_outputs_batch = self.pose_head(
+ PredictionHeadInput(last_feature=pose_head_inputs_batch)
+ )
+ pose_final_outputs_batch = self.pose_adaptor(
+ AdaptorInput(
+ adaptor_feature=pose_head_outputs_batch.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ pose_final_outputs_list.append(pose_final_outputs_batch)
+
+ # Concatenate the dense prediction head outputs from all mini-batches
+ available_keys = dense_final_outputs_batch.__dict__.keys()
+ dense_pred_data_dict = {
+ key: torch.cat(
+ [getattr(output, key) for output in dense_final_outputs_list], dim=0
+ )
+ for key in available_keys
+ }
+ dense_final_outputs = dense_final_outputs_batch.__class__(
+ **dense_pred_data_dict
+ )
+
+ # Concatenate the pose prediction head outputs from all mini-batches
+ pose_final_outputs = None
+ if self.pred_head_type == "dpt+pose":
+ available_keys = pose_final_outputs_batch.__dict__.keys()
+ pose_pred_data_dict = {
+ key: torch.cat(
+ [getattr(output, key) for output in pose_final_outputs_list],
+ dim=0,
+ )
+ for key in available_keys
+ }
+ pose_final_outputs = pose_final_outputs_batch.__class__(
+ **pose_pred_data_dict
+ )
+
+ # Clear CUDA cache for better memory efficiency
+ if device.type == "cuda":
+ torch.cuda.empty_cache()
+ else:
+ # Run prediction for all (batch_size * num_views) in one go
+ # Dense prediction
+ dense_final_outputs = self.downstream_dense_head(
+ dense_head_inputs, img_shape
+ )
+
+ # Pose prediction
+ pose_final_outputs = None
+ if self.pred_head_type == "dpt+pose":
+ pose_head_outputs = self.pose_head(
+ PredictionHeadInput(last_feature=dense_head_inputs[-1])
+ )
+ pose_final_outputs = self.pose_adaptor(
+ AdaptorInput(
+ adaptor_feature=pose_head_outputs.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+
+ # Scale prediction is lightweight, so we can run it in one go
+ scale_head_output = self.scale_head(
+ PredictionHeadTokenInput(last_feature=scale_head_inputs)
+ )
+ scale_final_output = self.scale_adaptor(
+ AdaptorInput(
+ adaptor_feature=scale_head_output.decoded_channels,
+ output_shape_hw=img_shape,
+ )
+ )
+ scale_final_output = scale_final_output.value.squeeze(-1) # (B, 1, 1) -> (B, 1)
+
+ # Clear CUDA cache for better memory efficiency
+ if memory_efficient_inference and device.type == "cuda":
+ torch.cuda.empty_cache()
+
+ return dense_final_outputs, pose_final_outputs, scale_final_output
+
+ def forward(self, views, memory_efficient_inference=False):
+ """
+ Forward pass performing the following operations:
+ 1. Encodes the N input views (images).
+ 2. Encodes the optional geometric inputs (ray directions, depths, camera rotations, camera translations).
+ 3. Fuses the encoded features from the N input views and the optional geometric inputs using addition and normalization.
+ 4. Information sharing across the encoded features and a scale token using a multi-view attention transformer.
+ 5. Passes the final features from transformer through the prediction heads.
+ 6. Returns the processed final outputs for N views.
+
+ Assumption:
+ - All the input views and dense geometric inputs have the same image shape.
+
+ Args:
+ views (List[dict]): List of dictionaries containing the input views' images and instance information.
+ Each dictionary should contain the following keys:
+ "img" (tensor): Image tensor of shape (B, C, H, W). Input images must be normalized based on the data norm type of image encoder.
+ "data_norm_type" (list): [model.encoder.data_norm_type]
+ Optionally, each dictionary can also contain the following keys for the respective optional geometric inputs:
+ "ray_directions_cam" (tensor): Ray directions in the local camera frame. Tensor of shape (B, H, W, 3).
+ "depth_along_ray" (tensor): Depth along the ray. Tensor of shape (B, H, W, 1).
+ "camera_pose_quats" (tensor): Camera pose quaternions. Tensor of shape (B, 4). Camera pose is opencv (RDF) cam2world transformation.
+ "camera_pose_trans" (tensor): Camera pose translations. Tensor of shape (B, 3). Camera pose is opencv (RDF) cam2world transformation.
+ "is_metric_scale" (tensor): Boolean tensor indicating whether the geometric inputs are in metric scale or not. Tensor of shape (B, 1).
+ memory_efficient_inference (bool): Whether to use memory efficient inference or not. This runs the dense prediction head (the memory bottleneck) in a memory efficient manner. Default is False.
+
+ Returns:
+ List[dict]: A list containing the final outputs for all N views.
+ """
+ # Get input shape of the images, number of views, and batch size per view
+ batch_size_per_view, _, height, width = views[0]["img"].shape
+ img_shape = (int(height), int(width))
+ num_views = len(views)
+
+ # Run the image encoder on all the input views
+ all_encoder_features_across_views = self._encode_n_views(views)
+
+ # Encode the optional geometric inputs and fuse with the encoded features from the N input views
+ # Use high precision to prevent NaN values after layer norm in dense representation encoder (due to high variance in last dim of features)
+ with torch.autocast("cuda", enabled=False):
+ all_encoder_features_across_views = (
+ self._encode_and_fuse_optional_geometric_inputs(
+ views, all_encoder_features_across_views
+ )
+ )
+
+ # Expand the scale token to match the batch size
+ input_scale_token = (
+ self.scale_token.unsqueeze(0)
+ .unsqueeze(-1)
+ .repeat(batch_size_per_view, 1, 1)
+ ) # (B, C, 1)
+
+ # Combine all images into view-centric representation
+ # Output is a list containing the encoded features for all N views after information sharing.
+ info_sharing_input = MultiViewTransformerInput(
+ features=all_encoder_features_across_views,
+ additional_input_tokens=input_scale_token,
+ )
+ if self.info_sharing_return_type == "no_intermediate_features":
+ final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
+ elif self.info_sharing_return_type == "intermediate_features":
+ (
+ final_info_sharing_multi_view_feat,
+ intermediate_info_sharing_multi_view_feat,
+ ) = self.info_sharing(info_sharing_input)
+
+ if self.pred_head_type == "linear":
+ # Stack the features for all views
+ dense_head_inputs = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
+ # Get the list of features for all views
+ dense_head_inputs_list = []
+ if self.use_encoder_features_for_dpt:
+ # Stack all the image encoder features for all views
+ stacked_encoder_features = torch.cat(
+ all_encoder_features_across_views, dim=0
+ )
+ dense_head_inputs_list.append(stacked_encoder_features)
+ # Stack the first intermediate features for all views
+ stacked_intermediate_features_1 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[0].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_1)
+ # Stack the second intermediate features for all views
+ stacked_intermediate_features_2 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[1].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_2)
+ # Stack the last layer features for all views
+ stacked_final_features = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_final_features)
+ else:
+ # Stack the first intermediate features for all views
+ stacked_intermediate_features_1 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[0].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_1)
+ # Stack the second intermediate features for all views
+ stacked_intermediate_features_2 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[1].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_2)
+ # Stack the third intermediate features for all views
+ stacked_intermediate_features_3 = torch.cat(
+ intermediate_info_sharing_multi_view_feat[2].features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_intermediate_features_3)
+ # Stack the last layer
+ stacked_final_features = torch.cat(
+ final_info_sharing_multi_view_feat.features, dim=0
+ )
+ dense_head_inputs_list.append(stacked_final_features)
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt', 'dpt+pose']"
+ )
+
+ with torch.autocast("cuda", enabled=False):
+ # Prepare inputs for the downstream heads
+ if self.pred_head_type == "linear":
+ dense_head_inputs = dense_head_inputs
+ elif self.pred_head_type in ["dpt", "dpt+pose"]:
+ dense_head_inputs = dense_head_inputs_list
+ scale_head_inputs = (
+ final_info_sharing_multi_view_feat.additional_token_features
+ )
+
+ # Run the downstream heads
+ dense_final_outputs, pose_final_outputs, scale_final_output = (
+ self.downstream_head(
+ dense_head_inputs=dense_head_inputs,
+ scale_head_inputs=scale_head_inputs,
+ img_shape=img_shape,
+ memory_efficient_inference=memory_efficient_inference,
+ )
+ )
+
+ # Prepare the final scene representation for all views
+ if self.scene_rep_type in [
+ "pointmap",
+ "pointmap+confidence",
+ "pointmap+mask",
+ "pointmap+confidence+mask",
+ ]:
+ output_pts3d = dense_final_outputs.value
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_pts3d = output_pts3d.permute(0, 2, 3, 1).contiguous()
+ # Split the predicted pointmaps back to their respective views
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ elif self.scene_rep_type in [
+ "raymap+depth",
+ "raymap+depth+confidence",
+ "raymap+depth+mask",
+ "raymap+depth+confidence+mask",
+ ]:
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_scene_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted ray origins, directions, and depths along rays
+ output_ray_origins, output_ray_directions, output_depth_along_ray = (
+ output_scene_rep.split([3, 3, 1], dim=-1)
+ )
+ # Get the predicted pointmaps
+ output_pts3d = (
+ output_ray_origins + output_ray_directions * output_depth_along_ray
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_origins_per_view = output_ray_origins.chunk(num_views, dim=0)
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_origins": output_ray_origins_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ elif self.scene_rep_type in [
+ "raydirs+depth+pose",
+ "raydirs+depth+pose+confidence",
+ "raydirs+depth+pose+mask",
+ "raydirs+depth+pose+confidence+mask",
+ ]:
+ # Reshape output dense rep to (B * V, H, W, C)
+ output_dense_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted ray directions and depths along rays
+ output_ray_directions, output_depth_along_ray = output_dense_rep.split(
+ [3, 1], dim=-1
+ )
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the predicted pointmaps in world frame and camera frame
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ output_pts3d_cam = output_ray_directions * output_depth_along_ray
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "pts3d_cam": output_pts3d_cam_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "cam_trans": output_cam_translations_per_view[i]
+ * scale_final_output,
+ "cam_quats": output_cam_quats_per_view[i],
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ elif self.scene_rep_type in [
+ "campointmap+pose",
+ "campointmap+pose+confidence",
+ "campointmap+pose+mask",
+ "campointmap+pose+confidence+mask",
+ ]:
+ # Get the predicted camera frame pointmaps
+ output_pts3d_cam = dense_final_outputs.value
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_pts3d_cam = output_pts3d_cam.permute(0, 2, 3, 1).contiguous()
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the ray directions and depths along rays
+ output_depth_along_ray = torch.norm(
+ output_pts3d_cam, dim=-1, keepdim=True
+ )
+ output_ray_directions = output_pts3d_cam / output_depth_along_ray
+ # Get the predicted pointmaps in world frame
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "pts3d_cam": output_pts3d_cam_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "cam_trans": output_cam_translations_per_view[i]
+ * scale_final_output,
+ "cam_quats": output_cam_quats_per_view[i],
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ elif self.scene_rep_type in [
+ "pointmap+raydirs+depth+pose",
+ "pointmap+raydirs+depth+pose+confidence",
+ "pointmap+raydirs+depth+pose+mask",
+ "pointmap+raydirs+depth+pose+confidence+mask",
+ ]:
+ # Reshape final scene representation to (B * V, H, W, C)
+ output_dense_rep = dense_final_outputs.value.permute(
+ 0, 2, 3, 1
+ ).contiguous()
+ # Get the predicted pointmaps, ray directions and depths along rays
+ output_pts3d, output_ray_directions, output_depth_along_ray = (
+ output_dense_rep.split([3, 3, 1], dim=-1)
+ )
+ # Get the predicted camera translations and quaternions
+ output_cam_translations, output_cam_quats = (
+ pose_final_outputs.value.split([3, 4], dim=-1)
+ )
+ # Get the predicted pointmaps in camera frame
+ output_pts3d_cam = output_ray_directions * output_depth_along_ray
+ # Replace the predicted world-frame pointmaps if required
+ if self.pred_head_config["adaptor_config"][
+ "use_factored_predictions_for_global_pointmaps"
+ ]:
+ output_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ output_ray_directions,
+ output_depth_along_ray,
+ output_cam_translations,
+ output_cam_quats,
+ )
+ )
+ # Split the predicted quantities back to their respective views
+ output_ray_directions_per_view = output_ray_directions.chunk(
+ num_views, dim=0
+ )
+ output_depth_along_ray_per_view = output_depth_along_ray.chunk(
+ num_views, dim=0
+ )
+ output_cam_translations_per_view = output_cam_translations.chunk(
+ num_views, dim=0
+ )
+ output_cam_quats_per_view = output_cam_quats.chunk(num_views, dim=0)
+ output_pts3d_per_view = output_pts3d.chunk(num_views, dim=0)
+ output_pts3d_cam_per_view = output_pts3d_cam.chunk(num_views, dim=0)
+ # Pack the output as a list of dictionaries
+ res = []
+ for i in range(num_views):
+ res.append(
+ {
+ "pts3d": output_pts3d_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "pts3d_cam": output_pts3d_cam_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "ray_directions": output_ray_directions_per_view[i],
+ "depth_along_ray": output_depth_along_ray_per_view[i]
+ * scale_final_output.unsqueeze(-1).unsqueeze(-1),
+ "cam_trans": output_cam_translations_per_view[i]
+ * scale_final_output,
+ "cam_quats": output_cam_quats_per_view[i],
+ "metric_scaling_factor": scale_final_output,
+ }
+ )
+ else:
+ raise ValueError(
+ f"Invalid scene_rep_type: {self.scene_rep_type}. \
+ Valid options: ['pointmap', 'raymap+depth', 'raydirs+depth+pose', 'campointmap+pose', 'pointmap+raydirs+depth+pose' \
+ 'pointmap+confidence', 'raymap+depth+confidence', 'raydirs+depth+pose+confidence', 'campointmap+pose+confidence', 'pointmap+raydirs+depth+pose+confidence' \
+ 'pointmap+mask', 'raymap+depth+mask', 'raydirs+depth+pose+mask', 'campointmap+pose+mask', 'pointmap+raydirs+depth+pose+mask' \
+ 'pointmap+confidence+mask', 'raymap+depth+confidence+mask', 'raydirs+depth+pose+confidence+mask', 'campointmap+pose+confidence+mask', 'pointmap+raydirs+depth+pose+confidence+mask']"
+ )
+
+ # Get the output confidences for all views (if available) and add them to the result
+ if "confidence" in self.scene_rep_type:
+ output_confidences = dense_final_outputs.confidence
+ # Reshape confidences to (B * V, H, W)
+ output_confidences = (
+ output_confidences.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ # Split the predicted confidences back to their respective views
+ output_confidences_per_view = output_confidences.chunk(num_views, dim=0)
+ # Add the confidences to the result
+ for i in range(num_views):
+ res[i]["conf"] = output_confidences_per_view[i]
+
+ # Get the output masks (and logits) for all views (if available) and add them to the result
+ if "mask" in self.scene_rep_type:
+ # Get the output masks
+ output_masks = dense_final_outputs.mask
+ # Reshape masks to (B * V, H, W)
+ output_masks = output_masks.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ # Threshold the masks at 0.5 to get binary masks (0: ambiguous, 1: non-ambiguous)
+ output_masks = output_masks > 0.5
+ # Split the predicted masks back to their respective views
+ output_masks_per_view = output_masks.chunk(num_views, dim=0)
+ # Get the output mask logits (for loss)
+ output_mask_logits = dense_final_outputs.logits
+ # Reshape mask logits to (B * V, H, W)
+ output_mask_logits = (
+ output_mask_logits.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ # Split the predicted mask logits back to their respective views
+ output_mask_logits_per_view = output_mask_logits.chunk(num_views, dim=0)
+ # Add the masks and logits to the result
+ for i in range(num_views):
+ res[i]["non_ambiguous_mask"] = output_masks_per_view[i]
+ res[i]["non_ambiguous_mask_logits"] = output_mask_logits_per_view[i]
+
+ return res
+
+ def _configure_geometric_input_config(
+ self,
+ use_calibration: bool,
+ use_depth: bool,
+ use_pose: bool,
+ use_depth_scale: bool,
+ use_pose_scale: bool,
+ ):
+ """
+ Configure the geometric input configuration
+ """
+ # Store original config for restoration
+ if not hasattr(self, "_original_geometric_config"):
+ self._original_geometric_config = dict(self.geometric_input_config)
+
+ # Set the geometric input configuration
+ if not (use_calibration or use_depth or use_pose):
+ # No geometric inputs (images-only mode)
+ self.geometric_input_config.update(
+ {
+ "overall_prob": 0.0,
+ "dropout_prob": 1.0,
+ "ray_dirs_prob": 0.0,
+ "depth_prob": 0.0,
+ "cam_prob": 0.0,
+ "sparse_depth_prob": 0.0,
+ "depth_scale_norm_all_prob": 0.0,
+ "pose_scale_norm_all_prob": 0.0,
+ }
+ )
+ else:
+ # Enable geometric inputs with deterministic behavior
+ self.geometric_input_config.update(
+ {
+ "overall_prob": 1.0,
+ "dropout_prob": 0.0,
+ "ray_dirs_prob": 1.0 if use_calibration else 0.0,
+ "depth_prob": 1.0 if use_depth else 0.0,
+ "cam_prob": 1.0 if use_pose else 0.0,
+ "sparse_depth_prob": 0.0,
+ "depth_scale_norm_all_prob": 0.0 if use_depth_scale else 1.0,
+ "pose_scale_norm_all_prob": 0.0 if use_pose_scale else 1.0,
+ }
+ )
+
+ def _restore_original_geometric_input_config(self):
+ """
+ Restore original geometric input configuration
+ """
+ if hasattr(self, "_original_geometric_config"):
+ self.geometric_input_config.update(self._original_geometric_config)
+
+ @torch.inference_mode()
+ def infer(
+ self,
+ views: List[Dict[str, Any]],
+ memory_efficient_inference: bool = False,
+ use_amp: bool = True,
+ amp_dtype: str = "bf16",
+ apply_mask: bool = True,
+ mask_edges: bool = True,
+ edge_normal_threshold: float = 5.0,
+ edge_depth_threshold: float = 0.03,
+ apply_confidence_mask: bool = False,
+ confidence_percentile: float = 10,
+ ignore_calibration_inputs: bool = False,
+ ignore_depth_inputs: bool = False,
+ ignore_pose_inputs: bool = False,
+ ignore_depth_scale_inputs: bool = False,
+ ignore_pose_scale_inputs: bool = False,
+ ) -> List[Dict[str, torch.Tensor]]:
+ """
+ User-friendly inference with strict input validation and automatic conversion.
+
+ Args:
+ views: List of view dictionaries. Each dict can contain:
+ Required:
+ - 'img': torch.Tensor of shape (B, 3, H, W) - normalized RGB images
+ - 'data_norm_type': str - normalization type used to normalize the images (must be equal to self.model.encoder.data_norm_type)
+
+ Optional Geometric Inputs (only one of intrinsics OR ray_directions):
+ - 'intrinsics': torch.Tensor of shape (B, 3, 3) - will be converted to ray directions
+ - 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame
+ - 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame (intrinsics or ray_directions must be provided)
+ - 'camera_poses': torch.Tensor of shape (B, 4, 4) or tuple of (quats - (B, 4), trans - (B, 3)) - can be any world frame
+ - 'is_metric_scale': bool or torch.Tensor of shape (B,) - if not provided, defaults to True
+
+ Optional Additional Info:
+ - 'instance': List[str] where length of list is B - instance info for each view
+ - 'idx': List[int] where length of list is B - index info for each view
+ - 'true_shape': List[tuple] where length of list is B - true shape info (H, W) for each view
+
+ memory_efficient_inference: Whether to use memory-efficient inference for dense prediction heads (trades off speed). Defaults to False.
+ use_amp: Whether to use automatic mixed precision for faster inference. Defaults to True.
+ amp_dtype: The dtype to use for mixed precision. Defaults to "bf16" (bfloat16). Options: "fp16", "bf16", "fp32".
+ apply_mask: Whether to apply the non-ambiguous mask to the output. Defaults to True.
+ mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
+ edge_normal_threshold: Tolerance threshold for normals-based edge detection. Defaults to 5.0.
+ edge_depth_threshold: Relative tolerance threshold for depth-based edge detection. Defaults to 0.03.
+ apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False.
+ confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10.
+ ignore_calibration_inputs: Whether to ignore the calibration inputs (intrinsics and ray_directions). Defaults to False.
+ ignore_depth_inputs: Whether to ignore the depth inputs. Defaults to False.
+ ignore_pose_inputs: Whether to ignore the pose inputs. Defaults to False.
+ ignore_depth_scale_inputs: Whether to ignore the depth scale inputs. Defaults to False.
+ ignore_pose_scale_inputs: Whether to ignore the pose scale inputs. Defaults to False.
+
+ IMPORTANT CONSTRAINTS:
+ - Cannot provide both 'intrinsics' and 'ray_directions' (they represent the same information)
+ - If 'depth' is provided, then 'intrinsics' or 'ray_directions' must also be provided
+ - If ANY view has 'camera_poses', then view 0 (first view) MUST also have 'camera_poses'
+
+ Returns:
+ List of prediction dictionaries, one per view. Each dict contains:
+ - 'img_no_norm': torch.Tensor of shape (B, H, W, 3) - denormalized rgb images
+ - 'pts3d': torch.Tensor of shape (B, H, W, 3) - predicted points in world frame
+ - 'pts3d_cam': torch.Tensor of shape (B, H, W, 3) - predicted points in camera frame
+ - 'ray_directions': torch.Tensor of shape (B, H, W, 3) - ray directions in camera frame
+ - 'intrinsics': torch.Tensor of shape (B, 3, 3) - pinhole camera intrinsics recovered from ray directions
+ - 'depth_along_ray': torch.Tensor of shape (B, H, W, 1) - depth along ray in camera frame
+ - 'depth_z': torch.Tensor of shape (B, H, W, 1) - Z depth in camera frame
+ - 'cam_trans': torch.Tensor of shape (B, 3) - camera translation in world frame
+ - 'cam_quats': torch.Tensor of shape (B, 4) - camera quaternion in world frame
+ - 'camera_poses': torch.Tensor of shape (B, 4, 4) - camera pose in world frame
+ - 'metric_scaling_factor': torch.Tensor of shape (B,) - applied metric scaling factor
+ - 'mask': torch.Tensor of shape (B, H, W, 1) - combo of non-ambiguous mask, edge mask and confidence-based mask if used
+ - 'non_ambiguous_mask': torch.Tensor of shape (B, H, W) - non-ambiguous mask
+ - 'non_ambiguous_mask_logits': torch.Tensor of shape (B, H, W) - non-ambiguous mask logits
+ - 'conf': torch.Tensor of shape (B, H, W) - confidence
+
+ Raises:
+ ValueError: For invalid inputs, missing required keys, conflicting modalities, or constraint violations
+ """
+ # Determine the mixed precision floating point type
+ if use_amp:
+ if amp_dtype == "fp16":
+ amp_dtype = torch.float16
+ elif amp_dtype == "bf16":
+ if torch.cuda.is_bf16_supported():
+ amp_dtype = torch.bfloat16
+ else:
+ warnings.warn(
+ "bf16 is not supported on this device. Using fp16 instead."
+ )
+ amp_dtype = torch.float16
+ elif amp_dtype == "fp32":
+ amp_dtype = torch.float32
+ else:
+ amp_dtype = torch.float32
+
+ # Validate the input views
+ validated_views = validate_input_views_for_inference(views)
+
+ # Transfer the views to the same device as the model
+ ignore_keys = set(
+ [
+ "instance",
+ "idx",
+ "true_shape",
+ "data_norm_type",
+ ]
+ )
+ for view in validated_views:
+ for name in view.keys():
+ if name in ignore_keys:
+ continue
+ view[name] = view[name].to(self.device, non_blocking=True)
+
+ # Pre-process the input views
+ processed_views = preprocess_input_views_for_inference(validated_views)
+
+ # Set the model input probabilities based on input args for ignoring inputs
+ self._configure_geometric_input_config(
+ use_calibration=not ignore_calibration_inputs,
+ use_depth=not ignore_depth_inputs,
+ use_pose=not ignore_pose_inputs,
+ use_depth_scale=not ignore_depth_scale_inputs,
+ use_pose_scale=not ignore_pose_scale_inputs,
+ )
+
+ # Run the model
+ with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
+ preds = self.forward(
+ processed_views, memory_efficient_inference=memory_efficient_inference
+ )
+
+ # Post-process the model outputs
+ preds = postprocess_model_outputs_for_inference(
+ raw_outputs=preds,
+ input_views=processed_views,
+ apply_mask=apply_mask,
+ mask_edges=mask_edges,
+ edge_normal_threshold=edge_normal_threshold,
+ edge_depth_threshold=edge_depth_threshold,
+ apply_confidence_mask=apply_confidence_mask,
+ confidence_percentile=confidence_percentile,
+ )
+
+ # Restore the original configuration
+ self._restore_original_geometric_input_config()
+
+ return preds
diff --git a/mapanything/models/mapanything/modular_dust3r.py b/mapanything/models/mapanything/modular_dust3r.py
new file mode 100644
index 0000000000000000000000000000000000000000..00914e7830eaf9e11c91a9d0384461fb3536ae98
--- /dev/null
+++ b/mapanything/models/mapanything/modular_dust3r.py
@@ -0,0 +1,470 @@
+"""
+Modular DUSt3R class defined using UniCeption modules.
+"""
+
+from typing import Callable, Dict
+
+import torch
+import torch.nn as nn
+
+from uniception.models.encoders import encoder_factory, ViTEncoderInput
+from uniception.models.info_sharing.alternating_attention_transformer import (
+ MultiViewAlternatingAttentionTransformer,
+ MultiViewAlternatingAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.base import MultiViewTransformerInput
+from uniception.models.info_sharing.cross_attention_transformer import (
+ MultiViewCrossAttentionTransformer,
+ MultiViewCrossAttentionTransformerIFR,
+)
+from uniception.models.info_sharing.global_attention_transformer import (
+ MultiViewGlobalAttentionTransformer,
+ MultiViewGlobalAttentionTransformerIFR,
+)
+from uniception.models.libs.croco.pos_embed import RoPE2D
+from uniception.models.prediction_heads.adaptors import PointMapWithConfidenceAdaptor
+from uniception.models.prediction_heads.base import (
+ AdaptorInput,
+ PredictionHeadInput,
+ PredictionHeadLayeredInput,
+)
+from uniception.models.prediction_heads.dpt import DPTFeature, DPTRegressionProcessor
+from uniception.models.prediction_heads.linear import LinearFeature
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+class ModularDUSt3R(nn.Module):
+ "Modular DUSt3R model class."
+
+ def __init__(
+ self,
+ name: str,
+ encoder_config: Dict,
+ info_sharing_config: Dict,
+ pred_head_config: Dict,
+ pretrained_checkpoint_path: str = None,
+ load_specific_pretrained_submodules: bool = False,
+ specific_pretrained_submodules: list = [],
+ torch_hub_force_reload: bool = False,
+ *args,
+ **kwargs,
+ ):
+ """
+ Two-view model containing siamese encoders followed by a two-view attention transformer and respective downstream heads.
+ The goal is to output scene representation directly, both outputs in view1's frame (hence the asymmetry).
+
+ Args:
+ name (str): Name of the model.
+ encoder_config (Dict): Configuration for the encoder.
+ info_sharing_config (Dict): Configuration for the two-view attention transformer.
+ pred_head_config (Dict): Configuration for the prediction heads.
+ pretrained_checkpoint_path (str): Path to pretrained checkpoint. (default: None)
+ load_specific_pretrained_submodules (bool): Whether to load specific pretrained submodules. (default: False)
+ specific_pretrained_submodules (list): List of specific pretrained submodules to load. Must be provided when load_specific_pretrained_submodules is True. (default: [])
+ torch_hub_force_reload (bool): Whether to force reload the encoder from torch hub. (default: False)
+ """
+ super().__init__(*args, **kwargs)
+
+ # Initalize the attributes
+ self.name = name
+ self.encoder_config = encoder_config
+ self.info_sharing_config = info_sharing_config
+ self.pred_head_config = pred_head_config
+ self.pretrained_checkpoint_path = pretrained_checkpoint_path
+ self.load_specific_pretrained_submodules = load_specific_pretrained_submodules
+ self.specific_pretrained_submodules = specific_pretrained_submodules
+ self.torch_hub_force_reload = torch_hub_force_reload
+ self.class_init_args = {
+ "name": self.name,
+ "encoder_config": self.encoder_config,
+ "info_sharing_config": self.info_sharing_config,
+ "pred_head_config": self.pred_head_config,
+ "pretrained_checkpoint_path": self.pretrained_checkpoint_path,
+ "load_specific_pretrained_submodules": self.load_specific_pretrained_submodules,
+ "specific_pretrained_submodules": self.specific_pretrained_submodules,
+ "torch_hub_force_reload": self.torch_hub_force_reload,
+ }
+
+ # Get relevant parameters from the configs
+ custom_positional_encoding = info_sharing_config["custom_positional_encoding"]
+ self.info_sharing_type = info_sharing_config["model_type"]
+ self.info_sharing_return_type = info_sharing_config["model_return_type"]
+ self.pred_head_type = pred_head_config["type"]
+
+ # Initialize Encoder
+ if self.encoder_config["uses_torch_hub"]:
+ self.encoder_config["torch_hub_force_reload"] = torch_hub_force_reload
+ # Create a copy of the config before deleting the key to preserve it for serialization
+ encoder_config_copy = self.encoder_config.copy()
+ del encoder_config_copy["uses_torch_hub"]
+ self.encoder = encoder_factory(**encoder_config_copy)
+
+ # Initialize Custom Positional Encoding if required
+ if custom_positional_encoding is not None:
+ if isinstance(custom_positional_encoding, str):
+ print(
+ f"Using custom positional encoding for multi-view cross attention transformer: {custom_positional_encoding}"
+ )
+ if custom_positional_encoding.startswith("RoPE"):
+ rope_freq = float(custom_positional_encoding[len("RoPE") :])
+ print(f"RoPE frequency: {rope_freq}")
+ self.custom_positional_encoding = RoPE2D(freq=rope_freq)
+ else:
+ raise ValueError(
+ f"Invalid custom_positional_encoding: {custom_positional_encoding}."
+ )
+ elif isinstance(custom_positional_encoding, Callable):
+ print(
+ "Using callable function as custom positional encoding for multi-view cross attention transformer."
+ )
+ self.custom_positional_encoding = custom_positional_encoding
+ else:
+ self.custom_positional_encoding = None
+
+ # Add dependecies to info_sharing_config
+ info_sharing_config["module_args"]["input_embed_dim"] = (
+ self.encoder.enc_embed_dim
+ )
+ info_sharing_config["module_args"]["custom_positional_encoding"] = (
+ self.custom_positional_encoding
+ )
+
+ # Initialize Multi-View Transformer
+ if self.info_sharing_return_type == "no_intermediate_features":
+ # Returns only normalized last layer features
+ # Intialize multi-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformer(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ elif self.info_sharing_return_type == "intermediate_features":
+ # Returns intermediate features and normalized last layer features
+ # Initialize mulit-view transformer based on type
+ if self.info_sharing_type == "cross_attention":
+ self.info_sharing = MultiViewCrossAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "global_attention":
+ self.info_sharing = MultiViewGlobalAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ elif self.info_sharing_type == "alternating_attention":
+ self.info_sharing = MultiViewAlternatingAttentionTransformerIFR(
+ **info_sharing_config["module_args"]
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_type: {self.info_sharing_type}. Valid options: ['cross_attention', 'global_attention', 'alternating_attention']"
+ )
+ # Assess if the DPT needs to use encoder features
+ if len(self.info_sharing.indices) == 2:
+ self.use_encoder_features_for_dpt = True
+ elif len(self.info_sharing.indices) == 3:
+ self.use_encoder_features_for_dpt = False
+ else:
+ raise ValueError(
+ "Invalid number of indices provided for info sharing feature returner. Please provide 2 or 3 indices."
+ )
+ else:
+ raise ValueError(
+ f"Invalid info_sharing_return_type: {self.info_sharing_return_type}. Valid options: ['no_intermediate_features', 'intermediate_features']"
+ )
+
+ # Add dependencies to prediction head config
+ pred_head_config["feature_head"]["patch_size"] = self.encoder.patch_size
+ if self.pred_head_type == "linear":
+ pred_head_config["feature_head"]["input_feature_dim"] = (
+ self.info_sharing.dim
+ )
+ elif self.pred_head_type == "dpt":
+ if self.use_encoder_features_for_dpt:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.encoder.enc_embed_dim
+ ] + [self.info_sharing.dim] * 3
+ else:
+ pred_head_config["feature_head"]["input_feature_dims"] = [
+ self.info_sharing.dim
+ ] * 4
+ pred_head_config["regressor_head"]["input_feature_dim"] = pred_head_config[
+ "feature_head"
+ ]["feature_dim"]
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt']"
+ )
+
+ # Initialize Prediction Heads
+ if self.pred_head_type == "linear":
+ # Initialize Prediction Head 1
+ self.head1 = LinearFeature(**pred_head_config["feature_head"])
+ # Initialize Prediction Head 2
+ self.head2 = LinearFeature(**pred_head_config["feature_head"])
+ elif self.pred_head_type == "dpt":
+ # Initialze Predction Head 1
+ self.dpt_feature_head1 = DPTFeature(**pred_head_config["feature_head"])
+ self.dpt_regressor_head1 = DPTRegressionProcessor(
+ **pred_head_config["regressor_head"]
+ )
+ self.head1 = nn.Sequential(self.dpt_feature_head1, self.dpt_regressor_head1)
+ # Initialize Prediction Head 2
+ self.dpt_feature_head2 = DPTFeature(**pred_head_config["feature_head"])
+ self.dpt_regressor_head2 = DPTRegressionProcessor(
+ **pred_head_config["regressor_head"]
+ )
+ self.head2 = nn.Sequential(self.dpt_feature_head2, self.dpt_regressor_head2)
+ else:
+ raise ValueError(
+ f"Invalid pred_head_type: {self.pred_head_type}. Valid options: ['linear', 'dpt']"
+ )
+
+ # Initialize Final Output Adaptor
+ if pred_head_config["adaptor_type"] == "pointmap+confidence":
+ self.adaptor = PointMapWithConfidenceAdaptor(**pred_head_config["adaptor"])
+ self.scene_rep_type = "pointmap"
+ else:
+ raise ValueError(
+ f"Invalid adaptor_type: {pred_head_config['adaptor_type']}. Valid options: ['pointmap+confidence']"
+ )
+
+ # Load pretrained weights
+ if self.pretrained_checkpoint_path is not None:
+ if not self.load_specific_pretrained_submodules:
+ print(
+ f"Loading pretrained weights from {self.pretrained_checkpoint_path} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ print(self.load_state_dict(ckpt["model"]))
+ else:
+ print(
+ f"Loading pretrained weights from {self.pretrained_checkpoint_path} for specific submodules: {specific_pretrained_submodules} ..."
+ )
+ ckpt = torch.load(self.pretrained_checkpoint_path, weights_only=False)
+ filtered_ckpt = {}
+ for ckpt_key, ckpt_value in ckpt["model"].items():
+ for submodule in specific_pretrained_submodules:
+ if ckpt_key.startswith(submodule):
+ filtered_ckpt[ckpt_key] = ckpt_value
+ print(self.load_state_dict(filtered_ckpt, strict=False))
+
+ def _encode_image_pairs(self, img1, img2, data_norm_type):
+ "Encode two different batches of images (each batch can have different image shape)"
+ if img1.shape[-2:] == img2.shape[-2:]:
+ encoder_input = ViTEncoderInput(
+ image=torch.cat((img1, img2), dim=0), data_norm_type=data_norm_type
+ )
+ encoder_output = self.encoder(encoder_input)
+ out, out2 = encoder_output.features.chunk(2, dim=0)
+ else:
+ encoder_input = ViTEncoderInput(image=img1, data_norm_type=data_norm_type)
+ out = self.encoder(encoder_input)
+ out = out.features
+ encoder_input2 = ViTEncoderInput(image=img2, data_norm_type=data_norm_type)
+ out2 = self.encoder(encoder_input2)
+ out2 = out2.features
+
+ return out, out2
+
+ def _encode_symmetrized(self, view1, view2):
+ "Encode image pairs accounting for symmetrization, i.e., (a, b) and (b, a) always exist in the input"
+ img1 = view1["img"]
+ img2 = view2["img"]
+ if isinstance(view1["data_norm_type"], list):
+ assert all(
+ [x == view1["data_norm_type"][0] for x in view1["data_norm_type"]]
+ ), "All data_norm_type values should be the same in the list."
+ data_norm_type = view1["data_norm_type"][0]
+ elif isinstance(view1["data_norm_type"], str):
+ data_norm_type = view1["data_norm_type"]
+ else:
+ raise ValueError(
+ f"Invalid data_norm_type: {view1['data_norm_type']}. Should be either a list with all same values or a string."
+ )
+ feat1, feat2 = self._encode_image_pairs(
+ img1, img2, data_norm_type=data_norm_type
+ )
+
+ return feat1, feat2
+
+ def _downstream_head(self, head_num, decout, img_shape):
+ "Run the respective prediction heads"
+ head = getattr(self, f"head{head_num}")
+ if self.pred_head_type == "linear":
+ head_input = PredictionHeadInput(last_feature=decout[f"{head_num}"])
+ elif self.pred_head_type == "dpt":
+ head_input = PredictionHeadLayeredInput(
+ list_features=decout[f"{head_num}"], target_output_shape=img_shape
+ )
+
+ return head(head_input)
+
+ def forward(self, views):
+ """
+ Forward pass performing the following operations:
+ 1. Encodes the two input views (images).
+ 2. Combines the encoded features using a two-view attention transformer.
+ 3. Passes the combined features through the respective prediction heads.
+ 4. Returns the processed final outputs for both views.
+
+ Args:
+ views (List(dict)): A list of size two whose elements are:
+ view1 (dict): Dictionary containing the first view's images and instance information.
+ "img" is a required key and value is a tensor of shape (B, C, H, W).
+ view2 (dict): Dictionary containing the second view's images and instance information.
+ "img" is a required key and value is a tensor of shape (B, C, H, W).
+
+ Returns:
+ List[dict, dict]: A list containing the final outputs for both views.
+ """
+ # Get input shapes
+ view1 = views[0]
+ view2 = views[1]
+ _, _, height1, width1 = view1["img"].shape
+ _, _, height2, width2 = view2["img"].shape
+ shape1 = (int(height1), int(width1))
+ shape2 = (int(height2), int(width2))
+
+ if "img_encoder_feats" in view1 and "img_encoder_feats" in view2:
+ # Reuse the pre-computed image features for the two views
+ feat1 = view1["img_encoder_feats"]
+ feat2 = view2["img_encoder_feats"]
+ else:
+ # Encode the two images --> Each feat output: BCHW features (batch_size, feature_dim, feature_height, feature_width)
+ feat1, feat2 = self._encode_symmetrized(view1, view2)
+
+ # Combine all images into view-centric representation
+ info_sharing_input = MultiViewTransformerInput(features=[feat1, feat2])
+ if self.info_sharing_return_type == "no_intermediate_features":
+ final_info_sharing_multi_view_feat = self.info_sharing(info_sharing_input)
+ elif self.info_sharing_return_type == "intermediate_features":
+ (
+ final_info_sharing_multi_view_feat,
+ intermediate_info_sharing_multi_view_feat,
+ ) = self.info_sharing(info_sharing_input)
+
+ if self.pred_head_type == "linear":
+ # Define feature dictionary for linear head
+ info_sharing_outputs = {
+ "1": final_info_sharing_multi_view_feat.features[0].float(),
+ "2": final_info_sharing_multi_view_feat.features[1].float(),
+ }
+ elif self.pred_head_type == "dpt":
+ # Define feature dictionary for DPT head
+ if self.use_encoder_features_for_dpt:
+ info_sharing_outputs = {
+ "1": [
+ feat1.float(),
+ intermediate_info_sharing_multi_view_feat[0]
+ .features[0]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[1]
+ .features[0]
+ .float(),
+ final_info_sharing_multi_view_feat.features[0].float(),
+ ],
+ "2": [
+ feat2.float(),
+ intermediate_info_sharing_multi_view_feat[0]
+ .features[1]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[1]
+ .features[1]
+ .float(),
+ final_info_sharing_multi_view_feat.features[1].float(),
+ ],
+ }
+ else:
+ info_sharing_outputs = {
+ "1": [
+ intermediate_info_sharing_multi_view_feat[0]
+ .features[0]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[1]
+ .features[0]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[2]
+ .features[0]
+ .float(),
+ final_info_sharing_multi_view_feat.features[0].float(),
+ ],
+ "2": [
+ intermediate_info_sharing_multi_view_feat[0]
+ .features[1]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[1]
+ .features[1]
+ .float(),
+ intermediate_info_sharing_multi_view_feat[2]
+ .features[1]
+ .float(),
+ final_info_sharing_multi_view_feat.features[1].float(),
+ ],
+ }
+
+ # Downstream task prediction
+ with torch.autocast("cuda", enabled=False):
+ # Prediction heads
+ head_output1 = self._downstream_head(1, info_sharing_outputs, shape1)
+ head_output2 = self._downstream_head(2, info_sharing_outputs, shape2)
+
+ # Post-process outputs
+ final_output1 = self.adaptor(
+ AdaptorInput(
+ adaptor_feature=head_output1.decoded_channels,
+ output_shape_hw=shape1,
+ )
+ )
+ final_output2 = self.adaptor(
+ AdaptorInput(
+ adaptor_feature=head_output2.decoded_channels,
+ output_shape_hw=shape2,
+ )
+ )
+
+ # Reshape final scene representation to (B, H, W, C)
+ final_scene_rep1 = final_output1.value.permute(0, 2, 3, 1).contiguous()
+ final_scene_rep2 = final_output2.value.permute(0, 2, 3, 1).contiguous()
+
+ # Convert output scene representation to pointmaps
+ if self.scene_rep_type == "pointmap":
+ output_pts3d1 = final_scene_rep1
+ output_pts3d2 = final_scene_rep2
+ else:
+ raise ValueError(f"Invalid scene_rep_type: {self.scene_rep_type}.")
+
+ # Reshape confidence to (B, H, W, 1)
+ output_conf1 = (
+ final_output1.confidence.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+ output_conf2 = (
+ final_output2.confidence.permute(0, 2, 3, 1).squeeze(-1).contiguous()
+ )
+
+ # Convert outputs to dictionary
+ res1 = {
+ "pts3d": output_pts3d1,
+ "conf": output_conf1,
+ }
+ res2 = {
+ "pts3d": output_pts3d2,
+ "conf": output_conf2,
+ }
+ res = [res1, res2]
+
+ return res
diff --git a/mapanything/train/__init__.py b/mapanything/train/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/train/losses.py b/mapanything/train/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..097b5684fe006969287fdb89d8b6f020a50f1fab
--- /dev/null
+++ b/mapanything/train/losses.py
@@ -0,0 +1,5060 @@
+"""
+Multi-view geometric losses for training 3D reconstruction models.
+
+References: DUSt3R & MASt3R
+"""
+
+import math
+from copy import copy, deepcopy
+
+import einops as ein
+import torch
+import torch.nn as nn
+
+from mapanything.utils.geometry import (
+ angle_diff_vec3,
+ apply_log_to_norm,
+ closed_form_pose_inverse,
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap,
+ geotrf,
+ normalize_multiple_pointclouds,
+ quaternion_inverse,
+ quaternion_multiply,
+ quaternion_to_rotation_matrix,
+ transform_pose_using_quats_and_trans_2_to_1,
+)
+
+
+def get_loss_terms_and_details(
+ losses_dict, valid_masks, self_name, n_views, flatten_across_image_only
+):
+ """
+ Helper function to generate loss terms and details for different loss types.
+
+ Args:
+ losses_dict (dict): Dictionary mapping loss types to their values.
+ Format: {
+ 'loss_type': {
+ 'values': list_of_loss_tensors or single_tensor,
+ 'use_mask': bool,
+ 'is_multi_view': bool
+ }
+ }
+ valid_masks (list): List of valid masks for each view.
+ self_name (str): Name of the loss class.
+ n_views (int): Number of views.
+ flatten_across_image_only (bool): Whether flattening was done across image only.
+
+ Returns:
+ tuple: (loss_terms, details) where loss_terms is a list of tuples (loss, mask, type)
+ and details is a dictionary of loss details.
+ """
+ loss_terms = []
+ details = {}
+
+ for loss_type, loss_info in losses_dict.items():
+ values = loss_info["values"]
+ use_mask = loss_info["use_mask"]
+ is_multi_view = loss_info["is_multi_view"]
+ if is_multi_view:
+ # Handle multi-view losses (list of tensors)
+ view_loss_details = []
+ for i in range(n_views):
+ mask = valid_masks[i] if use_mask else None
+ loss_terms.append((values[i], mask, loss_type))
+
+ # Add details for individual view
+ if not flatten_across_image_only or not use_mask:
+ values_after_masking = values[i]
+ else:
+ values_after_masking = values[i][mask]
+
+ if values_after_masking.numel() > 0:
+ view_loss_detail = float(values_after_masking.mean())
+ if view_loss_detail > 0:
+ details[f"{self_name}_{loss_type}_view{i + 1}"] = (
+ view_loss_detail
+ )
+ view_loss_details.append(view_loss_detail)
+ # Add average across views
+ if len(view_loss_details) > 0:
+ details[f"{self_name}_{loss_type}_avg"] = sum(view_loss_details) / len(
+ view_loss_details
+ )
+ else:
+ # Handle single tensor losses
+ if values is not None:
+ loss_terms.append((values, None, loss_type))
+ if values.numel() > 0:
+ loss_detail = float(values.mean())
+ if loss_detail > 0:
+ details[f"{self_name}_{loss_type}"] = loss_detail
+
+ return loss_terms, details
+
+
+def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor:
+ if beta == 0:
+ return err
+ else:
+ return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta)
+
+
+def compute_normal_loss(points, gt_points, mask):
+ """
+ Compute the normal loss between the predicted and ground truth points.
+ References:
+ https://github.com/microsoft/MoGe/blob/a8c37341bc0325ca99b9d57981cc3bb2bd3e255b/moge/train/losses.py#L205
+
+ Args:
+ points (torch.Tensor): Predicted points. Shape: (..., H, W, 3).
+ gt_points (torch.Tensor): Ground truth points. Shape: (..., H, W, 3).
+ mask (torch.Tensor): Mask indicating valid points. Shape: (..., H, W).
+
+ Returns:
+ torch.Tensor: Normal loss.
+ """
+ height, width = points.shape[-3:-1]
+
+ leftup, rightup, leftdown, rightdown = (
+ points[..., :-1, :-1, :],
+ points[..., :-1, 1:, :],
+ points[..., 1:, :-1, :],
+ points[..., 1:, 1:, :],
+ )
+ upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1)
+ leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1)
+ downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1)
+ rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1)
+
+ gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = (
+ gt_points[..., :-1, :-1, :],
+ gt_points[..., :-1, 1:, :],
+ gt_points[..., 1:, :-1, :],
+ gt_points[..., 1:, 1:, :],
+ )
+ gt_upxleft = torch.cross(
+ gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1
+ )
+ gt_leftxdown = torch.cross(
+ gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1
+ )
+ gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1)
+ gt_rightxup = torch.cross(
+ gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1
+ )
+
+ mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = (
+ mask[..., :-1, :-1],
+ mask[..., :-1, 1:],
+ mask[..., 1:, :-1],
+ mask[..., 1:, 1:],
+ )
+ mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown
+ mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup
+ mask_downxright = mask_leftdown & mask_rightup & mask_leftup
+ mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown
+
+ MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3)
+
+ loss = (
+ mask_upxleft
+ * _smooth(
+ angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE),
+ beta=BETA_RAD,
+ )
+ + mask_leftxdown
+ * _smooth(
+ angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE),
+ beta=BETA_RAD,
+ )
+ + mask_downxright
+ * _smooth(
+ angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE),
+ beta=BETA_RAD,
+ )
+ + mask_rightxup
+ * _smooth(
+ angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE),
+ beta=BETA_RAD,
+ )
+ )
+
+ total_valid_mask = mask_upxleft | mask_leftxdown | mask_downxright | mask_rightxup
+ valid_count = total_valid_mask.sum()
+ if valid_count > 0:
+ loss = loss.sum() / (valid_count * (4 * max(points.shape[-3:-1])))
+ else:
+ loss = 0 * loss.sum()
+
+ return loss
+
+
+def compute_gradient_loss(prediction, gt_target, mask):
+ """
+ Compute the gradient loss between the prediction and GT target at valid points.
+ References:
+ https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss
+ https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py
+
+ Args:
+ prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C).
+ gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C).
+ mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W).
+ """
+ # Expand mask to match number of channels in prediction
+ mask = mask[..., None].expand(-1, -1, -1, prediction.shape[-1])
+ summed_mask = torch.sum(mask, (1, 2, 3))
+
+ # Compute the gradient of the prediction and GT target
+ diff = prediction - gt_target
+ diff = torch.mul(mask, diff)
+
+ # Gradient in x direction
+ grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1])
+ mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1])
+ grad_x = torch.mul(mask_x, grad_x)
+
+ # Gradient in y direction
+ grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :])
+ mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :])
+ grad_y = torch.mul(mask_y, grad_y)
+
+ # Clamp the outlier gradients
+ grad_x = grad_x.clamp(max=100)
+ grad_y = grad_y.clamp(max=100)
+
+ # Compute the total loss
+ image_loss = torch.sum(grad_x, (1, 2, 3)) + torch.sum(grad_y, (1, 2, 3))
+ num_valid_pixels = torch.sum(summed_mask)
+ if num_valid_pixels > 0:
+ image_loss = torch.sum(image_loss) / num_valid_pixels
+ else:
+ image_loss = 0 * torch.sum(image_loss)
+
+ return image_loss
+
+
+def compute_gradient_matching_loss(prediction, gt_target, mask, scales=4):
+ """
+ Compute the multi-scale gradient matching loss between the prediction and GT target at valid points.
+ This loss biases discontinuities to be sharp and to coincide with discontinuities in the ground truth.
+ More info in MiDAS: https://arxiv.org/pdf/1907.01341.pdf; Equation 11
+ References:
+ https://docs.nerf.studio/_modules/nerfstudio/model_components/losses.html#GradientLoss
+ https://github.com/autonomousvision/monosdf/blob/main/code/model/loss.py
+
+ Args:
+ prediction (torch.Tensor): Predicted scene representation. Shape: (B, H, W, C).
+ gt_target (torch.Tensor): Ground truth scene representation. Shape: (B, H, W, C).
+ mask (torch.Tensor): Mask indicating valid points. Shape: (B, H, W).
+ scales (int): Number of scales to compute the loss at. Default: 4.
+ """
+ # Define total loss
+ total_loss = 0.0
+
+ # Compute the gradient loss at different scales
+ for scale in range(scales):
+ step = pow(2, scale)
+ grad_loss = compute_gradient_loss(
+ prediction[:, ::step, ::step],
+ gt_target[:, ::step, ::step],
+ mask[:, ::step, ::step],
+ )
+ total_loss += grad_loss
+
+ return total_loss
+
+
+def Sum(*losses_and_masks):
+ """
+ Aggregates multiple losses into a single loss value or returns the original losses.
+
+ Args:
+ *losses_and_masks: Variable number of tuples, each containing (loss, mask, rep_type)
+ - loss: Tensor containing loss values
+ - mask: Mask indicating valid pixels/regions
+ - rep_type: String indicating the type of representation (e.g., 'pts3d', 'depth')
+
+ Returns:
+ If the first loss has dimensions > 0:
+ Returns the original list of (loss, mask, rep_type) tuples
+ Otherwise:
+ Returns a scalar tensor that is the sum of all loss values
+ """
+ loss, mask, rep_type = losses_and_masks[0]
+ if loss.ndim > 0:
+ # we are actually returning the loss for every pixels
+ return losses_and_masks
+ else:
+ # we are returning the global loss
+ for loss2, mask2, rep_type2 in losses_and_masks[1:]:
+ loss = loss + loss2
+ return loss
+
+
+class BaseCriterion(nn.Module):
+ "Base Criterion to support different reduction methods"
+
+ def __init__(self, reduction="mean"):
+ super().__init__()
+ self.reduction = reduction
+
+
+class LLoss(BaseCriterion):
+ "L-norm loss"
+
+ def forward(self, a, b, **kwargs):
+ assert a.shape == b.shape and a.ndim >= 2 and 1 <= a.shape[-1] <= 4, (
+ f"Bad shape = {a.shape}"
+ )
+ dist = self.distance(a, b, **kwargs)
+ assert dist.ndim == a.ndim - 1 # one dimension less
+ if self.reduction == "none":
+ return dist
+ if self.reduction == "sum":
+ return dist.sum()
+ if self.reduction == "mean":
+ return dist.mean() if dist.numel() > 0 else dist.new_zeros(())
+ raise ValueError(f"bad {self.reduction=} mode")
+
+ def distance(self, a, b, **kwargs):
+ raise NotImplementedError()
+
+
+class L1Loss(LLoss):
+ "L1 distance"
+
+ def distance(self, a, b, **kwargs):
+ return torch.abs(a - b).sum(dim=-1)
+
+
+class L2Loss(LLoss):
+ "Euclidean (L2 Norm) distance"
+
+ def distance(self, a, b, **kwargs):
+ return torch.norm(a - b, dim=-1)
+
+
+class GenericLLoss(LLoss):
+ "Criterion that supports different L-norms"
+
+ def distance(self, a, b, loss_type, **kwargs):
+ if loss_type == "l1":
+ # L1 distance
+ return torch.abs(a - b).sum(dim=-1)
+ elif loss_type == "l2":
+ # Euclidean (L2 norm) distance
+ return torch.norm(a - b, dim=-1)
+ else:
+ raise ValueError(
+ f"Unsupported loss type: {loss_type}. Supported types are 'l1' and 'l2'."
+ )
+
+
+class FactoredLLoss(LLoss):
+ "Criterion that supports different L-norms for the factored loss functions"
+
+ def __init__(
+ self,
+ reduction="mean",
+ points_loss_type="l2",
+ depth_loss_type="l1",
+ ray_directions_loss_type="l1",
+ pose_quats_loss_type="l1",
+ pose_trans_loss_type="l1",
+ scale_loss_type="l1",
+ ):
+ super().__init__(reduction)
+ self.points_loss_type = points_loss_type
+ self.depth_loss_type = depth_loss_type
+ self.ray_directions_loss_type = ray_directions_loss_type
+ self.pose_quats_loss_type = pose_quats_loss_type
+ self.pose_trans_loss_type = pose_trans_loss_type
+ self.scale_loss_type = scale_loss_type
+
+ def _distance(self, a, b, loss_type):
+ if loss_type == "l1":
+ # L1 distance
+ return torch.abs(a - b).sum(dim=-1)
+ elif loss_type == "l2":
+ # Euclidean (L2 norm) distance
+ return torch.norm(a - b, dim=-1)
+ else:
+ raise ValueError(f"Unsupported loss type: {loss_type}.")
+
+ def distance(self, a, b, factor, **kwargs):
+ if factor == "points":
+ return self._distance(a, b, self.points_loss_type)
+ elif factor == "depth":
+ return self._distance(a, b, self.depth_loss_type)
+ elif factor == "ray_directions":
+ return self._distance(a, b, self.ray_directions_loss_type)
+ elif factor == "pose_quats":
+ return self._distance(a, b, self.pose_quats_loss_type)
+ elif factor == "pose_trans":
+ return self._distance(a, b, self.pose_trans_loss_type)
+ elif factor == "scale":
+ return self._distance(a, b, self.scale_loss_type)
+ else:
+ raise ValueError(f"Unsupported factor type: {factor}.")
+
+
+class RobustRegressionLoss(LLoss):
+ """
+ Generalized Robust Loss introduced in https://arxiv.org/abs/1701.03077.
+ """
+
+ def __init__(self, alpha=0.5, scaling_c=0.25, reduction="mean"):
+ """
+ Initialize the Robust Regression Loss.
+
+ Args:
+ alpha (float): Shape parameter controlling the robustness of the loss.
+ Lower values make the loss more robust to outliers. Default: 0.5.
+ scaling_c (float): Scale parameter controlling the transition between
+ quadratic and robust behavior. Default: 0.1.
+ reduction (str): Specifies the reduction to apply to the output:
+ 'none' | 'mean' | 'sum'. Default: 'mean'.
+ """
+ super().__init__(reduction)
+ self.alpha = alpha
+ self.scaling_c = scaling_c
+
+ def distance(self, a, b, **kwargs):
+ error_scaled = torch.sum(((a - b) / self.scaling_c) ** 2, dim=-1)
+ robust_loss = (abs(self.alpha - 2) / self.alpha) * (
+ torch.pow((error_scaled / abs(self.alpha - 2)) + 1, self.alpha / 2) - 1
+ )
+ return robust_loss
+
+
+class BCELoss(BaseCriterion):
+ """Binary Cross Entropy loss"""
+
+ def forward(self, predicted_logits, reference_mask):
+ """
+ Args:
+ predicted_logits: (B, H, W) tensor of predicted logits for the mask
+ reference_mask: (B, H, W) tensor of reference mask
+
+ Returns:
+ loss: scalar tensor of the BCE loss
+ """
+ bce_loss = torch.nn.functional.binary_cross_entropy_with_logits(
+ predicted_logits, reference_mask.float()
+ )
+
+ return bce_loss
+
+
+class Criterion(nn.Module):
+ """
+ Base class for all criterion modules that wrap a BaseCriterion.
+
+ This class serves as a wrapper around BaseCriterion objects, providing
+ additional functionality like naming and reduction mode control.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to wrap.
+ """
+
+ def __init__(self, criterion=None):
+ super().__init__()
+ assert isinstance(criterion, BaseCriterion), (
+ f"{criterion} is not a proper criterion!"
+ )
+ self.criterion = copy(criterion)
+
+ def get_name(self):
+ """
+ Returns a string representation of this criterion.
+
+ Returns:
+ str: A string containing the class name and the wrapped criterion.
+ """
+ return f"{type(self).__name__}({self.criterion})"
+
+ def with_reduction(self, mode="none"):
+ """
+ Creates a deep copy of this criterion with the specified reduction mode.
+
+ This method recursively sets the reduction mode for this criterion and
+ any chained MultiLoss criteria.
+
+ Args:
+ mode (str): The reduction mode to set. Default: "none".
+
+ Returns:
+ Criterion: A new criterion with the specified reduction mode.
+ """
+ res = loss = deepcopy(self)
+ while loss is not None:
+ assert isinstance(loss, Criterion)
+ loss.criterion.reduction = mode # make it return the loss for each sample
+ loss = loss._loss2 # we assume loss is a Multiloss
+ return res
+
+
+class MultiLoss(nn.Module):
+ """
+ Base class for combinable loss functions with automatic tracking of individual loss values.
+
+ This class enables easy combination of multiple loss functions through arithmetic operations:
+ loss = MyLoss1() + 0.1*MyLoss2()
+
+ The combined loss functions maintain their individual weights and the forward pass
+ automatically computes and aggregates all losses while tracking individual loss values.
+
+ Usage:
+ Inherit from this class and override get_name() and compute_loss() methods.
+
+ Attributes:
+ _alpha (float): Weight multiplier for this loss component.
+ _loss2 (MultiLoss): Reference to the next loss in the chain, if any.
+ """
+
+ def __init__(self):
+ """Initialize the MultiLoss with default weight of 1 and no chained loss."""
+ super().__init__()
+ self._alpha = 1
+ self._loss2 = None
+
+ def compute_loss(self, *args, **kwargs):
+ """
+ Compute the loss value for this specific loss component.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ torch.Tensor or tuple: Either the loss tensor or a tuple of (loss, details_dict).
+
+ Raises:
+ NotImplementedError: This method must be implemented by subclasses.
+ """
+ raise NotImplementedError()
+
+ def get_name(self):
+ """
+ Get the name of this loss component.
+
+ Returns:
+ str: The name of the loss.
+
+ Raises:
+ NotImplementedError: This method must be implemented by subclasses.
+ """
+ raise NotImplementedError()
+
+ def __mul__(self, alpha):
+ """
+ Multiply the loss by a scalar weight.
+
+ Args:
+ alpha (int or float): The weight to multiply the loss by.
+
+ Returns:
+ MultiLoss: A new loss object with the updated weight.
+
+ Raises:
+ AssertionError: If alpha is not a number.
+ """
+ assert isinstance(alpha, (int, float))
+ res = copy(self)
+ res._alpha = alpha
+ return res
+
+ __rmul__ = __mul__ # Support both loss*alpha and alpha*loss
+
+ def __add__(self, loss2):
+ """
+ Add another loss to this loss, creating a chain of losses.
+
+ Args:
+ loss2 (MultiLoss): Another loss to add to this one.
+
+ Returns:
+ MultiLoss: A new loss object representing the combined losses.
+
+ Raises:
+ AssertionError: If loss2 is not a MultiLoss.
+ """
+ assert isinstance(loss2, MultiLoss)
+ res = cur = copy(self)
+ # Find the end of the chain
+ while cur._loss2 is not None:
+ cur = cur._loss2
+ cur._loss2 = loss2
+ return res
+
+ def __repr__(self):
+ """
+ Create a string representation of the loss, including weights and chained losses.
+
+ Returns:
+ str: String representation of the loss.
+ """
+ name = self.get_name()
+ if self._alpha != 1:
+ name = f"{self._alpha:g}*{name}"
+ if self._loss2:
+ name = f"{name} + {self._loss2}"
+ return name
+
+ def forward(self, *args, **kwargs):
+ """
+ Compute the weighted loss and aggregate with any chained losses.
+
+ Args:
+ *args: Variable length argument list.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ tuple: A tuple containing:
+ - torch.Tensor: The total weighted loss.
+ - dict: Details about individual loss components.
+ """
+ loss = self.compute_loss(*args, **kwargs)
+ if isinstance(loss, tuple):
+ loss, details = loss
+ elif loss.ndim == 0:
+ details = {self.get_name(): float(loss)}
+ else:
+ details = {}
+ loss = loss * self._alpha
+
+ if self._loss2:
+ loss2, details2 = self._loss2(*args, **kwargs)
+ loss = loss + loss2
+ details |= details2
+
+ return loss, details
+
+
+class NonAmbiguousMaskLoss(Criterion, MultiLoss):
+ """
+ Loss on non-ambiguous mask prediction logits.
+ """
+
+ def __init__(self, criterion):
+ super().__init__(criterion)
+
+ def compute_loss(self, batch, preds, **kw):
+ """
+ Args:
+ batch: list of dicts with the gt data
+ preds: list of dicts with the predictions
+
+ Returns:
+ loss: Sum class of the lossses for N-views and the loss details
+ """
+ # Init loss list to keep track of individual losses for each view
+ loss_list = []
+ mask_loss_details = {}
+ mask_loss_total = 0
+ self_name = type(self).__name__
+
+ # Loop over the views
+ for view_idx, (gt, pred) in enumerate(zip(batch, preds)):
+ # Get the GT non-ambiguous masks
+ gt_non_ambiguous_mask = gt["non_ambiguous_mask"]
+
+ # Get the predicted non-ambiguous mask logits
+ pred_non_ambiguous_mask_logits = pred["non_ambiguous_mask_logits"]
+
+ # Compute the loss for the current view
+ loss = self.criterion(pred_non_ambiguous_mask_logits, gt_non_ambiguous_mask)
+
+ # Add the loss to the list
+ loss_list.append((loss, None, "non_ambiguous_mask"))
+
+ # Add the loss details to the dictionary
+ mask_loss_details[f"{self_name}_mask_view{view_idx + 1}"] = float(loss)
+ mask_loss_total += float(loss)
+
+ # Compute the average loss across all views
+ mask_loss_details[f"{self_name}_mask_avg"] = mask_loss_total / len(batch)
+
+ return Sum(*loss_list), (mask_loss_details | {})
+
+
+class ConfLoss(MultiLoss):
+ """
+ Applies confidence-weighted regression loss using model-predicted confidence values.
+
+ The confidence-weighted loss has the form:
+ conf_loss = raw_loss * conf - alpha * log(conf)
+
+ Where:
+ - raw_loss is the original per-pixel loss
+ - conf is the predicted confidence (higher values = higher confidence)
+ - alpha is a hyperparameter controlling the regularization strength
+
+ This loss can be selectively applied to specific loss components in factored and multi-view settings.
+ """
+
+ def __init__(self, pixel_loss, alpha=1, loss_set_indices=None):
+ """
+ Args:
+ pixel_loss (MultiLoss): The pixel-level regression loss to be used.
+ alpha (float): Hyperparameter controlling the confidence regularization strength.
+ loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to.
+ Each index selects a specific loss set across all views (with the same rep_type).
+ If None, defaults to [0] which applies to the first loss set only.
+ """
+ super().__init__()
+ assert alpha > 0
+ self.alpha = alpha
+ self.pixel_loss = pixel_loss.with_reduction("none")
+ self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices
+
+ def get_name(self):
+ return f"ConfLoss({self.pixel_loss})"
+
+ def get_conf_log(self, x):
+ return x, torch.log(x)
+
+ def compute_loss(self, batch, preds, **kw):
+ # Init loss list and details
+ total_loss = 0
+ conf_loss_details = {}
+ running_avg_dict = {}
+ self_name = type(self.pixel_loss).__name__
+ n_views = len(batch)
+
+ # Compute per-pixel loss for each view
+ losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw)
+
+ # Select specific loss sets based on indices
+ selected_losses = []
+ processed_indices = set()
+ for idx in self.loss_set_indices:
+ start_idx = idx * n_views
+ end_idx = min((idx + 1) * n_views, len(losses))
+ selected_losses.extend(losses[start_idx:end_idx])
+ processed_indices.update(range(start_idx, end_idx))
+
+ # Process selected losses with confidence weighting
+ for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses):
+ view_idx = loss_idx % n_views # Map to corresponding view index
+
+ if loss.numel() == 0:
+ # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ continue
+
+ # Get the confidence and log confidence
+ if (
+ hasattr(self.pixel_loss, "flatten_across_image_only")
+ and self.pixel_loss.flatten_across_image_only
+ ):
+ # Reshape confidence to match the flattened dimensions
+ conf_reshaped = preds[view_idx]["conf"].view(
+ preds[view_idx]["conf"].shape[0], -1
+ )
+ conf, log_conf = self.get_conf_log(conf_reshaped[msk])
+ loss = loss[msk]
+ else:
+ conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk])
+
+ # Weight the loss by the confidence
+ conf_loss = loss * conf - self.alpha * log_conf
+
+ # Only add to total loss and store details if there are valid elements
+ if conf_loss.numel() > 0:
+ conf_loss = conf_loss.mean()
+ total_loss = total_loss + conf_loss
+
+ # Store details
+ conf_loss_details[
+ f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}"
+ ] = float(conf_loss)
+
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_{rep_type}_conf_loss_avg"
+ if avg_key not in conf_loss_details:
+ conf_loss_details[avg_key] = float(conf_loss)
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ] = 1
+ else:
+ valid_views = (
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ]
+ + 1
+ )
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ] = valid_views
+ conf_loss_details[avg_key] += (
+ float(conf_loss) - conf_loss_details[avg_key]
+ ) / valid_views
+
+ # Add unmodified losses for sets not in selected_losses
+ for idx, (loss, msk, rep_type) in enumerate(losses):
+ if idx not in processed_indices:
+ if msk is not None:
+ loss_after_masking = loss[msk]
+ else:
+ loss_after_masking = loss
+ if loss_after_masking.numel() > 0:
+ loss_mean = loss_after_masking.mean()
+ else:
+ # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ loss_mean = 0
+ total_loss = total_loss + loss_mean
+
+ return total_loss, dict(**conf_loss_details, **pixel_loss_details)
+
+
+class ExcludeTopNPercentPixelLoss(MultiLoss):
+ """
+ Pixel-level regression loss where for each instance in a batch the top N% of per-pixel loss values are ignored
+ for the mean loss computation.
+ Allows selecting which pixel-level regression loss sets to apply the exclusion to.
+ """
+
+ def __init__(
+ self,
+ pixel_loss,
+ top_n_percent=5,
+ apply_to_real_data_only=True,
+ loss_set_indices=None,
+ ):
+ """
+ Args:
+ pixel_loss (MultiLoss): The pixel-level regression loss to be used.
+ top_n_percent (float): The percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5.
+ apply_to_real_data_only (bool): Whether to apply the loss only to real world data. Default: True.
+ loss_set_indices (list or None): Indices of the loss sets to apply the exclusion to.
+ Each index selects a specific loss set across all views (with the same rep_type).
+ If None, defaults to [0] which applies to the first loss set only.
+ """
+ super().__init__()
+ self.pixel_loss = pixel_loss.with_reduction("none")
+ self.top_n_percent = top_n_percent
+ self.bottom_n_percent = 100 - top_n_percent
+ self.apply_to_real_data_only = apply_to_real_data_only
+ self.loss_set_indices = [0] if loss_set_indices is None else loss_set_indices
+
+ def get_name(self):
+ return f"ExcludeTopNPercentPixelLoss({self.pixel_loss})"
+
+ def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent):
+ """
+ Function to compute the mask for keeping the bottom n percent of per-pixel loss values.
+
+ Args:
+ tensor (torch.Tensor): The tensor containing the per-pixel loss values.
+ Shape: (B, N) where B is the batch size and N is the number of total pixels.
+ mask (torch.Tensor): The mask indicating valid pixels. Shape: (B, N).
+
+ Returns:
+ torch.Tensor: Flattened tensor containing the bottom n percent of per-pixel loss values.
+ """
+ B, N = tensor.shape
+
+ # Calculate the number of valid elements (where mask is True)
+ num_valid = mask.sum(dim=1)
+
+ # Calculate the number of elements to keep (n% of valid elements)
+ num_keep = (num_valid * bottom_n_percent / 100).long()
+
+ # Create a mask for the bottom n% elements
+ keep_mask = torch.arange(N, device=tensor.device).unsqueeze(
+ 0
+ ) < num_keep.unsqueeze(1)
+
+ # Create a tensor with inf where mask is False
+ masked_tensor = torch.where(
+ mask, tensor, torch.tensor(float("inf"), device=tensor.device)
+ )
+
+ # Sort the masked tensor along the N dimension
+ sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False)
+
+ # Get the bottom n% elements
+ bottom_n_percent_elements = sorted_tensor[keep_mask]
+
+ return bottom_n_percent_elements
+
+ def compute_loss(self, batch, preds, **kw):
+ # Compute per-pixel loss
+ losses, details = self.pixel_loss(batch, preds, **kw)
+ n_views = len(batch)
+
+ # Select specific loss sets based on indices
+ selected_losses = []
+ processed_indices = set()
+ for idx in self.loss_set_indices:
+ start_idx = idx * n_views
+ end_idx = min((idx + 1) * n_views, len(losses))
+ selected_losses.extend(losses[start_idx:end_idx])
+ processed_indices.update(range(start_idx, end_idx))
+
+ # Initialize total loss
+ total_loss = 0.0
+ loss_details = {}
+ running_avg_dict = {}
+ self_name = type(self.pixel_loss).__name__
+
+ # Process selected losses with top N percent exclusion
+ for loss_idx, (loss, msk, rep_type) in enumerate(selected_losses):
+ view_idx = loss_idx % n_views # Map to corresponding view index
+
+ if loss.numel() == 0:
+ # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ continue
+
+ # Create empty list for current view's aggregated tensors
+ aggregated_losses = []
+
+ if self.apply_to_real_data_only:
+ # Get the synthetic and real world data mask
+ synthetic_mask = batch[view_idx]["is_synthetic"]
+ real_data_mask = ~batch[view_idx]["is_synthetic"]
+ else:
+ # Apply the filtering to all data
+ synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"])
+ real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"])
+
+ # Process synthetic data
+ if synthetic_mask.any():
+ synthetic_loss = loss[synthetic_mask]
+ synthetic_msk = msk[synthetic_mask]
+ aggregated_losses.append(synthetic_loss[synthetic_msk])
+
+ # Process real data
+ if real_data_mask.any():
+ real_loss = loss[real_data_mask]
+ real_msk = msk[real_data_mask]
+ real_bottom_n_percent_loss = self.keep_bottom_n_percent(
+ real_loss, real_msk, self.bottom_n_percent
+ )
+ aggregated_losses.append(real_bottom_n_percent_loss)
+
+ # Compute view loss
+ view_loss = torch.cat(aggregated_losses, dim=0)
+
+ # Only add to total loss and store details if there are valid elements
+ if view_loss.numel() > 0:
+ view_loss = view_loss.mean()
+ total_loss = total_loss + view_loss
+
+ # Store details
+ loss_details[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}"
+ ] = float(view_loss)
+
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg"
+ if avg_key not in loss_details:
+ loss_details[avg_key] = float(view_loss)
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ] = 1
+ else:
+ valid_views = (
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ]
+ + 1
+ )
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ] = valid_views
+ loss_details[avg_key] += (
+ float(view_loss) - loss_details[avg_key]
+ ) / valid_views
+
+ # Add unmodified losses for sets not in selected_losses
+ for idx, (loss, msk, rep_type) in enumerate(losses):
+ if idx not in processed_indices:
+ if msk is not None:
+ loss_after_masking = loss[msk]
+ else:
+ loss_after_masking = loss
+ if loss_after_masking.numel() > 0:
+ loss_mean = loss_after_masking.mean()
+ else:
+ # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ loss_mean = 0
+ total_loss = total_loss + loss_mean
+
+ return total_loss, dict(**loss_details, **details)
+
+
+class ConfAndExcludeTopNPercentPixelLoss(MultiLoss):
+ """
+ Combined loss that applies ConfLoss to one set of pixel-level regression losses
+ and ExcludeTopNPercentPixelLoss to another set of pixel-level regression losses.
+ """
+
+ def __init__(
+ self,
+ pixel_loss,
+ conf_alpha=1,
+ top_n_percent=5,
+ apply_to_real_data_only=True,
+ conf_loss_set_indices=None,
+ exclude_loss_set_indices=None,
+ ):
+ """
+ Args:
+ pixel_loss (MultiLoss): The pixel-level regression loss to be used.
+ conf_alpha (float): Alpha parameter for ConfLoss. Default: 1.
+ top_n_percent (float): Percentage of top per-pixel loss values to ignore. Range: [0, 100]. Default: 5.
+ apply_to_real_data_only (bool): Whether to apply the exclude loss only to real world data. Default: True.
+ conf_loss_set_indices (list or None): Indices of the loss sets to apply confidence weighting to.
+ Each index selects a specific loss set across all views (with the same rep_type).
+ If None, defaults to [0] which applies to the first loss set only.
+ exclude_loss_set_indices (list or None): Indices of the loss sets to apply top N percent exclusion to.
+ Each index selects a specific loss set across all views (with the same rep_type).
+ If None, defaults to [1] which applies to the second loss set only.
+ """
+ super().__init__()
+ self.pixel_loss = pixel_loss.with_reduction("none")
+ assert conf_alpha > 0
+ self.conf_alpha = conf_alpha
+ self.top_n_percent = top_n_percent
+ self.bottom_n_percent = 100 - top_n_percent
+ self.apply_to_real_data_only = apply_to_real_data_only
+ self.conf_loss_set_indices = (
+ [0] if conf_loss_set_indices is None else conf_loss_set_indices
+ )
+ self.exclude_loss_set_indices = (
+ [1] if exclude_loss_set_indices is None else exclude_loss_set_indices
+ )
+
+ def get_name(self):
+ return f"ConfAndExcludeTopNPercentPixelLoss({self.pixel_loss})"
+
+ def get_conf_log(self, x):
+ return x, torch.log(x)
+
+ def keep_bottom_n_percent(self, tensor, mask, bottom_n_percent):
+ """
+ Function to compute the mask for keeping the bottom n percent of per-pixel loss values.
+ """
+ B, N = tensor.shape
+
+ # Calculate the number of valid elements (where mask is True)
+ num_valid = mask.sum(dim=1)
+
+ # Calculate the number of elements to keep (n% of valid elements)
+ num_keep = (num_valid * bottom_n_percent / 100).long()
+
+ # Create a mask for the bottom n% elements
+ keep_mask = torch.arange(N, device=tensor.device).unsqueeze(
+ 0
+ ) < num_keep.unsqueeze(1)
+
+ # Create a tensor with inf where mask is False
+ masked_tensor = torch.where(
+ mask, tensor, torch.tensor(float("inf"), device=tensor.device)
+ )
+
+ # Sort the masked tensor along the N dimension
+ sorted_tensor, _ = torch.sort(masked_tensor, dim=1, descending=False)
+
+ # Get the bottom n% elements
+ bottom_n_percent_elements = sorted_tensor[keep_mask]
+
+ return bottom_n_percent_elements
+
+ def compute_loss(self, batch, preds, **kw):
+ # Compute per-pixel loss
+ losses, pixel_loss_details = self.pixel_loss(batch, preds, **kw)
+ n_views = len(batch)
+
+ # Select specific loss sets for confidence weighting
+ conf_selected_losses = []
+ conf_processed_indices = set()
+ for idx in self.conf_loss_set_indices:
+ start_idx = idx * n_views
+ end_idx = min((idx + 1) * n_views, len(losses))
+ conf_selected_losses.extend(losses[start_idx:end_idx])
+ conf_processed_indices.update(range(start_idx, end_idx))
+
+ # Select specific loss sets for top N percent exclusion
+ exclude_selected_losses = []
+ exclude_processed_indices = set()
+ for idx in self.exclude_loss_set_indices:
+ start_idx = idx * n_views
+ end_idx = min((idx + 1) * n_views, len(losses))
+ exclude_selected_losses.extend(losses[start_idx:end_idx])
+ exclude_processed_indices.update(range(start_idx, end_idx))
+
+ # Initialize total loss and details
+ total_loss = 0
+ loss_details = {}
+ running_avg_dict = {}
+ self_name = type(self.pixel_loss).__name__
+
+ # Process selected losses with confidence weighting
+ for loss_idx, (loss, msk, rep_type) in enumerate(conf_selected_losses):
+ view_idx = loss_idx % n_views # Map to corresponding view index
+
+ if loss.numel() == 0:
+ # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for conf loss", force=True)
+ continue
+
+ # Get the confidence and log confidence
+ if (
+ hasattr(self.pixel_loss, "flatten_across_image_only")
+ and self.pixel_loss.flatten_across_image_only
+ ):
+ # Reshape confidence to match the flattened dimensions
+ conf_reshaped = preds[view_idx]["conf"].view(
+ preds[view_idx]["conf"].shape[0], -1
+ )
+ conf, log_conf = self.get_conf_log(conf_reshaped[msk])
+ loss = loss[msk]
+ else:
+ conf, log_conf = self.get_conf_log(preds[view_idx]["conf"][msk])
+
+ # Weight the loss by the confidence
+ conf_loss = loss * conf - self.conf_alpha * log_conf
+
+ # Only add to total loss and store details if there are valid elements
+ if conf_loss.numel() > 0:
+ conf_loss = conf_loss.mean()
+ total_loss = total_loss + conf_loss
+
+ # Store details
+ loss_details[f"{self_name}_{rep_type}_conf_loss_view{view_idx + 1}"] = (
+ float(conf_loss)
+ )
+
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_{rep_type}_conf_loss_avg"
+ if avg_key not in loss_details:
+ loss_details[avg_key] = float(conf_loss)
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ] = 1
+ else:
+ valid_views = (
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ]
+ + 1
+ )
+ running_avg_dict[
+ f"{self_name}_{rep_type}_conf_loss_valid_views"
+ ] = valid_views
+ loss_details[avg_key] += (
+ float(conf_loss) - loss_details[avg_key]
+ ) / valid_views
+
+ # Process selected losses with top N percent exclusion
+ for loss_idx, (loss, msk, rep_type) in enumerate(exclude_selected_losses):
+ view_idx = loss_idx % n_views # Map to corresponding view index
+
+ if loss.numel() == 0:
+ # print(f"NO VALID VALUES in loss idx {loss_idx} (Rep Type: {rep_type}, Num Views: {n_views}) for exclude loss", force=True)
+ continue
+
+ # Create empty list for current view's aggregated tensors
+ aggregated_losses = []
+
+ if self.apply_to_real_data_only:
+ # Get the synthetic and real world data mask
+ synthetic_mask = batch[view_idx]["is_synthetic"]
+ real_data_mask = ~batch[view_idx]["is_synthetic"]
+ else:
+ # Apply the filtering to all data
+ synthetic_mask = torch.zeros_like(batch[view_idx]["is_synthetic"])
+ real_data_mask = torch.ones_like(batch[view_idx]["is_synthetic"])
+
+ # Process synthetic data
+ if synthetic_mask.any():
+ synthetic_loss = loss[synthetic_mask]
+ synthetic_msk = msk[synthetic_mask]
+ aggregated_losses.append(synthetic_loss[synthetic_msk])
+
+ # Process real data
+ if real_data_mask.any():
+ real_loss = loss[real_data_mask]
+ real_msk = msk[real_data_mask]
+ real_bottom_n_percent_loss = self.keep_bottom_n_percent(
+ real_loss, real_msk, self.bottom_n_percent
+ )
+ aggregated_losses.append(real_bottom_n_percent_loss)
+
+ # Compute view loss
+ view_loss = torch.cat(aggregated_losses, dim=0)
+
+ # Only add to total loss and store details if there are valid elements
+ if view_loss.numel() > 0:
+ view_loss = view_loss.mean()
+ total_loss = total_loss + view_loss
+
+ # Store details
+ loss_details[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_view{view_idx + 1}"
+ ] = float(view_loss)
+
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_loss_avg"
+ if avg_key not in loss_details:
+ loss_details[avg_key] = float(view_loss)
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ] = 1
+ else:
+ valid_views = (
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ]
+ + 1
+ )
+ running_avg_dict[
+ f"{self_name}_{rep_type}_bot{self.bottom_n_percent}%_valid_views"
+ ] = valid_views
+ loss_details[avg_key] += (
+ float(view_loss) - loss_details[avg_key]
+ ) / valid_views
+
+ # Add unmodified losses for sets not processed with either confidence or exclusion
+ all_processed_indices = conf_processed_indices.union(exclude_processed_indices)
+ for idx, (loss, msk, rep_type) in enumerate(losses):
+ if idx not in all_processed_indices:
+ if msk is not None:
+ loss_after_masking = loss[msk]
+ else:
+ loss_after_masking = loss
+ if loss_after_masking.numel() > 0:
+ loss_mean = loss_after_masking.mean()
+ else:
+ # print(f"NO VALID VALUES in loss idx {idx} (Rep Type: {rep_type}, Num Views: {n_views})", force=True)
+ loss_mean = 0
+ total_loss = total_loss + loss_mean
+
+ return total_loss, dict(**loss_details, **pixel_loss_details)
+
+
+class Regr3D(Criterion, MultiLoss):
+ """
+ Regression Loss for World Frame Pointmaps.
+ Asymmetric loss where view 1 is supposed to be the anchor.
+
+ For each view i:
+ Pi = RTi @ Di
+ lossi = (RTi1 @ pred_Di) - (RT1^-1 @ RTi @ Di)
+ where RT1 is the anchor view camera pose
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_mode="?avg_dis",
+ gt_scale=False,
+ ambiguous_loss_value=0,
+ max_metric_scale=False,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ ):
+ """
+ Initialize the loss criterion for World Frame Pointmaps.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis".
+ If prefixed with "?", normalization is only applied to non-metric scale data.
+ gt_scale (bool): If True, enforce predictions to have the same scale as ground truth.
+ If False, both GT and predictions are normalized independently. Default: False.
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this
+ value, it will be treated as non-metric. Default: False (no limit).
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for pointmaps. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ """
+ super().__init__(criterion)
+ if norm_mode.startswith("?"):
+ # Do no norm pts from metric scale datasets
+ self.norm_all = False
+ self.norm_mode = norm_mode[1:]
+ else:
+ self.norm_all = True
+ self.norm_mode = norm_mode
+ self.gt_scale = gt_scale
+ self.ambiguous_loss_value = ambiguous_loss_value
+ self.max_metric_scale = max_metric_scale
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ n_views = len(batch)
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+
+ # Initialize lists to store points and masks
+ no_norm_gt_pts = []
+ valid_masks = []
+
+ # Process ground truth points and valid masks
+ for view_idx in range(n_views):
+ no_norm_gt_pts.append(
+ geotrf(in_camera0, batch[view_idx]["pts3d"])
+ ) # B,H,W,3
+ valid_masks.append(batch[view_idx]["valid_mask"].clone())
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for view_idx in range(n_views):
+ dis = no_norm_gt_pts[view_idx].norm(dim=-1) # (B, H, W)
+ valid_masks[view_idx] = valid_masks[view_idx] & (dis <= dist_clip)
+
+ # Get predicted points
+ no_norm_pr_pts = []
+ for view_idx in range(n_views):
+ no_norm_pr_pts.append(preds[view_idx]["pts3d"])
+
+ if not self.norm_all:
+ if self.max_metric_scale:
+ B = valid_masks[0].shape[0]
+ # Calculate distances to camera for all views
+ dists_to_cam1 = []
+ for view_idx in range(n_views):
+ dist = torch.where(
+ valid_masks[view_idx],
+ torch.norm(no_norm_gt_pts[view_idx], dim=-1),
+ 0,
+ ).view(B, -1)
+ dists_to_cam1.append(dist)
+
+ # Update metric scale flags
+ metric_scale_mask = batch[0]["is_metric_scale"]
+ for dist in dists_to_cam1:
+ metric_scale_mask = metric_scale_mask & (
+ dist.max(dim=-1).values < self.max_metric_scale
+ )
+
+ for view_idx in range(n_views):
+ batch[view_idx]["is_metric_scale"] = metric_scale_mask
+
+ non_metric_scale_mask = ~batch[0]["is_metric_scale"]
+ else:
+ non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"])
+
+ # Initialize normalized points
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+
+ # Normalize 3d points
+ if self.norm_mode and non_metric_scale_mask.any():
+ normalized_pr_pts = normalize_multiple_pointclouds(
+ [pts[non_metric_scale_mask] for pts in no_norm_pr_pts],
+ [mask[non_metric_scale_mask] for mask in valid_masks],
+ self.norm_mode,
+ )
+ for i in range(n_views):
+ pr_pts[i][non_metric_scale_mask] = normalized_pr_pts[i]
+ elif non_metric_scale_mask.any():
+ for i in range(n_views):
+ pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][
+ non_metric_scale_mask
+ ]
+
+ if self.norm_mode and not self.gt_scale:
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ normalized_gt_pts = gt_normalization_output[:-1]
+ norm_factor = gt_normalization_output[-1]
+ for i in range(n_views):
+ gt_pts[i] = normalized_gt_pts[i]
+ pr_pts[i][~non_metric_scale_mask] = (
+ no_norm_pr_pts[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask]
+ )
+ elif ~non_metric_scale_mask.any():
+ for i in range(n_views):
+ gt_pts[i] = no_norm_gt_pts[i]
+ pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][
+ ~non_metric_scale_mask
+ ]
+ else:
+ for i in range(n_views):
+ gt_pts[i] = no_norm_gt_pts[i]
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for view_idx in range(n_views):
+ ambiguous_masks.append(
+ (~batch[view_idx]["non_ambiguous_mask"]) & (~valid_masks[view_idx])
+ )
+
+ return gt_pts, pr_pts, valid_masks, ambiguous_masks, {}
+
+ def compute_loss(self, batch, preds, **kw):
+ gt_pts, pred_pts, masks, ambiguous_masks, monitoring = self.get_all_info(
+ batch, preds, **kw
+ )
+ n_views = len(batch)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixels as "valid" pixels
+ masks = [mask | amb_mask for mask, amb_mask in zip(masks, ambiguous_masks)]
+
+ losses = []
+ details = {}
+ running_avg_dict = {}
+ self_name = type(self).__name__
+
+ if not self.flatten_across_image_only:
+ for view_idx in range(n_views):
+ pred = pred_pts[view_idx][masks[view_idx]]
+ gt = gt_pts[view_idx][masks[view_idx]]
+
+ if self.loss_in_log:
+ pred = apply_log_to_norm(pred)
+ gt = apply_log_to_norm(gt)
+
+ loss = self.criterion(pred, gt)
+
+ if self.ambiguous_loss_value > 0:
+ loss = torch.where(
+ ambiguous_masks[view_idx][masks[view_idx]],
+ self.ambiguous_loss_value,
+ loss,
+ )
+
+ losses.append((loss, masks[view_idx], "pts3d"))
+ if loss.numel() > 0:
+ loss_mean = float(loss.mean())
+ details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_pts3d_avg"
+ if avg_key not in details:
+ details[avg_key] = loss_mean
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1
+ else:
+ valid_views = (
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1
+ )
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views
+ details[avg_key] += (loss_mean - details[avg_key]) / valid_views
+ else:
+ batch_size, _, _, dim = gt_pts[0].shape
+
+ for view_idx in range(n_views):
+ gt = gt_pts[view_idx].view(batch_size, -1, dim)
+ pred = pred_pts[view_idx].view(batch_size, -1, dim)
+ view_mask = masks[view_idx].view(batch_size, -1)
+ amb_mask = ambiguous_masks[view_idx].view(batch_size, -1)
+
+ if self.loss_in_log:
+ pred = apply_log_to_norm(pred)
+ gt = apply_log_to_norm(gt)
+
+ loss = self.criterion(pred, gt)
+
+ if self.ambiguous_loss_value > 0:
+ loss = torch.where(amb_mask, self.ambiguous_loss_value, loss)
+
+ losses.append((loss, view_mask, "pts3d"))
+ loss_after_masking = loss[view_mask]
+ if loss_after_masking.numel() > 0:
+ loss_mean = float(loss_after_masking.mean())
+ details[f"{self_name}_pts3d_view{view_idx + 1}"] = loss_mean
+ # Initialize or update running average directly
+ avg_key = f"{self_name}_pts3d_avg"
+ if avg_key not in details:
+ details[avg_key] = loss_mean
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] = 1
+ else:
+ valid_views = (
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] + 1
+ )
+ running_avg_dict[f"{self_name}_pts3d_valid_views"] = valid_views
+ details[avg_key] += (loss_mean - details[avg_key]) / valid_views
+
+ return Sum(*losses), (details | monitoring)
+
+
+class PointsPlusScaleRegr3D(Criterion, MultiLoss):
+ """
+ Regression Loss for World Frame Pointmaps & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ ambiguous_loss_value=0,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ world_frame_points_loss_weight=1,
+ scale_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for World Frame Pointmaps & Scale.
+ The predicited scene representation is always normalized w.r.t. the frame of view0.
+ Loss is applied between the predicted metric scale and the ground truth metric scale.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ """
+ super().__init__(criterion)
+ self.norm_predictions = norm_predictions
+ self.norm_mode = norm_mode
+ self.ambiguous_loss_value = ambiguous_loss_value
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+ self.world_frame_points_loss_weight = world_frame_points_loss_weight
+ self.scale_loss_weight = scale_loss_weight
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ """
+ Function to get all the information needed to compute the loss.
+ Returns all quantities normalized w.r.t. camera of view0.
+ """
+ n_views = len(batch)
+
+ # Everything is normalized w.r.t. camera of view0
+ # Intialize lists to store data for all views
+ # Ground truth quantities
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+ no_norm_gt_pts = []
+ valid_masks = []
+ # Predicted quantities
+ no_norm_pr_pts = []
+ metric_pr_pts_to_compute_scale = []
+
+ # Get ground truth & prediction info for all views
+ for i in range(n_views):
+ # Get the ground truth
+ no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
+ valid_masks.append(batch[i]["valid_mask"].clone())
+
+ # Get predictions for normalized loss
+ if "metric_scaling_factor" in preds[i].keys():
+ # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
+ # This detaches the predicted metric scaling factor from the geometry based loss
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ else:
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"]
+ no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
+
+ # Get the predicted metric scale points
+ if "metric_scaling_factor" in preds[i].keys():
+ # Detach the raw predicted points so that the scale loss is only applied to the scaling factor
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.detach()
+ * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1)
+ )
+ else:
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.clone()
+ )
+ metric_pr_pts_to_compute_scale.append(
+ curr_view_metric_pr_pts_to_compute_scale
+ )
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for i in range(n_views):
+ dis = no_norm_gt_pts[i].norm(dim=-1)
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
+
+ # Initialize normalized tensors
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+
+ # Normalize the predicted points if specified
+ if self.norm_predictions:
+ pr_normalization_output = normalize_multiple_pointclouds(
+ no_norm_pr_pts,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_pts_norm = pr_normalization_output[:-1]
+
+ # Normalize the ground truth points
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ gt_pts_norm = gt_normalization_output[:-1]
+ gt_norm_factor = gt_normalization_output[-1]
+
+ for i in range(n_views):
+ if self.norm_predictions:
+ # Assign the normalized predictions
+ pr_pts[i] = pr_pts_norm[i]
+ else:
+ pr_pts[i] = no_norm_pr_pts[i]
+ # Assign the normalized ground truth quantities
+ gt_pts[i] = gt_pts_norm[i]
+
+ # Get the mask indicating ground truth metric scale quantities
+ metric_scale_mask = batch[0]["is_metric_scale"]
+ valid_gt_norm_factor_mask = (
+ gt_norm_factor[:, 0, 0, 0] > 1e-8
+ ) # Mask out cases where depth for all views is invalid
+ valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask
+
+ if valid_metric_scale_mask.any():
+ # Compute the scale norm factor using the predicted metric scale points
+ metric_pr_normalization_output = normalize_multiple_pointclouds(
+ metric_pr_pts_to_compute_scale,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_metric_norm_factor = metric_pr_normalization_output[-1]
+
+ # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities
+ gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask]
+ pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask]
+ else:
+ gt_metric_norm_factor = None
+ pr_metric_norm_factor = None
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for i in range(n_views):
+ ambiguous_masks.append(
+ (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
+ )
+
+ # Pack into info dicts
+ gt_info = []
+ pred_info = []
+ for i in range(n_views):
+ gt_info.append(
+ {
+ "pts3d": gt_pts[i],
+ }
+ )
+ pred_info.append(
+ {
+ "pts3d": pr_pts[i],
+ }
+ )
+
+ return (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ )
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, pts_dim = gt_info[i]["pts3d"].shape
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space if specified
+ if self.loss_in_log:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ }
+
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class NormalGMLoss(MultiLoss):
+ """
+ Normal & Gradient Matching Loss for Monocular Depth Training.
+ """
+
+ def __init__(
+ self,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ apply_normal_and_gm_loss_to_synthetic_data_only=True,
+ ):
+ """
+ Initialize the loss criterion for Normal & Gradient Matching Loss (currently only valid for 1 view).
+ Computes:
+ (1) Normal Loss over the PointMap (naturally will be in local frame) in euclidean coordinates,
+ (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
+
+ Args:
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
+ If False, apply the normal and gm loss to all data. Default: True.
+ """
+ super().__init__()
+ self.norm_predictions = norm_predictions
+ self.norm_mode = norm_mode
+ self.apply_normal_and_gm_loss_to_synthetic_data_only = (
+ apply_normal_and_gm_loss_to_synthetic_data_only
+ )
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ """
+ Function to get all the information needed to compute the loss.
+ Returns all quantities normalized.
+ """
+ n_views = len(batch)
+ assert n_views == 1, (
+ "Normal & Gradient Matching Loss Class only supports 1 view"
+ )
+
+ # Everything is normalized w.r.t. camera of view1
+ in_camera1 = closed_form_pose_inverse(batch[0]["camera_pose"])
+
+ # Initialize lists to store data for all views
+ no_norm_gt_pts = []
+ valid_masks = []
+ no_norm_pr_pts = []
+
+ # Get ground truth & prediction info for all views
+ for i in range(n_views):
+ # Get ground truth
+ no_norm_gt_pts.append(geotrf(in_camera1, batch[i]["pts3d"]))
+ valid_masks.append(batch[i]["valid_mask"].clone())
+
+ # Get predictions for normalized loss
+ if "metric_scaling_factor" in preds[i].keys():
+ # Divide by the predicted metric scaling factor to get the raw predicted points
+ # This detaches the predicted metric scaling factor from the geometry based loss
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ else:
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"]
+ no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for i in range(n_views):
+ dis = no_norm_gt_pts[i].norm(dim=-1)
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
+
+ # Initialize normalized tensors
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+
+ # Normalize the predicted points if specified
+ if self.norm_predictions:
+ pr_normalization_output = normalize_multiple_pointclouds(
+ no_norm_pr_pts,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_pts_norm = pr_normalization_output[:-1]
+
+ # Normalize the ground truth points
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ gt_pts_norm = gt_normalization_output[:-1]
+
+ for i in range(n_views):
+ if self.norm_predictions:
+ # Assign the normalized predictions
+ pr_pts[i] = pr_pts_norm[i]
+ else:
+ # Assign the raw predicted points
+ pr_pts[i] = no_norm_pr_pts[i]
+ # Assign the normalized ground truth
+ gt_pts[i] = gt_pts_norm[i]
+
+ return gt_pts, pr_pts, valid_masks
+
+ def compute_loss(self, batch, preds, **kw):
+ gt_pts, pred_pts, valid_masks = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+ assert n_views == 1, (
+ "Normal & Gradient Matching Loss Class only supports 1 view"
+ )
+
+ normal_losses = []
+ gradient_matching_losses = []
+ details = {}
+ running_avg_dict = {}
+ self_name = type(self).__name__
+
+ for i in range(n_views):
+ # Get the local frame points, log space depth_z & valid masks
+ pred_local_pts3d = pred_pts[i]
+ pred_depth_z = pred_local_pts3d[..., 2:]
+ pred_depth_z = apply_log_to_norm(pred_depth_z)
+ gt_local_pts3d = gt_pts[i]
+ gt_depth_z = gt_local_pts3d[..., 2:]
+ gt_depth_z = apply_log_to_norm(gt_depth_z)
+ valid_mask_for_normal_gm_loss = valid_masks[i].clone()
+
+ # Update the validity mask for normal & gm loss based on the synthetic data mask if required
+ if self.apply_normal_and_gm_loss_to_synthetic_data_only:
+ synthetic_mask = batch[i]["is_synthetic"] # (B, )
+ synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
+ synthetic_mask = synthetic_mask.expand(
+ -1, pred_depth_z.shape[1], pred_depth_z.shape[2]
+ ) # (B, H, W)
+ valid_mask_for_normal_gm_loss = (
+ valid_mask_for_normal_gm_loss & synthetic_mask
+ )
+
+ # Compute the normal loss
+ normal_loss = compute_normal_loss(
+ pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
+ )
+ normal_losses.append(normal_loss)
+
+ # Compute the gradient matching loss
+ gradient_matching_loss = compute_gradient_matching_loss(
+ pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
+ )
+ gradient_matching_losses.append(gradient_matching_loss)
+
+ # Add loss details if only valid values are present
+ # Initialize or update running average directly
+ # Normal loss details
+ if float(normal_loss) > 0:
+ details[f"{self_name}_normal_view{i + 1}"] = float(normal_loss)
+ normal_avg_key = f"{self_name}_normal_avg"
+ if normal_avg_key not in details:
+ details[normal_avg_key] = float(normal_losses[i])
+ running_avg_dict[f"{self_name}_normal_valid_views"] = 1
+ else:
+ normal_valid_views = (
+ running_avg_dict[f"{self_name}_normal_valid_views"] + 1
+ )
+ running_avg_dict[f"{self_name}_normal_valid_views"] = (
+ normal_valid_views
+ )
+ details[normal_avg_key] += (
+ float(normal_losses[i]) - details[normal_avg_key]
+ ) / normal_valid_views
+
+ # Gradient Matching loss details
+ if float(gradient_matching_loss) > 0:
+ details[f"{self_name}_gradient_matching_view{i + 1}"] = float(
+ gradient_matching_loss
+ )
+ # For gradient matching loss
+ gm_avg_key = f"{self_name}_gradient_matching_avg"
+ if gm_avg_key not in details:
+ details[gm_avg_key] = float(gradient_matching_losses[i])
+ running_avg_dict[f"{self_name}_gm_valid_views"] = 1
+ else:
+ gm_valid_views = running_avg_dict[f"{self_name}_gm_valid_views"] + 1
+ running_avg_dict[f"{self_name}_gm_valid_views"] = gm_valid_views
+ details[gm_avg_key] += (
+ float(gradient_matching_losses[i]) - details[gm_avg_key]
+ ) / gm_valid_views
+
+ # Put the losses together
+ loss_terms = []
+ for i in range(n_views):
+ loss_terms.append((normal_losses[i], None, "normal"))
+ loss_terms.append((gradient_matching_losses[i], None, "gradient_matching"))
+ losses = Sum(*loss_terms)
+
+ return losses, details
+
+
+class FactoredGeometryRegr3D(Criterion, MultiLoss):
+ """
+ Regression Loss for Factored Geometry.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_mode="?avg_dis",
+ gt_scale=False,
+ ambiguous_loss_value=0,
+ max_metric_scale=False,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ cam_frame_points_loss_weight=1,
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ compute_pairwise_relative_pose_loss=False,
+ convert_predictions_to_view0_frame=False,
+ compute_world_frame_points_loss=True,
+ world_frame_points_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose),
+ and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps.
+ If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order:
+ (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans.
+ Else, the pixel-level losses are returned in the following order:
+ (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_mode (str): Normalization mode for scene representation. Default: "?avg_dis".
+ If prefixed with "?", normalization is only applied to non-metric scale data.
+ gt_scale (bool): If True, enforce predictions to have the same scale as ground truth.
+ If False, both GT and predictions are normalized independently. Default: False.
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this
+ value, it will be treated as non-metric. Default: False (no limit).
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth and pointmaps. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
+ exhaustive pairwise relative poses. Default: False.
+ convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
+ Use this if the predictions are not already in the view0 frame. Default: False.
+ compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ """
+ super().__init__(criterion)
+ if norm_mode.startswith("?"):
+ # Do no norm pts from metric scale datasets
+ self.norm_all = False
+ self.norm_mode = norm_mode[1:]
+ else:
+ self.norm_all = True
+ self.norm_mode = norm_mode
+ self.gt_scale = gt_scale
+ self.ambiguous_loss_value = ambiguous_loss_value
+ self.max_metric_scale = max_metric_scale
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+ self.depth_type_for_loss = depth_type_for_loss
+ assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], (
+ "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']"
+ )
+ self.cam_frame_points_loss_weight = cam_frame_points_loss_weight
+ self.depth_loss_weight = depth_loss_weight
+ self.ray_directions_loss_weight = ray_directions_loss_weight
+ self.pose_quats_loss_weight = pose_quats_loss_weight
+ self.pose_trans_loss_weight = pose_trans_loss_weight
+ self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
+ self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame
+ self.compute_world_frame_points_loss = compute_world_frame_points_loss
+ self.world_frame_points_loss_weight = world_frame_points_loss_weight
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ """
+ Function to get all the information needed to compute the loss.
+ Returns all quantities normalized w.r.t. camera of view0.
+ """
+ n_views = len(batch)
+
+ # Everything is normalized w.r.t. camera of view0
+ # Intialize lists to store data for all views
+ # Ground truth quantities
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+ no_norm_gt_pts = []
+ no_norm_gt_pts_cam = []
+ no_norm_gt_depth = []
+ no_norm_gt_pose_trans = []
+ valid_masks = []
+ gt_ray_directions = []
+ gt_pose_quats = []
+ # Predicted quantities
+ if self.convert_predictions_to_view0_frame:
+ # Get the camera transform to convert quantities to view0 frame
+ pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze(
+ 0
+ )
+ batch_size = preds[0]["cam_quats"].shape[0]
+ pred_camera0 = pred_camera0.repeat(batch_size, 1, 1)
+ pred_camera0_rot = quaternion_to_rotation_matrix(
+ preds[0]["cam_quats"].clone()
+ )
+ pred_camera0[..., :3, :3] = pred_camera0_rot
+ pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone()
+ pred_in_camera0 = closed_form_pose_inverse(pred_camera0)
+ no_norm_pr_pts = []
+ no_norm_pr_pts_cam = []
+ no_norm_pr_depth = []
+ no_norm_pr_pose_trans = []
+ pr_ray_directions = []
+ pr_pose_quats = []
+
+ # Get ground truth & prediction info for all views
+ for i in range(n_views):
+ # Get ground truth
+ no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
+ valid_masks.append(batch[i]["valid_mask"].clone())
+ no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"])
+ gt_ray_directions.append(batch[i]["ray_directions_cam"])
+ if self.depth_type_for_loss == "depth_along_ray":
+ no_norm_gt_depth.append(batch[i]["depth_along_ray"])
+ elif self.depth_type_for_loss == "depth_z":
+ no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:])
+ if i == 0:
+ # For view0, initialize identity pose
+ gt_pose_quats.append(
+ torch.tensor(
+ [0, 0, 0, 1],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ no_norm_gt_pose_trans.append(
+ torch.tensor(
+ [0, 0, 0],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ else:
+ # For other views, transform pose to view0's frame
+ gt_pose_quats_world = batch[i]["camera_pose_quats"]
+ no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"]
+ gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = (
+ transform_pose_using_quats_and_trans_2_to_1(
+ batch[0]["camera_pose_quats"],
+ batch[0]["camera_pose_trans"],
+ gt_pose_quats_world,
+ no_norm_gt_pose_trans_world,
+ )
+ )
+ gt_pose_quats.append(gt_pose_quats_in_view0)
+ no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
+
+ # Get the local predictions
+ no_norm_pr_pts_cam.append(preds[i]["pts3d_cam"])
+ pr_ray_directions.append(preds[i]["ray_directions"])
+ if self.depth_type_for_loss == "depth_along_ray":
+ no_norm_pr_depth.append(preds[i]["depth_along_ray"])
+ elif self.depth_type_for_loss == "depth_z":
+ no_norm_pr_depth.append(preds[i]["pts3d_cam"][..., 2:])
+
+ # Get the predicted global predictions in view0's frame
+ if self.convert_predictions_to_view0_frame:
+ # Convert predictions to view0 frame
+ pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"])
+ pr_pose_quats_in_view0, pr_pose_trans_in_view0 = (
+ transform_pose_using_quats_and_trans_2_to_1(
+ preds[0]["cam_quats"],
+ preds[0]["cam_trans"],
+ preds[i]["cam_quats"],
+ preds[i]["cam_trans"],
+ )
+ )
+ no_norm_pr_pts.append(pr_pts3d_in_view0)
+ no_norm_pr_pose_trans.append(pr_pose_trans_in_view0)
+ pr_pose_quats.append(pr_pose_quats_in_view0)
+ else:
+ # Predictions are already in view0 frame
+ no_norm_pr_pts.append(preds[i]["pts3d"])
+ no_norm_pr_pose_trans.append(preds[i]["cam_trans"])
+ pr_pose_quats.append(preds[i]["cam_quats"])
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for i in range(n_views):
+ dis = no_norm_gt_pts[i].norm(dim=-1)
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
+
+ # Handle metric scale
+ if not self.norm_all:
+ if self.max_metric_scale:
+ B = valid_masks[0].shape[0]
+ dists_to_cam1 = []
+ for i in range(n_views):
+ dists_to_cam1.append(
+ torch.where(
+ valid_masks[i], torch.norm(no_norm_gt_pts[i], dim=-1), 0
+ ).view(B, -1)
+ )
+
+ batch[0]["is_metric_scale"] = batch[0]["is_metric_scale"]
+ for dist in dists_to_cam1:
+ batch[0]["is_metric_scale"] &= (
+ dist.max(dim=-1).values < self.max_metric_scale
+ )
+
+ for i in range(1, n_views):
+ batch[i]["is_metric_scale"] = batch[0]["is_metric_scale"]
+
+ non_metric_scale_mask = ~batch[0]["is_metric_scale"]
+ else:
+ non_metric_scale_mask = torch.ones_like(batch[0]["is_metric_scale"])
+
+ # Initialize normalized tensors
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam]
+ gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth]
+ gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans]
+
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+ pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam]
+ pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth]
+ pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans]
+
+ # Normalize points
+ if self.norm_mode and non_metric_scale_mask.any():
+ pr_normalization_output = normalize_multiple_pointclouds(
+ [pts[non_metric_scale_mask] for pts in no_norm_pr_pts],
+ [mask[non_metric_scale_mask] for mask in valid_masks],
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_pts_norm = pr_normalization_output[:-1]
+ pr_norm_factor = pr_normalization_output[-1]
+
+ for i in range(n_views):
+ pr_pts[i][non_metric_scale_mask] = pr_pts_norm[i]
+ pr_pts_cam[i][non_metric_scale_mask] = (
+ no_norm_pr_pts_cam[i][non_metric_scale_mask] / pr_norm_factor
+ )
+ pr_depth[i][non_metric_scale_mask] = (
+ no_norm_pr_depth[i][non_metric_scale_mask] / pr_norm_factor
+ )
+ pr_pose_trans[i][non_metric_scale_mask] = (
+ no_norm_pr_pose_trans[i][non_metric_scale_mask]
+ / pr_norm_factor[:, :, 0, 0]
+ )
+
+ elif non_metric_scale_mask.any():
+ for i in range(n_views):
+ pr_pts[i][non_metric_scale_mask] = no_norm_pr_pts[i][
+ non_metric_scale_mask
+ ]
+ pr_pts_cam[i][non_metric_scale_mask] = no_norm_pr_pts_cam[i][
+ non_metric_scale_mask
+ ]
+ pr_depth[i][non_metric_scale_mask] = no_norm_pr_depth[i][
+ non_metric_scale_mask
+ ]
+ pr_pose_trans[i][non_metric_scale_mask] = no_norm_pr_pose_trans[i][
+ non_metric_scale_mask
+ ]
+
+ if self.norm_mode and not self.gt_scale:
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ gt_pts_norm = gt_normalization_output[:-1]
+ norm_factor = gt_normalization_output[-1]
+
+ for i in range(n_views):
+ gt_pts[i] = gt_pts_norm[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i] / norm_factor
+ gt_depth[i] = no_norm_gt_depth[i] / norm_factor
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i] / norm_factor[:, :, 0, 0]
+
+ pr_pts[i][~non_metric_scale_mask] = (
+ no_norm_pr_pts[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask]
+ )
+ pr_pts_cam[i][~non_metric_scale_mask] = (
+ no_norm_pr_pts_cam[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask]
+ )
+ pr_depth[i][~non_metric_scale_mask] = (
+ no_norm_pr_depth[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask]
+ )
+ pr_pose_trans[i][~non_metric_scale_mask] = (
+ no_norm_pr_pose_trans[i][~non_metric_scale_mask]
+ / norm_factor[~non_metric_scale_mask][:, :, 0, 0]
+ )
+
+ elif ~non_metric_scale_mask.any():
+ for i in range(n_views):
+ gt_pts[i] = no_norm_gt_pts[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i]
+ gt_depth[i] = no_norm_gt_depth[i]
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i]
+ pr_pts[i][~non_metric_scale_mask] = no_norm_pr_pts[i][
+ ~non_metric_scale_mask
+ ]
+ pr_pts_cam[i][~non_metric_scale_mask] = no_norm_pr_pts_cam[i][
+ ~non_metric_scale_mask
+ ]
+ pr_depth[i][~non_metric_scale_mask] = no_norm_pr_depth[i][
+ ~non_metric_scale_mask
+ ]
+ pr_pose_trans[i][~non_metric_scale_mask] = no_norm_pr_pose_trans[i][
+ ~non_metric_scale_mask
+ ]
+ else:
+ for i in range(n_views):
+ gt_pts[i] = no_norm_gt_pts[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i]
+ gt_depth[i] = no_norm_gt_depth[i]
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i]
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for i in range(n_views):
+ ambiguous_masks.append(
+ (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
+ )
+
+ # Pack into info dicts
+ gt_info = []
+ pred_info = []
+ for i in range(n_views):
+ gt_info.append(
+ {
+ "ray_directions": gt_ray_directions[i],
+ self.depth_type_for_loss: gt_depth[i],
+ "pose_trans": gt_pose_trans[i],
+ "pose_quats": gt_pose_quats[i],
+ "pts3d": gt_pts[i],
+ "pts3d_cam": gt_pts_cam[i],
+ }
+ )
+ pred_info.append(
+ {
+ "ray_directions": pr_ray_directions[i],
+ self.depth_type_for_loss: pr_depth[i],
+ "pose_trans": pr_pose_trans[i],
+ "pose_quats": pr_pose_quats[i],
+ "pts3d": pr_pts[i],
+ "pts3d_cam": pr_pts_cam[i],
+ }
+ )
+
+ return gt_info, pred_info, valid_masks, ambiguous_masks
+
+ def compute_loss(self, batch, preds, **kw):
+ gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info(
+ batch, preds, **kw
+ )
+ n_views = len(batch)
+
+ # Mask out samples in the batch where the gt depth validity mask is entirely zero
+ valid_norm_factor_masks = [
+ mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
+ ] # List of (B,)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+ cam_pts3d_losses = []
+ if self.compute_world_frame_points_loss:
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
+ gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
+ if self.compute_world_frame_points_loss:
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
+ gt_ray_directions = gt_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ pred_ray_directions = pred_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
+ gt_depth = gt_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ pred_depth = pred_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
+ batch_size, -1, cam_pts_dim
+ )
+ if self.compute_world_frame_points_loss:
+ pts_dim = gt_info[i]["pts3d"].shape[-1]
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space for depth if specified
+ if self.loss_in_log:
+ gt_depth = apply_log_to_norm(gt_depth)
+ pred_depth = apply_log_to_norm(pred_depth)
+ gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
+ pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
+ if self.compute_world_frame_points_loss:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ if self.compute_pairwise_relative_pose_loss:
+ # Get the inverse of current view predicted pose
+ pred_inv_curr_view_pose_quats = quaternion_inverse(
+ pred_info[i]["pose_quats"]
+ )
+ pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ pred_inv_curr_view_pose_quats
+ )
+ pred_inv_curr_view_pose_trans = -1 * ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the inverse of the current view GT pose
+ gt_inv_curr_view_pose_quats = quaternion_inverse(
+ gt_info[i]["pose_quats"]
+ )
+ gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ gt_inv_curr_view_pose_quats
+ )
+ gt_inv_curr_view_pose_trans = -1 * ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the other N-1 relative poses using the current pose as reference frame
+ pred_rel_pose_quats = []
+ pred_rel_pose_trans = []
+ gt_rel_pose_quats = []
+ gt_rel_pose_trans = []
+ for ov_idx in range(n_views):
+ if ov_idx == i:
+ continue
+ # Get the relative predicted pose
+ pred_ov_rel_pose_quats = quaternion_multiply(
+ pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
+ )
+ pred_ov_rel_pose_trans = (
+ ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + pred_inv_curr_view_pose_trans
+ )
+
+ # Get the relative GT pose
+ gt_ov_rel_pose_quats = quaternion_multiply(
+ gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
+ )
+ gt_ov_rel_pose_trans = (
+ ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + gt_inv_curr_view_pose_trans
+ )
+
+ # Get the valid translations using valid_norm_factor_masks for current view and other view
+ overall_valid_mask_for_trans = (
+ valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
+ )
+
+ # Append the relative poses
+ pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
+ pred_rel_pose_trans.append(
+ pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+ gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
+ gt_rel_pose_trans.append(
+ gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+
+ # Cat the N-1 relative poses along the batch dimension
+ pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
+ pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
+ gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
+ gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(
+ pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
+ ),
+ self.criterion(
+ pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+ else:
+ # Get the pose info for the current view
+ pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans, gt_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
+ self.criterion(
+ pred_pose_quats, -gt_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions, gt_ray_directions, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute camera frame point loss
+ cam_pts3d_loss = self.criterion(
+ pred_cam_pts3d, gt_cam_pts3d, factor="points"
+ )
+ cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
+ cam_pts3d_losses.append(cam_pts3d_loss)
+
+ if self.compute_world_frame_points_loss:
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Use helper function to generate loss terms and details
+ if self.compute_world_frame_points_loss:
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ }
+ else:
+ losses_dict = {}
+ losses_dict.update(
+ {
+ "cam_pts3d": {
+ "values": cam_pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class FactoredGeometryRegr3DPlusNormalGMLoss(FactoredGeometryRegr3D):
+ """
+ Regression, Normals & Gradient Matching Loss for Factored Geometry.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_mode="?avg_dis",
+ gt_scale=False,
+ ambiguous_loss_value=0,
+ max_metric_scale=False,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ cam_frame_points_loss_weight=1,
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ compute_pairwise_relative_pose_loss=False,
+ convert_predictions_to_view0_frame=False,
+ compute_world_frame_points_loss=True,
+ world_frame_points_loss_weight=1,
+ apply_normal_and_gm_loss_to_synthetic_data_only=True,
+ normal_loss_weight=1,
+ gm_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for Factored Geometry (see parent class for details).
+ Additionally computes:
+ (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates,
+ (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_mode (str): Normalization mode for scene representation. Default: "avg_dis".
+ If prefixed with "?", normalization is only applied to non-metric scale data.
+ gt_scale (bool): If True, enforce predictions to have the same scale as ground truth.
+ If False, both GT and predictions are normalized independently. Default: False.
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ max_metric_scale (float): Maximum scale for metric scale data. If data exceeds this
+ value, it will be treated as non-metric. Default: False (no limit).
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth and pointmaps. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
+ exhaustive pairwise relative poses. Default: False.
+ convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
+ Use this if the predictions are not already in the view0 frame. Default: False.
+ compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
+ If False, apply the normal and gm loss to all data. Default: True.
+ normal_loss_weight (float): Weight to use for the normal loss. Default: 1.
+ gm_loss_weight (float): Weight to use for the gm loss. Default: 1.
+ """
+ super().__init__(
+ criterion=criterion,
+ norm_mode=norm_mode,
+ gt_scale=gt_scale,
+ ambiguous_loss_value=ambiguous_loss_value,
+ max_metric_scale=max_metric_scale,
+ loss_in_log=loss_in_log,
+ flatten_across_image_only=flatten_across_image_only,
+ depth_type_for_loss=depth_type_for_loss,
+ cam_frame_points_loss_weight=cam_frame_points_loss_weight,
+ depth_loss_weight=depth_loss_weight,
+ ray_directions_loss_weight=ray_directions_loss_weight,
+ pose_quats_loss_weight=pose_quats_loss_weight,
+ pose_trans_loss_weight=pose_trans_loss_weight,
+ compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
+ convert_predictions_to_view0_frame=convert_predictions_to_view0_frame,
+ compute_world_frame_points_loss=compute_world_frame_points_loss,
+ world_frame_points_loss_weight=world_frame_points_loss_weight,
+ )
+ self.apply_normal_and_gm_loss_to_synthetic_data_only = (
+ apply_normal_and_gm_loss_to_synthetic_data_only
+ )
+ self.normal_loss_weight = normal_loss_weight
+ self.gm_loss_weight = gm_loss_weight
+
+ def compute_loss(self, batch, preds, **kw):
+ gt_info, pred_info, valid_masks, ambiguous_masks = self.get_all_info(
+ batch, preds, **kw
+ )
+ n_views = len(batch)
+
+ # Mask out samples in the batch where the gt depth validity mask is entirely zero
+ valid_norm_factor_masks = [
+ mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
+ ] # List of (B,)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ normal_losses = []
+ gradient_matching_losses = []
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+ cam_pts3d_losses = []
+ if self.compute_world_frame_points_loss:
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the camera frame points, log space depth_z & valid masks
+ pred_local_pts3d = pred_info[i]["pts3d_cam"]
+ pred_depth_z = pred_local_pts3d[..., 2:]
+ pred_depth_z = apply_log_to_norm(pred_depth_z)
+ gt_local_pts3d = gt_info[i]["pts3d_cam"]
+ gt_depth_z = gt_local_pts3d[..., 2:]
+ gt_depth_z = apply_log_to_norm(gt_depth_z)
+ valid_mask_for_normal_gm_loss = valid_masks[i].clone()
+
+ # Update the validity mask for normal & gm loss based on the synthetic data mask if required
+ if self.apply_normal_and_gm_loss_to_synthetic_data_only:
+ synthetic_mask = batch[i]["is_synthetic"] # (B, )
+ synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
+ synthetic_mask = synthetic_mask.expand(
+ -1, pred_depth_z.shape[1], pred_depth_z.shape[2]
+ ) # (B, H, W)
+ valid_mask_for_normal_gm_loss = (
+ valid_mask_for_normal_gm_loss & synthetic_mask
+ )
+
+ # Compute the normal loss
+ normal_loss = compute_normal_loss(
+ pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
+ )
+ normal_loss = normal_loss * self.normal_loss_weight
+ normal_losses.append(normal_loss)
+
+ # Compute the gradient matching loss
+ gradient_matching_loss = compute_gradient_matching_loss(
+ pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
+ )
+ gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight
+ gradient_matching_losses.append(gradient_matching_loss)
+
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
+ gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
+ if self.compute_world_frame_points_loss:
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
+ gt_ray_directions = gt_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ pred_ray_directions = pred_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
+ gt_depth = gt_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ pred_depth = pred_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
+ batch_size, -1, cam_pts_dim
+ )
+ if self.compute_world_frame_points_loss:
+ pts_dim = gt_info[i]["pts3d"].shape[-1]
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space for depth if specified
+ if self.loss_in_log:
+ gt_depth = apply_log_to_norm(gt_depth)
+ pred_depth = apply_log_to_norm(pred_depth)
+ gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
+ pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
+ if self.compute_world_frame_points_loss:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ if self.compute_pairwise_relative_pose_loss:
+ # Get the inverse of current view predicted pose
+ pred_inv_curr_view_pose_quats = quaternion_inverse(
+ pred_info[i]["pose_quats"]
+ )
+ pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ pred_inv_curr_view_pose_quats
+ )
+ pred_inv_curr_view_pose_trans = -1 * ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the inverse of the current view GT pose
+ gt_inv_curr_view_pose_quats = quaternion_inverse(
+ gt_info[i]["pose_quats"]
+ )
+ gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ gt_inv_curr_view_pose_quats
+ )
+ gt_inv_curr_view_pose_trans = -1 * ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the other N-1 relative poses using the current pose as reference frame
+ pred_rel_pose_quats = []
+ pred_rel_pose_trans = []
+ gt_rel_pose_quats = []
+ gt_rel_pose_trans = []
+ for ov_idx in range(n_views):
+ if ov_idx == i:
+ continue
+ # Get the relative predicted pose
+ pred_ov_rel_pose_quats = quaternion_multiply(
+ pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
+ )
+ pred_ov_rel_pose_trans = (
+ ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + pred_inv_curr_view_pose_trans
+ )
+
+ # Get the relative GT pose
+ gt_ov_rel_pose_quats = quaternion_multiply(
+ gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
+ )
+ gt_ov_rel_pose_trans = (
+ ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + gt_inv_curr_view_pose_trans
+ )
+
+ # Get the valid translations using valid_norm_factor_masks for current view and other view
+ overall_valid_mask_for_trans = (
+ valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
+ )
+
+ # Append the relative poses
+ pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
+ pred_rel_pose_trans.append(
+ pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+ gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
+ gt_rel_pose_trans.append(
+ gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+
+ # Cat the N-1 relative poses along the batch dimension
+ pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
+ pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
+ gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
+ gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(
+ pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
+ ),
+ self.criterion(
+ pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+ else:
+ # Get the pose info for the current view
+ pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans, gt_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
+ self.criterion(
+ pred_pose_quats, -gt_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions, gt_ray_directions, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute camera frame point loss
+ cam_pts3d_loss = self.criterion(
+ pred_cam_pts3d, gt_cam_pts3d, factor="points"
+ )
+ cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
+ cam_pts3d_losses.append(cam_pts3d_loss)
+
+ if self.compute_world_frame_points_loss:
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Use helper function to generate loss terms and details
+ if self.compute_world_frame_points_loss:
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ }
+ else:
+ losses_dict = {}
+ losses_dict.update(
+ {
+ "cam_pts3d": {
+ "values": cam_pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "normal": {
+ "values": normal_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "gradient_matching": {
+ "values": gradient_matching_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class FactoredGeometryScaleRegr3D(Criterion, MultiLoss):
+ """
+ Regression Loss for Factored Geometry & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ ambiguous_loss_value=0,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ cam_frame_points_loss_weight=1,
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ scale_loss_weight=1,
+ compute_pairwise_relative_pose_loss=False,
+ convert_predictions_to_view0_frame=False,
+ compute_world_frame_points_loss=True,
+ world_frame_points_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for Factored Geometry (Ray Directions, Depth, Pose), Scale
+ and the Collective Geometry i.e. Local Frame Pointmaps & optionally World Frame Pointmaps.
+ If world-frame pointmap loss is computed, the pixel-level losses are computed in the following order:
+ (1) world points, (2) cam points, (3) depth, (4) ray directions, (5) pose quats, (6) pose trans, (7) scale.
+ Else, the pixel-level losses are returned in the following order:
+ (1) cam points, (2) depth, (3) ray directions, (4) pose quats, (5) pose trans, (6) scale.
+ The predicited scene representation is always normalized w.r.t. the frame of view0.
+ Loss is applied between the predicted metric scale and the ground truth metric scale.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
+ exhaustive pairwise relative poses. Default: False.
+ convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
+ Use this if the predictions are not already in the view0 frame. Default: False.
+ compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ """
+ super().__init__(criterion)
+ self.norm_predictions = norm_predictions
+ self.norm_mode = norm_mode
+ self.ambiguous_loss_value = ambiguous_loss_value
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+ self.depth_type_for_loss = depth_type_for_loss
+ assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], (
+ "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']"
+ )
+ self.cam_frame_points_loss_weight = cam_frame_points_loss_weight
+ self.depth_loss_weight = depth_loss_weight
+ self.ray_directions_loss_weight = ray_directions_loss_weight
+ self.pose_quats_loss_weight = pose_quats_loss_weight
+ self.pose_trans_loss_weight = pose_trans_loss_weight
+ self.scale_loss_weight = scale_loss_weight
+ self.compute_pairwise_relative_pose_loss = compute_pairwise_relative_pose_loss
+ self.convert_predictions_to_view0_frame = convert_predictions_to_view0_frame
+ self.compute_world_frame_points_loss = compute_world_frame_points_loss
+ self.world_frame_points_loss_weight = world_frame_points_loss_weight
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ """
+ Function to get all the information needed to compute the loss.
+ Returns all quantities normalized w.r.t. camera of view0.
+ """
+ n_views = len(batch)
+
+ # Everything is normalized w.r.t. camera of view0
+ # Intialize lists to store data for all views
+ # Ground truth quantities
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+ no_norm_gt_pts = []
+ no_norm_gt_pts_cam = []
+ no_norm_gt_depth = []
+ no_norm_gt_pose_trans = []
+ valid_masks = []
+ gt_ray_directions = []
+ gt_pose_quats = []
+ # Predicted quantities
+ if self.convert_predictions_to_view0_frame:
+ # Get the camera transform to convert quantities to view0 frame
+ pred_camera0 = torch.eye(4, device=preds[0]["cam_quats"].device).unsqueeze(
+ 0
+ )
+ batch_size = preds[0]["cam_quats"].shape[0]
+ pred_camera0 = pred_camera0.repeat(batch_size, 1, 1)
+ pred_camera0_rot = quaternion_to_rotation_matrix(
+ preds[0]["cam_quats"].clone()
+ )
+ pred_camera0[..., :3, :3] = pred_camera0_rot
+ pred_camera0[..., :3, 3] = preds[0]["cam_trans"].clone()
+ pred_in_camera0 = closed_form_pose_inverse(pred_camera0)
+ no_norm_pr_pts = []
+ no_norm_pr_pts_cam = []
+ no_norm_pr_depth = []
+ no_norm_pr_pose_trans = []
+ pr_ray_directions = []
+ pr_pose_quats = []
+ metric_pr_pts_to_compute_scale = []
+
+ # Get ground truth & prediction info for all views
+ for i in range(n_views):
+ # Get the ground truth
+ no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
+ valid_masks.append(batch[i]["valid_mask"].clone())
+ no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"])
+ gt_ray_directions.append(batch[i]["ray_directions_cam"])
+ if self.depth_type_for_loss == "depth_along_ray":
+ no_norm_gt_depth.append(batch[i]["depth_along_ray"])
+ elif self.depth_type_for_loss == "depth_z":
+ no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:])
+ if i == 0:
+ # For view0, initialize identity pose
+ gt_pose_quats.append(
+ torch.tensor(
+ [0, 0, 0, 1],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ no_norm_gt_pose_trans.append(
+ torch.tensor(
+ [0, 0, 0],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ else:
+ # For other views, transform pose to view0's frame
+ gt_pose_quats_world = batch[i]["camera_pose_quats"]
+ no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"]
+ gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = (
+ transform_pose_using_quats_and_trans_2_to_1(
+ batch[0]["camera_pose_quats"],
+ batch[0]["camera_pose_trans"],
+ gt_pose_quats_world,
+ no_norm_gt_pose_trans_world,
+ )
+ )
+ gt_pose_quats.append(gt_pose_quats_in_view0)
+ no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
+
+ # Get the global predictions in view0's frame
+ if self.convert_predictions_to_view0_frame:
+ # Convert predictions to view0 frame
+ pr_pts3d_in_view0 = geotrf(pred_in_camera0, preds[i]["pts3d"])
+ pr_pose_quats_in_view0, pr_pose_trans_in_view0 = (
+ transform_pose_using_quats_and_trans_2_to_1(
+ preds[0]["cam_quats"],
+ preds[0]["cam_trans"],
+ preds[i]["cam_quats"],
+ preds[i]["cam_trans"],
+ )
+ )
+ else:
+ # Predictions are already in view0 frame
+ pr_pts3d_in_view0 = preds[i]["pts3d"]
+ pr_pose_trans_in_view0 = preds[i]["cam_trans"]
+ pr_pose_quats_in_view0 = preds[i]["cam_quats"]
+
+ # Get predictions for normalized loss
+ if self.depth_type_for_loss == "depth_along_ray":
+ curr_view_no_norm_depth = preds[i]["depth_along_ray"]
+ elif self.depth_type_for_loss == "depth_z":
+ curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:]
+ if "metric_scaling_factor" in preds[i].keys():
+ # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
+ # This detaches the predicted metric scaling factor from the geometry based loss
+ curr_view_no_norm_pr_pts = pr_pts3d_in_view0 / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_pr_pose_trans = (
+ pr_pose_trans_in_view0 / preds[i]["metric_scaling_factor"]
+ )
+ else:
+ curr_view_no_norm_pr_pts = pr_pts3d_in_view0
+ curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
+ curr_view_no_norm_depth = curr_view_no_norm_depth
+ curr_view_no_norm_pr_pose_trans = pr_pose_trans_in_view0
+ no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
+ no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
+ no_norm_pr_depth.append(curr_view_no_norm_depth)
+ no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
+ pr_ray_directions.append(preds[i]["ray_directions"])
+ pr_pose_quats.append(pr_pose_quats_in_view0)
+
+ # Get the predicted metric scale points
+ if "metric_scaling_factor" in preds[i].keys():
+ # Detach the raw predicted points so that the scale loss is only applied to the scaling factor
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.detach()
+ * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1)
+ )
+ else:
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.clone()
+ )
+ metric_pr_pts_to_compute_scale.append(
+ curr_view_metric_pr_pts_to_compute_scale
+ )
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for i in range(n_views):
+ dis = no_norm_gt_pts[i].norm(dim=-1)
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
+
+ # Initialize normalized tensors
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam]
+ gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth]
+ gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans]
+
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+ pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam]
+ pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth]
+ pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans]
+
+ # Normalize the predicted points if specified
+ if self.norm_predictions:
+ pr_normalization_output = normalize_multiple_pointclouds(
+ no_norm_pr_pts,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_pts_norm = pr_normalization_output[:-1]
+ pr_norm_factor = pr_normalization_output[-1]
+
+ # Normalize the ground truth points
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ gt_pts_norm = gt_normalization_output[:-1]
+ gt_norm_factor = gt_normalization_output[-1]
+
+ for i in range(n_views):
+ if self.norm_predictions:
+ # Assign the normalized predictions
+ pr_pts[i] = pr_pts_norm[i]
+ pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor
+ pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor
+ pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0]
+ else:
+ pr_pts[i] = no_norm_pr_pts[i]
+ pr_pts_cam[i] = no_norm_pr_pts_cam[i]
+ pr_depth[i] = no_norm_pr_depth[i]
+ pr_pose_trans[i] = no_norm_pr_pose_trans[i]
+ # Assign the normalized ground truth quantities
+ gt_pts[i] = gt_pts_norm[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor
+ gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0]
+
+ # Get the mask indicating ground truth metric scale quantities
+ metric_scale_mask = batch[0]["is_metric_scale"]
+ valid_gt_norm_factor_mask = (
+ gt_norm_factor[:, 0, 0, 0] > 1e-8
+ ) # Mask out cases where depth for all views is invalid
+ valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask
+
+ if valid_metric_scale_mask.any():
+ # Compute the scale norm factor using the predicted metric scale points
+ metric_pr_normalization_output = normalize_multiple_pointclouds(
+ metric_pr_pts_to_compute_scale,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_metric_norm_factor = metric_pr_normalization_output[-1]
+
+ # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities
+ gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask]
+ pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask]
+ else:
+ gt_metric_norm_factor = None
+ pr_metric_norm_factor = None
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for i in range(n_views):
+ ambiguous_masks.append(
+ (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
+ )
+
+ # Pack into info dicts
+ gt_info = []
+ pred_info = []
+ for i in range(n_views):
+ gt_info.append(
+ {
+ "ray_directions": gt_ray_directions[i],
+ self.depth_type_for_loss: gt_depth[i],
+ "pose_trans": gt_pose_trans[i],
+ "pose_quats": gt_pose_quats[i],
+ "pts3d": gt_pts[i],
+ "pts3d_cam": gt_pts_cam[i],
+ }
+ )
+ pred_info.append(
+ {
+ "ray_directions": pr_ray_directions[i],
+ self.depth_type_for_loss: pr_depth[i],
+ "pose_trans": pr_pose_trans[i],
+ "pose_quats": pr_pose_quats[i],
+ "pts3d": pr_pts[i],
+ "pts3d_cam": pr_pts_cam[i],
+ }
+ )
+
+ return (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ )
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ # Mask out samples in the batch where the gt depth validity mask is entirely zero
+ valid_norm_factor_masks = [
+ mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
+ ] # List of (B,)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+ cam_pts3d_losses = []
+ if self.compute_world_frame_points_loss:
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
+ gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
+ if self.compute_world_frame_points_loss:
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
+ gt_ray_directions = gt_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ pred_ray_directions = pred_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
+ gt_depth = gt_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ pred_depth = pred_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
+ batch_size, -1, cam_pts_dim
+ )
+ if self.compute_world_frame_points_loss:
+ pts_dim = gt_info[i]["pts3d"].shape[-1]
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space for depth if specified
+ if self.loss_in_log:
+ gt_depth = apply_log_to_norm(gt_depth)
+ pred_depth = apply_log_to_norm(pred_depth)
+ gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
+ pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
+ if self.compute_world_frame_points_loss:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ if self.compute_pairwise_relative_pose_loss:
+ # Get the inverse of current view predicted pose
+ pred_inv_curr_view_pose_quats = quaternion_inverse(
+ pred_info[i]["pose_quats"]
+ )
+ pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ pred_inv_curr_view_pose_quats
+ )
+ pred_inv_curr_view_pose_trans = -1 * ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the inverse of the current view GT pose
+ gt_inv_curr_view_pose_quats = quaternion_inverse(
+ gt_info[i]["pose_quats"]
+ )
+ gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ gt_inv_curr_view_pose_quats
+ )
+ gt_inv_curr_view_pose_trans = -1 * ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the other N-1 relative poses using the current pose as reference frame
+ pred_rel_pose_quats = []
+ pred_rel_pose_trans = []
+ gt_rel_pose_quats = []
+ gt_rel_pose_trans = []
+ for ov_idx in range(n_views):
+ if ov_idx == i:
+ continue
+ # Get the relative predicted pose
+ pred_ov_rel_pose_quats = quaternion_multiply(
+ pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
+ )
+ pred_ov_rel_pose_trans = (
+ ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + pred_inv_curr_view_pose_trans
+ )
+
+ # Get the relative GT pose
+ gt_ov_rel_pose_quats = quaternion_multiply(
+ gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
+ )
+ gt_ov_rel_pose_trans = (
+ ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + gt_inv_curr_view_pose_trans
+ )
+
+ # Get the valid translations using valid_norm_factor_masks for current view and other view
+ overall_valid_mask_for_trans = (
+ valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
+ )
+
+ # Append the relative poses
+ pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
+ pred_rel_pose_trans.append(
+ pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+ gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
+ gt_rel_pose_trans.append(
+ gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+
+ # Cat the N-1 relative poses along the batch dimension
+ pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
+ pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
+ gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
+ gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(
+ pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
+ ),
+ self.criterion(
+ pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+ else:
+ # Get the pose info for the current view
+ pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans, gt_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
+ self.criterion(
+ pred_pose_quats, -gt_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions, gt_ray_directions, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute camera frame point loss
+ cam_pts3d_loss = self.criterion(
+ pred_cam_pts3d, gt_cam_pts3d, factor="points"
+ )
+ cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
+ cam_pts3d_losses.append(cam_pts3d_loss)
+
+ if self.compute_world_frame_points_loss:
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+ if self.compute_world_frame_points_loss:
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ }
+ else:
+ losses_dict = {}
+ losses_dict.update(
+ {
+ "cam_pts3d": {
+ "values": cam_pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class FactoredGeometryScaleRegr3DPlusNormalGMLoss(FactoredGeometryScaleRegr3D):
+ """
+ Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ ambiguous_loss_value=0,
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ cam_frame_points_loss_weight=1,
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ scale_loss_weight=1,
+ compute_pairwise_relative_pose_loss=False,
+ convert_predictions_to_view0_frame=False,
+ compute_world_frame_points_loss=True,
+ world_frame_points_loss_weight=1,
+ apply_normal_and_gm_loss_to_synthetic_data_only=True,
+ normal_loss_weight=1,
+ gm_loss_weight=1,
+ ):
+ """
+ Initialize the loss criterion for Ray Directions, Depth, Pose, Pointmaps & Scale.
+ Additionally computes:
+ (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates,
+ (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ ambiguous_loss_value (float): Value to use for ambiguous pixels in the loss.
+ If 0, ambiguous pixels are ignored. Default: 0.
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ cam_frame_points_loss_weight (float): Weight to use for the camera frame pointmap loss. Default: 1.
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ compute_pairwise_relative_pose_loss (bool): If True, the pose loss is computed on the
+ exhaustive pairwise relative poses. Default: False.
+ convert_predictions_to_view0_frame (bool): If True, convert predictions to view0 frame.
+ Use this if the predictions are not already in the view0 frame. Default: False.
+ compute_world_frame_points_loss (bool): If True, compute the world frame pointmap loss. Default: True.
+ world_frame_points_loss_weight (float): Weight to use for the world frame pointmap loss. Default: 1.
+ apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
+ If False, apply the normal and gm loss to all data. Default: True.
+ normal_loss_weight (float): Weight to use for the normal loss. Default: 1.
+ gm_loss_weight (float): Weight to use for the gm loss. Default: 1.
+ """
+ super().__init__(
+ criterion=criterion,
+ norm_predictions=norm_predictions,
+ norm_mode=norm_mode,
+ ambiguous_loss_value=ambiguous_loss_value,
+ loss_in_log=loss_in_log,
+ flatten_across_image_only=flatten_across_image_only,
+ depth_type_for_loss=depth_type_for_loss,
+ cam_frame_points_loss_weight=cam_frame_points_loss_weight,
+ depth_loss_weight=depth_loss_weight,
+ ray_directions_loss_weight=ray_directions_loss_weight,
+ pose_quats_loss_weight=pose_quats_loss_weight,
+ pose_trans_loss_weight=pose_trans_loss_weight,
+ scale_loss_weight=scale_loss_weight,
+ compute_pairwise_relative_pose_loss=compute_pairwise_relative_pose_loss,
+ convert_predictions_to_view0_frame=convert_predictions_to_view0_frame,
+ compute_world_frame_points_loss=compute_world_frame_points_loss,
+ world_frame_points_loss_weight=world_frame_points_loss_weight,
+ )
+ self.apply_normal_and_gm_loss_to_synthetic_data_only = (
+ apply_normal_and_gm_loss_to_synthetic_data_only
+ )
+ self.normal_loss_weight = normal_loss_weight
+ self.gm_loss_weight = gm_loss_weight
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ # Mask out samples in the batch where the gt depth validity mask is entirely zero
+ valid_norm_factor_masks = [
+ mask.sum(dim=(1, 2)) > 0 for mask in valid_masks
+ ] # List of (B,)
+
+ if self.ambiguous_loss_value > 0:
+ assert self.criterion.reduction == "none", (
+ "ambiguous_loss_value should be 0 if no conf loss"
+ )
+ # Add the ambiguous pixel as "valid" pixels...
+ valid_masks = [
+ mask | ambig_mask
+ for mask, ambig_mask in zip(valid_masks, ambiguous_masks)
+ ]
+
+ normal_losses = []
+ gradient_matching_losses = []
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+ cam_pts3d_losses = []
+ if self.compute_world_frame_points_loss:
+ pts3d_losses = []
+
+ for i in range(n_views):
+ # Get the camera frame points, log space depth_z & valid masks
+ pred_local_pts3d = pred_info[i]["pts3d_cam"]
+ pred_depth_z = pred_local_pts3d[..., 2:]
+ pred_depth_z = apply_log_to_norm(pred_depth_z)
+ gt_local_pts3d = gt_info[i]["pts3d_cam"]
+ gt_depth_z = gt_local_pts3d[..., 2:]
+ gt_depth_z = apply_log_to_norm(gt_depth_z)
+ valid_mask_for_normal_gm_loss = valid_masks[i].clone()
+
+ # Update the validity mask for normal & gm loss based on the synthetic data mask if required
+ if self.apply_normal_and_gm_loss_to_synthetic_data_only:
+ synthetic_mask = batch[i]["is_synthetic"] # (B, )
+ synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
+ synthetic_mask = synthetic_mask.expand(
+ -1, pred_depth_z.shape[1], pred_depth_z.shape[2]
+ ) # (B, H, W)
+ valid_mask_for_normal_gm_loss = (
+ valid_mask_for_normal_gm_loss & synthetic_mask
+ )
+
+ # Compute the normal loss
+ normal_loss = compute_normal_loss(
+ pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
+ )
+ normal_loss = normal_loss * self.normal_loss_weight
+ normal_losses.append(normal_loss)
+
+ # Compute the gradient matching loss
+ gradient_matching_loss = compute_gradient_matching_loss(
+ pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
+ )
+ gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight
+ gradient_matching_losses.append(gradient_matching_loss)
+
+ # Get the predicted dense quantities
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks and compute the metrics
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss][valid_masks[i]]
+ gt_depth = gt_info[i][self.depth_type_for_loss][valid_masks[i]]
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"][valid_masks[i]]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"][valid_masks[i]]
+ if self.compute_world_frame_points_loss:
+ pred_pts3d = pred_info[i]["pts3d"][valid_masks[i]]
+ gt_pts3d = gt_info[i]["pts3d"][valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W and compute the metrics
+ batch_size, _, _, direction_dim = gt_info[i]["ray_directions"].shape
+ gt_ray_directions = gt_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ pred_ray_directions = pred_info[i]["ray_directions"].view(
+ batch_size, -1, direction_dim
+ )
+ depth_dim = gt_info[i][self.depth_type_for_loss].shape[-1]
+ gt_depth = gt_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ pred_depth = pred_info[i][self.depth_type_for_loss].view(
+ batch_size, -1, depth_dim
+ )
+ cam_pts_dim = gt_info[i]["pts3d_cam"].shape[-1]
+ gt_cam_pts3d = gt_info[i]["pts3d_cam"].view(batch_size, -1, cam_pts_dim)
+ pred_cam_pts3d = pred_info[i]["pts3d_cam"].view(
+ batch_size, -1, cam_pts_dim
+ )
+ if self.compute_world_frame_points_loss:
+ pts_dim = gt_info[i]["pts3d"].shape[-1]
+ gt_pts3d = gt_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ pred_pts3d = pred_info[i]["pts3d"].view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space for depth if specified
+ if self.loss_in_log:
+ gt_depth = apply_log_to_norm(gt_depth)
+ pred_depth = apply_log_to_norm(pred_depth)
+ gt_cam_pts3d = apply_log_to_norm(gt_cam_pts3d)
+ pred_cam_pts3d = apply_log_to_norm(pred_cam_pts3d)
+ if self.compute_world_frame_points_loss:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_pts3d = apply_log_to_norm(pred_pts3d)
+
+ if self.compute_pairwise_relative_pose_loss:
+ # Get the inverse of current view predicted pose
+ pred_inv_curr_view_pose_quats = quaternion_inverse(
+ pred_info[i]["pose_quats"]
+ )
+ pred_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ pred_inv_curr_view_pose_quats
+ )
+ pred_inv_curr_view_pose_trans = -1 * ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the inverse of the current view GT pose
+ gt_inv_curr_view_pose_quats = quaternion_inverse(
+ gt_info[i]["pose_quats"]
+ )
+ gt_inv_curr_view_pose_rot_mat = quaternion_to_rotation_matrix(
+ gt_inv_curr_view_pose_quats
+ )
+ gt_inv_curr_view_pose_trans = -1 * ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[i]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+
+ # Get the other N-1 relative poses using the current pose as reference frame
+ pred_rel_pose_quats = []
+ pred_rel_pose_trans = []
+ gt_rel_pose_quats = []
+ gt_rel_pose_trans = []
+ for ov_idx in range(n_views):
+ if ov_idx == i:
+ continue
+ # Get the relative predicted pose
+ pred_ov_rel_pose_quats = quaternion_multiply(
+ pred_inv_curr_view_pose_quats, pred_info[ov_idx]["pose_quats"]
+ )
+ pred_ov_rel_pose_trans = (
+ ein.einsum(
+ pred_inv_curr_view_pose_rot_mat,
+ pred_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + pred_inv_curr_view_pose_trans
+ )
+
+ # Get the relative GT pose
+ gt_ov_rel_pose_quats = quaternion_multiply(
+ gt_inv_curr_view_pose_quats, gt_info[ov_idx]["pose_quats"]
+ )
+ gt_ov_rel_pose_trans = (
+ ein.einsum(
+ gt_inv_curr_view_pose_rot_mat,
+ gt_info[ov_idx]["pose_trans"],
+ "b i j, b j -> b i",
+ )
+ + gt_inv_curr_view_pose_trans
+ )
+
+ # Get the valid translations using valid_norm_factor_masks for current view and other view
+ overall_valid_mask_for_trans = (
+ valid_norm_factor_masks[i] & valid_norm_factor_masks[ov_idx]
+ )
+
+ # Append the relative poses
+ pred_rel_pose_quats.append(pred_ov_rel_pose_quats)
+ pred_rel_pose_trans.append(
+ pred_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+ gt_rel_pose_quats.append(gt_ov_rel_pose_quats)
+ gt_rel_pose_trans.append(
+ gt_ov_rel_pose_trans[overall_valid_mask_for_trans]
+ )
+
+ # Cat the N-1 relative poses along the batch dimension
+ pred_rel_pose_quats = torch.cat(pred_rel_pose_quats, dim=0)
+ pred_rel_pose_trans = torch.cat(pred_rel_pose_trans, dim=0)
+ gt_rel_pose_quats = torch.cat(gt_rel_pose_quats, dim=0)
+ gt_rel_pose_trans = torch.cat(gt_rel_pose_trans, dim=0)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_rel_pose_trans, gt_rel_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(
+ pred_rel_pose_quats, gt_rel_pose_quats, factor="pose_quats"
+ ),
+ self.criterion(
+ pred_rel_pose_quats, -gt_rel_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+ else:
+ # Get the pose info for the current view
+ pred_pose_trans = pred_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ gt_pose_trans = gt_info[i]["pose_trans"][valid_norm_factor_masks[i]]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans, gt_pose_trans, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ # Handle quaternion two-to-one mapping
+ pose_quats_loss = torch.minimum(
+ self.criterion(pred_pose_quats, gt_pose_quats, factor="pose_quats"),
+ self.criterion(
+ pred_pose_quats, -gt_pose_quats, factor="pose_quats"
+ ),
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions, gt_ray_directions, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth, gt_depth, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute camera frame point loss
+ cam_pts3d_loss = self.criterion(
+ pred_cam_pts3d, gt_cam_pts3d, factor="points"
+ )
+ cam_pts3d_loss = cam_pts3d_loss * self.cam_frame_points_loss_weight
+ cam_pts3d_losses.append(cam_pts3d_loss)
+
+ if self.compute_world_frame_points_loss:
+ # Compute point loss
+ pts3d_loss = self.criterion(pred_pts3d, gt_pts3d, factor="points")
+ pts3d_loss = pts3d_loss * self.world_frame_points_loss_weight
+ pts3d_losses.append(pts3d_loss)
+
+ # Handle ambiguous pixels
+ if self.ambiguous_loss_value > 0:
+ if not self.flatten_across_image_only:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i][valid_masks[i]],
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+ else:
+ depth_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ depth_losses[i],
+ )
+ cam_pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ cam_pts3d_losses[i],
+ )
+ if self.compute_world_frame_points_loss:
+ pts3d_losses[i] = torch.where(
+ ambiguous_masks[i].view(ambiguous_masks[i].shape[0], -1),
+ self.ambiguous_loss_value,
+ pts3d_losses[i],
+ )
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+ if self.compute_world_frame_points_loss:
+ losses_dict = {
+ "pts3d": {
+ "values": pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ }
+ else:
+ losses_dict = {}
+ losses_dict.update(
+ {
+ "cam_pts3d": {
+ "values": cam_pts3d_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ "normal": {
+ "values": normal_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "gradient_matching": {
+ "values": gradient_matching_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class DisentangledFactoredGeometryScaleRegr3D(Criterion, MultiLoss):
+ """
+ Disentangled Regression Loss for Factored Geometry & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ scale_loss_weight=1,
+ ):
+ """
+ Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale.
+ It isolates/disentangles the contribution of each factor to the final task of 3D reconstruction.
+ All the losses are in the same space where the loss for each factor is computed by constructing world-frame pointmaps.
+ This sidesteps the difficulty of finding a proper weighting.
+ For insance, for predicted rays, the GT depth & pose is used to construct the predicted world-frame pointmaps on which the loss is computed.
+ Inspired by https://openaccess.thecvf.com/content_ICCV_2019/papers/Simonelli_Disentangling_Monocular_3D_Object_Detection_ICCV_2019_paper.pdf
+
+ The pixel-level losses are computed in the following order:
+ (1) depth, (2) ray directions, (3) pose quats, (4) pose trans, (5) scale.
+ The predicited scene representation is always normalized w.r.t. the frame of view0.
+ Loss is applied between the predicted metric scale and the ground truth metric scale.
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ """
+ super().__init__(criterion)
+ self.norm_predictions = norm_predictions
+ self.norm_mode = norm_mode
+ self.loss_in_log = loss_in_log
+ self.flatten_across_image_only = flatten_across_image_only
+ self.depth_type_for_loss = depth_type_for_loss
+ assert self.depth_type_for_loss in ["depth_along_ray", "depth_z"], (
+ "depth_type_for_loss must be one of ['depth_along_ray', 'depth_z']"
+ )
+ self.depth_loss_weight = depth_loss_weight
+ self.ray_directions_loss_weight = ray_directions_loss_weight
+ self.pose_quats_loss_weight = pose_quats_loss_weight
+ self.pose_trans_loss_weight = pose_trans_loss_weight
+ self.scale_loss_weight = scale_loss_weight
+
+ def get_all_info(self, batch, preds, dist_clip=None):
+ """
+ Function to get all the information needed to compute the loss.
+ Returns all quantities normalized w.r.t. camera of view0.
+ """
+ n_views = len(batch)
+
+ # Everything is normalized w.r.t. camera of view0
+ # Intialize lists to store data for all views
+ # Ground truth quantities
+ in_camera0 = closed_form_pose_inverse(batch[0]["camera_pose"])
+ no_norm_gt_pts = []
+ no_norm_gt_pts_cam = []
+ no_norm_gt_depth = []
+ no_norm_gt_pose_trans = []
+ valid_masks = []
+ gt_ray_directions = []
+ gt_pose_quats = []
+ # Predicted quantities
+ no_norm_pr_pts = []
+ no_norm_pr_pts_cam = []
+ no_norm_pr_depth = []
+ no_norm_pr_pose_trans = []
+ pr_ray_directions = []
+ pr_pose_quats = []
+ metric_pr_pts_to_compute_scale = []
+
+ # Get ground truth & prediction info for all views
+ for i in range(n_views):
+ # Get the ground truth
+ no_norm_gt_pts.append(geotrf(in_camera0, batch[i]["pts3d"]))
+ valid_masks.append(batch[i]["valid_mask"].clone())
+ no_norm_gt_pts_cam.append(batch[i]["pts3d_cam"])
+ gt_ray_directions.append(batch[i]["ray_directions_cam"])
+ if self.depth_type_for_loss == "depth_along_ray":
+ no_norm_gt_depth.append(batch[i]["depth_along_ray"])
+ elif self.depth_type_for_loss == "depth_z":
+ no_norm_gt_depth.append(batch[i]["pts3d_cam"][..., 2:])
+ if i == 0:
+ # For view0, initialize identity pose
+ gt_pose_quats.append(
+ torch.tensor(
+ [0, 0, 0, 1],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ no_norm_gt_pose_trans.append(
+ torch.tensor(
+ [0, 0, 0],
+ dtype=gt_ray_directions[0].dtype,
+ device=gt_ray_directions[0].device,
+ )
+ .unsqueeze(0)
+ .repeat(gt_ray_directions[0].shape[0], 1)
+ )
+ else:
+ # For other views, transform pose to view0's frame
+ gt_pose_quats_world = batch[i]["camera_pose_quats"]
+ no_norm_gt_pose_trans_world = batch[i]["camera_pose_trans"]
+ gt_pose_quats_in_view0, no_norm_gt_pose_trans_in_view0 = (
+ transform_pose_using_quats_and_trans_2_to_1(
+ batch[0]["camera_pose_quats"],
+ batch[0]["camera_pose_trans"],
+ gt_pose_quats_world,
+ no_norm_gt_pose_trans_world,
+ )
+ )
+ gt_pose_quats.append(gt_pose_quats_in_view0)
+ no_norm_gt_pose_trans.append(no_norm_gt_pose_trans_in_view0)
+
+ # Get predictions for normalized loss
+ if self.depth_type_for_loss == "depth_along_ray":
+ curr_view_no_norm_depth = preds[i]["depth_along_ray"]
+ elif self.depth_type_for_loss == "depth_z":
+ curr_view_no_norm_depth = preds[i]["pts3d_cam"][..., 2:]
+ if "metric_scaling_factor" in preds[i].keys():
+ # Divide by the predicted metric scaling factor to get the raw predicted points, depth_along_ray, and pose_trans
+ # This detaches the predicted metric scaling factor from the geometry based loss
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"] / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_depth = curr_view_no_norm_depth / preds[i][
+ "metric_scaling_factor"
+ ].unsqueeze(-1).unsqueeze(-1)
+ curr_view_no_norm_pr_pose_trans = (
+ preds[i]["cam_trans"] / preds[i]["metric_scaling_factor"]
+ )
+ else:
+ curr_view_no_norm_pr_pts = preds[i]["pts3d"]
+ curr_view_no_norm_pr_pts_cam = preds[i]["pts3d_cam"]
+ curr_view_no_norm_depth = curr_view_no_norm_depth
+ curr_view_no_norm_pr_pose_trans = preds[i]["cam_trans"]
+ no_norm_pr_pts.append(curr_view_no_norm_pr_pts)
+ no_norm_pr_pts_cam.append(curr_view_no_norm_pr_pts_cam)
+ no_norm_pr_depth.append(curr_view_no_norm_depth)
+ no_norm_pr_pose_trans.append(curr_view_no_norm_pr_pose_trans)
+ pr_ray_directions.append(preds[i]["ray_directions"])
+ pr_pose_quats.append(preds[i]["cam_quats"])
+
+ # Get the predicted metric scale points
+ if "metric_scaling_factor" in preds[i].keys():
+ # Detach the raw predicted points so that the scale loss is only applied to the scaling factor
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.detach()
+ * preds[i]["metric_scaling_factor"].unsqueeze(-1).unsqueeze(-1)
+ )
+ else:
+ curr_view_metric_pr_pts_to_compute_scale = (
+ curr_view_no_norm_pr_pts.clone()
+ )
+ metric_pr_pts_to_compute_scale.append(
+ curr_view_metric_pr_pts_to_compute_scale
+ )
+
+ if dist_clip is not None:
+ # Points that are too far-away == invalid
+ for i in range(n_views):
+ dis = no_norm_gt_pts[i].norm(dim=-1)
+ valid_masks[i] = valid_masks[i] & (dis <= dist_clip)
+
+ # Initialize normalized tensors
+ gt_pts = [torch.zeros_like(pts) for pts in no_norm_gt_pts]
+ gt_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_gt_pts_cam]
+ gt_depth = [torch.zeros_like(depth) for depth in no_norm_gt_depth]
+ gt_pose_trans = [torch.zeros_like(trans) for trans in no_norm_gt_pose_trans]
+
+ pr_pts = [torch.zeros_like(pts) for pts in no_norm_pr_pts]
+ pr_pts_cam = [torch.zeros_like(pts_cam) for pts_cam in no_norm_pr_pts_cam]
+ pr_depth = [torch.zeros_like(depth) for depth in no_norm_pr_depth]
+ pr_pose_trans = [torch.zeros_like(trans) for trans in no_norm_pr_pose_trans]
+
+ # Normalize the predicted points if specified
+ if self.norm_predictions:
+ pr_normalization_output = normalize_multiple_pointclouds(
+ no_norm_pr_pts,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_pts_norm = pr_normalization_output[:-1]
+ pr_norm_factor = pr_normalization_output[-1]
+
+ # Normalize the ground truth points
+ gt_normalization_output = normalize_multiple_pointclouds(
+ no_norm_gt_pts, valid_masks, self.norm_mode, ret_factor=True
+ )
+ gt_pts_norm = gt_normalization_output[:-1]
+ gt_norm_factor = gt_normalization_output[-1]
+
+ for i in range(n_views):
+ if self.norm_predictions:
+ # Assign the normalized predictions
+ pr_pts[i] = pr_pts_norm[i]
+ pr_pts_cam[i] = no_norm_pr_pts_cam[i] / pr_norm_factor
+ pr_depth[i] = no_norm_pr_depth[i] / pr_norm_factor
+ pr_pose_trans[i] = no_norm_pr_pose_trans[i] / pr_norm_factor[:, :, 0, 0]
+ else:
+ pr_pts[i] = no_norm_pr_pts[i]
+ pr_pts_cam[i] = no_norm_pr_pts_cam[i]
+ pr_depth[i] = no_norm_pr_depth[i]
+ pr_pose_trans[i] = no_norm_pr_pose_trans[i]
+ # Assign the normalized ground truth quantities
+ gt_pts[i] = gt_pts_norm[i]
+ gt_pts_cam[i] = no_norm_gt_pts_cam[i] / gt_norm_factor
+ gt_depth[i] = no_norm_gt_depth[i] / gt_norm_factor
+ gt_pose_trans[i] = no_norm_gt_pose_trans[i] / gt_norm_factor[:, :, 0, 0]
+
+ # Get the mask indicating ground truth metric scale quantities
+ metric_scale_mask = batch[0]["is_metric_scale"]
+ valid_gt_norm_factor_mask = (
+ gt_norm_factor[:, 0, 0, 0] > 1e-8
+ ) # Mask out cases where depth for all views is invalid
+ valid_metric_scale_mask = metric_scale_mask & valid_gt_norm_factor_mask
+
+ if valid_metric_scale_mask.any():
+ # Compute the scale norm factor using the predicted metric scale points
+ metric_pr_normalization_output = normalize_multiple_pointclouds(
+ metric_pr_pts_to_compute_scale,
+ valid_masks,
+ self.norm_mode,
+ ret_factor=True,
+ )
+ pr_metric_norm_factor = metric_pr_normalization_output[-1]
+
+ # Get the valid ground truth and predicted scale norm factors for the metric ground truth quantities
+ gt_metric_norm_factor = gt_norm_factor[valid_metric_scale_mask]
+ pr_metric_norm_factor = pr_metric_norm_factor[valid_metric_scale_mask]
+ else:
+ gt_metric_norm_factor = None
+ pr_metric_norm_factor = None
+
+ # Get ambiguous masks
+ ambiguous_masks = []
+ for i in range(n_views):
+ ambiguous_masks.append(
+ (~batch[i]["non_ambiguous_mask"]) & (~valid_masks[i])
+ )
+
+ # Pack into info dicts
+ gt_info = []
+ pred_info = []
+ for i in range(n_views):
+ gt_info.append(
+ {
+ "ray_directions": gt_ray_directions[i],
+ self.depth_type_for_loss: gt_depth[i],
+ "pose_trans": gt_pose_trans[i],
+ "pose_quats": gt_pose_quats[i],
+ "pts3d": gt_pts[i],
+ "pts3d_cam": gt_pts_cam[i],
+ }
+ )
+ pred_info.append(
+ {
+ "ray_directions": pr_ray_directions[i],
+ self.depth_type_for_loss: pr_depth[i],
+ "pose_trans": pr_pose_trans[i],
+ "pose_quats": pr_pose_quats[i],
+ "pts3d": pr_pts[i],
+ "pts3d_cam": pr_pts_cam[i],
+ }
+ )
+
+ return (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ )
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+
+ for i in range(n_views):
+ # Get the GT factored quantities for the current view
+ gt_pts3d = gt_info[i]["pts3d"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ gt_depth = gt_info[i][self.depth_type_for_loss]
+ gt_pose_trans = gt_info[i]["pose_trans"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Get the predicted factored quantities for the current view
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss]
+ pred_pose_trans = pred_info[i]["pose_trans"]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+
+ # Get the predicted world-frame pointmaps using the different factors
+ if self.depth_type_for_loss == "depth_along_ray":
+ pred_ray_directions_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ pred_ray_directions,
+ gt_depth,
+ gt_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_depth_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ pred_depth,
+ gt_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_pose_trans_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ gt_depth,
+ pred_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_pose_quats_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ gt_depth,
+ gt_pose_trans,
+ pred_pose_quats,
+ )
+ )
+ else:
+ raise NotImplementedError
+
+ # Mask out the valid quantities as required
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]]
+ pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]]
+ pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]]
+ pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]]
+ gt_pts3d = gt_pts3d[valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, pts_dim = gt_pts3d.shape
+ pred_ray_directions_pts3d = pred_ray_directions_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim)
+ pred_pose_trans_pts3d = pred_pose_trans_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ pred_pose_quats_pts3d = pred_pose_quats_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space if specified
+ if self.loss_in_log:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d)
+ pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d)
+ pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d)
+ pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ pose_quats_loss = self.criterion(
+ pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats"
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+ losses_dict = {}
+ losses_dict.update(
+ {
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
+
+
+class DisentangledFactoredGeometryScaleRegr3DPlusNormalGMLoss(
+ DisentangledFactoredGeometryScaleRegr3D
+):
+ """
+ Disentangled Regression, Normals & Gradient Matching Loss for Factored Geometry & Scale.
+ """
+
+ def __init__(
+ self,
+ criterion,
+ norm_predictions=True,
+ norm_mode="avg_dis",
+ loss_in_log=True,
+ flatten_across_image_only=False,
+ depth_type_for_loss="depth_along_ray",
+ depth_loss_weight=1,
+ ray_directions_loss_weight=1,
+ pose_quats_loss_weight=1,
+ pose_trans_loss_weight=1,
+ scale_loss_weight=1,
+ apply_normal_and_gm_loss_to_synthetic_data_only=True,
+ normal_loss_weight=1,
+ gm_loss_weight=1,
+ ):
+ """
+ Initialize the disentangled loss criterion for Factored Geometry (Ray Directions, Depth, Pose) & Scale.
+ See parent class (DisentangledFactoredGeometryScaleRegr3D) for more details.
+ Additionally computes:
+ (1) Normal Loss over the Camera Frame Pointmaps in euclidean coordinates,
+ (2) Gradient Matching (GM) Loss over the Depth Z in log space. (MiDAS applied GM loss in disparity space)
+
+ Args:
+ criterion (BaseCriterion): The base criterion to use for computing the loss.
+ norm_predictions (bool): If True, normalize the predictions before computing the loss.
+ norm_mode (str): Normalization mode for the gt and predicted (optional) scene representation. Default: "avg_dis".
+ loss_in_log (bool): If True, apply logarithmic transformation to input before
+ computing the loss for depth, pointmaps and scale. Default: True.
+ flatten_across_image_only (bool): If True, flatten H x W dimensions only when computing
+ the loss. If False, flatten across batch and spatial dimensions. Default: False.
+ depth_type_for_loss (str): Type of depth to use for loss computation. Default: "depth_along_ray".
+ Options: "depth_along_ray", "depth_z"
+ depth_loss_weight (float): Weight to use for the depth loss. Default: 1.
+ ray_directions_loss_weight (float): Weight to use for the ray directions loss. Default: 1.
+ pose_quats_loss_weight (float): Weight to use for the pose quats loss. Default: 1.
+ pose_trans_loss_weight (float): Weight to use for the pose trans loss. Default: 1.
+ scale_loss_weight (float): Weight to use for the scale loss. Default: 1.
+ apply_normal_and_gm_loss_to_synthetic_data_only (bool): If True, apply the normal and gm loss only to synthetic data.
+ If False, apply the normal and gm loss to all data. Default: True.
+ normal_loss_weight (float): Weight to use for the normal loss. Default: 1.
+ gm_loss_weight (float): Weight to use for the gm loss. Default: 1.
+ """
+ super().__init__(
+ criterion=criterion,
+ norm_predictions=norm_predictions,
+ norm_mode=norm_mode,
+ loss_in_log=loss_in_log,
+ flatten_across_image_only=flatten_across_image_only,
+ depth_type_for_loss=depth_type_for_loss,
+ depth_loss_weight=depth_loss_weight,
+ ray_directions_loss_weight=ray_directions_loss_weight,
+ pose_quats_loss_weight=pose_quats_loss_weight,
+ pose_trans_loss_weight=pose_trans_loss_weight,
+ scale_loss_weight=scale_loss_weight,
+ )
+ self.apply_normal_and_gm_loss_to_synthetic_data_only = (
+ apply_normal_and_gm_loss_to_synthetic_data_only
+ )
+ self.normal_loss_weight = normal_loss_weight
+ self.gm_loss_weight = gm_loss_weight
+
+ def compute_loss(self, batch, preds, **kw):
+ (
+ gt_info,
+ pred_info,
+ valid_masks,
+ ambiguous_masks,
+ gt_metric_norm_factor,
+ pr_metric_norm_factor,
+ ) = self.get_all_info(batch, preds, **kw)
+ n_views = len(batch)
+
+ normal_losses = []
+ gradient_matching_losses = []
+ pose_trans_losses = []
+ pose_quats_losses = []
+ ray_directions_losses = []
+ depth_losses = []
+
+ for i in range(n_views):
+ # Get the camera frame points, log space depth_z & valid masks
+ pred_local_pts3d = pred_info[i]["pts3d_cam"]
+ pred_depth_z = pred_local_pts3d[..., 2:]
+ pred_depth_z = apply_log_to_norm(pred_depth_z)
+ gt_local_pts3d = gt_info[i]["pts3d_cam"]
+ gt_depth_z = gt_local_pts3d[..., 2:]
+ gt_depth_z = apply_log_to_norm(gt_depth_z)
+ valid_mask_for_normal_gm_loss = valid_masks[i].clone()
+
+ # Update the validity mask for normal & gm loss based on the synthetic data mask if required
+ if self.apply_normal_and_gm_loss_to_synthetic_data_only:
+ synthetic_mask = batch[i]["is_synthetic"] # (B, )
+ synthetic_mask = synthetic_mask.unsqueeze(-1).unsqueeze(-1) # (B, 1, 1)
+ synthetic_mask = synthetic_mask.expand(
+ -1, pred_depth_z.shape[1], pred_depth_z.shape[2]
+ ) # (B, H, W)
+ valid_mask_for_normal_gm_loss = (
+ valid_mask_for_normal_gm_loss & synthetic_mask
+ )
+
+ # Compute the normal loss
+ normal_loss = compute_normal_loss(
+ pred_local_pts3d, gt_local_pts3d, valid_mask_for_normal_gm_loss.clone()
+ )
+ normal_loss = normal_loss * self.normal_loss_weight
+ normal_losses.append(normal_loss)
+
+ # Compute the gradient matching loss
+ gradient_matching_loss = compute_gradient_matching_loss(
+ pred_depth_z, gt_depth_z, valid_mask_for_normal_gm_loss.clone()
+ )
+ gradient_matching_loss = gradient_matching_loss * self.gm_loss_weight
+ gradient_matching_losses.append(gradient_matching_loss)
+
+ # Get the GT factored quantities for the current view
+ gt_pts3d = gt_info[i]["pts3d"]
+ gt_ray_directions = gt_info[i]["ray_directions"]
+ gt_depth = gt_info[i][self.depth_type_for_loss]
+ gt_pose_trans = gt_info[i]["pose_trans"]
+ gt_pose_quats = gt_info[i]["pose_quats"]
+
+ # Get the predicted factored quantities for the current view
+ pred_ray_directions = pred_info[i]["ray_directions"]
+ pred_depth = pred_info[i][self.depth_type_for_loss]
+ pred_pose_trans = pred_info[i]["pose_trans"]
+ pred_pose_quats = pred_info[i]["pose_quats"]
+
+ # Get the predicted world-frame pointmaps using the different factors
+ if self.depth_type_for_loss == "depth_along_ray":
+ pred_ray_directions_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ pred_ray_directions,
+ gt_depth,
+ gt_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_depth_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ pred_depth,
+ gt_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_pose_trans_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ gt_depth,
+ pred_pose_trans,
+ gt_pose_quats,
+ )
+ )
+ pred_pose_quats_pts3d = (
+ convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ gt_ray_directions,
+ gt_depth,
+ gt_pose_trans,
+ pred_pose_quats,
+ )
+ )
+ else:
+ raise NotImplementedError
+
+ # Mask out the valid quantities as required
+ if not self.flatten_across_image_only:
+ # Flatten the points across the entire batch with the masks
+ pred_ray_directions_pts3d = pred_ray_directions_pts3d[valid_masks[i]]
+ pred_depth_pts3d = pred_depth_pts3d[valid_masks[i]]
+ pred_pose_trans_pts3d = pred_pose_trans_pts3d[valid_masks[i]]
+ pred_pose_quats_pts3d = pred_pose_quats_pts3d[valid_masks[i]]
+ gt_pts3d = gt_pts3d[valid_masks[i]]
+ else:
+ # Flatten the H x W dimensions to H*W
+ batch_size, _, _, pts_dim = gt_pts3d.shape
+ pred_ray_directions_pts3d = pred_ray_directions_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ pred_depth_pts3d = pred_depth_pts3d.view(batch_size, -1, pts_dim)
+ pred_pose_trans_pts3d = pred_pose_trans_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ pred_pose_quats_pts3d = pred_pose_quats_pts3d.view(
+ batch_size, -1, pts_dim
+ )
+ gt_pts3d = gt_pts3d.view(batch_size, -1, pts_dim)
+ valid_masks[i] = valid_masks[i].view(batch_size, -1)
+
+ # Apply loss in log space if specified
+ if self.loss_in_log:
+ gt_pts3d = apply_log_to_norm(gt_pts3d)
+ pred_ray_directions_pts3d = apply_log_to_norm(pred_ray_directions_pts3d)
+ pred_depth_pts3d = apply_log_to_norm(pred_depth_pts3d)
+ pred_pose_trans_pts3d = apply_log_to_norm(pred_pose_trans_pts3d)
+ pred_pose_quats_pts3d = apply_log_to_norm(pred_pose_quats_pts3d)
+
+ # Compute pose translation loss
+ pose_trans_loss = self.criterion(
+ pred_pose_trans_pts3d, gt_pts3d, factor="pose_trans"
+ )
+ pose_trans_loss = pose_trans_loss * self.pose_trans_loss_weight
+ pose_trans_losses.append(pose_trans_loss)
+
+ # Compute pose rotation loss
+ pose_quats_loss = self.criterion(
+ pred_pose_quats_pts3d, gt_pts3d, factor="pose_quats"
+ )
+ pose_quats_loss = pose_quats_loss * self.pose_quats_loss_weight
+ pose_quats_losses.append(pose_quats_loss)
+
+ # Compute ray direction loss
+ ray_directions_loss = self.criterion(
+ pred_ray_directions_pts3d, gt_pts3d, factor="ray_directions"
+ )
+ ray_directions_loss = ray_directions_loss * self.ray_directions_loss_weight
+ ray_directions_losses.append(ray_directions_loss)
+
+ # Compute depth loss
+ depth_loss = self.criterion(pred_depth_pts3d, gt_pts3d, factor="depth")
+ depth_loss = depth_loss * self.depth_loss_weight
+ depth_losses.append(depth_loss)
+
+ # Compute the scale loss
+ if gt_metric_norm_factor is not None:
+ if self.loss_in_log:
+ gt_metric_norm_factor = apply_log_to_norm(gt_metric_norm_factor)
+ pr_metric_norm_factor = apply_log_to_norm(pr_metric_norm_factor)
+ scale_loss = (
+ self.criterion(
+ pr_metric_norm_factor, gt_metric_norm_factor, factor="scale"
+ )
+ * self.scale_loss_weight
+ )
+ else:
+ scale_loss = None
+
+ # Use helper function to generate loss terms and details
+ losses_dict = {}
+ losses_dict.update(
+ {
+ self.depth_type_for_loss: {
+ "values": depth_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "ray_directions": {
+ "values": ray_directions_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "pose_quats": {
+ "values": pose_quats_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "pose_trans": {
+ "values": pose_trans_losses,
+ "use_mask": True,
+ "is_multi_view": True,
+ },
+ "scale": {
+ "values": scale_loss,
+ "use_mask": False,
+ "is_multi_view": False,
+ },
+ "normal": {
+ "values": normal_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ "gradient_matching": {
+ "values": gradient_matching_losses,
+ "use_mask": False,
+ "is_multi_view": True,
+ },
+ }
+ )
+ loss_terms, details = get_loss_terms_and_details(
+ losses_dict,
+ valid_masks,
+ type(self).__name__,
+ n_views,
+ self.flatten_across_image_only,
+ )
+ losses = Sum(*loss_terms)
+
+ return losses, (details | {})
diff --git a/mapanything/train/profile_dataloading.py b/mapanything/train/profile_dataloading.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fb9b7615be4cbb56e07d2d182cbbd16d103070e
--- /dev/null
+++ b/mapanything/train/profile_dataloading.py
@@ -0,0 +1,285 @@
+"""
+Debug script to profile dataloading for MapAnything training.
+
+This script measures and analyzes the performance of data loading operations
+for MapAnything training workflows. It simulates the training process without
+actual model training to isolate and profile the data loading components.
+"""
+
+import datetime
+import json
+import os
+import time
+from pathlib import Path
+from typing import Sized
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+
+import mapanything.utils.train_tools as train_tools
+from mapanything.datasets import get_test_data_loader, get_train_data_loader
+from mapanything.datasets.base.base_dataset import view_name
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+def profile_dataloading(args):
+ """
+ Main profiling function that simulates the training process to measure data loading performance.
+
+ This function initializes the distributed environment, sets up datasets and data loaders,
+ and runs through training epochs to profile the data loading operations. It measures
+ the time taken for data loading without performing actual model training or optimization.
+
+ In this simulation, an epoch represents a complete pass through a chunk of the dataset.
+
+ Args:
+ args: Configuration object containing all parameters including:
+ - dataset: Dataset configuration (train_dataset, test_dataset, num_workers)
+ - train_params: Training parameters (batch_size, epochs, seed, etc.)
+ - distributed: Distributed training configuration
+ - output_dir: Directory for saving logs and profiling results
+ """
+ # Initialize distributed training if required
+ train_tools.init_distributed_mode(args.distributed)
+ global_rank = train_tools.get_rank()
+ world_size = train_tools.get_world_size() # noqa
+
+ # Init output directory and device
+ print("output_dir: " + args.output_dir)
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(", ", ",\n"))
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Fix the seed
+ seed = args.train_params.seed + train_tools.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = not args.train_params.disable_cudnn_benchmark
+
+ # Datasets and Dataloaders
+ print("Building train dataset {:s}".format(args.dataset.train_dataset))
+ data_loader_train = build_dataset(
+ dataset=args.dataset.train_dataset,
+ num_workers=args.dataset.num_workers,
+ test=False,
+ max_num_of_imgs_per_gpu=args.train_params.max_num_of_imgs_per_gpu,
+ )
+ print("Building test dataset {:s}".format(args.dataset.test_dataset))
+ test_batch_size = 2 * (
+ args.train_params.max_num_of_imgs_per_gpu // args.dataset.num_views
+ ) # Since we don't have any backward overhead
+ data_loader_test = {
+ dataset.split("(")[0]: build_dataset(
+ dataset=dataset,
+ num_workers=args.dataset.num_workers,
+ test=True,
+ batch_size=test_batch_size,
+ )
+ for dataset in args.dataset.test_dataset.split("+")
+ if "(" in dataset
+ }
+
+ def write_log_stats(epoch, train_stats, test_stats):
+ """
+ Writes profiling statistics to log files and TensorBoard.
+
+ This function collects metrics from the training and testing phases and writes them
+ to log files and TensorBoard for visualization and analysis. It only executes on the
+ main process in a distributed setting.
+
+ Args:
+ epoch: int, current epoch number
+ train_stats: dict, containing training metrics and timing information
+ test_stats: dict, containing testing metrics for each test dataset
+ """
+ if train_tools.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+
+ log_stats = dict(
+ epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()}
+ )
+ for test_name in data_loader_test:
+ if test_name not in test_stats:
+ continue
+ log_stats.update(
+ {test_name + "_" + k: v for k, v in test_stats[test_name].items()}
+ )
+
+ with open(
+ os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
+ ) as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ if global_rank == 0 and args.output_dir is not None:
+ log_writer = SummaryWriter(log_dir=args.output_dir)
+ else:
+ log_writer = None
+
+ print(f"Start training for {args.train_params.epochs} epochs")
+ start_time = time.time()
+ train_stats = test_stats = {}
+ args.train_params.start_epoch = 0
+ for epoch in range(args.train_params.start_epoch, args.train_params.epochs + 1):
+ # Save more stuff
+ write_log_stats(epoch, train_stats, test_stats)
+
+ if epoch >= args.train_params.epochs:
+ break # exit after writing last test to disk
+
+ # Train
+ train_stats = train_one_epoch(
+ data_loader_train,
+ device,
+ epoch,
+ log_writer=log_writer,
+ args=args,
+ )
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print("Training time {}".format(total_time_str))
+
+
+def build_dataset(
+ dataset, num_workers, test, batch_size=None, max_num_of_imgs_per_gpu=None
+):
+ """
+ Builds data loaders for training or testing.
+
+ Args:
+ dataset: Dataset specification string.
+ num_workers: Number of worker processes for data loading.
+ test: Boolean flag indicating whether this is a test dataset.
+ batch_size: Number of samples per batch. Defaults to None. Used only for testing.
+ max_num_of_imgs_per_gpu: Maximum number of images per GPU. Defaults to None. Used only for training.
+
+ Returns:
+ DataLoader: PyTorch DataLoader configured for the specified dataset.
+ """
+ split = ["Train", "Test"][test]
+ print(f"Building {split} Data loader for dataset: ", dataset)
+ if test:
+ assert batch_size is not None, (
+ "batch_size must be specified for testing dataloader"
+ )
+ loader = get_test_data_loader(
+ dataset=dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_mem=True,
+ shuffle=False,
+ drop_last=False,
+ )
+ else:
+ assert max_num_of_imgs_per_gpu is not None, (
+ "max_num_of_imgs_per_gpu must be specified for training dataloader"
+ )
+ loader = get_train_data_loader(
+ dataset=dataset,
+ max_num_of_imgs_per_gpu=max_num_of_imgs_per_gpu,
+ num_workers=num_workers,
+ pin_mem=True,
+ shuffle=True,
+ drop_last=True,
+ )
+
+ print(f"{split} dataset length: ", len(loader))
+ return loader
+
+
+def train_one_epoch(
+ data_loader: Sized,
+ device: torch.device,
+ epoch: int,
+ args,
+ log_writer=None,
+):
+ """
+ Simulates training for one epoch to profile data loading performance.
+
+ This function runs through a single epoch, simulating the data loading and device transfer
+ operations that would occur during actual training. It measures and logs the time taken
+ for these operations without performing actual model training.
+
+ Args:
+ data_loader: Sized, DataLoader providing the training data
+ device: torch.device, device to transfer data to (CPU or GPU)
+ epoch: int, current epoch number
+ args: object, configuration object containing training parameters including:
+ - train_params.print_freq: frequency of logging during the epoch
+ log_writer: Optional[SummaryWriter], TensorBoard SummaryWriter for logging metrics
+
+ Returns:
+ dict: Dictionary containing profiling metrics averaged over the epoch
+ """
+ metric_logger = train_tools.MetricLogger(delimiter=" ")
+ header = "Epoch: [{}]".format(epoch)
+
+ if log_writer is not None:
+ print("log_dir: {}".format(log_writer.log_dir))
+
+ if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
+ data_loader.dataset.set_epoch(epoch)
+ if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"):
+ data_loader.sampler.set_epoch(epoch)
+ if hasattr(data_loader, "batch_sampler") and hasattr(
+ data_loader.batch_sampler, "set_epoch"
+ ):
+ data_loader.batch_sampler.set_epoch(epoch)
+
+ for data_iter_step, batch in enumerate(
+ metric_logger.log_every(data_loader, args.train_params.print_freq, header)
+ ):
+ epoch_f = epoch + data_iter_step / len(data_loader)
+
+ # Simulate the device loading in loss_of_one_batch_multi_view
+ ignore_keys = set(
+ [
+ "depthmap",
+ "dataset",
+ "label",
+ "instance",
+ "idx",
+ "true_shape",
+ "rng",
+ "data_norm_type",
+ ]
+ )
+ for view in batch:
+ for name in view.keys():
+ if name in ignore_keys:
+ continue
+ view[name] = view[name].to(device, non_blocking=True)
+
+ local_rank = train_tools.get_rank()
+ n_views = len(batch)
+ batch_shape = batch[0]["img"].shape
+ first_sample_name = view_name(batch[0], batch_index=0)
+ print(
+ f"Rank: {local_rank}, Num views: {n_views}, Batch Shape: {batch_shape}, First Sample Name: {first_sample_name}",
+ force=True,
+ )
+
+ del batch
+
+ metric_logger.update(epoch=epoch_f)
+ metric_logger.update(loss=0)
+
+ # # Gather the stats from all processes
+ # metric_logger.synchronize_between_processes()
+ # print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
diff --git a/mapanything/train/training.py b/mapanything/train/training.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5dd83cd08caa07c9a893aba6a8e07ebf7c05c26
--- /dev/null
+++ b/mapanything/train/training.py
@@ -0,0 +1,659 @@
+"""
+Training Code for MapAnything.
+
+References:
+DUSt3R: https://github.com/naver/dust3r
+"""
+
+import datetime
+import json
+import math
+import os
+import pickle
+import sys
+import time
+from collections import defaultdict
+from pathlib import Path
+from typing import Sized
+
+import numpy as np
+import torch
+import torch.backends.cudnn as cudnn
+from torch.utils.tensorboard import SummaryWriter
+
+import mapanything.utils.train_tools as train_tools
+from mapanything.datasets import get_test_data_loader, get_train_data_loader
+from mapanything.models import init_model
+from mapanything.train.losses import * # noqa
+from mapanything.utils.inference import loss_of_one_batch_multi_view
+from mapanything.utils.train_tools import NativeScalerWithGradNormCount as NativeScaler
+
+# Enable TF32 precision if supported (for GPU >= Ampere and PyTorch >= 1.12)
+if hasattr(torch.backends.cuda, "matmul") and hasattr(
+ torch.backends.cuda.matmul, "allow_tf32"
+):
+ torch.backends.cuda.matmul.allow_tf32 = True
+
+
+def train(args):
+ """
+ Main training function that handles the entire training process.
+
+ This function initializes the distributed training environment, sets up datasets,
+ initializes the model, optimizer, and loss functions, and manages the training
+ and evaluation loop across multiple epochs.
+
+ In this training, an epoch is just a chunk of the entire dataset.
+
+ Args:
+ args: Configuration object containing all training parameters including
+ dataset configs, model configs, training parameters, and loss functions.
+ """
+ # Initialize distributed training if required
+ train_tools.init_distributed_mode(args.distributed)
+ global_rank = train_tools.get_rank()
+ world_size = train_tools.get_world_size() # noqa
+
+ # Init output directory and device
+ print("output_dir: " + args.output_dir)
+ if args.output_dir:
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
+
+ print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
+ print("{}".format(args).replace(", ", ",\n"))
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ # Fix the seed
+ seed = args.train_params.seed + train_tools.get_rank()
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+
+ cudnn.benchmark = not args.train_params.disable_cudnn_benchmark
+
+ # Datasets and Dataloaders
+ print("Building train dataset {:s}".format(args.dataset.train_dataset))
+ data_loader_train = build_dataset(
+ dataset=args.dataset.train_dataset,
+ num_workers=args.dataset.num_workers,
+ test=False,
+ max_num_of_imgs_per_gpu=args.train_params.max_num_of_imgs_per_gpu,
+ )
+ print("Building test dataset {:s}".format(args.dataset.test_dataset))
+ test_batch_size = 2 * (
+ args.train_params.max_num_of_imgs_per_gpu // args.dataset.num_views
+ ) # Since we don't have any backward overhead
+ data_loader_test = {
+ dataset.split("(")[0]: build_dataset(
+ dataset=dataset,
+ num_workers=args.dataset.num_workers,
+ test=True,
+ batch_size=test_batch_size,
+ )
+ for dataset in args.dataset.test_dataset.split("+")
+ if "(" in dataset
+ }
+
+ # Load Model
+ if global_rank == 0:
+ model = init_model(
+ args.model.model_str,
+ args.model.model_config,
+ torch_hub_force_reload=args.model.torch_hub_force_reload,
+ )
+ if torch.distributed.is_initialized():
+ torch.distributed.barrier() # Make sure the model is initialized before proceeding
+ if global_rank != 0:
+ model = init_model(
+ args.model.model_str, args.model.model_config, torch_hub_force_reload=False
+ )
+ model.to(device) # Move model to device
+ model_without_ddp = model
+ print("Model = %s" % str(model_without_ddp))
+
+ # Criterion
+ print(f">> Creating train criterion = {args.loss.train_criterion}")
+ train_criterion = eval(args.loss.train_criterion).to(device)
+ print(
+ f">> Creating test criterion = {args.loss.test_criterion or args.loss.train_criterion}"
+ )
+ test_criterion = eval(args.loss.test_criterion or args.loss.train_criterion).to(
+ device
+ )
+
+ # Load pretrained model if provided
+ if args.model.pretrained:
+ print("Loading pretrained: ", args.model.pretrained)
+ ckpt = torch.load(
+ args.model.pretrained, map_location=device, weights_only=False
+ )
+ print(model.load_state_dict(ckpt["model"], strict=False))
+ del ckpt # in case it occupies memory
+
+ # Init model for DDP training
+ if args.distributed.distributed:
+ model = torch.nn.parallel.DistributedDataParallel(
+ model,
+ device_ids=[args.distributed.gpu],
+ find_unused_parameters=True,
+ static_graph=False,
+ )
+ model_without_ddp = model.module
+
+ # Optimizer and loss scaler for gradient accumulation
+ # Following timm: set wd as 0 for bias and norm layers
+ param_groups, param_groups_name_to_idx_map, param_groups_idx_to_name_map = (
+ train_tools.get_parameter_groups(
+ model_without_ddp,
+ args.train_params.lr,
+ args.train_params.weight_decay,
+ submodule_configs=args.train_params.submodule_configs,
+ warn_not_in_submodule=args.train_params.warn_not_in_submodule,
+ )
+ )
+ optimizer = torch.optim.AdamW(
+ param_groups, lr=args.train_params.lr, betas=(0.9, 0.95)
+ )
+ print(optimizer)
+ loss_scaler = NativeScaler()
+
+ def write_log_stats(epoch, train_stats, test_stats):
+ """
+ Writes training and testing statistics to log files and TensorBoard.
+
+ Args:
+ epoch: Current epoch number.
+ train_stats: Dictionary containing training metrics.
+ test_stats: Dictionary containing testing metrics for each test dataset.
+ """
+ if train_tools.is_main_process():
+ if log_writer is not None:
+ log_writer.flush()
+
+ log_stats = dict(
+ epoch=epoch, **{f"train_{k}": v for k, v in train_stats.items()}
+ )
+ for test_name in data_loader_test:
+ if test_name not in test_stats:
+ continue
+ log_stats.update(
+ {test_name + "_" + k: v for k, v in test_stats[test_name].items()}
+ )
+
+ with open(
+ os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8"
+ ) as f:
+ f.write(json.dumps(log_stats) + "\n")
+
+ def save_model(epoch, fname, best_so_far):
+ """
+ Saves model checkpoint to disk.
+
+ Args:
+ epoch: Current epoch number.
+ fname: Filename or identifier for the checkpoint.
+ best_so_far: Best validation metric achieved so far.
+ """
+ train_tools.save_model(
+ args=args,
+ model_without_ddp=model_without_ddp,
+ optimizer=optimizer,
+ loss_scaler=loss_scaler,
+ epoch=epoch,
+ fname=fname,
+ best_so_far=best_so_far,
+ )
+
+ # Resume from a checkpoint if needed
+ last_ckpt_fname = os.path.join(args.output_dir, "checkpoint-last.pth")
+ if args.train_params.resume and os.path.isfile(last_ckpt_fname):
+ args.train_params.resume_ckpt = last_ckpt_fname
+ else:
+ args.train_params.resume_ckpt = None
+ best_so_far = train_tools.load_model(
+ train_args=args.train_params,
+ model_without_ddp=model_without_ddp,
+ optimizer=optimizer,
+ loss_scaler=loss_scaler,
+ )
+ if best_so_far is None:
+ best_so_far = float("inf")
+
+ if global_rank == 0 and args.output_dir is not None:
+ log_writer = SummaryWriter(log_dir=args.output_dir)
+ else:
+ log_writer = None
+
+ print(f"Start training for {args.train_params.epochs} epochs")
+ start_time = time.time()
+ train_stats = test_stats = {}
+ for epoch in range(args.train_params.start_epoch, args.train_params.epochs + 1):
+ # Save immediately the last checkpoint
+ if epoch > args.train_params.start_epoch:
+ if (
+ args.train_params.save_freq
+ and epoch % args.train_params.save_freq == 0
+ or epoch == args.train_params.epochs
+ ):
+ save_model(epoch - 1, "last", best_so_far)
+
+ # Test on multiple datasets
+ new_best = False
+ test_stats = {}
+ if (
+ args.train_params.eval_freq > 0
+ and epoch % args.train_params.eval_freq == 0
+ and epoch > 0
+ ):
+ for test_name, testset in data_loader_test.items():
+ print(f"Testing on {test_name} ...")
+ stats = test_one_epoch(
+ model,
+ test_criterion,
+ testset,
+ device,
+ epoch,
+ log_writer=log_writer,
+ args=args,
+ prefix=test_name,
+ )
+ test_stats[test_name] = stats
+
+ # Calculate average test loss median
+ avg_test_loss_med = np.mean(
+ [stats["loss_med"] for stats in test_stats.values()]
+ )
+ test_stats["Average Test Loss Median"] = avg_test_loss_med
+ # Save best
+ if avg_test_loss_med < best_so_far:
+ best_so_far = avg_test_loss_med
+ new_best = True
+
+ # Save more stuff
+ write_log_stats(epoch, train_stats, test_stats)
+
+ if epoch > args.train_params.start_epoch:
+ if args.train_params.keep_freq and epoch % args.train_params.keep_freq == 0:
+ save_model(epoch - 1, str(epoch), best_so_far)
+ if new_best:
+ save_model(epoch - 1, "best", best_so_far)
+ if epoch >= args.train_params.epochs:
+ break # exit after writing last test to disk
+
+ # Train
+ train_stats = train_one_epoch(
+ model,
+ train_criterion,
+ data_loader_train,
+ optimizer,
+ device,
+ epoch,
+ loss_scaler,
+ log_writer=log_writer,
+ args=args,
+ param_groups_name_to_idx_map=param_groups_name_to_idx_map,
+ param_groups_idx_to_name_map=param_groups_idx_to_name_map,
+ model_without_ddp=model_without_ddp,
+ )
+
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print("Training time {}".format(total_time_str))
+
+ save_final_model(
+ args, args.train_params.epochs, model_without_ddp, best_so_far=best_so_far
+ )
+
+
+def save_final_model(args, epoch, model_without_ddp, best_so_far=None):
+ """
+ Saves the final model checkpoint after training completion.
+
+ Args:
+ args: Configuration object containing output directory information.
+ epoch: Current epoch number.
+ model_without_ddp: Model state dictionary or model instance without DistributedDataParallel wrapper.
+ best_so_far: Optional; Best validation metric achieved during training.
+ """
+ output_dir = Path(args.output_dir)
+ checkpoint_path = output_dir / "checkpoint-final.pth"
+ to_save = {
+ "args": args,
+ "model": model_without_ddp
+ if isinstance(model_without_ddp, dict)
+ else model_without_ddp.cpu().state_dict(),
+ "epoch": epoch,
+ }
+ if best_so_far is not None:
+ to_save["best_so_far"] = best_so_far
+ print(f">> Saving model to {checkpoint_path} ...")
+ train_tools.save_on_master(to_save, checkpoint_path)
+
+
+def build_dataset(
+ dataset, num_workers, test, batch_size=None, max_num_of_imgs_per_gpu=None
+):
+ """
+ Builds data loaders for training or testing.
+
+ Args:
+ dataset: Dataset specification string.
+ num_workers: Number of worker processes for data loading.
+ test: Boolean flag indicating whether this is a test dataset.
+ batch_size: Number of samples per batch. Defaults to None. Used only for testing.
+ max_num_of_imgs_per_gpu: Maximum number of images per GPU. Defaults to None. Used only for training.
+
+ Returns:
+ DataLoader: PyTorch DataLoader configured for the specified dataset.
+ """
+ split = ["Train", "Test"][test]
+ print(f"Building {split} Data loader for dataset: ", dataset)
+ if test:
+ assert batch_size is not None, (
+ "batch_size must be specified for testing dataloader"
+ )
+ loader = get_test_data_loader(
+ dataset=dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ pin_mem=True,
+ shuffle=False,
+ drop_last=False,
+ )
+ else:
+ assert max_num_of_imgs_per_gpu is not None, (
+ "max_num_of_imgs_per_gpu must be specified for training dataloader"
+ )
+ loader = get_train_data_loader(
+ dataset=dataset,
+ max_num_of_imgs_per_gpu=max_num_of_imgs_per_gpu,
+ num_workers=num_workers,
+ pin_mem=True,
+ shuffle=True,
+ drop_last=True,
+ )
+
+ print(f"{split} dataset length: ", len(loader))
+ return loader
+
+
+def train_one_epoch(
+ model: torch.nn.Module,
+ criterion: torch.nn.Module,
+ data_loader: Sized,
+ optimizer: torch.optim.Optimizer,
+ device: torch.device,
+ epoch: int,
+ loss_scaler,
+ args,
+ log_writer=None,
+ param_groups_name_to_idx_map=None,
+ param_groups_idx_to_name_map=None,
+ model_without_ddp=None,
+):
+ """
+ Trains the model for one epoch.
+ Epoch is just a chunk of the entire dataset.
+
+ This function handles the training loop for a single epoch, including forward/backward passes,
+ gradient accumulation, learning rate scheduling, and logging metrics.
+
+ Args:
+ model: The neural network model to train.
+ criterion: Loss function to optimize.
+ data_loader: DataLoader providing the training data.
+ optimizer: Optimizer for updating model parameters.
+ device: Device to run training on (CPU or GPU).
+ epoch: Current epoch number.
+ loss_scaler: Scaler for gradient accumulation and mixed precision training.
+ args: Configuration object containing training parameters.
+ log_writer: Optional; TensorBoard SummaryWriter for logging.
+ param_groups_name_to_idx_map: Mapping from parameter group names to indices.
+ param_groups_idx_to_name_map: Mapping from parameter group indices to names.
+ model_without_ddp: Model without DistributedDataParallel wrapper for debugging.
+
+ Returns:
+ dict: Dictionary containing training metrics averaged over the epoch.
+ """
+ model.train(True)
+ metric_logger = train_tools.MetricLogger(delimiter=" ")
+ for submodule_name in param_groups_name_to_idx_map:
+ lr_name = f"lr_{submodule_name}" if submodule_name != "default" else "lr"
+ metric_logger.add_meter(
+ lr_name, train_tools.SmoothedValue(window_size=1, fmt="{value:.6f}")
+ )
+ header = "Epoch: [{}]".format(epoch)
+ accum_iter = args.train_params.accum_iter
+
+ if log_writer is not None:
+ print("log_dir: {}".format(log_writer.log_dir))
+
+ if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
+ data_loader.dataset.set_epoch(epoch)
+ if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"):
+ data_loader.sampler.set_epoch(epoch)
+ if hasattr(data_loader, "batch_sampler") and hasattr(
+ data_loader.batch_sampler, "set_epoch"
+ ):
+ data_loader.batch_sampler.set_epoch(epoch)
+
+ optimizer.zero_grad()
+
+ for data_iter_step, batch in enumerate(
+ metric_logger.log_every(data_loader, args.train_params.print_freq, header)
+ ):
+ n_views = len(batch)
+ epoch_f = epoch + data_iter_step / len(data_loader)
+
+ # We use a per iteration (instead of per epoch) lr scheduler
+ if data_iter_step % accum_iter == 0:
+ train_tools.adjust_learning_rate(
+ optimizer,
+ epoch_f,
+ args.train_params,
+ param_groups_idx_to_name_map,
+ args.train_params.submodule_configs,
+ )
+
+ loss_tuple = loss_of_one_batch_multi_view(
+ batch,
+ model,
+ criterion,
+ device,
+ use_amp=bool(args.train_params.amp),
+ amp_dtype=args.train_params.amp_dtype,
+ ret="loss",
+ )
+ loss, loss_details = loss_tuple # criterion returns two values
+ if n_views > 2:
+ loss = loss * (
+ 2 / n_views
+ ) # scale the loss relative to the number of views (base is 2 views)
+ loss_value = float(loss)
+
+ if not math.isfinite(loss_value) or (loss_value > 1000):
+ print("Loss is {}, stopping training".format(loss_value), force=True)
+ print(f"Loss Details: {loss_details}", force=True)
+ print(f"Epoch: {epoch}, Data Iteration: {data_iter_step}", force=True)
+ # Save the current batch to the output folder for further inspection
+ for view_idx, view in enumerate(batch):
+ view_cpu = {}
+ for k, v in view.items():
+ view_cpu[k] = v.cpu() if isinstance(v, torch.Tensor) else v
+ with open(
+ os.path.join(args.output_dir, f"batch_view_{view_idx}.pkl"), "wb"
+ ) as f:
+ pickle.dump(view_cpu, f)
+ # Save the model to the output folder for further inspection
+ checkpoint_debug_path = os.path.join(
+ args.output_dir, "checkpoint-debug.pth"
+ )
+ to_save_debug = {
+ "args": args,
+ "model": (
+ model_without_ddp
+ if isinstance(model_without_ddp, dict)
+ else model_without_ddp.cpu().state_dict()
+ ),
+ "epoch": epoch,
+ "data_iter_step": data_iter_step,
+ }
+ torch.save(to_save_debug, checkpoint_debug_path)
+ print(f"Saved debugging material to {args.output_dir}", force=True)
+ sys.exit(1)
+
+ # Scale the loss by the number of gradient accumulation iterations
+ loss /= accum_iter
+
+ # Compute the scaled gradients (also clip the gradients to max norm of 1)
+ gradient_norm = loss_scaler(
+ loss,
+ optimizer,
+ parameters=model.parameters(),
+ update_grad=(data_iter_step + 1) % accum_iter == 0,
+ clip_grad=1.0,
+ )
+
+ # Zero out the gradients to prepare for the next iteration of gradient descent
+ if (data_iter_step + 1) % accum_iter == 0:
+ optimizer.zero_grad()
+
+ del loss
+ del batch
+
+ metric_logger.update(epoch=epoch_f)
+ for submodule_name in param_groups_name_to_idx_map:
+ lr_name = f"lr_{submodule_name}" if submodule_name != "default" else "lr"
+ log_lr = optimizer.param_groups[
+ param_groups_name_to_idx_map[submodule_name][0]
+ ]["lr"]
+ metric_logger.meters[lr_name].update(log_lr)
+ metric_logger.update(loss=loss_value, **loss_details)
+
+ if (data_iter_step + 1) % accum_iter == 0 and (
+ (data_iter_step + 1) % (accum_iter * args.train_params.print_freq)
+ ) == 0:
+ loss_value_reduce = train_tools.all_reduce_mean(
+ loss_value
+ ) # MUST BE EXECUTED BY ALL NODES
+ if log_writer is None:
+ continue
+ """
+ We use epoch_1000x as the x-axis in tensorboard.
+ This calibrates different curves when batch size changes.
+ """
+ epoch_1000x = int(epoch_f * 1000)
+ log_writer.add_scalar("train_loss", loss_value_reduce, epoch_1000x)
+ if gradient_norm is not None:
+ log_writer.add_scalar("train_grad_norm", gradient_norm, epoch_1000x)
+ for submodule_name in param_groups_name_to_idx_map:
+ lr_name = (
+ f"train_lr_{submodule_name}"
+ if submodule_name != "default"
+ else "train_lr"
+ )
+ log_lr = optimizer.param_groups[
+ param_groups_name_to_idx_map[submodule_name][0]
+ ]["lr"]
+ log_writer.add_scalar(lr_name, log_lr, epoch_1000x)
+ log_writer.add_scalar("train_iter", epoch_1000x, epoch_1000x)
+ for name, val in loss_details.items():
+ log_writer.add_scalar("train_" + name, val, epoch_1000x)
+
+ # # Gather the stats from all processes
+ # metric_logger.synchronize_between_processes()
+ # print("Averaged stats:", metric_logger)
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
+
+
+@torch.no_grad()
+def test_one_epoch(
+ model: torch.nn.Module,
+ criterion: torch.nn.Module,
+ data_loader: Sized,
+ device: torch.device,
+ epoch: int,
+ args,
+ log_writer=None,
+ prefix="test",
+):
+ """
+ Evaluates the model on a test dataset for one epoch.
+ Epoch is just a chunk of the entire dataset.
+
+ This function runs evaluation on the test dataset without computing gradients,
+ and collects metrics for model performance assessment.
+
+ Args:
+ model: The neural network model to evaluate.
+ criterion: Loss function for evaluation.
+ data_loader: DataLoader providing the test data.
+ device: Device to run evaluation on (CPU or GPU).
+ epoch: Current epoch number.
+ args: Configuration object containing evaluation parameters.
+ log_writer: Optional; TensorBoard SummaryWriter for logging.
+ prefix: String prefix for logging metrics.
+
+ Returns:
+ dict: Dictionary containing evaluation metrics (average and median values).
+ """
+ model.eval()
+ metric_logger = train_tools.MetricLogger(delimiter=" ")
+ metric_logger.meters = defaultdict(
+ lambda: train_tools.SmoothedValue(window_size=9**9)
+ )
+ header = "Test Epoch: [{}]".format(epoch)
+
+ if log_writer is not None:
+ print("log_dir: {}".format(log_writer.log_dir))
+
+ if args.train_params.freeze_val_samples_across_all_epochs:
+ dataloader_epoch = 0
+ else:
+ dataloader_epoch = epoch
+ if hasattr(data_loader, "dataset") and hasattr(data_loader.dataset, "set_epoch"):
+ data_loader.dataset.set_epoch(dataloader_epoch)
+ if hasattr(data_loader, "sampler") and hasattr(data_loader.sampler, "set_epoch"):
+ data_loader.sampler.set_epoch(dataloader_epoch)
+ if hasattr(data_loader, "batch_sampler") and hasattr(
+ data_loader.batch_sampler, "set_epoch"
+ ):
+ data_loader.batch_sampler.set_epoch(dataloader_epoch)
+
+ for _, batch in enumerate(
+ metric_logger.log_every(data_loader, args.train_params.print_freq, header)
+ ):
+ n_views = len(batch)
+ loss_tuple = loss_of_one_batch_multi_view(
+ batch,
+ model,
+ criterion,
+ device,
+ use_amp=bool(args.train_params.amp),
+ amp_dtype=args.train_params.amp_dtype,
+ ret="loss",
+ )
+ loss_value, loss_details = loss_tuple # criterion returns two values
+ if n_views > 2:
+ loss_value = loss_value * (
+ 2 / n_views
+ ) # scale the loss relative to the number of views (base is 2 views)
+ metric_logger.update(loss=float(loss_value), **loss_details)
+
+ # # Gather the stats from all processes
+ # metric_logger.synchronize_between_processes()
+ # print("Averaged stats:", metric_logger)
+
+ aggs = [("avg", "global_avg"), ("med", "median")]
+ results = {
+ f"{k}_{tag}": getattr(meter, attr)
+ for k, meter in metric_logger.meters.items()
+ for tag, attr in aggs
+ }
+
+ if log_writer is not None:
+ for name, val in results.items():
+ log_writer.add_scalar(prefix + "_" + name, val, 1000 * epoch)
+
+ return results
diff --git a/mapanything/utils/__init__.py b/mapanything/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/utils/cropping.py b/mapanything/utils/cropping.py
new file mode 100644
index 0000000000000000000000000000000000000000..709ff34f56276e80dc842d6689a88bc84363679b
--- /dev/null
+++ b/mapanything/utils/cropping.py
@@ -0,0 +1,462 @@
+"""
+Utility functions for cropping and resizing data while maintaining proper cameras.
+
+References: DUSt3R
+"""
+
+import cv2
+import numpy as np
+import PIL.Image
+
+try:
+ lanczos = PIL.Image.Resampling.LANCZOS
+ bicubic = PIL.Image.Resampling.BICUBIC
+except AttributeError:
+ lanczos = PIL.Image.LANCZOS
+ bicubic = PIL.Image.BICUBIC
+
+from mapanything.utils.geometry import (
+ colmap_to_opencv_intrinsics,
+ opencv_to_colmap_intrinsics,
+)
+
+
+class ImageList:
+ """
+ Convenience class to apply the same operation to a whole set of images.
+
+ This class wraps a list of PIL.Image objects and provides methods to perform
+ operations on all images simultaneously.
+ """
+
+ def __init__(self, images):
+ if not isinstance(images, (tuple, list, set)):
+ images = [images]
+ self.images = []
+ for image in images:
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+ self.images.append(image)
+
+ def __len__(self):
+ """Return the number of images in the list."""
+ return len(self.images)
+
+ def to_pil(self):
+ """
+ Convert ImageList back to PIL Image(s).
+
+ Returns:
+ PIL.Image.Image or tuple: Single PIL Image if list contains one image,
+ or tuple of PIL Images if multiple images
+ """
+ return tuple(self.images) if len(self.images) > 1 else self.images[0]
+
+ @property
+ def size(self):
+ """
+ Get the size of images in the list.
+
+ Returns:
+ tuple: (width, height) of the images
+
+ Raises:
+ AssertionError: If images have different sizes
+ """
+ sizes = [im.size for im in self.images]
+ assert all(sizes[0] == s for s in sizes), "All images must have the same size"
+ return sizes[0]
+
+ def resize(self, *args, **kwargs):
+ """
+ Resize all images with the same parameters.
+
+ Args:
+ *args, **kwargs: Arguments passed to PIL.Image.resize()
+
+ Returns:
+ ImageList: New ImageList containing resized images
+ """
+ return ImageList(self._dispatch("resize", *args, **kwargs))
+
+ def crop(self, *args, **kwargs):
+ """
+ Crop all images with the same parameters.
+
+ Args:
+ *args, **kwargs: Arguments passed to PIL.Image.crop()
+
+ Returns:
+ ImageList: New ImageList containing cropped images
+ """
+ return ImageList(self._dispatch("crop", *args, **kwargs))
+
+ def _dispatch(self, func, *args, **kwargs):
+ """
+ Apply a PIL.Image method to all images in the list.
+
+ Args:
+ func (str): Name of the PIL.Image method to call
+ *args, **kwargs: Arguments to pass to the method
+
+ Returns:
+ list: List of results from applying the method to each image
+ """
+ return [getattr(im, func)(*args, **kwargs) for im in self.images]
+
+
+def resize_with_nearest_interpolation_to_match_aspect_ratio(input_data, img_h, img_w):
+ """
+ Resize input map to match the aspect ratio of an image while ensuring
+ the input resolution never increases beyond the original.
+ Uses nearest interpolation for resizing.
+
+ Args:
+ input_data (np.ndarray): The input map to resize
+ img_h (int): Height of the target image
+ img_w (int): Width of the target image
+
+ Returns:
+ tuple: (resized_input, target_h, target_w)
+ - resized_input: The resized input map
+ - target_h: The target height used for resizing
+ - target_w: The target width used for resizing
+ """
+ # Get the dimensions of the input map
+ input_h, input_w = input_data.shape[:2]
+
+ # Calculate aspect ratios
+ img_aspect = img_w / img_h
+
+ # Option 1: Keep input_w fixed and calculate new height
+ option1_h = int(input_w / img_aspect)
+ # Option 2: Keep input_h fixed and calculate new width
+ option2_w = int(input_h * img_aspect)
+
+ # Check if either option would increase a dimension
+ option1_increases = option1_h > input_h
+ option2_increases = option2_w > input_w
+
+ if option1_increases and option2_increases:
+ # Both options would increase a dimension, so we need to scale down both dimensions
+ # Find the scaling factor that preserves aspect ratio and ensures no dimension increases
+ scale_h = input_h / img_h
+ scale_w = input_w / img_w
+ scale = min(scale_h, scale_w)
+
+ target_input_h = int(img_h * scale)
+ target_input_w = int(img_w * scale)
+ elif option1_increases:
+ # Option 1 would increase height, so use option 2
+ target_input_h = input_h
+ target_input_w = option2_w
+ elif option2_increases:
+ # Option 2 would increase width, so use option 1
+ target_input_w = input_w
+ target_input_h = option1_h
+ else:
+ # Neither option increases dimensions, choose the one that maintains resolution better
+ if abs(input_h * input_w - input_w * option1_h) < abs(
+ input_h * input_w - option2_w * input_h
+ ):
+ # Option 1 is better: keep width fixed, adjust height
+ target_input_w = input_w
+ target_input_h = option1_h
+ else:
+ # Option 2 is better: keep height fixed, adjust width
+ target_input_h = input_h
+ target_input_w = option2_w
+
+ # Resize input using nearest interpolation to maintain input values
+ if target_input_h != input_h or target_input_w != input_w:
+ resized_input = cv2.resize(
+ input_data,
+ (target_input_w, target_input_h),
+ interpolation=cv2.INTER_NEAREST,
+ )
+ else:
+ resized_input = input_data
+
+ return resized_input, target_input_h, target_input_w
+
+
+def rescale_image_and_other_optional_info(
+ image,
+ output_resolution,
+ depthmap=None,
+ camera_intrinsics=None,
+ force=True,
+ additional_quantities_to_be_resized_with_nearest=None,
+):
+ """
+ Rescale the image and depthmap to the output resolution.
+ If the image is larger than the output resolution, it is rescaled with lanczos interpolation.
+ If force is false and the image is smaller than the output resolution, it is not rescaled.
+ If force is true and the image is smaller than the output resolution, it is rescaled with bicubic interpolation.
+ Depth and other quantities are rescaled with nearest interpolation.
+
+ Args:
+ image (PIL.Image.Image or np.ndarray): The input image to be rescaled.
+ output_resolution (tuple): The desired output resolution as a tuple (width, height).
+ depthmap (np.ndarray, optional): The depth map associated with the image. Defaults to None.
+ camera_intrinsics (np.ndarray, optional): The camera intrinsics matrix. Defaults to None.
+ force (bool, optional): If True, force rescaling even if the image is smaller than the output resolution. Defaults to True.
+ additional_quantities_to_be_resized_with_nearest (list of np.ndarray, optional): Additional quantities to be rescaled using nearest interpolation. Defaults to None.
+
+ Returns:
+ tuple: A tuple containing:
+ - The rescaled image (PIL.Image.Image)
+ - The rescaled depthmap (numpy.ndarray or None)
+ - The updated camera intrinsics (numpy.ndarray or None)
+ - The list of rescaled additional quantities (list of numpy.ndarray or None)
+ """
+ image = ImageList(image)
+ input_resolution = np.array(image.size) # (W, H)
+ output_resolution = np.array(output_resolution)
+ if depthmap is not None:
+ assert tuple(depthmap.shape[:2]) == image.size[::-1]
+ if additional_quantities_to_be_resized_with_nearest is not None:
+ assert all(
+ tuple(additional_quantity.shape[:2]) == image.size[::-1]
+ for additional_quantity in additional_quantities_to_be_resized_with_nearest
+ )
+
+ # Define output resolution
+ assert output_resolution.shape == (2,)
+ scale_final = max(output_resolution / image.size) + 1e-8
+ if scale_final >= 1 and not force: # image is already smaller than what is asked
+ output = (
+ image.to_pil(),
+ depthmap,
+ camera_intrinsics,
+ additional_quantities_to_be_resized_with_nearest,
+ )
+ return output
+ output_resolution = np.floor(input_resolution * scale_final).astype(int)
+
+ # First rescale the image so that it contains the crop
+ image = image.resize(
+ tuple(output_resolution), resample=lanczos if scale_final < 1 else bicubic
+ )
+ if depthmap is not None:
+ depthmap = cv2.resize(
+ depthmap,
+ output_resolution,
+ fx=scale_final,
+ fy=scale_final,
+ interpolation=cv2.INTER_NEAREST,
+ )
+ if additional_quantities_to_be_resized_with_nearest is not None:
+ resized_additional_quantities = []
+ for quantity in additional_quantities_to_be_resized_with_nearest:
+ resized_additional_quantities.append(
+ cv2.resize(
+ quantity,
+ output_resolution,
+ fx=scale_final,
+ fy=scale_final,
+ interpolation=cv2.INTER_NEAREST,
+ )
+ )
+ additional_quantities_to_be_resized_with_nearest = resized_additional_quantities
+
+ # No offset here; simple rescaling
+ if camera_intrinsics is not None:
+ camera_intrinsics = camera_matrix_of_crop(
+ camera_intrinsics, input_resolution, output_resolution, scaling=scale_final
+ )
+
+ # Return
+ return (
+ image.to_pil(),
+ depthmap,
+ camera_intrinsics,
+ additional_quantities_to_be_resized_with_nearest,
+ )
+
+
+def camera_matrix_of_crop(
+ input_camera_matrix,
+ input_resolution,
+ output_resolution,
+ scaling=1,
+ offset_factor=0.5,
+ offset=None,
+):
+ """
+ Calculate the camera matrix for a cropped image.
+
+ Args:
+ input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix
+ input_resolution (tuple or numpy.ndarray): Original image resolution as (width, height)
+ output_resolution (tuple or numpy.ndarray): Target image resolution as (width, height)
+ scaling (float, optional): Scaling factor for the image. Defaults to 1.
+ offset_factor (float, optional): Factor to determine crop offset. Defaults to 0.5 (centered).
+ offset (tuple or numpy.ndarray, optional): Explicit offset to use. If None, calculated from offset_factor.
+
+ Returns:
+ numpy.ndarray: Updated camera matrix for the cropped image
+ """
+ # Margins to offset the origin
+ margins = np.asarray(input_resolution) * scaling - output_resolution
+ assert np.all(margins >= 0.0)
+ if offset is None:
+ offset = offset_factor * margins
+
+ # Generate new camera parameters
+ output_camera_matrix_colmap = opencv_to_colmap_intrinsics(input_camera_matrix)
+ output_camera_matrix_colmap[:2, :] *= scaling
+ output_camera_matrix_colmap[:2, 2] -= offset
+ output_camera_matrix = colmap_to_opencv_intrinsics(output_camera_matrix_colmap)
+
+ return output_camera_matrix
+
+
+def crop_image_and_other_optional_info(
+ image,
+ crop_bbox,
+ depthmap=None,
+ camera_intrinsics=None,
+ additional_quantities=None,
+):
+ """
+ Return a crop of the input view and associated data.
+
+ Args:
+ image (PIL.Image.Image or numpy.ndarray): The input image to be cropped
+ crop_bbox (tuple): Crop bounding box as (left, top, right, bottom)
+ depthmap (numpy.ndarray, optional): Depth map associated with the image
+ camera_intrinsics (numpy.ndarray, optional): Camera intrinsics matrix
+ additional_quantities (list of numpy.ndarray, optional): Additional data arrays to crop
+
+ Returns:
+ tuple: A tuple containing:
+ - The cropped image
+ - The cropped depth map (if provided or None)
+ - Updated camera intrinsics (if provided or None)
+ - List of cropped additional quantities (if provided or None)
+ """
+ image = ImageList(image)
+ left, top, right, bottom = crop_bbox
+
+ image = image.crop((left, top, right, bottom))
+ if depthmap is not None:
+ depthmap = depthmap[top:bottom, left:right]
+ if additional_quantities is not None:
+ additional_quantities = [
+ quantity[top:bottom, left:right] for quantity in additional_quantities
+ ]
+
+ if camera_intrinsics is not None:
+ camera_intrinsics = camera_intrinsics.copy()
+ camera_intrinsics[0, 2] -= left
+ camera_intrinsics[1, 2] -= top
+
+ return (image.to_pil(), depthmap, camera_intrinsics, additional_quantities)
+
+
+def bbox_from_intrinsics_in_out(
+ input_camera_matrix, output_camera_matrix, output_resolution
+):
+ """
+ Calculate the bounding box for cropping based on input and output camera intrinsics.
+
+ Args:
+ input_camera_matrix (numpy.ndarray): Original camera intrinsics matrix
+ output_camera_matrix (numpy.ndarray): Target camera intrinsics matrix
+ output_resolution (tuple): Target resolution as (width, height)
+
+ Returns:
+ tuple: Crop bounding box as (left, top, right, bottom)
+ """
+ out_width, out_height = output_resolution
+ left, top = np.int32(
+ np.round(input_camera_matrix[:2, 2] - output_camera_matrix[:2, 2])
+ )
+ crop_bbox = (left, top, left + out_width, top + out_height)
+ return crop_bbox
+
+
+def crop_resize_if_necessary(
+ image,
+ resolution,
+ depthmap=None,
+ intrinsics=None,
+ additional_quantities=None,
+):
+ """
+ First downsample image using LANCZOS and then crop if necessary to achieve target resolution.
+
+ This function performs high-quality downsampling followed by cropping to achieve the
+ desired output resolution while maintaining proper camera intrinsics.
+
+ Args:
+ image (PIL.Image.Image or numpy.ndarray): The input image to be processed
+ resolution (tuple): Target resolution as (width, height)
+ depthmap (numpy.ndarray, optional): Depth map associated with the image
+ intrinsics (numpy.ndarray, optional): Camera intrinsics matrix
+ additional_quantities (list of numpy.ndarray, optional): Additional data arrays to process
+
+ Returns:
+ tuple: A tuple containing the processed image and any provided additional data
+ (depthmap, intrinsics, additional_quantities) that have been similarly processed
+ """
+ # Convert image to PIL.Image.Image if necessary
+ if not isinstance(image, PIL.Image.Image):
+ image = PIL.Image.fromarray(image)
+
+ # Get width and height of image
+ original_width, original_height = image.size
+
+ # High-quality Lanczos down-scaling
+ target_rescale_resolution = np.array(resolution)
+ image, depthmap, intrinsics, additional_quantities = (
+ rescale_image_and_other_optional_info(
+ image=image,
+ output_resolution=target_rescale_resolution,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities_to_be_resized_with_nearest=additional_quantities,
+ )
+ )
+
+ # Actual cropping (if necessary)
+ if intrinsics is not None:
+ new_intrinsics = camera_matrix_of_crop(
+ input_camera_matrix=intrinsics,
+ input_resolution=image.size,
+ output_resolution=resolution,
+ offset_factor=0.5,
+ )
+ crop_bbox = bbox_from_intrinsics_in_out(
+ input_camera_matrix=intrinsics,
+ output_camera_matrix=new_intrinsics,
+ output_resolution=resolution,
+ )
+ else:
+ # Create a centered crop if no intrinsics are available
+ w, h = image.size
+ target_w, target_h = resolution
+ left = (w - target_w) // 2
+ top = (h - target_h) // 2
+ crop_bbox = (left, top, left + target_w, top + target_h)
+
+ image, depthmap, new_intrinsics, additional_quantities = (
+ crop_image_and_other_optional_info(
+ image=image,
+ crop_bbox=crop_bbox,
+ depthmap=depthmap,
+ camera_intrinsics=intrinsics,
+ additional_quantities=additional_quantities,
+ )
+ )
+
+ # Return the output
+ output = (image,)
+ if depthmap is not None:
+ output += (depthmap,)
+ if new_intrinsics is not None:
+ output += (new_intrinsics,)
+ if additional_quantities is not None:
+ output += (additional_quantities,)
+ return output
diff --git a/mapanything/utils/device.py b/mapanything/utils/device.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc9ca662767457f89d8801bdfba786bb99c6ec2a
--- /dev/null
+++ b/mapanything/utils/device.py
@@ -0,0 +1,83 @@
+"""
+Utility functions for managing computation device
+"""
+
+import numpy as np
+import torch
+
+
+def to_device(batch, device, callback=None, non_blocking=False):
+ """
+ Transfer data to another device (i.e. GPU, CPU:torch, CPU:numpy).
+
+ This function recursively processes nested data structures (lists, tuples, dicts)
+ and transfers each tensor to the specified device.
+
+ Args:
+ batch: Data to transfer (list, tuple, dict of tensors or other objects)
+ device: Target device - pytorch device (e.g., 'cuda', 'cpu') or 'numpy'
+ callback: Optional function that would be called on every element before processing
+ non_blocking: If True, allows asynchronous copy to GPU (may be faster)
+
+ Returns:
+ Data with the same structure as input but with tensors transferred to target device
+ """
+ if callback:
+ batch = callback(batch)
+
+ if isinstance(batch, dict):
+ return {
+ k: to_device(v, device, non_blocking=non_blocking) for k, v in batch.items()
+ }
+
+ if isinstance(batch, (tuple, list)):
+ return type(batch)(
+ to_device(x, device, non_blocking=non_blocking) for x in batch
+ )
+
+ x = batch
+ if device == "numpy":
+ if isinstance(x, torch.Tensor):
+ x = x.detach().cpu().numpy()
+ elif x is not None:
+ if isinstance(x, np.ndarray):
+ x = torch.from_numpy(x)
+ if torch.is_tensor(x):
+ x = x.to(device, non_blocking=non_blocking)
+ return x
+
+
+def to_numpy(x):
+ """Convert data to numpy arrays.
+
+ Args:
+ x: Input data (can be tensor, array, or nested structure)
+
+ Returns:
+ Data with the same structure but with tensors converted to numpy arrays
+ """
+ return to_device(x, "numpy")
+
+
+def to_cpu(x):
+ """Transfer data to CPU.
+
+ Args:
+ x: Input data (can be tensor, array, or nested structure)
+
+ Returns:
+ Data with the same structure but with tensors moved to CPU
+ """
+ return to_device(x, "cpu")
+
+
+def to_cuda(x):
+ """Transfer data to CUDA device (GPU).
+
+ Args:
+ x: Input data (can be tensor, array, or nested structure)
+
+ Returns:
+ Data with the same structure but with tensors moved to GPU
+ """
+ return to_device(x, "cuda")
diff --git a/mapanything/utils/geometry.py b/mapanything/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..5431ba6c765337c4988d181c0d6f8502c20f6a78
--- /dev/null
+++ b/mapanything/utils/geometry.py
@@ -0,0 +1,2183 @@
+"""
+Utilities for geometry operations.
+
+References: DUSt3R, MoGe
+"""
+
+from numbers import Number
+from typing import Tuple, Union
+
+import einops as ein
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from mapanything.utils.misc import invalid_to_zeros
+from mapanything.utils.warnings import no_warnings
+
+
+def depthmap_to_camera_frame(depthmap, intrinsics):
+ """
+ Convert depth image to a pointcloud in camera frame.
+
+ Args:
+ - depthmap: HxW or BxHxW torch tensor
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+
+ Returns:
+ pointmap in camera frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels.
+ """
+ # Add batch dimension if not present
+ if depthmap.dim() == 2:
+ depthmap = depthmap.unsqueeze(0)
+ intrinsics = intrinsics.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size, height, width = depthmap.shape
+ device = depthmap.device
+
+ # Compute 3D point in camera frame associated with each pixel
+ x_grid, y_grid = torch.meshgrid(
+ torch.arange(width, device=device).float(),
+ torch.arange(height, device=device).float(),
+ indexing="xy",
+ )
+ x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1)
+ y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1)
+
+ fx = intrinsics[:, 0, 0].view(-1, 1, 1)
+ fy = intrinsics[:, 1, 1].view(-1, 1, 1)
+ cx = intrinsics[:, 0, 2].view(-1, 1, 1)
+ cy = intrinsics[:, 1, 2].view(-1, 1, 1)
+
+ depth_z = depthmap
+ xx = (x_grid - cx) * depth_z / fx
+ yy = (y_grid - cy) * depth_z / fy
+ pts3d_cam = torch.stack((xx, yy, depth_z), dim=-1)
+
+ # Compute mask of valid non-zero depth pixels
+ valid_mask = depthmap > 0.0
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ pts3d_cam = pts3d_cam.squeeze(0)
+ valid_mask = valid_mask.squeeze(0)
+
+ return pts3d_cam, valid_mask
+
+
+def depthmap_to_world_frame(depthmap, intrinsics, camera_pose=None):
+ """
+ Convert depth image to a pointcloud in world frame.
+
+ Args:
+ - depthmap: HxW or BxHxW torch tensor
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+ - camera_pose: 4x4 or Bx4x4 torch tensor
+
+ Returns:
+ pointmap in world frame (HxWx3 or BxHxWx3 tensor), and a mask specifying valid pixels.
+ """
+ pts3d_cam, valid_mask = depthmap_to_camera_frame(depthmap, intrinsics)
+
+ if camera_pose is not None:
+ # Add batch dimension if not present
+ if camera_pose.dim() == 2:
+ camera_pose = camera_pose.unsqueeze(0)
+ pts3d_cam = pts3d_cam.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Convert points from camera frame to world frame
+ pts3d_cam_homo = torch.cat(
+ [pts3d_cam, torch.ones_like(pts3d_cam[..., :1])], dim=-1
+ )
+ pts3d_world = ein.einsum(
+ camera_pose, pts3d_cam_homo, "b i k, b h w k -> b h w i"
+ )
+ pts3d_world = pts3d_world[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ pts3d_world = pts3d_world.squeeze(0)
+ else:
+ pts3d_world = pts3d_cam
+
+ return pts3d_world, valid_mask
+
+
+def transform_pts3d(pts3d, transformation):
+ """
+ Transform 3D points using a 4x4 transformation matrix.
+
+ Args:
+ - pts3d: HxWx3 or BxHxWx3 torch tensor
+ - transformation: 4x4 or Bx4x4 torch tensor
+
+ Returns:
+ transformed points (HxWx3 or BxHxWx3 tensor)
+ """
+ # Add batch dimension if not present
+ if pts3d.dim() == 3:
+ pts3d = pts3d.unsqueeze(0)
+ transformation = transformation.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Convert points to homogeneous coordinates
+ pts3d_homo = torch.cat([pts3d, torch.ones_like(pts3d[..., :1])], dim=-1)
+
+ # Transform points
+ transformed_pts3d = ein.einsum(
+ transformation, pts3d_homo, "b i k, b h w k -> b h w i"
+ )
+ transformed_pts3d = transformed_pts3d[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ transformed_pts3d = transformed_pts3d.squeeze(0)
+
+ return transformed_pts3d
+
+
+def project_pts3d_to_image(pts3d, intrinsics, return_z_dim):
+ """
+ Project 3D points to image plane (assumes pinhole camera model with no distortion).
+
+ Args:
+ - pts3d: HxWx3 or BxHxWx3 torch tensor
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+ - return_z_dim: bool, whether to return the third dimension of the projected points
+
+ Returns:
+ projected points (HxWx2)
+ """
+ if pts3d.dim() == 3:
+ pts3d = pts3d.unsqueeze(0)
+ intrinsics = intrinsics.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Project points to image plane
+ projected_pts2d = ein.einsum(intrinsics, pts3d, "b i k, b h w k -> b h w i")
+ projected_pts2d[..., :2] /= projected_pts2d[..., 2].unsqueeze(-1).clamp(min=1e-6)
+
+ # Remove the z dimension if not required
+ if not return_z_dim:
+ projected_pts2d = projected_pts2d[..., :2]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ projected_pts2d = projected_pts2d.squeeze(0)
+
+ return projected_pts2d
+
+
+def get_rays_in_camera_frame(intrinsics, height, width, normalize_to_unit_sphere):
+ """
+ Convert camera intrinsics to a raymap (ray origins + directions) in camera frame.
+ Note: Currently only supports pinhole camera model.
+
+ Args:
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+ - height: int
+ - width: int
+ - normalize_to_unit_sphere: bool
+
+ Returns:
+ - ray_origins: (HxWx3 or BxHxWx3) tensor
+ - ray_directions: (HxWx3 or BxHxWx3) tensor
+ """
+ # Add batch dimension if not present
+ if intrinsics.dim() == 2:
+ intrinsics = intrinsics.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size = intrinsics.shape[0]
+ device = intrinsics.device
+
+ # Compute rays in camera frame associated with each pixel
+ x_grid, y_grid = torch.meshgrid(
+ torch.arange(width, device=device).float(),
+ torch.arange(height, device=device).float(),
+ indexing="xy",
+ )
+ x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1)
+ y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1)
+
+ fx = intrinsics[:, 0, 0].view(-1, 1, 1)
+ fy = intrinsics[:, 1, 1].view(-1, 1, 1)
+ cx = intrinsics[:, 0, 2].view(-1, 1, 1)
+ cy = intrinsics[:, 1, 2].view(-1, 1, 1)
+
+ ray_origins = torch.zeros((batch_size, height, width, 3), device=device)
+ xx = (x_grid - cx) / fx
+ yy = (y_grid - cy) / fy
+ ray_directions = torch.stack((xx, yy, torch.ones_like(xx)), dim=-1)
+
+ # Normalize ray directions to unit sphere if required (else rays will lie on unit plane)
+ if normalize_to_unit_sphere:
+ ray_directions = ray_directions / torch.norm(
+ ray_directions, dim=-1, keepdim=True
+ )
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ ray_origins = ray_origins.squeeze(0)
+ ray_directions = ray_directions.squeeze(0)
+
+ return ray_origins, ray_directions
+
+
+def get_rays_in_world_frame(
+ intrinsics, height, width, normalize_to_unit_sphere, camera_pose=None
+):
+ """
+ Convert camera intrinsics & camera_pose (if provided) to a raymap (ray origins + directions) in camera or world frame (if camera_pose is provided).
+ Note: Currently only supports pinhole camera model.
+
+ Args:
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+ - height: int
+ - width: int
+ - normalize_to_unit_sphere: bool
+ - camera_pose: 4x4 or Bx4x4 torch tensor
+
+ Returns:
+ - ray_origins: (HxWx3 or BxHxWx3) tensor
+ - ray_directions: (HxWx3 or BxHxWx3) tensor
+ """
+ # Get rays in camera frame
+ ray_origins, ray_directions = get_rays_in_camera_frame(
+ intrinsics, height, width, normalize_to_unit_sphere
+ )
+
+ if camera_pose is not None:
+ # Add batch dimension if not present
+ if camera_pose.dim() == 2:
+ camera_pose = camera_pose.unsqueeze(0)
+ ray_origins = ray_origins.unsqueeze(0)
+ ray_directions = ray_directions.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Convert rays from camera frame to world frame
+ ray_origins_homo = torch.cat(
+ [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1
+ )
+ ray_directions_homo = torch.cat(
+ [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1
+ )
+ ray_origins_world = ein.einsum(
+ camera_pose, ray_origins_homo, "b i k, b h w k -> b h w i"
+ )
+ ray_directions_world = ein.einsum(
+ camera_pose, ray_directions_homo, "b i k, b h w k -> b h w i"
+ )
+ ray_origins_world = ray_origins_world[..., :3]
+ ray_directions_world = ray_directions_world[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ ray_origins_world = ray_origins_world.squeeze(0)
+ ray_directions_world = ray_directions_world.squeeze(0)
+ else:
+ ray_origins_world = ray_origins
+ ray_directions_world = ray_directions
+
+ return ray_origins_world, ray_directions_world
+
+
+def recover_pinhole_intrinsics_from_ray_directions(
+ ray_directions, use_geometric_calculation=False
+):
+ """
+ Recover pinhole camera intrinsics from ray directions, supporting both batched and non-batched inputs.
+
+ Args:
+ ray_directions: Tensor of shape [H, W, 3] or [B, H, W, 3] containing unit normalized ray directions
+
+ Returns:
+ Dictionary containing camera intrinsics (fx, fy, cx, cy) as tensors
+ """
+ # Add batch dimension if not present
+ if ray_directions.dim() == 3: # [H, W, 3]
+ ray_directions = ray_directions.unsqueeze(0) # [1, H, W, 3]
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size, height, width, _ = ray_directions.shape
+ device = ray_directions.device
+
+ # Create pixel coordinate grid
+ x_grid, y_grid = torch.meshgrid(
+ torch.arange(width, device=device).float(),
+ torch.arange(height, device=device).float(),
+ indexing="xy",
+ )
+
+ # Expand grid for all batches
+ x_grid = x_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W]
+ y_grid = y_grid.unsqueeze(0).expand(batch_size, -1, -1) # [B, H, W]
+
+ # Determine if high resolution or not
+ is_high_res = height * width > 1000000
+
+ if is_high_res or use_geometric_calculation:
+ # For high-resolution cases, use direct geometric calculation
+ # Define key points
+ center_h, center_w = height // 2, width // 2
+ quarter_w, three_quarter_w = width // 4, 3 * width // 4
+ quarter_h, three_quarter_h = height // 4, 3 * height // 4
+
+ # Get rays at key points
+ center_rays = ray_directions[:, center_h, center_w, :].clone() # [B, 3]
+ left_rays = ray_directions[:, center_h, quarter_w, :].clone() # [B, 3]
+ right_rays = ray_directions[:, center_h, three_quarter_w, :].clone() # [B, 3]
+ top_rays = ray_directions[:, quarter_h, center_w, :].clone() # [B, 3]
+ bottom_rays = ray_directions[:, three_quarter_h, center_w, :].clone() # [B, 3]
+
+ # Normalize rays to have dz = 1
+ center_rays = center_rays / center_rays[:, 2].unsqueeze(1) # [B, 3]
+ left_rays = left_rays / left_rays[:, 2].unsqueeze(1) # [B, 3]
+ right_rays = right_rays / right_rays[:, 2].unsqueeze(1) # [B, 3]
+ top_rays = top_rays / top_rays[:, 2].unsqueeze(1) # [B, 3]
+ bottom_rays = bottom_rays / bottom_rays[:, 2].unsqueeze(1) # [B, 3]
+
+ # Calculate fx directly (vectorized across batch)
+ fx_left = (quarter_w - center_w) / (left_rays[:, 0] - center_rays[:, 0])
+ fx_right = (three_quarter_w - center_w) / (right_rays[:, 0] - center_rays[:, 0])
+ fx = (fx_left + fx_right) / 2 # Average for robustness
+
+ # Calculate cx
+ cx = center_w - fx * center_rays[:, 0]
+
+ # Calculate fy and cy
+ fy_top = (quarter_h - center_h) / (top_rays[:, 1] - center_rays[:, 1])
+ fy_bottom = (three_quarter_h - center_h) / (
+ bottom_rays[:, 1] - center_rays[:, 1]
+ )
+ fy = (fy_top + fy_bottom) / 2
+
+ cy = center_h - fy * center_rays[:, 1]
+ else:
+ # For standard resolution, use regression with sampling for efficiency
+ # Sample a grid of points (but more dense than for high-res)
+ step_h = max(1, height // 50)
+ step_w = max(1, width // 50)
+
+ h_indices = torch.arange(0, height, step_h, device=device)
+ w_indices = torch.arange(0, width, step_w, device=device)
+
+ # Extract subset of coordinates
+ x_sampled = x_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W']
+ y_sampled = y_grid[:, h_indices[:, None], w_indices[None, :]] # [B, H', W']
+ rays_sampled = ray_directions[
+ :, h_indices[:, None], w_indices[None, :], :
+ ] # [B, H', W', 3]
+
+ # Reshape for linear regression
+ x_flat = x_sampled.reshape(batch_size, -1) # [B, N]
+ y_flat = y_sampled.reshape(batch_size, -1) # [B, N]
+
+ # Extract ray direction components
+ dx = rays_sampled[..., 0].reshape(batch_size, -1) # [B, N]
+ dy = rays_sampled[..., 1].reshape(batch_size, -1) # [B, N]
+ dz = rays_sampled[..., 2].reshape(batch_size, -1) # [B, N]
+
+ # Compute ratios for linear regression
+ ratio_x = dx / dz # [B, N]
+ ratio_y = dy / dz # [B, N]
+
+ # Since torch.linalg.lstsq doesn't support batched input, we'll use a different approach
+ # For x-direction: x = cx + fx * (dx/dz)
+ # We can solve this using normal equations: A^T A x = A^T b
+ # Create design matrices
+ ones = torch.ones_like(x_flat) # [B, N]
+ A_x = torch.stack([ones, ratio_x], dim=2) # [B, N, 2]
+ b_x = x_flat.unsqueeze(2) # [B, N, 1]
+
+ # Compute A^T A and A^T b for each batch
+ ATA_x = torch.bmm(A_x.transpose(1, 2), A_x) # [B, 2, 2]
+ ATb_x = torch.bmm(A_x.transpose(1, 2), b_x) # [B, 2, 1]
+
+ # Solve the system for each batch
+ solution_x = torch.linalg.solve(ATA_x, ATb_x).squeeze(2) # [B, 2]
+ cx, fx = solution_x[:, 0], solution_x[:, 1]
+
+ # Repeat for y-direction
+ A_y = torch.stack([ones, ratio_y], dim=2) # [B, N, 2]
+ b_y = y_flat.unsqueeze(2) # [B, N, 1]
+
+ ATA_y = torch.bmm(A_y.transpose(1, 2), A_y) # [B, 2, 2]
+ ATb_y = torch.bmm(A_y.transpose(1, 2), b_y) # [B, 2, 1]
+
+ solution_y = torch.linalg.solve(ATA_y, ATb_y).squeeze(2) # [B, 2]
+ cy, fy = solution_y[:, 0], solution_y[:, 1]
+
+ # Create intrinsics matrices
+ batch_size = fx.shape[0]
+ intrinsics = torch.zeros(batch_size, 3, 3, device=ray_directions.device)
+
+ # Fill in the intrinsics matrices
+ intrinsics[:, 0, 0] = fx # focal length x
+ intrinsics[:, 1, 1] = fy # focal length y
+ intrinsics[:, 0, 2] = cx # principal point x
+ intrinsics[:, 1, 2] = cy # principal point y
+ intrinsics[:, 2, 2] = 1.0 # bottom-right element is always 1
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ intrinsics = intrinsics.squeeze(0)
+
+ return intrinsics
+
+
+def transform_rays(ray_origins, ray_directions, transformation):
+ """
+ Transform 6D rays (ray origins and ray directions) using a 4x4 transformation matrix.
+
+ Args:
+ - ray_origins: HxWx3 or BxHxWx3 torch tensor
+ - ray_directions: HxWx3 or BxHxWx3 torch tensor
+ - transformation: 4x4 or Bx4x4 torch tensor
+ - normalize_to_unit_sphere: bool, whether to normalize the transformed ray directions to unit length
+
+ Returns:
+ transformed ray_origins (HxWx3 or BxHxWx3 tensor) and ray_directions (HxWx3 or BxHxWx3 tensor)
+ """
+ # Add batch dimension if not present
+ if ray_origins.dim() == 3:
+ ray_origins = ray_origins.unsqueeze(0)
+ ray_directions = ray_directions.unsqueeze(0)
+ transformation = transformation.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Transform ray origins and directions
+ ray_origins_homo = torch.cat(
+ [ray_origins, torch.ones_like(ray_origins[..., :1])], dim=-1
+ )
+ ray_directions_homo = torch.cat(
+ [ray_directions, torch.zeros_like(ray_directions[..., :1])], dim=-1
+ )
+ transformed_ray_origins = ein.einsum(
+ transformation, ray_origins_homo, "b i k, b h w k -> b h w i"
+ )
+ transformed_ray_directions = ein.einsum(
+ transformation, ray_directions_homo, "b i k, b h w k -> b h w i"
+ )
+ transformed_ray_origins = transformed_ray_origins[..., :3]
+ transformed_ray_directions = transformed_ray_directions[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ transformed_ray_origins = transformed_ray_origins.squeeze(0)
+ transformed_ray_directions = transformed_ray_directions.squeeze(0)
+
+ return transformed_ray_origins, transformed_ray_directions
+
+
+def convert_z_depth_to_depth_along_ray(z_depth, intrinsics):
+ """
+ Convert z-depth image to depth along camera rays.
+
+ Args:
+ - z_depth: HxW or BxHxW torch tensor
+ - intrinsics: 3x3 or Bx3x3 torch tensor
+
+ Returns:
+ - depth_along_ray: HxW or BxHxW torch tensor
+ """
+ # Add batch dimension if not present
+ if z_depth.dim() == 2:
+ z_depth = z_depth.unsqueeze(0)
+ intrinsics = intrinsics.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Get rays in camera frame
+ batch_size, height, width = z_depth.shape
+ _, ray_directions = get_rays_in_camera_frame(
+ intrinsics, height, width, normalize_to_unit_sphere=False
+ )
+
+ # Compute depth along ray
+ pts3d_cam = z_depth[..., None] * ray_directions
+ depth_along_ray = torch.norm(pts3d_cam, dim=-1)
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ depth_along_ray = depth_along_ray.squeeze(0)
+
+ return depth_along_ray
+
+
+def convert_raymap_z_depth_quats_to_pointmap(ray_origins, ray_directions, depth, quats):
+ """
+ Convert raymap (ray origins + directions on unit plane), z-depth and
+ unit quaternions (representing rotation) to a pointmap in world frame.
+
+ Args:
+ - ray_origins: (HxWx3 or BxHxWx3) torch tensor
+ - ray_directions: (HxWx3 or BxHxWx3) torch tensor
+ - depth: (HxWx1 or BxHxWx1) torch tensor
+ - quats: (HxWx4 or BxHxWx4) torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - pointmap: (HxWx3 or BxHxWx3) torch tensor
+ """
+ # Add batch dimension if not present
+ if ray_origins.dim() == 3:
+ ray_origins = ray_origins.unsqueeze(0)
+ ray_directions = ray_directions.unsqueeze(0)
+ depth = depth.unsqueeze(0)
+ quats = quats.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size, height, width, _ = depth.shape
+ device = depth.device
+
+ # Normalize the quaternions to ensure they are unit quaternions
+ quats = quats / torch.norm(quats, dim=-1, keepdim=True)
+
+ # Convert quaternions to pixel-wise rotation matrices
+ qx, qy, qz, qw = quats[..., 0], quats[..., 1], quats[..., 2], quats[..., 3]
+ rot_mat = (
+ torch.stack(
+ [
+ qw**2 + qx**2 - qy**2 - qz**2,
+ 2 * (qx * qy - qw * qz),
+ 2 * (qw * qy + qx * qz),
+ 2 * (qw * qz + qx * qy),
+ qw**2 - qx**2 + qy**2 - qz**2,
+ 2 * (qy * qz - qw * qx),
+ 2 * (qx * qz - qw * qy),
+ 2 * (qw * qx + qy * qz),
+ qw**2 - qx**2 - qy**2 + qz**2,
+ ],
+ dim=-1,
+ )
+ .reshape(batch_size, height, width, 3, 3)
+ .to(device)
+ )
+
+ # Compute 3D points in local camera frame
+ pts3d_local = depth * ray_directions
+
+ # Rotate the local points using the quaternions
+ rotated_pts3d_local = ein.einsum(
+ rot_mat, pts3d_local, "b h w i k, b h w k -> b h w i"
+ )
+
+ # Compute 3D point in world frame associated with each pixel
+ pts3d = ray_origins + rotated_pts3d_local
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ pts3d = pts3d.squeeze(0)
+
+ return pts3d
+
+
+def quaternion_to_rotation_matrix(quat):
+ """
+ Convert a quaternion into a 3x3 rotation matrix.
+
+ Args:
+ - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - rot_matrix: 3x3 or Bx3x3 torch tensor
+ """
+ if quat.dim() == 1:
+ quat = quat.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Ensure the quaternion is normalized
+ quat = quat / quat.norm(dim=1, keepdim=True)
+ x, y, z, w = quat.unbind(dim=1)
+
+ # Compute the rotation matrix elements
+ xx = x * x
+ yy = y * y
+ zz = z * z
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ wx = w * x
+ wy = w * y
+ wz = w * z
+
+ # Construct the rotation matrix
+ rot_matrix = torch.stack(
+ [
+ 1 - 2 * (yy + zz),
+ 2 * (xy - wz),
+ 2 * (xz + wy),
+ 2 * (xy + wz),
+ 1 - 2 * (xx + zz),
+ 2 * (yz - wx),
+ 2 * (xz - wy),
+ 2 * (yz + wx),
+ 1 - 2 * (xx + yy),
+ ],
+ dim=1,
+ ).view(-1, 3, 3)
+
+ # Squeeze batch dimension if it was unsqueezed
+ if squeeze_batch_dim:
+ rot_matrix = rot_matrix.squeeze(0)
+
+ return rot_matrix
+
+
+def rotation_matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
+
+
+def quaternion_inverse(quat):
+ """
+ Compute the inverse of a quaternion.
+
+ Args:
+ - quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - inv_quat: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ """
+ # Unsqueeze batch dimension if not present
+ if quat.dim() == 1:
+ quat = quat.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Compute the inverse
+ quat_conj = quat.clone()
+ quat_conj[:, :3] = -quat_conj[:, :3]
+ quat_norm = torch.sum(quat * quat, dim=1, keepdim=True)
+ inv_quat = quat_conj / quat_norm
+
+ # Squeeze batch dimension if it was unsqueezed
+ if squeeze_batch_dim:
+ inv_quat = inv_quat.squeeze(0)
+
+ return inv_quat
+
+
+def quaternion_multiply(q1, q2):
+ """
+ Multiply two quaternions.
+
+ Args:
+ - q1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ - q2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - qm: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ """
+ # Unsqueeze batch dimension if not present
+ if q1.dim() == 1:
+ q1 = q1.unsqueeze(0)
+ q2 = q2.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Unbind the quaternions
+ x1, y1, z1, w1 = q1.unbind(dim=1)
+ x2, y2, z2, w2 = q2.unbind(dim=1)
+
+ # Compute the product
+ x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
+ y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
+ z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
+ w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
+
+ # Stack the components
+ qm = torch.stack([x, y, z, w], dim=1)
+
+ # Squeeze batch dimension if it was unsqueezed
+ if squeeze_batch_dim:
+ qm = qm.squeeze(0)
+
+ return qm
+
+
+def transform_pose_using_quats_and_trans_2_to_1(quats1, trans1, quats2, trans2):
+ """
+ Transform quats and translation of pose2 from absolute frame (pose2 to world) to relative frame (pose2 to pose1).
+
+ Args:
+ - quats1: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ - trans1: 3 or Bx3 torch tensor
+ - quats2: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ - trans2: 3 or Bx3 torch tensor
+
+ Returns:
+ - quats: 4 or Bx4 torch tensor (unit quaternions and notation is (x, y, z, w))
+ - trans: 3 or Bx3 torch tensor
+ """
+ # Unsqueeze batch dimension if not present
+ if quats1.dim() == 1:
+ quats1 = quats1.unsqueeze(0)
+ trans1 = trans1.unsqueeze(0)
+ quats2 = quats2.unsqueeze(0)
+ trans2 = trans2.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ # Compute the inverse of view1's pose
+ inv_quats1 = quaternion_inverse(quats1)
+ R1_inv = quaternion_to_rotation_matrix(inv_quats1)
+ t1_inv = -1 * ein.einsum(R1_inv, trans1, "b i j, b j -> b i")
+
+ # Transform view2's pose to view1's frame
+ quats = quaternion_multiply(inv_quats1, quats2)
+ trans = ein.einsum(R1_inv, trans2, "b i j, b j -> b i") + t1_inv
+
+ # Squeeze batch dimension if it was unsqueezed
+ if squeeze_batch_dim:
+ quats = quats.squeeze(0)
+ trans = trans.squeeze(0)
+
+ return quats, trans
+
+
+def convert_ray_dirs_depth_along_ray_pose_trans_quats_to_pointmap(
+ ray_directions, depth_along_ray, pose_trans, pose_quats
+):
+ """
+ Convert ray directions, depth along ray, pose translation, and
+ unit quaternions (representing pose rotation) to a pointmap in world frame.
+
+ Args:
+ - ray_directions: (HxWx3 or BxHxWx3) torch tensor
+ - depth_along_ray: (HxWx1 or BxHxWx1) torch tensor
+ - pose_trans: (3 or Bx3) torch tensor
+ - pose_quats: (4 or Bx4) torch tensor (unit quaternions and notation is (x, y, z, w))
+
+ Returns:
+ - pointmap: (HxWx3 or BxHxWx3) torch tensor
+ """
+ # Add batch dimension if not present
+ if ray_directions.dim() == 3:
+ ray_directions = ray_directions.unsqueeze(0)
+ depth_along_ray = depth_along_ray.unsqueeze(0)
+ pose_trans = pose_trans.unsqueeze(0)
+ pose_quats = pose_quats.unsqueeze(0)
+ squeeze_batch_dim = True
+ else:
+ squeeze_batch_dim = False
+
+ batch_size, height, width, _ = depth_along_ray.shape
+ device = depth_along_ray.device
+
+ # Normalize the quaternions to ensure they are unit quaternions
+ pose_quats = pose_quats / torch.norm(pose_quats, dim=-1, keepdim=True)
+
+ # Convert quaternions to rotation matrices (B x 3 x 3)
+ rot_mat = quaternion_to_rotation_matrix(pose_quats)
+
+ # Get pose matrix (B x 4 x 4)
+ pose_mat = torch.eye(4, device=device).unsqueeze(0).repeat(batch_size, 1, 1)
+ pose_mat[:, :3, :3] = rot_mat
+ pose_mat[:, :3, 3] = pose_trans
+
+ # Compute 3D points in local camera frame
+ pts3d_local = depth_along_ray * ray_directions
+
+ # Compute 3D points in world frame
+ pts3d_homo = torch.cat([pts3d_local, torch.ones_like(pts3d_local[..., :1])], dim=-1)
+ pts3d_world = ein.einsum(pose_mat, pts3d_homo, "b i k, b h w k -> b h w i")
+ pts3d_world = pts3d_world[..., :3]
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ pts3d_world = pts3d_world.squeeze(0)
+
+ return pts3d_world
+
+
+def xy_grid(
+ W,
+ H,
+ device=None,
+ origin=(0, 0),
+ unsqueeze=None,
+ cat_dim=-1,
+ homogeneous=False,
+ **arange_kw,
+):
+ """
+ Generate a coordinate grid of shape (H,W,2) or (H,W,3) if homogeneous=True.
+
+ Args:
+ W (int): Width of the grid
+ H (int): Height of the grid
+ device (torch.device, optional): Device to place the grid on. If None, uses numpy arrays
+ origin (tuple, optional): Origin coordinates (x,y) for the grid. Default is (0,0)
+ unsqueeze (int, optional): Dimension to unsqueeze in the output tensors
+ cat_dim (int, optional): Dimension to concatenate the x,y coordinates. If None, returns tuple
+ homogeneous (bool, optional): If True, adds a third dimension of ones to make homogeneous coordinates
+ **arange_kw: Additional keyword arguments passed to np.arange or torch.arange
+
+ Returns:
+ numpy.ndarray or torch.Tensor: Coordinate grid where:
+ - output[j,i,0] = i + origin[0] (x-coordinate)
+ - output[j,i,1] = j + origin[1] (y-coordinate)
+ - output[j,i,2] = 1 (if homogeneous=True)
+ """
+ if device is None:
+ # numpy
+ arange, meshgrid, stack, ones = np.arange, np.meshgrid, np.stack, np.ones
+ else:
+ # torch
+ def arange(*a, **kw):
+ return torch.arange(*a, device=device, **kw)
+
+ meshgrid, stack = torch.meshgrid, torch.stack
+
+ def ones(*a):
+ return torch.ones(*a, device=device)
+
+ tw, th = [arange(o, o + s, **arange_kw) for s, o in zip((W, H), origin)]
+ grid = meshgrid(tw, th, indexing="xy")
+ if homogeneous:
+ grid = grid + (ones((H, W)),)
+ if unsqueeze is not None:
+ grid = (grid[0].unsqueeze(unsqueeze), grid[1].unsqueeze(unsqueeze))
+ if cat_dim is not None:
+ grid = stack(grid, cat_dim)
+
+ return grid
+
+
+def geotrf(Trf, pts, ncol=None, norm=False):
+ """
+ Apply a geometric transformation to a set of 3-D points.
+
+ Args:
+ Trf: 3x3 or 4x4 projection matrix (typically a Homography) or batch of matrices
+ with shape (B, 3, 3) or (B, 4, 4)
+ pts: numpy/torch/tuple of coordinates with shape (..., 2) or (..., 3)
+ ncol: int, number of columns of the result (2 or 3)
+ norm: float, if not 0, the result is projected on the z=norm plane
+ (homogeneous normalization)
+
+ Returns:
+ Array or tensor of projected points with the same type as input and shape (..., ncol)
+ """
+ assert Trf.ndim >= 2
+ if isinstance(Trf, np.ndarray):
+ pts = np.asarray(pts)
+ elif isinstance(Trf, torch.Tensor):
+ pts = torch.as_tensor(pts, dtype=Trf.dtype)
+
+ # Adapt shape if necessary
+ output_reshape = pts.shape[:-1]
+ ncol = ncol or pts.shape[-1]
+
+ # Optimized code
+ if (
+ isinstance(Trf, torch.Tensor)
+ and isinstance(pts, torch.Tensor)
+ and Trf.ndim == 3
+ and pts.ndim == 4
+ ):
+ d = pts.shape[3]
+ if Trf.shape[-1] == d:
+ pts = torch.einsum("bij, bhwj -> bhwi", Trf, pts)
+ elif Trf.shape[-1] == d + 1:
+ pts = (
+ torch.einsum("bij, bhwj -> bhwi", Trf[:, :d, :d], pts)
+ + Trf[:, None, None, :d, d]
+ )
+ else:
+ raise ValueError(f"bad shape, not ending with 3 or 4, for {pts.shape=}")
+ else:
+ if Trf.ndim >= 3:
+ n = Trf.ndim - 2
+ assert Trf.shape[:n] == pts.shape[:n], "batch size does not match"
+ Trf = Trf.reshape(-1, Trf.shape[-2], Trf.shape[-1])
+
+ if pts.ndim > Trf.ndim:
+ # Trf == (B,d,d) & pts == (B,H,W,d) --> (B, H*W, d)
+ pts = pts.reshape(Trf.shape[0], -1, pts.shape[-1])
+ elif pts.ndim == 2:
+ # Trf == (B,d,d) & pts == (B,d) --> (B, 1, d)
+ pts = pts[:, None, :]
+
+ if pts.shape[-1] + 1 == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf[..., :-1, :] + Trf[..., -1:, :]
+ elif pts.shape[-1] == Trf.shape[-1]:
+ Trf = Trf.swapaxes(-1, -2) # transpose Trf
+ pts = pts @ Trf
+ else:
+ pts = Trf @ pts.T
+ if pts.ndim >= 2:
+ pts = pts.swapaxes(-1, -2)
+
+ if norm:
+ pts = pts / pts[..., -1:] # DONT DO /=, it will lead to a bug
+ if norm != 1:
+ pts *= norm
+
+ res = pts[..., :ncol].reshape(*output_reshape, ncol)
+
+ return res
+
+
+def inv(mat):
+ """
+ Invert a torch or numpy matrix
+ """
+ if isinstance(mat, torch.Tensor):
+ return torch.linalg.inv(mat)
+ if isinstance(mat, np.ndarray):
+ return np.linalg.inv(mat)
+ raise ValueError(f"bad matrix type = {type(mat)}")
+
+
+def closed_form_pose_inverse(
+ pose_matrices, rotation_matrices=None, translation_vectors=None
+):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 pose matrices in a batch.
+
+ If `rotation_matrices` and `translation_vectors` are provided, they must correspond to the rotation and translation
+ components of `pose_matrices`. Otherwise, they will be extracted from `pose_matrices`.
+
+ Args:
+ pose_matrices: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ rotation_matrices (optional): Nx3x3 array or tensor of rotation matrices.
+ translation_vectors (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as input `pose_matrices`.
+
+ Shapes:
+ pose_matrices: (N, 4, 4)
+ rotation_matrices: (N, 3, 3)
+ translation_vectors: (N, 3, 1)
+ """
+ # Check if pose_matrices is a numpy array or a torch tensor
+ is_numpy = isinstance(pose_matrices, np.ndarray)
+
+ # Validate shapes
+ if pose_matrices.shape[-2:] != (4, 4) and pose_matrices.shape[-2:] != (3, 4):
+ raise ValueError(
+ f"pose_matrices must be of shape (N,4,4), got {pose_matrices.shape}."
+ )
+
+ # Extract rotation_matrices and translation_vectors if not provided
+ if rotation_matrices is None:
+ rotation_matrices = pose_matrices[:, :3, :3]
+ if translation_vectors is None:
+ translation_vectors = pose_matrices[:, :3, 3:]
+
+ # Compute the inverse of input SE3 matrices
+ if is_numpy:
+ rotation_transposed = np.transpose(rotation_matrices, (0, 2, 1))
+ new_translation = -np.matmul(rotation_transposed, translation_vectors)
+ inverted_matrix = np.tile(np.eye(4), (len(rotation_matrices), 1, 1))
+ else:
+ rotation_transposed = rotation_matrices.transpose(1, 2)
+ new_translation = -torch.bmm(rotation_transposed, translation_vectors)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(rotation_matrices), 1, 1)
+ inverted_matrix = inverted_matrix.to(rotation_matrices.dtype).to(
+ rotation_matrices.device
+ )
+ inverted_matrix[:, :3, :3] = rotation_transposed
+ inverted_matrix[:, :3, 3:] = new_translation
+
+ return inverted_matrix
+
+
+def relative_pose_transformation(trans_01, trans_02):
+ r"""
+ Function that computes the relative homogenous transformation from a
+ reference transformation :math:`T_1^{0} = \begin{bmatrix} R_1 & t_1 \\
+ \mathbf{0} & 1 \end{bmatrix}` to destination :math:`T_2^{0} =
+ \begin{bmatrix} R_2 & t_2 \\ \mathbf{0} & 1 \end{bmatrix}`.
+
+ The relative transformation is computed as follows:
+
+ .. math::
+
+ T_1^{2} = (T_0^{1})^{-1} \cdot T_0^{2}
+
+ Arguments:
+ trans_01 (torch.Tensor): reference transformation tensor of shape
+ :math:`(N, 4, 4)` or :math:`(4, 4)`.
+ trans_02 (torch.Tensor): destination transformation tensor of shape
+ :math:`(N, 4, 4)` or :math:`(4, 4)`.
+
+ Shape:
+ - Output: :math:`(N, 4, 4)` or :math:`(4, 4)`.
+
+ Returns:
+ torch.Tensor: the relative transformation between the transformations.
+
+ Example::
+ >>> trans_01 = torch.eye(4) # 4x4
+ >>> trans_02 = torch.eye(4) # 4x4
+ >>> trans_12 = relative_transformation(trans_01, trans_02) # 4x4
+ """
+ if not torch.is_tensor(trans_01):
+ raise TypeError(
+ "Input trans_01 type is not a torch.Tensor. Got {}".format(type(trans_01))
+ )
+ if not torch.is_tensor(trans_02):
+ raise TypeError(
+ "Input trans_02 type is not a torch.Tensor. Got {}".format(type(trans_02))
+ )
+ if trans_01.dim() not in (2, 3) and trans_01.shape[-2:] == (4, 4):
+ raise ValueError(
+ "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_01.shape)
+ )
+ if trans_02.dim() not in (2, 3) and trans_02.shape[-2:] == (4, 4):
+ raise ValueError(
+ "Input must be a of the shape Nx4x4 or 4x4. Got {}".format(trans_02.shape)
+ )
+ if not trans_01.dim() == trans_02.dim():
+ raise ValueError(
+ "Input number of dims must match. Got {} and {}".format(
+ trans_01.dim(), trans_02.dim()
+ )
+ )
+
+ # Convert to Nx4x4 if inputs are 4x4
+ squeeze_batch_dim = False
+ if trans_01.dim() == 2:
+ trans_01 = trans_01.unsqueeze(0)
+ trans_02 = trans_02.unsqueeze(0)
+ squeeze_batch_dim = True
+
+ # Compute inverse of trans_01 using closed form
+ trans_10 = closed_form_pose_inverse(trans_01)
+
+ # Compose transformations using matrix multiplication
+ trans_12 = torch.matmul(trans_10, trans_02)
+
+ # Remove batch dimension if it was added
+ if squeeze_batch_dim:
+ trans_12 = trans_12.squeeze(0)
+
+ return trans_12
+
+
+def depthmap_to_pts3d(depth, pseudo_focal, pp=None, **_):
+ """
+ Args:
+ - depthmap (BxHxW array):
+ - pseudo_focal: [B,H,W] ; [B,2,H,W] or [B,1,H,W]
+ Returns:
+ pointmap of absolute coordinates (BxHxWx3 array)
+ """
+
+ if len(depth.shape) == 4:
+ B, H, W, n = depth.shape
+ else:
+ B, H, W = depth.shape
+ n = None
+
+ if len(pseudo_focal.shape) == 3: # [B,H,W]
+ pseudo_focalx = pseudo_focaly = pseudo_focal
+ elif len(pseudo_focal.shape) == 4: # [B,2,H,W] or [B,1,H,W]
+ pseudo_focalx = pseudo_focal[:, 0]
+ if pseudo_focal.shape[1] == 2:
+ pseudo_focaly = pseudo_focal[:, 1]
+ else:
+ pseudo_focaly = pseudo_focalx
+ else:
+ raise NotImplementedError("Error, unknown input focal shape format.")
+
+ assert pseudo_focalx.shape == depth.shape[:3]
+ assert pseudo_focaly.shape == depth.shape[:3]
+ grid_x, grid_y = xy_grid(W, H, cat_dim=0, device=depth.device)[:, None]
+
+ # set principal point
+ if pp is None:
+ grid_x = grid_x - (W - 1) / 2
+ grid_y = grid_y - (H - 1) / 2
+ else:
+ grid_x = grid_x.expand(B, -1, -1) - pp[:, 0, None, None]
+ grid_y = grid_y.expand(B, -1, -1) - pp[:, 1, None, None]
+
+ if n is None:
+ pts3d = torch.empty((B, H, W, 3), device=depth.device)
+ pts3d[..., 0] = depth * grid_x / pseudo_focalx
+ pts3d[..., 1] = depth * grid_y / pseudo_focaly
+ pts3d[..., 2] = depth
+ else:
+ pts3d = torch.empty((B, H, W, 3, n), device=depth.device)
+ pts3d[..., 0, :] = depth * (grid_x / pseudo_focalx)[..., None]
+ pts3d[..., 1, :] = depth * (grid_y / pseudo_focaly)[..., None]
+ pts3d[..., 2, :] = depth
+ return pts3d
+
+
+def depthmap_to_camera_coordinates(depthmap, camera_intrinsics, pseudo_focal=None):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ camera_intrinsics = np.float32(camera_intrinsics)
+ H, W = depthmap.shape
+
+ # Compute 3D ray associated with each pixel
+ # Strong assumption: there are no skew terms
+ assert camera_intrinsics[0, 1] == 0.0
+ assert camera_intrinsics[1, 0] == 0.0
+ if pseudo_focal is None:
+ fu = camera_intrinsics[0, 0]
+ fv = camera_intrinsics[1, 1]
+ else:
+ assert pseudo_focal.shape == (H, W)
+ fu = fv = pseudo_focal
+ cu = camera_intrinsics[0, 2]
+ cv = camera_intrinsics[1, 2]
+
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+ z_cam = depthmap
+ x_cam = (u - cu) * z_cam / fu
+ y_cam = (v - cv) * z_cam / fv
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
+
+ # Mask for valid coordinates
+ valid_mask = depthmap > 0.0
+
+ return X_cam, valid_mask
+
+
+def depthmap_to_absolute_camera_coordinates(
+ depthmap, camera_intrinsics, camera_pose, **kw
+):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array), and a mask specifying valid pixels.
+ """
+ X_cam, valid_mask = depthmap_to_camera_coordinates(depthmap, camera_intrinsics)
+
+ X_world = X_cam # default
+ if camera_pose is not None:
+ # R_cam2world = np.float32(camera_params["R_cam2world"])
+ # t_cam2world = np.float32(camera_params["t_cam2world"]).squeeze()
+ R_cam2world = camera_pose[:3, :3]
+ t_cam2world = camera_pose[:3, 3]
+
+ # Express in absolute coordinates (invalid depth values)
+ X_world = (
+ np.einsum("ik, vuk -> vui", R_cam2world, X_cam) + t_cam2world[None, None, :]
+ )
+
+ return X_world, valid_mask
+
+
+def get_absolute_pointmaps_and_rays_info(
+ depthmap, camera_intrinsics, camera_pose, **kw
+):
+ """
+ Args:
+ - depthmap (HxW array):
+ - camera_intrinsics: a 3x3 matrix
+ - camera_pose: a 4x3 or 4x4 cam2world matrix
+ Returns:
+ pointmap of absolute coordinates (HxWx3 array),
+ a mask specifying valid pixels,
+ ray origins of absolute coordinates (HxWx3 array),
+ ray directions of absolute coordinates (HxWx3 array),
+ depth along ray (HxWx1 array),
+ ray directions of camera/local coordinates (HxWx3 array),
+ pointmap of camera/local coordinates (HxWx3 array).
+ """
+ camera_intrinsics = np.float32(camera_intrinsics)
+ H, W = depthmap.shape
+
+ # Compute 3D ray associated with each pixel
+ # Strong assumption: pinhole & there are no skew terms
+ assert camera_intrinsics[0, 1] == 0.0
+ assert camera_intrinsics[1, 0] == 0.0
+ fu = camera_intrinsics[0, 0]
+ fv = camera_intrinsics[1, 1]
+ cu = camera_intrinsics[0, 2]
+ cv = camera_intrinsics[1, 2]
+
+ # Get the rays on the unit plane
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
+ x_cam = (u - cu) / fu
+ y_cam = (v - cv) / fv
+ z_cam = np.ones_like(x_cam)
+ ray_dirs_cam_on_unit_plane = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(
+ np.float32
+ )
+
+ # Compute the 3d points in the local camera coordinate system
+ pts_cam = depthmap[..., None] * ray_dirs_cam_on_unit_plane
+
+ # Get the depth along the ray and compute the ray directions on the unit sphere
+ depth_along_ray = np.linalg.norm(pts_cam, axis=-1, keepdims=True)
+ ray_directions_cam = ray_dirs_cam_on_unit_plane / np.linalg.norm(
+ ray_dirs_cam_on_unit_plane, axis=-1, keepdims=True
+ )
+
+ # Mask for valid coordinates
+ valid_mask = depthmap > 0.0
+
+ # Get the ray origins in absolute coordinates and the ray directions in absolute coordinates
+ ray_origins_world = np.zeros_like(ray_directions_cam)
+ ray_directions_world = ray_directions_cam
+ pts_world = pts_cam
+ if camera_pose is not None:
+ R_cam2world = camera_pose[:3, :3]
+ t_cam2world = camera_pose[:3, 3]
+
+ # Express in absolute coordinates
+ ray_origins_world = ray_origins_world + t_cam2world[None, None, :]
+ ray_directions_world = np.einsum(
+ "ik, vuk -> vui", R_cam2world, ray_directions_cam
+ )
+ pts_world = ray_origins_world + ray_directions_world * depth_along_ray
+
+ return (
+ pts_world,
+ valid_mask,
+ ray_origins_world,
+ ray_directions_world,
+ depth_along_ray,
+ ray_directions_cam,
+ pts_cam,
+ )
+
+
+def adjust_camera_params_for_rotation(camera_params, original_size, k):
+ """
+ Adjust camera parameters for rotation.
+
+ Args:
+ camera_params: Camera parameters [fx, fy, cx, cy, ...]
+ original_size: Original image size as (width, height)
+ k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise)
+
+ Returns:
+ Adjusted camera parameters
+ """
+ fx, fy, cx, cy = camera_params[:4]
+ width, height = original_size
+
+ if k % 4 == 1: # 90 degrees counter-clockwise
+ new_fx, new_fy = fy, fx
+ new_cx, new_cy = height - cy, cx
+ elif k % 4 == 2: # 180 degrees
+ new_fx, new_fy = fx, fy
+ new_cx, new_cy = width - cx, height - cy
+ elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise)
+ new_fx, new_fy = fy, fx
+ new_cx, new_cy = cy, width - cx
+ else: # No rotation
+ return camera_params
+
+ adjusted_params = [new_fx, new_fy, new_cx, new_cy]
+ if len(camera_params) > 4:
+ adjusted_params.extend(camera_params[4:])
+
+ return adjusted_params
+
+
+def adjust_pose_for_rotation(pose, k):
+ """
+ Adjust camera pose for rotation.
+
+ Args:
+ pose: 4x4 camera pose matrix (camera-to-world, OpenCV convention - X right, Y down, Z forward)
+ k: Number of 90-degree rotations counter-clockwise (k=3 means 90 degrees clockwise)
+
+ Returns:
+ Adjusted 4x4 camera pose matrix
+ """
+ # Create rotation matrices for different rotations
+ if k % 4 == 1: # 90 degrees counter-clockwise
+ rot_transform = np.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]])
+ elif k % 4 == 2: # 180 degrees
+ rot_transform = np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]])
+ elif k % 4 == 3: # 90 degrees clockwise (270 counter-clockwise)
+ rot_transform = np.array([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
+ else: # No rotation
+ return pose
+
+ # Apply the transformation to the pose
+ adjusted_pose = pose
+ adjusted_pose[:3, :3] = adjusted_pose[:3, :3] @ rot_transform.T
+
+ return adjusted_pose
+
+
+def crop_to_aspect_ratio(image, depth, camera_params, target_ratio=1.5):
+ """
+ Crop image and depth to the largest possible target aspect ratio while
+ keeping the left side if aspect ratio is wider and the bottom of image if the aspect ratio is taller.
+
+ Args:
+ image: PIL image
+ depth: Depth map as numpy array
+ camera_params: Camera parameters [fx, fy, cx, cy, ...]
+ target_ratio: Target width/height ratio
+
+ Returns:
+ Cropped image, cropped depth, adjusted camera parameters
+ """
+ width, height = image.size
+ fx, fy, cx, cy = camera_params[:4]
+ current_ratio = width / height
+
+ if abs(current_ratio - target_ratio) < 1e-6:
+ # Already at target ratio
+ return image, depth, camera_params
+
+ if current_ratio > target_ratio:
+ # Image is wider than target ratio, crop width
+ new_width = int(height * target_ratio)
+ left = 0
+ right = new_width
+
+ # Crop image
+ cropped_image = image.crop((left, 0, right, height))
+
+ # Crop depth
+ if len(depth.shape) == 3:
+ cropped_depth = depth[:, left:right, :]
+ else:
+ cropped_depth = depth[:, left:right]
+
+ # Adjust camera parameters
+ new_cx = cx - left
+ adjusted_params = [fx, fy, new_cx, cy] + list(camera_params[4:])
+
+ else:
+ # Image is taller than target ratio, crop height
+ new_height = int(width / target_ratio)
+ top = max(0, height - new_height)
+ bottom = height
+
+ # Crop image
+ cropped_image = image.crop((0, top, width, bottom))
+
+ # Crop depth
+ if len(depth.shape) == 3:
+ cropped_depth = depth[top:bottom, :, :]
+ else:
+ cropped_depth = depth[top:bottom, :]
+
+ # Adjust camera parameters
+ new_cy = cy - top
+ adjusted_params = [fx, fy, cx, new_cy] + list(camera_params[4:])
+
+ return cropped_image, cropped_depth, adjusted_params
+
+
+def colmap_to_opencv_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] -= 0.5
+ K[1, 2] -= 0.5
+
+ return K
+
+
+def opencv_to_colmap_intrinsics(K):
+ """
+ Modify camera intrinsics to follow a different convention.
+ Coordinates of the center of the top-left pixels are by default:
+ - (0.5, 0.5) in Colmap
+ - (0,0) in OpenCV
+ """
+ K = K.copy()
+ K[0, 2] += 0.5
+ K[1, 2] += 0.5
+
+ return K
+
+
+def normalize_depth_using_non_zero_pixels(depth, return_norm_factor=False):
+ """
+ Normalize the depth by the average depth of non-zero depth pixels.
+
+ Args:
+ depth (torch.Tensor): Depth tensor of size [B, H, W, 1].
+ Returns:
+ normalized_depth (torch.Tensor): Normalized depth tensor.
+ norm_factor (torch.Tensor): Norm factor tensor of size B.
+ """
+ assert depth.ndim == 4 and depth.shape[3] == 1
+ # Calculate the sum and count of non-zero depth pixels for each batch
+ valid_depth_mask = depth > 0
+ valid_sum = torch.sum(depth * valid_depth_mask, dim=(1, 2, 3))
+ valid_count = torch.sum(valid_depth_mask, dim=(1, 2, 3))
+
+ # Calculate the norm factor
+ norm_factor = valid_sum / (valid_count + 1e-8)
+ while norm_factor.ndim < depth.ndim:
+ norm_factor.unsqueeze_(-1)
+
+ # Normalize the depth by the norm factor
+ norm_factor = norm_factor.clip(min=1e-8)
+ normalized_depth = depth / norm_factor
+
+ # Create the output tuple
+ output = (
+ (normalized_depth, norm_factor.squeeze(-1).squeeze(-1).squeeze(-1))
+ if return_norm_factor
+ else normalized_depth
+ )
+
+ return output
+
+
+def normalize_pose_translations(pose_translations, return_norm_factor=False):
+ """
+ Normalize the pose translations by the average norm of the non-zero pose translations.
+
+ Args:
+ pose_translations (torch.Tensor): Pose translations tensor of size [B, V, 3]. B is the batch size, V is the number of views.
+ Returns:
+ normalized_pose_translations (torch.Tensor): Normalized pose translations tensor of size [B, V, 3].
+ norm_factor (torch.Tensor): Norm factor tensor of size B.
+ """
+ assert pose_translations.ndim == 3 and pose_translations.shape[2] == 3
+ # Compute distance of all pose translations to origin
+ pose_translations_dis = pose_translations.norm(dim=-1) # [B, V]
+ non_zero_pose_translations_dis = pose_translations_dis > 0 # [B, V]
+
+ # Calculate the average norm of the translations across all views (considering only views with non-zero translations)
+ sum_of_all_views_pose_translations = pose_translations_dis.sum(dim=1) # [B]
+ count_of_all_views_with_non_zero_pose_translations = (
+ non_zero_pose_translations_dis.sum(dim=1)
+ ) # [B]
+ norm_factor = sum_of_all_views_pose_translations / (
+ count_of_all_views_with_non_zero_pose_translations + 1e-8
+ ) # [B]
+
+ # Normalize the pose translations by the norm factor
+ norm_factor = norm_factor.clip(min=1e-8)
+ normalized_pose_translations = pose_translations / norm_factor.unsqueeze(
+ -1
+ ).unsqueeze(-1)
+
+ # Create the output tuple
+ output = (
+ (normalized_pose_translations, norm_factor)
+ if return_norm_factor
+ else normalized_pose_translations
+ )
+
+ return output
+
+
+def normalize_multiple_pointclouds(
+ pts_list, valid_masks=None, norm_mode="avg_dis", ret_factor=False
+):
+ """
+ Normalize multiple point clouds using a joint normalization strategy.
+
+ Args:
+ pts_list: List of point clouds, each with shape (..., H, W, 3) or (B, H, W, 3)
+ valid_masks: Optional list of masks indicating valid points in each point cloud
+ norm_mode: String in format "{norm}_{dis}" where:
+ - norm: Normalization strategy (currently only "avg" is supported)
+ - dis: Distance transformation ("dis" for raw distance, "log1p" for log(1+distance),
+ "warp-log1p" to warp points using log distance)
+ ret_factor: If True, return the normalization factor as the last element in the result list
+
+ Returns:
+ List of normalized point clouds with the same shapes as inputs.
+ If ret_factor is True, the last element is the normalization factor.
+ """
+ assert all(pts.ndim >= 3 and pts.shape[-1] == 3 for pts in pts_list)
+ if valid_masks is not None:
+ assert len(pts_list) == len(valid_masks)
+
+ norm_mode, dis_mode = norm_mode.split("_")
+
+ # Gather all points together (joint normalization)
+ nan_pts_list = [
+ invalid_to_zeros(pts, valid_masks[i], ndim=3)
+ if valid_masks
+ else invalid_to_zeros(pts, None, ndim=3)
+ for i, pts in enumerate(pts_list)
+ ]
+ all_pts = torch.cat([nan_pts for nan_pts, _ in nan_pts_list], dim=1)
+ nnz_list = [nnz for _, nnz in nan_pts_list]
+
+ # Compute distance to origin
+ all_dis = all_pts.norm(dim=-1)
+ if dis_mode == "dis":
+ pass # do nothing
+ elif dis_mode == "log1p":
+ all_dis = torch.log1p(all_dis)
+ elif dis_mode == "warp-log1p":
+ # Warp input points before normalizing them
+ log_dis = torch.log1p(all_dis)
+ warp_factor = log_dis / all_dis.clip(min=1e-8)
+ for i, pts in enumerate(pts_list):
+ H, W = pts.shape[1:-1]
+ pts_list[i] = pts * warp_factor[:, i * (H * W) : (i + 1) * (H * W)].view(
+ -1, H, W, 1
+ )
+ all_dis = log_dis
+ else:
+ raise ValueError(f"bad {dis_mode=}")
+
+ # Compute normalization factor
+ norm_factor = all_dis.sum(dim=1) / (sum(nnz_list) + 1e-8)
+ norm_factor = norm_factor.clip(min=1e-8)
+ while norm_factor.ndim < pts_list[0].ndim:
+ norm_factor.unsqueeze_(-1)
+
+ # Normalize points
+ res = [pts / norm_factor for pts in pts_list]
+ if ret_factor:
+ res.append(norm_factor)
+
+ return res
+
+
+def apply_log_to_norm(input_data):
+ """
+ Normalize the input data and apply a logarithmic transformation based on the normalization factor.
+
+ Args:
+ input_data (torch.Tensor): The input tensor to be normalized and transformed.
+
+ Returns:
+ torch.Tensor: The transformed tensor after normalization and logarithmic scaling.
+ """
+ org_d = input_data.norm(dim=-1, keepdim=True)
+ input_data = input_data / org_d.clip(min=1e-8)
+ input_data = input_data * torch.log1p(org_d)
+ return input_data
+
+
+def angle_diff_vec3(v1, v2, eps=1e-12):
+ """
+ Compute angle difference between 3D vectors.
+
+ Args:
+ v1: torch.Tensor of shape (..., 3)
+ v2: torch.Tensor of shape (..., 3)
+ eps: Small epsilon value for numerical stability
+
+ Returns:
+ torch.Tensor: Angle differences in radians
+ """
+ cross_norm = torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps
+ dot_prod = (v1 * v2).sum(dim=-1)
+ return torch.atan2(cross_norm, dot_prod)
+
+
+def angle_diff_vec3_numpy(v1: np.ndarray, v2: np.ndarray, eps: float = 1e-12):
+ """
+ Compute angle difference between 3D vectors using NumPy.
+
+ Args:
+ v1 (np.ndarray): First vector of shape (..., 3)
+ v2 (np.ndarray): Second vector of shape (..., 3)
+ eps (float, optional): Small epsilon value for numerical stability. Defaults to 1e-12.
+
+ Returns:
+ np.ndarray: Angle differences in radians
+ """
+ return np.arctan2(
+ np.linalg.norm(np.cross(v1, v2, axis=-1), axis=-1) + eps, (v1 * v2).sum(axis=-1)
+ )
+
+
+@no_warnings(category=RuntimeWarning)
+def points_to_normals(
+ point: np.ndarray, mask: np.ndarray = None, edge_threshold: float = None
+) -> np.ndarray:
+ """
+ Calculate normal map from point map. Value range is [-1, 1].
+
+ Args:
+ point (np.ndarray): shape (height, width, 3), point map
+ mask (optional, np.ndarray): shape (height, width), dtype=bool. Mask of valid depth pixels. Defaults to None.
+ edge_threshold (optional, float): threshold for the angle (in degrees) between the normal and the view direction. Defaults to None.
+
+ Returns:
+ normal (np.ndarray): shape (height, width, 3), normal map.
+ """
+ height, width = point.shape[-3:-1]
+ has_mask = mask is not None
+
+ if mask is None:
+ mask = np.ones_like(point[..., 0], dtype=bool)
+ mask_pad = np.zeros((height + 2, width + 2), dtype=bool)
+ mask_pad[1:-1, 1:-1] = mask
+ mask = mask_pad
+
+ pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype)
+ pts[1:-1, 1:-1, :] = point
+ up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :]
+ left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :]
+ down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :]
+ right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :]
+ normal = np.stack(
+ [
+ np.cross(up, left, axis=-1),
+ np.cross(left, down, axis=-1),
+ np.cross(down, right, axis=-1),
+ np.cross(right, up, axis=-1),
+ ]
+ )
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
+
+ valid = (
+ np.stack(
+ [
+ mask[:-2, 1:-1] & mask[1:-1, :-2],
+ mask[1:-1, :-2] & mask[2:, 1:-1],
+ mask[2:, 1:-1] & mask[1:-1, 2:],
+ mask[1:-1, 2:] & mask[:-2, 1:-1],
+ ]
+ )
+ & mask[None, 1:-1, 1:-1]
+ )
+ if edge_threshold is not None:
+ view_angle = angle_diff_vec3_numpy(pts[None, 1:-1, 1:-1, :], normal)
+ view_angle = np.minimum(view_angle, np.pi - view_angle)
+ valid = valid & (view_angle < np.deg2rad(edge_threshold))
+
+ normal = (normal * valid[..., None]).sum(axis=0)
+ normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12)
+
+ if has_mask:
+ normal_mask = valid.any(axis=0)
+ normal = np.where(normal_mask[..., None], normal, 0)
+ return normal, normal_mask
+ else:
+ return normal
+
+
+def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1):
+ """
+ Create a sliding window view of the input array along a specified axis.
+
+ This function creates a memory-efficient view of the input array with sliding windows
+ of the specified size and stride. The window dimension is appended to the end of the
+ output array's shape. This is useful for operations like convolution, pooling, or
+ any analysis that requires examining local neighborhoods in the data.
+
+ Args:
+ x (np.ndarray): Input array with shape (..., axis_size, ...)
+ window_size (int): Size of the sliding window
+ stride (int): Stride of the sliding window (step size between consecutive windows)
+ axis (int, optional): Axis to perform sliding window over. Defaults to -1 (last axis)
+
+ Returns:
+ np.ndarray: View of the input array with shape (..., n_windows, ..., window_size),
+ where n_windows = (axis_size - window_size + 1) // stride
+
+ Raises:
+ AssertionError: If window_size is larger than the size of the specified axis
+
+ Example:
+ >>> x = np.array([1, 2, 3, 4, 5, 6])
+ >>> sliding_window_1d(x, window_size=3, stride=2)
+ array([[1, 2, 3],
+ [3, 4, 5]])
+ """
+ assert x.shape[axis] >= window_size, (
+ f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})"
+ )
+ axis = axis % x.ndim
+ shape = (
+ *x.shape[:axis],
+ (x.shape[axis] - window_size + 1) // stride,
+ *x.shape[axis + 1 :],
+ window_size,
+ )
+ strides = (
+ *x.strides[:axis],
+ stride * x.strides[axis],
+ *x.strides[axis + 1 :],
+ x.strides[axis],
+ )
+ x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)
+ return x_sliding
+
+
+def sliding_window_nd(
+ x: np.ndarray,
+ window_size: Tuple[int, ...],
+ stride: Tuple[int, ...],
+ axis: Tuple[int, ...],
+) -> np.ndarray:
+ """
+ Create sliding windows along multiple dimensions of the input array.
+
+ This function applies sliding_window_1d sequentially along multiple axes to create
+ N-dimensional sliding windows. This is useful for operations that need to examine
+ local neighborhoods in multiple dimensions simultaneously.
+
+ Args:
+ x (np.ndarray): Input array
+ window_size (Tuple[int, ...]): Size of the sliding window for each axis
+ stride (Tuple[int, ...]): Stride of the sliding window for each axis
+ axis (Tuple[int, ...]): Axes to perform sliding window over
+
+ Returns:
+ np.ndarray: Array with sliding windows along the specified dimensions.
+ The window dimensions are appended to the end of the shape.
+
+ Note:
+ The length of window_size, stride, and axis tuples must be equal.
+
+ Example:
+ >>> x = np.random.rand(10, 10)
+ >>> windows = sliding_window_nd(x, window_size=(3, 3), stride=(2, 2), axis=(-2, -1))
+ >>> # Creates 3x3 sliding windows with stride 2 in both dimensions
+ """
+ axis = [axis[i] % x.ndim for i in range(len(axis))]
+ for i in range(len(axis)):
+ x = sliding_window_1d(x, window_size[i], stride[i], axis[i])
+ return x
+
+
+def sliding_window_2d(
+ x: np.ndarray,
+ window_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]],
+ axis: Tuple[int, int] = (-2, -1),
+) -> np.ndarray:
+ """
+ Create 2D sliding windows over the input array.
+
+ Convenience function for creating 2D sliding windows, commonly used for image
+ processing operations like convolution, pooling, or patch extraction.
+
+ Args:
+ x (np.ndarray): Input array
+ window_size (Union[int, Tuple[int, int]]): Size of the 2D sliding window.
+ If int, same size is used for both dimensions.
+ stride (Union[int, Tuple[int, int]]): Stride of the 2D sliding window.
+ If int, same stride is used for both dimensions.
+ axis (Tuple[int, int], optional): Two axes to perform sliding window over.
+ Defaults to (-2, -1) (last two dimensions).
+
+ Returns:
+ np.ndarray: Array with 2D sliding windows. The window dimensions (height, width)
+ are appended to the end of the shape.
+
+ Example:
+ >>> image = np.random.rand(100, 100)
+ >>> patches = sliding_window_2d(image, window_size=8, stride=4)
+ >>> # Creates 8x8 patches with stride 4 from the image
+ """
+ if isinstance(window_size, int):
+ window_size = (window_size, window_size)
+ if isinstance(stride, int):
+ stride = (stride, stride)
+ return sliding_window_nd(x, window_size, stride, axis)
+
+
+def max_pool_1d(
+ x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1
+):
+ """
+ Perform 1D max pooling on the input array.
+
+ Max pooling reduces the dimensionality of the input by taking the maximum value
+ within each sliding window. This is commonly used in neural networks and signal
+ processing for downsampling and feature extraction.
+
+ Args:
+ x (np.ndarray): Input array
+ kernel_size (int): Size of the pooling kernel
+ stride (int): Stride of the pooling operation
+ padding (int, optional): Amount of padding to add on both sides. Defaults to 0.
+ axis (int, optional): Axis to perform max pooling over. Defaults to -1.
+
+ Returns:
+ np.ndarray: Max pooled array with reduced size along the specified axis
+
+ Note:
+ - For floating point arrays, padding is done with np.nan values
+ - For integer arrays, padding is done with the minimum value of the dtype
+ - np.nanmax is used to handle NaN values in the computation
+
+ Example:
+ >>> x = np.array([1, 3, 2, 4, 5, 1, 2])
+ >>> max_pool_1d(x, kernel_size=3, stride=2)
+ array([3, 5, 2])
+ """
+ axis = axis % x.ndim
+ if padding > 0:
+ fill_value = np.nan if x.dtype.kind == "f" else np.iinfo(x.dtype).min
+ padding_arr = np.full(
+ (*x.shape[:axis], padding, *x.shape[axis + 1 :]),
+ fill_value=fill_value,
+ dtype=x.dtype,
+ )
+ x = np.concatenate([padding_arr, x, padding_arr], axis=axis)
+ a_sliding = sliding_window_1d(x, kernel_size, stride, axis)
+ max_pool = np.nanmax(a_sliding, axis=-1)
+ return max_pool
+
+
+def max_pool_nd(
+ x: np.ndarray,
+ kernel_size: Tuple[int, ...],
+ stride: Tuple[int, ...],
+ padding: Tuple[int, ...],
+ axis: Tuple[int, ...],
+) -> np.ndarray:
+ """
+ Perform N-dimensional max pooling on the input array.
+
+ This function applies max_pool_1d sequentially along multiple axes to perform
+ multi-dimensional max pooling. This is useful for downsampling multi-dimensional
+ data while preserving the most important features.
+
+ Args:
+ x (np.ndarray): Input array
+ kernel_size (Tuple[int, ...]): Size of the pooling kernel for each axis
+ stride (Tuple[int, ...]): Stride of the pooling operation for each axis
+ padding (Tuple[int, ...]): Amount of padding for each axis
+ axis (Tuple[int, ...]): Axes to perform max pooling over
+
+ Returns:
+ np.ndarray: Max pooled array with reduced size along the specified axes
+
+ Note:
+ The length of kernel_size, stride, padding, and axis tuples must be equal.
+ Max pooling is applied sequentially along each axis in the order specified.
+
+ Example:
+ >>> x = np.random.rand(10, 10, 10)
+ >>> pooled = max_pool_nd(x, kernel_size=(2, 2, 2), stride=(2, 2, 2),
+ ... padding=(0, 0, 0), axis=(-3, -2, -1))
+ >>> # Reduces each dimension by half with 2x2x2 max pooling
+ """
+ for i in range(len(axis)):
+ x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i])
+ return x
+
+
+def max_pool_2d(
+ x: np.ndarray,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]],
+ padding: Union[int, Tuple[int, int]],
+ axis: Tuple[int, int] = (-2, -1),
+):
+ """
+ Perform 2D max pooling on the input array.
+
+ Convenience function for 2D max pooling, commonly used in computer vision
+ and image processing for downsampling images while preserving important features.
+
+ Args:
+ x (np.ndarray): Input array
+ kernel_size (Union[int, Tuple[int, int]]): Size of the 2D pooling kernel.
+ If int, same size is used for both dimensions.
+ stride (Union[int, Tuple[int, int]]): Stride of the 2D pooling operation.
+ If int, same stride is used for both dimensions.
+ padding (Union[int, Tuple[int, int]]): Amount of padding for both dimensions.
+ If int, same padding is used for both dimensions.
+ axis (Tuple[int, int], optional): Two axes to perform max pooling over.
+ Defaults to (-2, -1) (last two dimensions).
+
+ Returns:
+ np.ndarray: 2D max pooled array with reduced size along the specified axes
+
+ Example:
+ >>> image = np.random.rand(64, 64)
+ >>> pooled = max_pool_2d(image, kernel_size=2, stride=2, padding=0)
+ >>> # Reduces image size from 64x64 to 32x32 with 2x2 max pooling
+ """
+ if isinstance(kernel_size, Number):
+ kernel_size = (kernel_size, kernel_size)
+ if isinstance(stride, Number):
+ stride = (stride, stride)
+ if isinstance(padding, Number):
+ padding = (padding, padding)
+ axis = tuple(axis)
+ return max_pool_nd(x, kernel_size, stride, padding, axis)
+
+
+@no_warnings(category=RuntimeWarning)
+def depth_edge(
+ depth: np.ndarray,
+ atol: float = None,
+ rtol: float = None,
+ kernel_size: int = 3,
+ mask: np.ndarray = None,
+) -> np.ndarray:
+ """
+ Compute the edge mask from depth map. The edge is defined as the pixels whose neighbors have large difference in depth.
+
+ Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ if mask is None:
+ diff = max_pool_2d(
+ depth, kernel_size, stride=1, padding=kernel_size // 2
+ ) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)
+ else:
+ diff = max_pool_2d(
+ np.where(mask, depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ ) + max_pool_2d(
+ np.where(mask, -depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+
+ edge = np.zeros_like(depth, dtype=bool)
+ if atol is not None:
+ edge |= diff > atol
+
+ if rtol is not None:
+ edge |= diff / depth > rtol
+ return edge
+
+
+def depth_aliasing(
+ depth: np.ndarray,
+ atol: float = None,
+ rtol: float = None,
+ kernel_size: int = 3,
+ mask: np.ndarray = None,
+) -> np.ndarray:
+ """
+ Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors.
+ Args:
+ depth (np.ndarray): shape (..., height, width), linear depth map
+ atol (float): absolute tolerance
+ rtol (float): relative tolerance
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ if mask is None:
+ diff_max = (
+ max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth
+ )
+ diff_min = (
+ max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth
+ )
+ else:
+ diff_max = (
+ max_pool_2d(
+ np.where(mask, depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+ - depth
+ )
+ diff_min = (
+ max_pool_2d(
+ np.where(mask, -depth, -np.inf),
+ kernel_size,
+ stride=1,
+ padding=kernel_size // 2,
+ )
+ + depth
+ )
+ diff = np.minimum(diff_max, diff_min)
+
+ edge = np.zeros_like(depth, dtype=bool)
+ if atol is not None:
+ edge |= diff > atol
+ if rtol is not None:
+ edge |= diff / depth > rtol
+ return edge
+
+
+@no_warnings(category=RuntimeWarning)
+def normals_edge(
+ normals: np.ndarray, tol: float, kernel_size: int = 3, mask: np.ndarray = None
+) -> np.ndarray:
+ """
+ Compute the edge mask from normal map.
+
+ Args:
+ normal (np.ndarray): shape (..., height, width, 3), normal map
+ tol (float): tolerance in degrees
+
+ Returns:
+ edge (np.ndarray): shape (..., height, width) of dtype torch.bool
+ """
+ assert normals.ndim >= 3 and normals.shape[-1] == 3, (
+ "normal should be of shape (..., height, width, 3)"
+ )
+ normals = normals / (np.linalg.norm(normals, axis=-1, keepdims=True) + 1e-12)
+
+ padding = kernel_size // 2
+ normals_window = sliding_window_2d(
+ np.pad(
+ normals,
+ (
+ *([(0, 0)] * (normals.ndim - 3)),
+ (padding, padding),
+ (padding, padding),
+ (0, 0),
+ ),
+ mode="edge",
+ ),
+ window_size=kernel_size,
+ stride=1,
+ axis=(-3, -2),
+ )
+ if mask is None:
+ angle_diff = np.arccos(
+ (normals[..., None, None] * normals_window).sum(axis=-3)
+ ).max(axis=(-2, -1))
+ else:
+ mask_window = sliding_window_2d(
+ np.pad(
+ mask,
+ (*([(0, 0)] * (mask.ndim - 3)), (padding, padding), (padding, padding)),
+ mode="edge",
+ ),
+ window_size=kernel_size,
+ stride=1,
+ axis=(-3, -2),
+ )
+ angle_diff = np.where(
+ mask_window,
+ np.arccos((normals[..., None, None] * normals_window).sum(axis=-3)),
+ 0,
+ ).max(axis=(-2, -1))
+
+ angle_diff = max_pool_2d(
+ angle_diff, kernel_size, stride=1, padding=kernel_size // 2
+ )
+ edge = angle_diff > np.deg2rad(tol)
+ return edge
diff --git a/mapanything/utils/hf_utils/__init__.py b/mapanything/utils/hf_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mapanything/utils/hf_utils/css_and_html.py b/mapanything/utils/hf_utils/css_and_html.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9e0f9c44d4176eda37856775f0acd7b08e25ea9
--- /dev/null
+++ b/mapanything/utils/hf_utils/css_and_html.py
@@ -0,0 +1,207 @@
+"""
+CSS and HTML content for the MapAnything Gradio application.
+This module contains all the CSS styles and HTML content blocks
+used in the Gradio interface.
+"""
+
+# CSS Styles for the Gradio interface
+GRADIO_CSS = """
+.custom-log * {
+ font-style: italic;
+ font-size: 22px !important;
+ background-image: linear-gradient(120deg, #ffb366 0%, #ffa366 60%, #ff9966 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ font-weight: bold !important;
+ color: transparent !important;
+ text-align: center !important;
+}
+
+.example-log * {
+ font-style: italic;
+ font-size: 16px !important;
+ background-image: linear-gradient(120deg, #ffb366 0%, #ffa366 60%, #ff9966 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ color: transparent !important;
+}
+
+#my_radio .wrap {
+ display: flex;
+ flex-wrap: nowrap;
+ justify-content: center;
+ align-items: center;
+}
+
+#my_radio .wrap label {
+ display: flex;
+ width: 50%;
+ justify-content: center;
+ align-items: center;
+ margin: 0;
+ padding: 10px 0;
+ box-sizing: border-box;
+}
+
+/* Align navigation buttons with dropdown bottom */
+.navigation-row {
+ display: flex !important;
+ align-items: flex-end !important;
+ gap: 8px !important;
+}
+
+.navigation-row > div:nth-child(1),
+.navigation-row > div:nth-child(3) {
+ align-self: flex-end !important;
+}
+
+.navigation-row > div:nth-child(2) {
+ flex: 1 !important;
+}
+
+/* Make thumbnails clickable with pointer cursor */
+.clickable-thumbnail img {
+ cursor: pointer !important;
+}
+
+.clickable-thumbnail:hover img {
+ cursor: pointer !important;
+ opacity: 0.8;
+ transition: opacity 0.3s ease;
+}
+
+/* Make thumbnail containers narrower horizontally */
+.clickable-thumbnail {
+ padding: 5px 2px !important;
+ margin: 0 2px !important;
+}
+
+.clickable-thumbnail .image-container {
+ margin: 0 !important;
+ padding: 0 !important;
+}
+
+.scene-info {
+ text-align: center !important;
+ padding: 5px 2px !important;
+ margin: 0 !important;
+}
+"""
+
+
+def get_header_html(logo_base64=None):
+ """
+ Generate the main header HTML with logo and title.
+
+ Args:
+ logo_base64 (str, optional): Base64 encoded logo image
+
+ Returns:
+ str: HTML string for the header
+ """
+ logo_style = "display: none;" if not logo_base64 else ""
+ logo_src = logo_base64 or ""
+
+ return f"""
+
+

+
MapAnything: Metric 3D Scene Reconstruction
+
+
+ 🌟 GitHub Repository |
+ 🚀 Project Page
+
+ """
+
+
+def get_description_html():
+ """
+ Generate the main description and getting started HTML.
+
+ Returns:
+ str: HTML string for the description
+ """
+ return """
+
+
Upload a video or a set of images to create a 3D reconstruction of a scene or object. MapAnything takes these images and generates 3D point clouds directly from multi-view images.
+
This demo demonstrates the use of image inputs only. However, MapAnything is extremely flexible and supports any combination of inputs (images, calibration, poses & depth). For trying out memory efficient inference or additional inputs like cameras & depth, please check out the code in our Github repo.
+
+
Getting Started:
+
+ - Upload Your Data: Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).
+ - Preview: Your uploaded images will appear in the gallery on the left.
+ - Reconstruct: Click the "Reconstruct" button to start the 3D reconstruction process.
+ - Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.
+ -
+ Adjust Reconstruction & Visualization (Optional):
+ You can fine-tune the visualization using the options below the viewer
+
+ (click to expand):
+
+ - Show Camera: Toggle the display of estimated camera positions.
+ - Show Mesh: Use meshes for the prediction visualization.
+ - Show Points from Frame: Select specific frames to display in the viewer.
+ - Filter Black Background: Remove black background pixels.
+ - Filter White Background: Remove white background pixels.
+
+
+
+
+
+
Please note: The inference time changes based on the amount of input images, for e.g., less than 1 second for up to 50 views. However, downloading model weights and visualizing 3D points may take tens of seconds. Please be patient or, for faster visualization, use a local machine to run our demo from our GitHub repository.
+
+ """
+
+
+def get_acknowledgements_html():
+ """
+ Generate the acknowledgements section HTML.
+
+ Returns:
+ str: HTML string for the acknowledgements
+ """
+ return """
+
+
+
Acknowledgements
+
This site builds upon code from:
+
+
We extend our gratitude to these projects for their valuable contributions to the research community.
+
+ """
+
+
+def get_gradio_theme():
+ """
+ Get the configured Gradio theme.
+
+ Returns:
+ gr.themes.Base: Configured Gradio theme
+ """
+ import gradio as gr
+
+ return gr.themes.Base(
+ primary_hue=gr.themes.Color(
+ c100="#ffedd5",
+ c200="#ffddb3",
+ c300="rgba(242.78125, 182.89427563548466, 120.32579495614034, 1)",
+ c400="#fb923c",
+ c50="#fff7ed",
+ c500="#f97316",
+ c600="#ea580c",
+ c700="#c2410c",
+ c800="#9a3412",
+ c900="#7c2d12",
+ c950="#6c2e12",
+ ),
+ secondary_hue="amber",
+ )
+
+
+# Measure tab instructions HTML
+MEASURE_INSTRUCTIONS_HTML = """
+### Click on the image to measure the distance between two points.
+"""
diff --git a/mapanything/utils/hf_utils/hf_helpers.py b/mapanything/utils/hf_utils/hf_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..62eb1aed3044cdf063407bdd49090f2f541304b2
--- /dev/null
+++ b/mapanything/utils/hf_utils/hf_helpers.py
@@ -0,0 +1,186 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Helper functions for HuggingFace integration and model initialization.
+"""
+
+import json
+import os
+
+
+def load_hf_token():
+ """Load HuggingFace access token from local file"""
+ token_file_paths = [
+ "/home/aknapitsch/hf_token.txt",
+ ]
+
+ for token_path in token_file_paths:
+ if os.path.exists(token_path):
+ try:
+ with open(token_path, "r") as f:
+ token = f.read().strip()
+ print(f"Loaded HuggingFace token from: {token_path}")
+ return token
+ except Exception as e:
+ print(f"Error reading token from {token_path}: {e}")
+ continue
+ else:
+ print(token_path, "token_path doesnt exist")
+
+ # Also try environment variable
+ # see https://huggingface.co/docs/hub/spaces-overview#managing-secrets on options
+ token = (
+ os.getenv("HF_TOKEN")
+ or os.getenv("HUGGING_FACE_HUB_TOKEN")
+ or os.getenv("HUGGING_FACE_MODEL_TOKEN")
+ )
+ if token:
+ print("Loaded HuggingFace token from environment variable")
+ return token
+
+ print(
+ "Warning: No HuggingFace token found. Model loading may fail for private repositories."
+ )
+ return None
+
+
+def init_hydra_config(config_path, overrides=None):
+ """Initialize Hydra config"""
+ import hydra
+
+ config_dir = os.path.dirname(config_path)
+ config_name = os.path.basename(config_path).split(".")[0]
+ relative_path = os.path.relpath(config_dir, os.path.dirname(__file__))
+ hydra.core.global_hydra.GlobalHydra.instance().clear()
+ hydra.initialize(version_base=None, config_path=relative_path)
+ if overrides is not None:
+ cfg = hydra.compose(config_name=config_name, overrides=overrides)
+ else:
+ cfg = hydra.compose(config_name=config_name)
+ return cfg
+
+
+def initialize_mapanything_model(high_level_config, device):
+ """
+ Initialize MapAnything model with three-tier fallback approach:
+ 1. Try HuggingFace from_pretrained()
+ 2. Download HF config + use local model factory + load HF weights
+ 3. Pure local configuration fallback
+
+ Args:
+ high_level_config (dict): Configuration dictionary containing model settings
+ device (torch.device): Device to load the model on
+
+ Returns:
+ torch.nn.Module: Initialized MapAnything model
+ """
+ import torch
+ from huggingface_hub import hf_hub_download
+ from mapanything.models import init_model, MapAnything
+
+ print("Initializing MapAnything model...")
+
+ # Initialize Hydra config and create model from configuration
+ cfg = init_hydra_config(
+ high_level_config["path"], overrides=high_level_config["config_overrides"]
+ )
+
+ # Try using from_pretrained first
+ try:
+ print("Loading MapAnything model from_pretrained...")
+ model = MapAnything.from_pretrained(high_level_config["hf_model_name"]).to(
+ device
+ )
+ print("Loading MapAnything model from_pretrained succeeded...")
+ return model
+ except Exception as e:
+ print(f"from_pretrained failed: {e}")
+ print("Falling back to local configuration approach using hf_hub_download...")
+
+ # Create model from local configuration instead of using from_pretrained
+ # Try to download and use the config from HuggingFace Hub
+ try:
+ print("Downloading model configuration from HuggingFace Hub...")
+ config_path = hf_hub_download(
+ repo_id=high_level_config["hf_model_name"],
+ filename=high_level_config["config_name"],
+ token=load_hf_token(),
+ )
+
+ # Load the config from the downloaded file
+ with open(config_path, "r") as f:
+ downloaded_config = json.load(f)
+
+ print("Using downloaded configuration for model initialization")
+ model = init_model(
+ model_str=downloaded_config.get(
+ "model_str", high_level_config["model_str"]
+ ),
+ model_config=downloaded_config.get(
+ "model_config", cfg.model.model_config
+ ),
+ torch_hub_force_reload=high_level_config.get(
+ "torch_hub_force_reload", False
+ ),
+ )
+ except Exception as config_e:
+ print(f"Failed to download/use HuggingFace config: {config_e}")
+ print("Falling back to local configuration...")
+ # Fall back to local configuration as before
+ model = init_model(
+ model_str=cfg.model.model_str,
+ model_config=cfg.model.model_config,
+ torch_hub_force_reload=high_level_config.get(
+ "torch_hub_force_reload", False
+ ),
+ )
+
+ # Load the pretrained weights from HuggingFace Hub
+ try:
+ # First, let's see what files are available in the repository
+ try:
+ checkpoint_filename = high_level_config["checkpoint_name"]
+ # Download the model weights
+ checkpoint_path = hf_hub_download(
+ repo_id=high_level_config["hf_model_name"],
+ filename=checkpoint_filename,
+ token=load_hf_token(),
+ )
+
+ # Load the weights
+ print("start loading checkpoint")
+ if checkpoint_filename.endswith(".safetensors"):
+ from safetensors.torch import load_file
+
+ checkpoint = load_file(checkpoint_path)
+ else:
+ checkpoint = torch.load(
+ checkpoint_path, map_location="cpu", weights_only=False
+ )
+
+ print("start loading state_dict")
+ if "model" in checkpoint:
+ model.load_state_dict(checkpoint["model"], strict=False)
+ elif "state_dict" in checkpoint:
+ model.load_state_dict(checkpoint["state_dict"], strict=False)
+ else:
+ model.load_state_dict(checkpoint, strict=False)
+
+ print(
+ f"Successfully loaded pretrained weights from HuggingFace Hub ({checkpoint_filename})"
+ )
+
+ except Exception as inner_e:
+ print(f"Error listing repository files or loading weights: {inner_e}")
+ raise inner_e
+
+ except Exception as e:
+ print(f"Warning: Could not load pretrained weights: {e}")
+ print("Proceeding with randomly initialized model...")
+
+ model = model.to(device)
+ return model
diff --git a/mapanything/utils/hf_utils/visual_util.py b/mapanything/utils/hf_utils/visual_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..39c0d3798f0ad7b98db261bedb809fac13895ae5
--- /dev/null
+++ b/mapanything/utils/hf_utils/visual_util.py
@@ -0,0 +1,679 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import copy
+import os
+from typing import Tuple
+
+import cv2
+import matplotlib
+import numpy as np
+import requests
+import trimesh
+from scipy.spatial.transform import Rotation
+
+
+def remove_unreferenced_vertices(
+ faces: np.ndarray, *vertice_attrs, return_indices: bool = False
+) -> Tuple[np.ndarray, ...]:
+ """
+ Remove unreferenced vertices of a mesh.
+ Unreferenced vertices are removed, and the face indices are updated accordingly.
+
+ Args:
+ faces (np.ndarray): [T, P] face indices
+ *vertice_attrs: vertex attributes
+
+ Returns:
+ faces (np.ndarray): [T, P] face indices
+ *vertice_attrs: vertex attributes
+ indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None.
+ """
+ P = faces.shape[-1]
+ fewer_indices, inv_map = np.unique(faces, return_inverse=True)
+ faces = inv_map.astype(np.int32).reshape(-1, P)
+ ret = [faces]
+ for attr in vertice_attrs:
+ ret.append(attr[fewer_indices])
+ if return_indices:
+ ret.append(fewer_indices)
+ return tuple(ret)
+
+
+def triangulate(
+ faces: np.ndarray, vertices: np.ndarray = None, backslash: np.ndarray = None
+) -> np.ndarray:
+ """
+ Triangulate a polygonal mesh.
+
+ Args:
+ faces (np.ndarray): [L, P] polygonal faces
+ vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices.
+ If given, the triangulation is performed according to the distance
+ between vertices. Defaults to None.
+ backslash (np.ndarray, optional): [L] boolean array indicating
+ how to triangulate the quad faces. Defaults to None.
+
+ Returns:
+ (np.ndarray): [L * (P - 2), 3] triangular faces
+ """
+ if faces.shape[-1] == 3:
+ return faces
+ P = faces.shape[-1]
+ if vertices is not None:
+ assert faces.shape[-1] == 4, "now only support quad mesh"
+ if backslash is None:
+ backslash = np.linalg.norm(
+ vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1
+ ) < np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1)
+ if backslash is None:
+ loop_indice = np.stack(
+ [
+ np.zeros(P - 2, dtype=int),
+ np.arange(1, P - 1, 1, dtype=int),
+ np.arange(2, P, 1, dtype=int),
+ ],
+ axis=1,
+ )
+ return faces[:, loop_indice].reshape((-1, 3))
+ else:
+ assert faces.shape[-1] == 4, "now only support quad mesh"
+ faces = np.where(
+ backslash[:, None],
+ faces[:, [0, 1, 2, 0, 2, 3]],
+ faces[:, [0, 1, 3, 3, 1, 2]],
+ ).reshape((-1, 3))
+ return faces
+
+
+def image_mesh(
+ *image_attrs: np.ndarray,
+ mask: np.ndarray = None,
+ tri: bool = False,
+ return_indices: bool = False,
+) -> Tuple[np.ndarray, ...]:
+ """
+ Get a mesh regarding image pixel uv coordinates as vertices and image grid as faces.
+
+ Args:
+ *image_attrs (np.ndarray): image attributes in shape (height, width, [channels])
+ mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None.
+
+ Returns:
+ faces (np.ndarray): faces connecting neighboring pixels. shape (T, 4) if tri is False, else (T, 3)
+ *vertex_attrs (np.ndarray): vertex attributes in corresponding order with input image_attrs
+ indices (np.ndarray, optional): indices of vertices in the original mesh
+ """
+ assert (len(image_attrs) > 0) or (
+ mask is not None
+ ), "At least one of image_attrs or mask should be provided"
+ height, width = next(image_attrs).shape[:2] if mask is None else mask.shape
+ assert all(
+ img.shape[:2] == (height, width) for img in image_attrs
+ ), "All image_attrs should have the same shape"
+
+ row_faces = np.stack(
+ [
+ np.arange(0, width - 1, dtype=np.int32),
+ np.arange(width, 2 * width - 1, dtype=np.int32),
+ np.arange(1 + width, 2 * width, dtype=np.int32),
+ np.arange(1, width, dtype=np.int32),
+ ],
+ axis=1,
+ )
+ faces = (
+ np.arange(0, (height - 1) * width, width, dtype=np.int32)[:, None, None]
+ + row_faces[None, :, :]
+ ).reshape((-1, 4))
+ if mask is None:
+ if tri:
+ faces = triangulate(faces)
+ ret = [faces, *(img.reshape(-1, *img.shape[2:]) for img in image_attrs)]
+ if return_indices:
+ ret.append(np.arange(height * width, dtype=np.int32))
+ return tuple(ret)
+ else:
+ quad_mask = (
+ mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]
+ ).ravel()
+ faces = faces[quad_mask]
+ if tri:
+ faces = triangulate(faces)
+ return remove_unreferenced_vertices(
+ faces,
+ *(x.reshape(-1, *x.shape[2:]) for x in image_attrs),
+ return_indices=return_indices,
+ )
+
+
+def predictions_to_glb(
+ predictions,
+ filter_by_frames="all",
+ mask_black_bg=False,
+ mask_white_bg=False,
+ show_cam=True,
+ mask_ambiguous=False,
+ as_mesh=True,
+) -> trimesh.Scene:
+ """
+ Converts MapAnything predictions to a 3D scene represented as a GLB file.
+
+ Args:
+ predictions (dict): Dictionary containing model predictions with keys:
+ - world_points: 3D point coordinates (S, H, W, 3)
+ - images: Input images (S, H, W, 3)
+ - extrinsic: Camera extrinsic matrices (S, 3, 4)
+ filter_by_frames (str): Frame filter specification (default: "all")
+ mask_black_bg (bool): Mask out black background pixels (default: False)
+ mask_white_bg (bool): Mask out white background pixels (default: False)
+ show_cam (bool): Include camera visualization (default: True)
+ mask_ambiguous (bool): Apply final mask to filter ambiguous predictions (default: False)
+ as_mesh (bool): Represent the data as a mesh instead of point cloud (default: False)
+
+ Returns:
+ trimesh.Scene: Processed 3D scene containing point cloud/mesh and cameras
+
+ Raises:
+ ValueError: If input predictions structure is invalid
+ """
+ if not isinstance(predictions, dict):
+ raise ValueError("predictions must be a dictionary")
+
+ print("Building GLB scene")
+ selected_frame_idx = None
+ if filter_by_frames != "all" and filter_by_frames != "All":
+ try:
+ # Extract the index part before the colon
+ selected_frame_idx = int(filter_by_frames.split(":")[0])
+ except (ValueError, IndexError):
+ pass
+
+ # Always use Pointmap Branch
+ print("Using Pointmap Branch")
+ if "world_points" not in predictions:
+ raise ValueError(
+ "world_points not found in predictions. Pointmap Branch requires 'world_points' key. "
+ "Depthmap and Camera branches have been removed."
+ )
+
+ pred_world_points = predictions["world_points"]
+
+ # Get images from predictions
+ images = predictions["images"]
+ # Use extrinsic matrices instead of pred_extrinsic_list
+ camera_matrices = predictions["extrinsic"]
+
+ if selected_frame_idx is not None:
+ pred_world_points = pred_world_points[selected_frame_idx][None]
+ images = images[selected_frame_idx][None]
+ camera_matrices = camera_matrices[selected_frame_idx][None]
+
+ vertices_3d = pred_world_points.reshape(-1, 3)
+ # Handle different image formats - check if images need transposing
+ if images.ndim == 4 and images.shape[1] == 3: # NCHW format
+ colors_rgb = np.transpose(images, (0, 2, 3, 1))
+ else: # Assume already in NHWC format
+ colors_rgb = images
+ colors_rgb = (colors_rgb.reshape(-1, 3) * 255).astype(np.uint8)
+
+ # Create mask for filtering
+ mask = np.ones(len(vertices_3d), dtype=bool)
+ final_mask = predictions["final_mask"].reshape(-1)
+
+ if mask_black_bg:
+ black_bg_mask = colors_rgb.sum(axis=1) >= 16
+ mask = mask & black_bg_mask
+
+ if mask_white_bg:
+ # Filter out white background pixels (RGB values close to white)
+ # Consider pixels white if all RGB values are above 240
+ white_bg_mask = (
+ (colors_rgb[:, 0] > 240)
+ & (colors_rgb[:, 1] > 240)
+ & (colors_rgb[:, 2] > 240)
+ )
+ mask = mask & ~white_bg_mask
+
+ # Use final_mask when mask_ambiguous is checked
+ if mask_ambiguous:
+ mask = mask & final_mask
+
+ vertices_3d = vertices_3d[mask].copy()
+ colors_rgb = colors_rgb[mask].copy()
+
+ if vertices_3d is None or np.asarray(vertices_3d).size == 0:
+ vertices_3d = np.array([[1, 0, 0]])
+ colors_rgb = np.array([[255, 255, 255]])
+ scene_scale = 1
+ else:
+ # Calculate the 5th and 95th percentiles along each axis
+ lower_percentile = np.percentile(vertices_3d, 5, axis=0)
+ upper_percentile = np.percentile(vertices_3d, 95, axis=0)
+
+ # Calculate the diagonal length of the percentile bounding box
+ scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
+
+ colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
+
+ # Initialize a 3D scene
+ scene_3d = trimesh.Scene()
+
+ # Add point cloud data to the scene
+ if as_mesh:
+ # Create mesh from pointcloud
+ # try:
+ if selected_frame_idx is not None:
+ # Single frame case - we can create a proper mesh
+ H, W = pred_world_points.shape[1:3]
+
+ # Get original unfiltered data for mesh creation
+ original_points = pred_world_points.reshape(H, W, 3)
+
+ # Reshape original image data properly
+ if images.ndim == 4 and images.shape[1] == 3: # NCHW format
+ original_image_colors = np.transpose(images[0], (1, 2, 0))
+ else: # Assume already in HWC format
+ original_image_colors = images[0]
+ original_image_colors *= 255
+ # Get original final mask
+ original_final_mask = predictions["final_mask"][selected_frame_idx].reshape(
+ H, W
+ )
+
+ # Create mask based on final mask
+ mask = original_final_mask
+
+ # Additional background masks if needed
+ if mask_black_bg:
+ black_bg_mask = original_image_colors.sum(axis=2) >= 16
+ mask = mask & black_bg_mask
+
+ if mask_white_bg:
+ white_bg_mask = ~(
+ (original_image_colors[:, :, 0] > 240)
+ & (original_image_colors[:, :, 1] > 240)
+ & (original_image_colors[:, :, 2] > 240)
+ )
+ mask = mask & white_bg_mask
+
+ # Check if normals are available in predictions
+ vertex_normals = None
+ if "normal" in predictions and predictions["normal"] is not None:
+ # Get normals for the selected frame
+ frame_normals = (
+ predictions["normal"][selected_frame_idx]
+ if selected_frame_idx is not None
+ else predictions["normal"][0]
+ )
+
+ # Create faces and vertices using image_mesh with normals support
+ faces, vertices, vertex_colors, vertex_normals = image_mesh(
+ original_points * np.array([1, -1, 1], dtype=np.float32),
+ original_image_colors / 255.0,
+ frame_normals * np.array([1, -1, 1], dtype=np.float32),
+ mask=mask,
+ tri=True,
+ return_indices=False,
+ )
+
+ # Apply coordinate transformations to normals
+ vertex_normals = vertex_normals * np.array([1, -1, 1], dtype=np.float32)
+ else:
+ # Create faces and vertices using image_mesh without normals
+ faces, vertices, vertex_colors = image_mesh(
+ original_points * np.array([1, -1, 1], dtype=np.float32),
+ original_image_colors / 255.0,
+ mask=mask,
+ tri=True,
+ return_indices=False,
+ )
+
+ # vertices = vertices * np.array([1, -1, 1], dtype=np.float32)
+
+ # Create trimesh object with optional normals
+ mesh_data = trimesh.Trimesh(
+ vertices=vertices * np.array([1, -1, 1], dtype=np.float32),
+ faces=faces,
+ vertex_colors=(vertex_colors * 255).astype(np.uint8),
+ vertex_normals=(vertex_normals if vertex_normals is not None else None),
+ process=False,
+ )
+ scene_3d.add_geometry(mesh_data)
+
+ else:
+ # Multi-frame case - create separate meshes for each frame
+ print("Creating mesh for multi-frame data...")
+
+ for frame_idx in range(pred_world_points.shape[0]):
+ H, W = pred_world_points.shape[1:3]
+
+ # Get data for this frame
+ frame_points = pred_world_points[frame_idx]
+ frame_final_mask = predictions["final_mask"][frame_idx]
+
+ # Get frame image
+ if images.ndim == 4 and images.shape[1] == 3: # NCHW format
+ frame_image = np.transpose(images[frame_idx], (1, 2, 0))
+ else: # Assume already in HWC format
+ frame_image = images[frame_idx]
+ frame_image *= 255
+ # Create mask for this frame using final_mask
+ mask = frame_final_mask
+
+ # Additional background masks if needed
+ if mask_black_bg:
+ black_bg_mask = frame_image.sum(axis=2) >= 16
+ mask = mask & black_bg_mask
+
+ if mask_white_bg:
+ white_bg_mask = ~(
+ (frame_image[:, :, 0] > 240)
+ & (frame_image[:, :, 1] > 240)
+ & (frame_image[:, :, 2] > 240)
+ )
+ mask = mask & white_bg_mask
+
+ # Create mesh for this frame
+ faces, vertices, vertex_colors = image_mesh(
+ frame_points * np.array([1, -1, 1], dtype=np.float32),
+ frame_image / 255.0,
+ mask=mask,
+ tri=True,
+ return_indices=False,
+ )
+
+ vertices = vertices * np.array([1, -1, 1], dtype=np.float32)
+ # Create trimesh object for this frame
+ frame_mesh = trimesh.Trimesh(
+ vertices=vertices,
+ faces=faces,
+ vertex_colors=(vertex_colors * 255).astype(np.uint8),
+ process=False,
+ )
+ scene_3d.add_geometry(frame_mesh)
+ else:
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
+ scene_3d.add_geometry(point_cloud_data)
+
+ # Prepare 4x4 matrices for camera extrinsics
+ num_cameras = len(camera_matrices)
+
+ if show_cam:
+ # Add camera models to the scene
+ for i in range(num_cameras):
+ world_to_camera = camera_matrices[i]
+ camera_to_world = np.linalg.inv(world_to_camera)
+ rgba_color = colormap(i / num_cameras)
+ current_color = tuple(int(255 * x) for x in rgba_color[:3])
+
+ integrate_camera_into_scene(
+ scene_3d, world_to_camera, current_color, scene_scale
+ )
+
+ # Align scene to the observation of the first camera
+ scene_3d = apply_scene_alignment(scene_3d, camera_matrices)
+
+ print("GLB Scene built")
+ return scene_3d
+
+
+def integrate_camera_into_scene(
+ scene: trimesh.Scene,
+ transform: np.ndarray,
+ face_colors: tuple,
+ scene_scale: float,
+):
+ """
+ Integrates a fake camera mesh into the 3D scene.
+
+ Args:
+ scene (trimesh.Scene): The 3D scene to add the camera model.
+ transform (np.ndarray): Transformation matrix for camera positioning.
+ face_colors (tuple): Color of the camera face.
+ scene_scale (float): Scale of the scene.
+ """
+ scene_scale = 12
+ cam_width = scene_scale * 0.05
+ cam_height = scene_scale * 0.1
+ # cam_width = scene_scale * 0.05
+ # cam_height = scene_scale * 0.1
+
+ # Create cone shape for camera
+ rot_45_degree = np.eye(4)
+ rot_45_degree[:3, :3] = Rotation.from_euler("z", 45, degrees=True).as_matrix()
+ rot_45_degree[2, 3] = -cam_height
+
+ opengl_transform = get_opengl_conversion_matrix()
+ # Combine transformations
+ complete_transform = transform @ opengl_transform @ rot_45_degree
+ camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
+
+ # Generate mesh for the camera
+ slight_rotation = np.eye(4)
+ slight_rotation[:3, :3] = Rotation.from_euler("z", 2, degrees=True).as_matrix()
+
+ vertices_combined = np.concatenate(
+ [
+ camera_cone_shape.vertices,
+ 0.95 * camera_cone_shape.vertices,
+ transform_points(slight_rotation, camera_cone_shape.vertices),
+ ]
+ )
+ vertices_transformed = transform_points(complete_transform, vertices_combined)
+
+ mesh_faces = compute_camera_faces(camera_cone_shape)
+
+ # Add the camera mesh to the scene
+ camera_mesh = trimesh.Trimesh(vertices=vertices_transformed, faces=mesh_faces)
+ camera_mesh.visual.face_colors[:, :3] = face_colors
+ scene.add_geometry(camera_mesh)
+
+
+def apply_scene_alignment(
+ scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
+) -> trimesh.Scene:
+ """
+ Aligns the 3D scene based on the extrinsics of the first camera.
+
+ Args:
+ scene_3d (trimesh.Scene): The 3D scene to be aligned.
+ extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
+
+ Returns:
+ trimesh.Scene: Aligned 3D scene.
+ """
+ # Set transformations for scene alignment
+ opengl_conversion_matrix = get_opengl_conversion_matrix()
+
+ # Rotation matrix for alignment (180 degrees around the y-axis)
+ align_rotation = np.eye(4)
+ align_rotation[:3, :3] = Rotation.from_euler("y", 0, degrees=True).as_matrix()
+
+ # Apply transformation
+ initial_transformation = (
+ np.linalg.inv(extrinsics_matrices[0])
+ @ opengl_conversion_matrix
+ @ align_rotation
+ )
+ scene_3d.apply_transform(initial_transformation)
+ return scene_3d
+
+
+def get_opengl_conversion_matrix() -> np.ndarray:
+ """
+ Constructs and returns the OpenGL conversion matrix.
+
+ Returns:
+ numpy.ndarray: A 4x4 OpenGL conversion matrix.
+ """
+ # Create an identity matrix
+ matrix = np.identity(4)
+
+ # Flip the y and z axes
+ matrix[1, 1] = -1
+ matrix[2, 2] = -1
+
+ return matrix
+
+
+def transform_points(
+ transformation: np.ndarray, points: np.ndarray, dim: int = None
+) -> np.ndarray:
+ """
+ Applies a 4x4 transformation to a set of points.
+
+ Args:
+ transformation (np.ndarray): Transformation matrix.
+ points (np.ndarray): Points to be transformed.
+ dim (int, optional): Dimension for reshaping the result.
+
+ Returns:
+ np.ndarray: Transformed points.
+ """
+ points = np.asarray(points)
+ initial_shape = points.shape[:-1]
+ dim = dim or points.shape[-1]
+
+ # Apply transformation
+ transformation = transformation.swapaxes(
+ -1, -2
+ ) # Transpose the transformation matrix
+ points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
+
+ # Reshape the result
+ result = points[..., :dim].reshape(*initial_shape, dim)
+ return result
+
+
+def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
+ """
+ Computes the faces for the camera mesh.
+
+ Args:
+ cone_shape (trimesh.Trimesh): The shape of the camera cone.
+
+ Returns:
+ np.ndarray: Array of faces for the camera mesh.
+ """
+ # Create pseudo cameras
+ faces_list = []
+ num_vertices_cone = len(cone_shape.vertices)
+
+ for face in cone_shape.faces:
+ if 0 in face:
+ continue
+ v1, v2, v3 = face
+ v1_offset, v2_offset, v3_offset = face + num_vertices_cone
+ v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
+
+ faces_list.extend(
+ [
+ (v1, v2, v2_offset),
+ (v1, v1_offset, v3),
+ (v3_offset, v2, v3),
+ (v1, v2, v2_offset_2),
+ (v1, v1_offset_2, v3),
+ (v3_offset_2, v2, v3),
+ ]
+ )
+
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
+ return np.array(faces_list)
+
+
+def segment_sky(image_path, onnx_session, mask_filename=None):
+ """
+ Segments sky from an image using an ONNX model.
+ Thanks for the great model provided by https://github.com/xiongzhu666/Sky-Segmentation-and-Post-processing
+
+ Args:
+ image_path: Path to input image
+ onnx_session: ONNX runtime session with loaded model
+ mask_filename: Path to save the output mask
+
+ Returns:
+ np.ndarray: Binary mask where 255 indicates non-sky regions
+ """
+
+ assert mask_filename is not None
+ image = cv2.imread(image_path)
+
+ result_map = run_skyseg(onnx_session, [320, 320], image)
+ # resize the result_map to the original image size
+ result_map_original = cv2.resize(result_map, (image.shape[1], image.shape[0]))
+
+ # Fix: Invert the mask so that 255 = non-sky, 0 = sky
+ # The model outputs low values for sky, high values for non-sky
+ output_mask = np.zeros_like(result_map_original)
+ output_mask[result_map_original < 32] = 255 # Use threshold of 32
+
+ os.makedirs(os.path.dirname(mask_filename), exist_ok=True)
+ cv2.imwrite(mask_filename, output_mask)
+ return output_mask
+
+
+def run_skyseg(onnx_session, input_size, image):
+ """
+ Runs sky segmentation inference using ONNX model.
+
+ Args:
+ onnx_session: ONNX runtime session
+ input_size: Target size for model input (width, height)
+ image: Input image in BGR format
+
+ Returns:
+ np.ndarray: Segmentation mask
+ """
+
+ # Pre process:Resize, BGR->RGB, Transpose, PyTorch standardization, float32 cast
+ temp_image = copy.deepcopy(image)
+ resize_image = cv2.resize(temp_image, dsize=(input_size[0], input_size[1]))
+ x = cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB)
+ x = np.array(x, dtype=np.float32)
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+ x = (x / 255 - mean) / std
+ x = x.transpose(2, 0, 1)
+ x = x.reshape(-1, 3, input_size[0], input_size[1]).astype("float32")
+
+ # Inference
+ input_name = onnx_session.get_inputs()[0].name
+ output_name = onnx_session.get_outputs()[0].name
+ onnx_result = onnx_session.run([output_name], {input_name: x})
+
+ # Post process
+ onnx_result = np.array(onnx_result).squeeze()
+ min_value = np.min(onnx_result)
+ max_value = np.max(onnx_result)
+ onnx_result = (onnx_result - min_value) / (max_value - min_value)
+ onnx_result *= 255
+ onnx_result = onnx_result.astype("uint8")
+
+ return onnx_result
+
+
+def download_file_from_url(url, filename):
+ """Downloads a file from a Hugging Face model repo, handling redirects."""
+ try:
+ # Get the redirect URL
+ response = requests.get(url, allow_redirects=False)
+ response.raise_for_status() # Raise HTTPError for bad requests (4xx or 5xx)
+
+ if response.status_code == 302: # Expecting a redirect
+ redirect_url = response.headers["Location"]
+ response = requests.get(redirect_url, stream=True)
+ response.raise_for_status()
+ else:
+ print(f"Unexpected status code: {response.status_code}")
+ return
+
+ with open(filename, "wb") as f:
+ for chunk in response.iter_content(chunk_size=8192):
+ f.write(chunk)
+ print(f"Downloaded {filename} successfully.")
+
+ except requests.exceptions.RequestException as e:
+ print(f"Error downloading file: {e}")
diff --git a/mapanything/utils/image.py b/mapanything/utils/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc097c01148aa58e09734b0433d8a102ce1bf079
--- /dev/null
+++ b/mapanything/utils/image.py
@@ -0,0 +1,326 @@
+"""
+Utility functions for loading, converting, and manipulating images.
+
+This module provides functions for:
+- Converting between different image formats and representations
+- Resizing and cropping images to specific resolutions
+- Loading and normalizing images for model input
+- Handling various image file formats including HEIF/HEIC when available
+"""
+
+import os
+
+import numpy as np
+import PIL.Image
+import torch
+import torchvision.transforms as tvf
+from PIL.ImageOps import exif_transpose
+
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import cv2
+
+try:
+ from pillow_heif import register_heif_opener
+
+ register_heif_opener()
+ heif_support_enabled = True
+except ImportError:
+ heif_support_enabled = False
+
+from mapanything.utils.cropping import crop_resize_if_necessary
+from uniception.models.encoders.image_normalizations import IMAGE_NORMALIZATION_DICT
+
+# Fixed resolution mappings with precomputed aspect ratios as keys
+RESOLUTION_MAPPINGS = {
+ 518: {
+ 1.000: (518, 518), # 1:1
+ 1.321: (518, 392), # 4:3
+ 1.542: (518, 336), # 3:2
+ 1.762: (518, 294), # 16:9
+ 2.056: (518, 252), # 2:1
+ 3.083: (518, 168), # 3.2:1
+ 0.757: (392, 518), # 3:4
+ 0.649: (336, 518), # 2:3
+ 0.567: (294, 518), # 9:16
+ 0.486: (252, 518), # 1:2
+ },
+ 512: {
+ 1.000: (512, 512), # 1:1
+ 1.333: (512, 384), # 4:3
+ 1.524: (512, 336), # 3:2
+ 1.778: (512, 288), # 16:9
+ 2.000: (512, 256), # 2:1
+ 3.200: (512, 160), # 3.2:1
+ 0.750: (384, 512), # 3:4
+ 0.656: (336, 512), # 2:3
+ 0.562: (288, 512), # 9:16
+ 0.500: (256, 512), # 1:2
+ },
+}
+
+# Precomputed sorted aspect ratio keys for efficient lookup
+ASPECT_RATIO_KEYS = {
+ 518: sorted(RESOLUTION_MAPPINGS[518].keys()),
+ 512: sorted(RESOLUTION_MAPPINGS[512].keys()),
+}
+
+
+def find_closest_aspect_ratio(aspect_ratio, resolution_set):
+ """
+ Find the closest aspect ratio from the resolution mappings using efficient key lookup.
+
+ Args:
+ aspect_ratio (float): Target aspect ratio
+ resolution_set (int): Resolution set to use (518 or 512)
+
+ Returns:
+ tuple: (target_width, target_height) from the resolution mapping
+ """
+ aspect_keys = ASPECT_RATIO_KEYS[resolution_set]
+
+ # Find the closest aspect ratio key using binary search approach
+ closest_key = min(aspect_keys, key=lambda x: abs(x - aspect_ratio))
+
+ return RESOLUTION_MAPPINGS[resolution_set][closest_key]
+
+
+def rgb(ftensor, norm_type, true_shape=None):
+ """
+ Convert normalized image tensor to RGB image for visualization.
+
+ Args:
+ ftensor (torch.Tensor or numpy.ndarray or list): Image tensor or list of image tensors
+ norm_type (str): Normalization type, see UniCeption IMAGE_NORMALIZATION_DICT keys or use "identity"
+ true_shape (tuple, optional): If provided, the image will be cropped to this shape (H, W)
+
+ Returns:
+ numpy.ndarray: RGB image with values in range [0, 1]
+ """
+ if isinstance(ftensor, list):
+ return [rgb(x, norm_type, true_shape=true_shape) for x in ftensor]
+ if isinstance(ftensor, torch.Tensor):
+ ftensor = ftensor.detach().cpu().numpy() # H,W,3
+ if ftensor.ndim == 3 and ftensor.shape[0] == 3:
+ ftensor = ftensor.transpose(1, 2, 0)
+ elif ftensor.ndim == 4 and ftensor.shape[1] == 3:
+ ftensor = ftensor.transpose(0, 2, 3, 1)
+ if true_shape is not None:
+ H, W = true_shape
+ ftensor = ftensor[:H, :W]
+ if ftensor.dtype == np.uint8:
+ img = np.float32(ftensor) / 255
+ else:
+ if norm_type in IMAGE_NORMALIZATION_DICT.keys():
+ img_norm = IMAGE_NORMALIZATION_DICT[norm_type]
+ mean = img_norm.mean.numpy()
+ std = img_norm.std.numpy()
+ elif norm_type == "identity":
+ mean = 0.0
+ std = 1.0
+ else:
+ raise ValueError(
+ f"Unknown image normalization type: {norm_type}. Available types: identity or {IMAGE_NORMALIZATION_DICT.keys()}"
+ )
+ img = ftensor * std + mean
+ return img.clip(min=0, max=1)
+
+
+def load_images(
+ folder_or_list,
+ resize_mode="fixed_mapping",
+ size=None,
+ norm_type="dinov2",
+ patch_size=14,
+ verbose=False,
+ bayer_format=False,
+ resolution_set=518,
+ stride=1,
+):
+ """
+ Open and convert all images in a list or folder to proper input format for model
+
+ Args:
+ folder_or_list (str or list): Path to folder or list of image paths.
+ resize_mode (str): Resize mode - "fixed_mapping", "longest_side", "square", or "fixed_size". Defaults to "fixed_mapping".
+ size (int or tuple, optional): Required for "longest_side", "square", and "fixed_size" modes.
+ - For "longest_side" and "square": int value for resize dimension
+ - For "fixed_size": tuple of (width, height)
+ norm_type (str, optional): Image normalization type. See UniCeption IMAGE_NORMALIZATION_DICT keys. Defaults to "dinov2".
+ patch_size (int, optional): Patch size for image processing. Defaults to 14.
+ verbose (bool, optional): If True, print progress messages. Defaults to False.
+ bayer_format (bool, optional): If True, read images in Bayer format. Defaults to False.
+ resolution_set (int, optional): Resolution set to use for "fixed_mapping" mode (518 or 512). Defaults to 518.
+ stride (int, optional): Load every nth image from the input. stride=1 loads all images, stride=2 loads every 2nd image, etc. Defaults to 1.
+
+ Returns:
+ list: List of dictionaries containing image data and metadata
+ """
+ # Validate resize_mode and size parameter requirements
+ valid_resize_modes = ["fixed_mapping", "longest_side", "square", "fixed_size"]
+ if resize_mode not in valid_resize_modes:
+ raise ValueError(
+ f"Resize_mode must be one of {valid_resize_modes}, got '{resize_mode}'"
+ )
+
+ if resize_mode in ["longest_side", "square", "fixed_size"] and size is None:
+ raise ValueError(f"Size parameter is required for resize_mode='{resize_mode}'")
+
+ # Validate size type based on resize mode
+ if resize_mode in ["longest_side", "square"]:
+ if not isinstance(size, int):
+ raise ValueError(
+ f"Size must be an int for resize_mode='{resize_mode}', got {type(size)}"
+ )
+ elif resize_mode == "fixed_size":
+ if not isinstance(size, (tuple, list)) or len(size) != 2:
+ raise ValueError(
+ f"Size must be a tuple/list of (width, height) for resize_mode='fixed_size', got {size}"
+ )
+ if not all(isinstance(x, int) for x in size):
+ raise ValueError(
+ f"Size values must be integers for resize_mode='fixed_size', got {size}"
+ )
+
+ # Get list of image paths
+ if isinstance(folder_or_list, str):
+ # If folder_or_list is a string, assume it's a path to a folder
+ if verbose:
+ print(f"Loading images from {folder_or_list}")
+ root, folder_content = folder_or_list, sorted(os.listdir(folder_or_list))
+ elif isinstance(folder_or_list, list):
+ # If folder_or_list is a list, assume it's a list of image paths
+ if verbose:
+ print(f"Loading a list of {len(folder_or_list)} images")
+ root, folder_content = "", folder_or_list
+ else:
+ # If folder_or_list is neither a string nor a list, raise an error
+ raise ValueError(f"Bad {folder_or_list=} ({type(folder_or_list)})")
+
+ # Define supported image extensions
+ supported_images_extensions = [".jpg", ".jpeg", ".png"]
+ if heif_support_enabled:
+ supported_images_extensions += [".heic", ".heif"]
+ supported_images_extensions = tuple(supported_images_extensions)
+
+ # First pass: Load all images and collect aspect ratios
+ loaded_images = []
+ aspect_ratios = []
+ for i, path in enumerate(folder_content):
+ # Skip images based on stride
+ if i % stride != 0:
+ continue
+
+ # Check if the file has a supported image extension
+ if not path.lower().endswith(supported_images_extensions):
+ continue
+
+ try:
+ if bayer_format:
+ # If bayer_format is True, read the image in Bayer format
+ color_bayer = cv2.imread(os.path.join(root, path), cv2.IMREAD_UNCHANGED)
+ color = cv2.cvtColor(color_bayer, cv2.COLOR_BAYER_RG2BGR)
+ img = PIL.Image.fromarray(color)
+ img = exif_transpose(img).convert("RGB")
+ else:
+ # Otherwise, read the image normally
+ img = exif_transpose(PIL.Image.open(os.path.join(root, path))).convert(
+ "RGB"
+ )
+
+ W1, H1 = img.size
+ aspect_ratios.append(W1 / H1)
+ loaded_images.append((path, img, W1, H1))
+
+ except Exception as e:
+ if verbose:
+ print(f"Warning: Could not load {path}: {e}")
+ continue
+
+ # Check if any images were loaded
+ if not loaded_images:
+ raise ValueError("No valid images found")
+
+ # Calculate average aspect ratio and determine target size
+ average_aspect_ratio = sum(aspect_ratios) / len(aspect_ratios)
+ if verbose:
+ print(
+ f"Calculated average aspect ratio: {average_aspect_ratio:.3f} from {len(aspect_ratios)} images"
+ )
+
+ # Determine target size for all images based on resize mode
+ if resize_mode == "fixed_mapping":
+ # Resolution mappings are already compatible with their respective patch sizes
+ # 518 mappings are divisible by 14, 512 mappings are divisible by 16
+ target_width, target_height = find_closest_aspect_ratio(
+ average_aspect_ratio, resolution_set
+ )
+ target_size = (target_width, target_height)
+ elif resize_mode == "square":
+ target_size = (
+ round((size // patch_size)) * patch_size,
+ round((size // patch_size)) * patch_size,
+ )
+ elif resize_mode == "longest_side":
+ # Use average aspect ratio to determine size for all images
+ # Longest side should be the input size
+ if average_aspect_ratio >= 1: # Landscape or square
+ # Width is the longest side
+ target_size = (
+ size,
+ round((size // patch_size) / average_aspect_ratio) * patch_size,
+ )
+ else: # Portrait
+ # Height is the longest side
+ target_size = (
+ round((size // patch_size) * average_aspect_ratio) * patch_size,
+ size,
+ )
+ elif resize_mode == "fixed_size":
+ # Use exact size provided, aligned to patch_size
+ target_size = (
+ (size[0] // patch_size) * patch_size,
+ (size[1] // patch_size) * patch_size,
+ )
+
+ if verbose:
+ print(
+ f"Using target resolution {target_size[0]}x{target_size[1]} (W x H) for all images"
+ )
+
+ # Get the image normalization function based on the norm_type
+ if norm_type in IMAGE_NORMALIZATION_DICT.keys():
+ img_norm = IMAGE_NORMALIZATION_DICT[norm_type]
+ ImgNorm = tvf.Compose(
+ [tvf.ToTensor(), tvf.Normalize(mean=img_norm.mean, std=img_norm.std)]
+ )
+ else:
+ raise ValueError(
+ f"Unknown image normalization type: {norm_type}. Available options: {list(IMAGE_NORMALIZATION_DICT.keys())}"
+ )
+
+ # Second pass: Resize all images to the same target size
+ imgs = []
+ for path, img, W1, H1 in loaded_images:
+ # Resize and crop the image to the target size
+ img = crop_resize_if_necessary(img, resolution=target_size)[0]
+
+ # Normalize image and add it to the list
+ W2, H2 = img.size
+ if verbose:
+ print(f" - Adding {path} with resolution {W1}x{H1} --> {W2}x{H2}")
+
+ imgs.append(
+ dict(
+ img=ImgNorm(img)[None],
+ true_shape=np.int32([img.size[::-1]]),
+ idx=len(imgs),
+ instance=str(len(imgs)),
+ data_norm_type=[norm_type],
+ )
+ )
+
+ assert imgs, "No images foud at " + root
+ if verbose:
+ print(f" (Found {len(imgs)} images)")
+
+ return imgs
diff --git a/mapanything/utils/inference.py b/mapanything/utils/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e6a84c46faf4e47e9344a0eae5ff86747e6f34b
--- /dev/null
+++ b/mapanything/utils/inference.py
@@ -0,0 +1,475 @@
+"""
+Inference utilities.
+"""
+
+import warnings
+from typing import Any, Dict, List
+
+import numpy as np
+import torch
+
+from mapanything.utils.geometry import (
+ depth_edge,
+ get_rays_in_camera_frame,
+ normals_edge,
+ points_to_normals,
+ quaternion_to_rotation_matrix,
+ recover_pinhole_intrinsics_from_ray_directions,
+ rotation_matrix_to_quaternion,
+)
+from mapanything.utils.image import rgb
+
+# Hard constraints - exactly what users can provide
+ALLOWED_VIEW_KEYS = {
+ "img", # Required - input images
+ "data_norm_type", # Required - normalization type of the input images
+ "depth_z", # Optional - Z depth maps
+ "ray_directions", # Optional - ray directions in camera frame
+ "intrinsics", # Optional - pinhole camera intrinsics (conflicts with ray_directions)
+ "camera_poses", # Optional - camera poses
+ "is_metric_scale", # Optional - whether inputs are metric scale
+ "true_shape", # Optional - original image shape
+ "idx", # Optional - index of the view
+ "instance", # Optional - instance info of the view
+}
+
+REQUIRED_KEYS = {"img", "data_norm_type"}
+
+# Define conflicting keys that cannot be used together
+CONFLICTING_KEYS = [
+ ("intrinsics", "ray_directions") # Both represent camera projection
+]
+
+
+def loss_of_one_batch_multi_view(
+ batch,
+ model,
+ criterion,
+ device,
+ use_amp=False,
+ amp_dtype="bf16",
+ ret=None,
+ ignore_keys=None,
+):
+ """
+ Calculate loss for a batch with multiple views.
+
+ Args:
+ batch (list): List of view dictionaries containing input data.
+ model (torch.nn.Module): Model to run inference with.
+ criterion (callable, optional): Loss function to compute the loss.
+ device (torch.device): Device to run the computation on.
+ use_amp (bool, optional): Whether to use automatic mixed precision. Defaults to False.
+ amp_dtype (str, optional): Floating point type to use for automatic mixed precision. Options: ["fp32", "fp16", "bf16"]. Defaults to "bf16".
+ ret (str, optional): If provided, return only the specified key from the result dictionary.
+ ignore_keys (set, optional): Set of keys to ignore when moving tensors to device.
+ Defaults to {"dataset", "label", "instance",
+ "idx", "true_shape", "rng", "data_norm_type"}.
+
+ Returns:
+ dict or Any: If ret is None, returns a dictionary containing views, predictions, and loss.
+ Otherwise, returns the value associated with the ret key.
+ """
+ # Move necessary tensors to device
+ if ignore_keys is None:
+ ignore_keys = set(
+ [
+ "depthmap",
+ "dataset",
+ "label",
+ "instance",
+ "idx",
+ "true_shape",
+ "rng",
+ "data_norm_type",
+ ]
+ )
+ for view in batch:
+ for name in view.keys():
+ if name in ignore_keys:
+ continue
+ view[name] = view[name].to(device, non_blocking=True)
+
+ # Determine the mixed precision floating point type
+ if use_amp:
+ if amp_dtype == "fp16":
+ amp_dtype = torch.float16
+ elif amp_dtype == "bf16":
+ if torch.cuda.is_bf16_supported():
+ amp_dtype = torch.bfloat16
+ else:
+ warnings.warn(
+ "bf16 is not supported on this device. Using fp16 instead."
+ )
+ amp_dtype = torch.float16
+ elif amp_dtype == "fp32":
+ amp_dtype = torch.float32
+ else:
+ amp_dtype = torch.float32
+
+ # Run model and compute loss
+ with torch.autocast("cuda", enabled=bool(use_amp), dtype=amp_dtype):
+ preds = model(batch)
+ with torch.autocast("cuda", enabled=False):
+ loss = criterion(batch, preds) if criterion is not None else None
+
+ result = {f"view{i + 1}": view for i, view in enumerate(batch)}
+ result.update({f"pred{i + 1}": pred for i, pred in enumerate(preds)})
+ result["loss"] = loss
+
+ return result[ret] if ret else result
+
+
+def validate_input_views_for_inference(
+ views: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """
+ Strict validation and preprocessing of input views.
+
+ Args:
+ views: List of view dictionaries
+
+ Returns:
+ Validated and preprocessed views
+
+ Raises:
+ ValueError: For invalid keys, missing required keys, conflicting inputs, or invalid camera pose constraints
+ """
+ # Ensure input is not empty
+ if not views:
+ raise ValueError("At least one view must be provided")
+
+ # Track which views have camera poses
+ views_with_poses = []
+
+ # Validate each view
+ for view_idx, view in enumerate(views):
+ # Check for invalid keys
+ provided_keys = set(view.keys())
+ invalid_keys = provided_keys - ALLOWED_VIEW_KEYS
+ if invalid_keys:
+ raise ValueError(
+ f"View {view_idx} contains invalid keys: {invalid_keys}. "
+ f"Allowed keys are: {sorted(ALLOWED_VIEW_KEYS)}"
+ )
+
+ # Check for missing required keys
+ missing_keys = REQUIRED_KEYS - provided_keys
+ if missing_keys:
+ raise ValueError(f"View {view_idx} missing required keys: {missing_keys}")
+
+ # Check for conflicting keys
+ for conflict_set in CONFLICTING_KEYS:
+ present_conflicts = [key for key in conflict_set if key in provided_keys]
+ if len(present_conflicts) > 1:
+ raise ValueError(
+ f"View {view_idx} contains conflicting keys: {present_conflicts}. "
+ f"Only one of {conflict_set} can be provided at a time."
+ )
+
+ # Check depth constraint: If depth is provided, intrinsics or ray_directions must also be provided
+ if "depth_z" in provided_keys:
+ if (
+ "intrinsics" not in provided_keys
+ and "ray_directions" not in provided_keys
+ ):
+ raise ValueError(
+ f"View {view_idx} depth constraint violation: If 'depth_z' is provided, "
+ f"then 'intrinsics' or 'ray_directions' must also be provided. "
+ f"Z Depth values require camera calibration information to be meaningful for an image."
+ )
+
+ # Track views with camera poses
+ if "camera_poses" in provided_keys:
+ views_with_poses.append(view_idx)
+
+ # Cross-view constraint: If any view has camera_poses, view 0 must have them too
+ if views_with_poses and 0 not in views_with_poses:
+ raise ValueError(
+ f"Camera pose constraint violation: Views {views_with_poses} have camera_poses, "
+ f"but view 0 (reference view) does not. When using camera_poses, the first view "
+ f"must also provide camera_poses to serve as the reference frame."
+ )
+
+ return views
+
+
+def preprocess_input_views_for_inference(
+ views: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
+ """
+ Pre-process input views to match the expected internal input format.
+
+ The following steps are performed:
+ 1. Convert intrinsics to ray directions when required. If ray directions are already provided, unit normalize them.
+ 2. Convert depth_z to depth_along_ray
+ 3. Convert camera_poses to the expected input keys (camera_pose_quats and camera_pose_trans)
+ 4. Default is_metric_scale to True when not provided
+
+ Args:
+ views: List of view dictionaries
+
+ Returns:
+ Preprocessed views with consistent internal format
+ """
+ processed_views = []
+
+ for view_idx, view in enumerate(views):
+ # Copy the view dictionary to avoid modifying the original input
+ processed_view = dict(view)
+
+ # Step 1: Convert intrinsics to ray_directions when required. If ray_directions are provided, unit normalize them.
+ if "intrinsics" in view:
+ images = view["img"]
+ height, width = images.shape[-2:]
+ intrinsics = view["intrinsics"]
+ _, ray_directions = get_rays_in_camera_frame(
+ intrinsics=intrinsics,
+ height=height,
+ width=width,
+ normalize_to_unit_sphere=True,
+ )
+ processed_view["ray_directions"] = ray_directions
+ del processed_view["intrinsics"]
+ elif "ray_directions" in view:
+ ray_directions = view["ray_directions"]
+ ray_norm = torch.norm(ray_directions, dim=-1, keepdim=True)
+ processed_view["ray_directions"] = ray_directions / (ray_norm + 1e-8)
+
+ # Step 2: Convert depth_z to depth_along_ray
+ if "depth_z" in view:
+ depth_z = view["depth_z"]
+ ray_directions = processed_view["ray_directions"]
+ ray_directions_unit_plane = ray_directions / ray_directions[..., 2:3]
+ pts3d_cam = depth_z * ray_directions_unit_plane
+ depth_along_ray = torch.norm(pts3d_cam, dim=-1, keepdim=True)
+ processed_view["depth_along_ray"] = depth_along_ray
+ del processed_view["depth_z"]
+
+ # Step 3: Convert camera_poses to expected input keys
+ if "camera_poses" in view:
+ camera_poses = view["camera_poses"]
+ if isinstance(camera_poses, tuple) and len(camera_poses) == 2:
+ quats, trans = camera_poses
+ processed_view["camera_pose_quats"] = quats
+ processed_view["camera_pose_trans"] = trans
+ elif torch.is_tensor(camera_poses) and camera_poses.shape[-2:] == (4, 4):
+ rotation_matrices = camera_poses[:, :3, :3]
+ translation_vectors = camera_poses[:, :3, 3]
+ quats = rotation_matrix_to_quaternion(rotation_matrices)
+ processed_view["camera_pose_quats"] = quats
+ processed_view["camera_pose_trans"] = translation_vectors
+ else:
+ raise ValueError(
+ f"View {view_idx}: camera_poses must be either a tuple of (quats, trans) "
+ f"or a tensor of (B, 4, 4) transformation matrices."
+ )
+ del processed_view["camera_poses"]
+
+ # Step 4: Default is_metric_scale to True when not provided
+ if "is_metric_scale" not in processed_view:
+ # Get batch size from the image tensor
+ batch_size = view["img"].shape[0]
+ # Default to True for all samples in the batch
+ processed_view["is_metric_scale"] = torch.ones(
+ batch_size, dtype=torch.bool, device=view["img"].device
+ )
+
+ # Rename keys to match expected model input format
+ if "ray_directions" in processed_view:
+ processed_view["ray_directions_cam"] = processed_view["ray_directions"]
+ del processed_view["ray_directions"]
+
+ # Append the processed view to the list
+ processed_views.append(processed_view)
+
+ return processed_views
+
+
+def postprocess_model_outputs_for_inference(
+ raw_outputs: List[Dict[str, torch.Tensor]],
+ input_views: List[Dict[str, Any]],
+ apply_mask: bool = True,
+ mask_edges: bool = True,
+ edge_normal_threshold: float = 5.0,
+ edge_depth_threshold: float = 0.03,
+ apply_confidence_mask: bool = False,
+ confidence_percentile: float = 10,
+) -> List[Dict[str, torch.Tensor]]:
+ """
+ Post-process raw model outputs by copying raw outputs and adding essential derived fields.
+
+ This function simplifies the raw model outputs by:
+ 1. Copying all raw outputs as-is
+ 2. Adding denormalized images (img_no_norm)
+ 3. Adding Z depth (depth_z) from camera frame points
+ 4. Recovering pinhole camera intrinsics from ray directions
+ 5. Adding camera pose matrices (camera_poses) if pose data is available
+ 6. Applying mask to dense geometry outputs if requested (supports edge masking and confidence masking)
+
+ Args:
+ raw_outputs: List of raw model output dictionaries, one per view
+ input_views: List of original input view dictionaries, one per view
+ apply_mask: Whether to apply non-ambiguous mask to dense outputs. Defaults to True.
+ mask_edges: Whether to compute an edge mask based on normals and depth and apply it to the output. Defaults to True.
+ apply_confidence_mask: Whether to apply the confidence mask to the output. Defaults to False.
+ confidence_percentile: The percentile to use for the confidence threshold. Defaults to 10.
+
+ Returns:
+ List of processed output dictionaries containing:
+ - All original raw outputs (after masking dense geometry outputs if requested)
+ - 'img_no_norm': Denormalized RGB images (B, H, W, 3)
+ - 'depth_z': Z depth from camera frame (B, H, W, 1) if points in camera frame available
+ - 'intrinsics': Recovered pinhole camera intrinsics (B, 3, 3) if ray directions available
+ - 'camera_poses': 4x4 pose matrices (B, 4, 4) if pose data available
+ - 'mask': comprehensive mask for dense geometry outputs (B, H, W, 1) if requested
+
+ """
+ processed_outputs = []
+
+ for view_idx, (raw_output, original_view) in enumerate(
+ zip(raw_outputs, input_views)
+ ):
+ # Start by copying all raw outputs
+ processed_output = dict(raw_output)
+
+ # 1. Add denormalized images
+ img = original_view["img"] # Shape: (B, 3, H, W)
+ data_norm_type = original_view["data_norm_type"][0]
+ img_hwc = rgb(img, data_norm_type)
+
+ # Convert numpy back to torch if needed (rgb returns numpy)
+ if isinstance(img_hwc, np.ndarray):
+ img_hwc = torch.from_numpy(img_hwc).to(img.device)
+
+ processed_output["img_no_norm"] = img_hwc
+
+ # 2. Add Z depth if we have camera frame points
+ if "pts3d_cam" in processed_output:
+ processed_output["depth_z"] = processed_output["pts3d_cam"][..., 2:3]
+
+ # 3. Recover pinhole camera intrinsics from ray directions if available
+ if "ray_directions" in processed_output:
+ intrinsics = recover_pinhole_intrinsics_from_ray_directions(
+ processed_output["ray_directions"]
+ )
+ processed_output["intrinsics"] = intrinsics
+
+ # 4. Add camera pose matrices if both translation and quaternions are available
+ if "cam_trans" in processed_output and "cam_quats" in processed_output:
+ cam_trans = processed_output["cam_trans"] # (B, 3)
+ cam_quats = processed_output["cam_quats"] # (B, 4)
+ batch_size = cam_trans.shape[0]
+
+ # Convert quaternions to rotation matrices
+ rotation_matrices = quaternion_to_rotation_matrix(cam_quats) # (B, 3, 3)
+
+ # Create 4x4 pose matrices
+ pose_matrices = (
+ torch.eye(4, device=img.device).unsqueeze(0).repeat(batch_size, 1, 1)
+ )
+ pose_matrices[:, :3, :3] = rotation_matrices
+ pose_matrices[:, :3, 3] = cam_trans
+
+ processed_output["camera_poses"] = pose_matrices # (B, 4, 4)
+
+ # 5. Apply comprehensive mask to dense geometry outputs if requested
+ if apply_mask:
+ final_mask = None
+
+ # Start with non-ambiguous mask if available
+ if "non_ambiguous_mask" in processed_output:
+ non_ambiguous_mask = (
+ processed_output["non_ambiguous_mask"].cpu().numpy()
+ ) # (B, H, W)
+ final_mask = non_ambiguous_mask
+
+ # Apply confidence mask if requested and available
+ if apply_confidence_mask and "conf" in processed_output:
+ confidences = processed_output["conf"].cpu() # (B, H, W)
+ # Compute percentile threshold for each batch element
+ batch_size = confidences.shape[0]
+ conf_mask = torch.zeros_like(confidences, dtype=torch.bool)
+ percentile_threshold = (
+ torch.quantile(
+ confidences.reshape(batch_size, -1),
+ confidence_percentile / 100.0,
+ dim=1,
+ )
+ .unsqueeze(-1)
+ .unsqueeze(-1)
+ ) # Shape: (B, 1, 1)
+
+ # Compute mask for each batch element
+ conf_mask = confidences > percentile_threshold
+ conf_mask = conf_mask.numpy()
+
+ if final_mask is not None:
+ final_mask = final_mask & conf_mask
+ else:
+ final_mask = conf_mask
+
+ # Apply edge mask if requested and we have the required data
+ if mask_edges and final_mask is not None and "pts3d" in processed_output:
+ # Get 3D points for edge computation
+ pred_pts3d = processed_output["pts3d"].cpu().numpy() # (B, H, W, 3)
+ batch_size, height, width = final_mask.shape
+
+ edge_masks = []
+ for b in range(batch_size):
+ batch_final_mask = final_mask[b] # (H, W)
+ batch_pts3d = pred_pts3d[b] # (H, W, 3)
+
+ if batch_final_mask.any(): # Only compute if we have valid points
+ # Compute normals and normal-based edge mask
+ normals, normals_mask = points_to_normals(
+ batch_pts3d, mask=batch_final_mask
+ )
+ normal_edges = normals_edge(
+ normals, tol=edge_normal_threshold, mask=normals_mask
+ )
+
+ # Compute depth-based edge mask
+ depth_z = (
+ processed_output["depth_z"][b].squeeze(-1).cpu().numpy()
+ )
+ depth_edges = depth_edge(
+ depth_z, rtol=edge_depth_threshold, mask=batch_final_mask
+ )
+
+ # Combine both edge types
+ edge_mask = ~(depth_edges & normal_edges)
+ edge_masks.append(edge_mask)
+ else:
+ # No valid points, keep all as invalid
+ edge_masks.append(np.zeros_like(batch_final_mask, dtype=bool))
+
+ # Stack batch edge masks and combine with final mask
+ edge_mask = np.stack(edge_masks, axis=0) # (B, H, W)
+ final_mask = final_mask & edge_mask
+
+ # Apply final mask to dense geometry outputs if we have a mask
+ if final_mask is not None:
+ # Convert mask to torch tensor
+ final_mask_torch = torch.from_numpy(final_mask).to(
+ processed_output["pts3d"].device
+ )
+ final_mask_torch = final_mask_torch.unsqueeze(-1) # (B, H, W, 1)
+
+ # Apply mask to dense geometry outputs (zero out invalid regions)
+ dense_geometry_keys = [
+ "pts3d",
+ "pts3d_cam",
+ "depth_along_ray",
+ "depth_z",
+ ]
+ for key in dense_geometry_keys:
+ if key in processed_output:
+ processed_output[key] = processed_output[key] * final_mask_torch
+
+ # Add mask to processed output
+ processed_output["mask"] = final_mask_torch
+
+ processed_outputs.append(processed_output)
+
+ return processed_outputs
diff --git a/mapanything/utils/metrics.py b/mapanything/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..29d28b5a6856daface74ddebe50d51ca9548ee51
--- /dev/null
+++ b/mapanything/utils/metrics.py
@@ -0,0 +1,504 @@
+"""
+Utils for Metrics
+Source for Pose AUC Metrics: VGGT
+"""
+
+import math
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+def l2_distance_of_unit_quats_to_angular_error(l2_distance):
+ """
+ Converts a given L2 distance (for unit quaternions) to the angular error in degrees.
+ For two quaternions differing by an angle θ the relationship is:
+ L2 distance = 2 * sin(θ/4)
+ Hence, the angular error in degrees is computed as:
+ 4 * asin(l2_distance / 2) * (180/π)
+
+ Args:
+ l2_distance: L2 distance between two unit quaternions (torch.Tensor, shape: (N,))
+ Returns:
+ angular_error_degrees: Angular error in degrees (torch.Tensor, shape: (N,))
+ """
+ angular_error_radians = 4 * torch.asin(l2_distance / 2)
+ angular_error_degrees = angular_error_radians * 180.0 / math.pi
+
+ return angular_error_degrees
+
+
+def l2_distance_of_unit_ray_directions_to_angular_error(l2_distance):
+ """
+ Converts a given L2 distance (for unit ray directions) to the angular error in degrees.
+ For two unit ray directions differing by an angle θ the relationship is:
+ L2 distance = 2 * sin(θ/2)
+ Hence, the angular error in degrees is computed as:
+ 2 * asin(l2_distance / 2) * (180/π)
+
+ Args:
+ l2_distance: L2 distance between two unit ray directions (torch.Tensor, shape: (N,))
+ Returns:
+ angular_error_degrees: Angular error in degrees (torch.Tensor, shape: (N,))
+ """
+ angular_error_radians = 2 * torch.asin(l2_distance / 2)
+ angular_error_degrees = angular_error_radians * 180.0 / math.pi
+
+ return angular_error_degrees
+
+
+def valid_mean(arr, mask, axis=None, keepdims=np._NoValue):
+ """Compute mean of elements across given dimensions of an array, considering only valid elements.
+
+ Args:
+ arr: The array to compute the mean.
+ mask: Array with numerical or boolean values for element weights or validity. For bool, False means invalid.
+ axis: Dimensions to reduce.
+ keepdims: If true, retains reduced dimensions with length 1.
+
+ Returns:
+ Mean array/scalar and a valid array/scalar that indicates where the mean could be computed successfully.
+ """
+
+ mask = mask.astype(arr.dtype) if mask.dtype == bool else mask
+ num_valid = np.sum(mask, axis=axis, keepdims=keepdims)
+ masked_arr = arr * mask
+ masked_arr_sum = np.sum(masked_arr, axis=axis, keepdims=keepdims)
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ valid_mean = masked_arr_sum / num_valid
+ is_valid = np.isfinite(valid_mean)
+ valid_mean = np.nan_to_num(valid_mean, nan=0, posinf=0, neginf=0)
+
+ return valid_mean, is_valid
+
+
+def thresh_inliers(gt, pred, thresh=1.03, mask=None, output_scaling_factor=1.0):
+ """Computes the inlier (=error within a threshold) ratio for a predicted and ground truth dense map of size H x W x C.
+
+ Args:
+ gt: Ground truth depth map as numpy array of shape HxW. Negative or 0 values are invalid and ignored.
+ pred: Predicted depth map as numpy array of shape HxW.
+ thresh: Threshold for the relative difference between the prediction and ground truth. Default: 1.03
+ mask: Array of shape HxW with boolean values to indicate validity. For bool, False means invalid. Default: None
+ output_scaling_factor: Scaling factor that is applied after computing the metrics (e.g. to get [%]). Default: 1
+
+ Returns:
+ Scalar that indicates the inlier ratio. Scalar is np.nan if the result is invalid.
+ """
+ # Compute the norms
+ gt_norm = np.linalg.norm(gt, axis=-1)
+ pred_norm = np.linalg.norm(pred, axis=-1)
+
+ gt_norm_valid = (gt_norm) > 0
+ if mask is not None:
+ combined_mask = mask & gt_norm_valid
+ else:
+ combined_mask = gt_norm_valid
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ rel_1 = np.nan_to_num(
+ gt_norm / pred_norm, nan=thresh + 1, posinf=thresh + 1, neginf=thresh + 1
+ ) # pred=0 should be an outlier
+ rel_2 = np.nan_to_num(
+ pred_norm / gt_norm, nan=0, posinf=0, neginf=0
+ ) # gt=0 is masked out anyways
+
+ max_rel = np.maximum(rel_1, rel_2)
+ inliers = ((0 < max_rel) & (max_rel < thresh)).astype(
+ np.float32
+ ) # 1 for inliers, 0 for outliers
+
+ inlier_ratio, valid = valid_mean(inliers, combined_mask)
+
+ inlier_ratio = inlier_ratio * output_scaling_factor
+ inlier_ratio = inlier_ratio if valid else np.nan
+
+ return inlier_ratio
+
+
+def m_rel_ae(gt, pred, mask=None, output_scaling_factor=1.0):
+ """Computes the mean-relative-absolute-error for a predicted and ground truth dense map of size HxWxC.
+
+ Args:
+ gt: Ground truth map as numpy array of shape H x W x C.
+ pred: Predicted map as numpy array of shape H x W x C.
+ mask: Array of shape HxW with boolean values to indicate validity. For bool, False means invalid. Default: None
+ output_scaling_factor: Scaling factor that is applied after computing the metrics (e.g. to get [%]). Default: 1
+
+ Returns:
+ Scalar that indicates the mean-relative-absolute-error. Scalar is np.nan if the result is invalid.
+ """
+ error_norm = np.linalg.norm(pred - gt, axis=-1)
+ gt_norm = np.linalg.norm(gt, axis=-1)
+
+ gt_norm_valid = (gt_norm) > 0
+ if mask is not None:
+ combined_mask = mask & gt_norm_valid
+ else:
+ combined_mask = gt_norm_valid
+
+ with np.errstate(divide="ignore", invalid="ignore"):
+ rel_ae = np.nan_to_num(error_norm / gt_norm, nan=0, posinf=0, neginf=0)
+
+ m_rel_ae, valid = valid_mean(rel_ae, combined_mask)
+
+ m_rel_ae = m_rel_ae * output_scaling_factor
+ m_rel_ae = m_rel_ae if valid else np.nan
+
+ return m_rel_ae
+
+
+def align(model, data):
+ """Align two trajectories using the method of Horn (closed-form).
+
+ Args:
+ model -- first trajectory (3xn)
+ data -- second trajectory (3xn)
+
+ Returns:
+ rot -- rotation matrix (3x3)
+ trans -- translation vector (3x1)
+ trans_error -- translational error per point (1xn)
+
+ """
+ np.set_printoptions(precision=3, suppress=True)
+ model_zerocentered = model - model.mean(1).reshape((3, -1))
+ data_zerocentered = data - data.mean(1).reshape((3, -1))
+
+ W = np.zeros((3, 3))
+ for column in range(model.shape[1]):
+ W += np.outer(model_zerocentered[:, column], data_zerocentered[:, column])
+ U, d, Vh = np.linalg.linalg.svd(W.transpose())
+ S = np.matrix(np.identity(3))
+ if np.linalg.det(U) * np.linalg.det(Vh) < 0:
+ S[2, 2] = -1
+ rot = U * S * Vh
+ trans = data.mean(1).reshape((3, -1)) - rot * model.mean(1).reshape((3, -1))
+
+ model_aligned = rot * model + trans
+ alignment_error = model_aligned - data
+
+ trans_error = np.sqrt(np.sum(np.multiply(alignment_error, alignment_error), 0)).A[0]
+
+ return rot, trans, trans_error
+
+
+def evaluate_ate(gt_traj, est_traj):
+ """
+ Input :
+ gt_traj: list of 4x4 matrices
+ est_traj: list of 4x4 matrices
+ len(gt_traj) == len(est_traj)
+ """
+ gt_traj_pts = [gt_traj[idx][:3, 3] for idx in range(len(gt_traj))]
+ est_traj_pts = [est_traj[idx][:3, 3] for idx in range(len(est_traj))]
+
+ gt_traj_pts = torch.stack(gt_traj_pts).detach().cpu().numpy().T
+ est_traj_pts = torch.stack(est_traj_pts).detach().cpu().numpy().T
+
+ _, _, trans_error = align(gt_traj_pts, est_traj_pts)
+
+ avg_trans_error = trans_error.mean()
+
+ return avg_trans_error
+
+
+def build_pair_index(N, B=1):
+ """
+ Build indices for all possible pairs of frames.
+
+ Args:
+ N: Number of frames
+ B: Batch size
+
+ Returns:
+ i1, i2: Indices for all possible pairs
+ """
+ i1_, i2_ = torch.combinations(torch.arange(N), 2, with_replacement=False).unbind(-1)
+ i1, i2 = [(i[None] + torch.arange(B)[:, None] * N).reshape(-1) for i in [i1_, i2_]]
+ return i1, i2
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,)) # pylint: disable=not-callable
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
+
+
+def rotation_angle(rot_gt, rot_pred, batch_size=None, eps=1e-15):
+ """
+ Calculate rotation angle error between ground truth and predicted rotations.
+
+ Args:
+ rot_gt: Ground truth rotation matrices
+ rot_pred: Predicted rotation matrices
+ batch_size: Batch size for reshaping the result
+ eps: Small value to avoid numerical issues
+
+ Returns:
+ Rotation angle error in degrees
+ """
+ q_pred = mat_to_quat(rot_pred)
+ q_gt = mat_to_quat(rot_gt)
+
+ loss_q = (1 - (q_pred * q_gt).sum(dim=1) ** 2).clamp(min=eps)
+ err_q = torch.arccos(1 - 2 * loss_q)
+
+ rel_rangle_deg = err_q * 180 / np.pi
+
+ if batch_size is not None:
+ rel_rangle_deg = rel_rangle_deg.reshape(batch_size, -1)
+
+ return rel_rangle_deg
+
+
+def translation_angle(tvec_gt, tvec_pred, batch_size=None, ambiguity=True):
+ """
+ Calculate translation angle error between ground truth and predicted translations.
+
+ Args:
+ tvec_gt: Ground truth translation vectors
+ tvec_pred: Predicted translation vectors
+ batch_size: Batch size for reshaping the result
+ ambiguity: Whether to handle direction ambiguity
+
+ Returns:
+ Translation angle error in degrees
+ """
+ rel_tangle_deg = compare_translation_by_angle(tvec_gt, tvec_pred)
+ rel_tangle_deg = rel_tangle_deg * 180.0 / np.pi
+
+ if ambiguity:
+ rel_tangle_deg = torch.min(rel_tangle_deg, (180 - rel_tangle_deg).abs())
+
+ if batch_size is not None:
+ rel_tangle_deg = rel_tangle_deg.reshape(batch_size, -1)
+
+ return rel_tangle_deg
+
+
+def compare_translation_by_angle(t_gt, t, eps=1e-15, default_err=1e6):
+ """
+ Normalize the translation vectors and compute the angle between them.
+
+ Args:
+ t_gt: Ground truth translation vectors
+ t: Predicted translation vectors
+ eps: Small value to avoid division by zero
+ default_err: Default error value for invalid cases
+
+ Returns:
+ Angular error between translation vectors in radians
+ """
+ t_norm = torch.norm(t, dim=1, keepdim=True)
+ t = t / (t_norm + eps)
+
+ t_gt_norm = torch.norm(t_gt, dim=1, keepdim=True)
+ t_gt = t_gt / (t_gt_norm + eps)
+
+ loss_t = torch.clamp_min(1.0 - torch.sum(t * t_gt, dim=1) ** 2, eps)
+ err_t = torch.acos(torch.sqrt(1 - loss_t))
+
+ err_t[torch.isnan(err_t) | torch.isinf(err_t)] = default_err
+ return err_t
+
+
+def calculate_auc_np(r_error, t_error, max_threshold=30):
+ """
+ Calculate the Area Under the Curve (AUC) for the given error arrays using NumPy.
+
+ Args:
+ r_error: numpy array representing R error values (Degree)
+ t_error: numpy array representing T error values (Degree)
+ max_threshold: Maximum threshold value for binning the histogram
+
+ Returns:
+ AUC value and the normalized histogram
+ """
+ error_matrix = np.concatenate((r_error[:, None], t_error[:, None]), axis=1)
+ max_errors = np.max(error_matrix, axis=1)
+ bins = np.arange(max_threshold + 1)
+ histogram, _ = np.histogram(max_errors, bins=bins)
+ num_pairs = float(len(max_errors))
+ normalized_histogram = histogram.astype(float) / num_pairs
+ return np.mean(np.cumsum(normalized_histogram)), normalized_histogram
+
+
+def closed_form_inverse_se3(se3, R=None, T=None):
+ """
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
+
+ If `R` and `T` are provided, they must correspond to the rotation and translation
+ components of `se3`. Otherwise, they will be extracted from `se3`.
+
+ Args:
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
+ R (optional): Nx3x3 array or tensor of rotation matrices.
+ T (optional): Nx3x1 array or tensor of translation vectors.
+
+ Returns:
+ Inverted SE3 matrices with the same type and device as `se3`.
+
+ Shapes:
+ se3: (N, 4, 4)
+ R: (N, 3, 3)
+ T: (N, 3, 1)
+ """
+ # Check if se3 is a numpy array or a torch tensor
+ is_numpy = isinstance(se3, np.ndarray)
+
+ # Validate shapes
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
+
+ # Extract R and T if not provided
+ if R is None:
+ R = se3[:, :3, :3] # (N,3,3)
+ if T is None:
+ T = se3[:, :3, 3:] # (N,3,1)
+
+ # Transpose R
+ if is_numpy:
+ # Compute the transpose of the rotation for NumPy
+ R_transposed = np.transpose(R, (0, 2, 1))
+ # -R^T t for NumPy
+ top_right = -np.matmul(R_transposed, T)
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
+ else:
+ R_transposed = R.transpose(1, 2) # (N,3,3)
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
+
+ inverted_matrix[:, :3, :3] = R_transposed
+ inverted_matrix[:, :3, 3:] = top_right
+
+ return inverted_matrix
+
+
+def se3_to_relative_pose_error(pred_se3, gt_se3, num_frames):
+ """
+ Compute rotation and translation errors between predicted and ground truth poses.
+
+ Args:
+ pred_se3: Predicted SE(3) transformations
+ gt_se3: Ground truth SE(3) transformations
+ num_frames: Number of frames
+
+ Returns:
+ Rotation and translation angle errors in degrees
+ """
+ pair_idx_i1, pair_idx_i2 = build_pair_index(num_frames)
+
+ # Compute relative camera poses between pairs
+ # We use closed_form_inverse to avoid potential numerical loss by torch.inverse()
+ relative_pose_gt = closed_form_inverse_se3(gt_se3[pair_idx_i1]).bmm(
+ gt_se3[pair_idx_i2]
+ )
+ relative_pose_pred = closed_form_inverse_se3(pred_se3[pair_idx_i1]).bmm(
+ pred_se3[pair_idx_i2]
+ )
+
+ # Compute the difference in rotation and translation
+ rel_rangle_deg = rotation_angle(
+ relative_pose_gt[:, :3, :3], relative_pose_pred[:, :3, :3]
+ )
+ rel_tangle_deg = translation_angle(
+ relative_pose_gt[:, :3, 3], relative_pose_pred[:, :3, 3]
+ )
+
+ return rel_rangle_deg, rel_tangle_deg
diff --git a/mapanything/utils/misc.py b/mapanything/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6b976fa1df7430465961bc946d1744ae5665faa
--- /dev/null
+++ b/mapanything/utils/misc.py
@@ -0,0 +1,109 @@
+"""
+Miscellaneous utility functions.
+"""
+
+import logging
+import os
+import random
+
+import numpy as np
+import torch
+
+
+class StreamToLogger:
+ """
+ A class that redirects stream writes to a logger.
+
+ This class can be used to redirect stdout or stderr to a logger
+ by implementing a file-like interface with write and flush methods.
+
+ Parameters:
+ - logger: A logger instance that will receive the log messages
+ - log_level: The logging level to use (default: logging.INFO)
+ """
+
+ def __init__(self, logger, log_level=logging.INFO):
+ self.logger = logger
+ self.log_level = log_level
+ self.linebuf = ""
+
+ def write(self, buf):
+ """
+ Write the buffer content to the logger.
+
+ Parameters:
+ - buf: The string buffer to write
+ """
+ for line in buf.rstrip().splitlines():
+ self.logger.log(self.log_level, line.rstrip())
+
+ def flush(self):
+ """
+ Flush method to comply with file-like object interface.
+ This method is required but does nothing in this implementation.
+ """
+ pass
+
+
+def seed_everything(seed: int = 42):
+ """
+ Set the `seed` value for torch and numpy seeds. Also turns on
+ deterministic execution for cudnn.
+
+ Parameters:
+ - seed: A hashable seed value
+ """
+ random.seed(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ print(f"Seed set to: {seed}")
+
+
+def invalid_to_nans(arr, valid_mask, ndim=999):
+ """
+ Replace invalid values in an array with NaN values based on a validity mask.
+
+ Parameters:
+ - arr: Input array (typically a PyTorch tensor)
+ - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False)
+ - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim
+
+ Returns:
+ - Modified array with invalid values replaced by NaN
+ """
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = float("nan")
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr
+
+
+def invalid_to_zeros(arr, valid_mask, ndim=999):
+ """
+ Replace invalid values in an array with zeros based on a validity mask.
+
+ Parameters:
+ - arr: Input array (typically a PyTorch tensor)
+ - valid_mask: Boolean mask indicating valid elements (True) and invalid elements (False)
+ - ndim: Maximum number of dimensions to keep; flattens dimensions if arr.ndim > ndim
+
+ Returns:
+ - Tuple containing:
+ - Modified array with invalid values replaced by zeros
+ - nnz: Number of non-zero (valid) elements per sample in the batch
+ """
+ if valid_mask is not None:
+ arr = arr.clone()
+ arr[~valid_mask] = 0
+ nnz = valid_mask.view(len(valid_mask), -1).sum(1)
+ else:
+ nnz = (
+ arr[..., 0].numel() // len(arr) if len(arr) else 0
+ ) # Number of pixels per image
+ if arr.ndim > ndim:
+ arr = arr.flatten(-2 - (arr.ndim - ndim), -2)
+ return arr, nnz
diff --git a/mapanything/utils/parallel.py b/mapanything/utils/parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5c6f826291646a6c4059d3b6ae4c3d71d5e0104
--- /dev/null
+++ b/mapanything/utils/parallel.py
@@ -0,0 +1,158 @@
+"""
+Utility functions for multiprocessing
+"""
+
+import os
+from multiprocessing.dummy import Pool as ThreadPool
+
+import torch
+from torch.multiprocessing import Pool as TorchPool, set_start_method
+from tqdm import tqdm
+
+
+def cpu_count():
+ """
+ Returns the number of available CPUs for the python process
+ """
+ return len(os.sched_getaffinity(0))
+
+
+def parallel_threads(
+ function,
+ args,
+ workers=0,
+ star_args=False,
+ kw_args=False,
+ front_num=1,
+ Pool=ThreadPool,
+ ordered_res=True,
+ **tqdm_kw,
+):
+ """tqdm but with parallel execution.
+
+ Will essentially return
+ res = [ function(arg) # default
+ function(*arg) # if star_args is True
+ function(**arg) # if kw_args is True
+ for arg in args]
+
+ Note:
+ the first elements of args will not be parallelized.
+ This can be useful for debugging.
+ """
+ # Determine the number of workers
+ while workers <= 0:
+ workers += cpu_count()
+
+ # Convert args to an iterable
+ try:
+ n_args_parallel = len(args) - front_num
+ except TypeError:
+ n_args_parallel = None
+ args = iter(args)
+
+ # Sequential execution for the first few elements (useful for debugging)
+ front = []
+ while len(front) < front_num:
+ try:
+ a = next(args)
+ except StopIteration:
+ return front # end of the iterable
+ front.append(
+ function(*a) if star_args else function(**a) if kw_args else function(a)
+ )
+
+ # Parallel execution using multiprocessing.dummy
+ out = []
+ with Pool(workers) as pool:
+ if star_args:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(starcall, [(function, a) for a in args])
+ elif kw_args:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(starstarcall, [(function, a) for a in args])
+ else:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(function, args)
+ # Track progress with tqdm
+ for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
+ out.append(f)
+ return front + out
+
+
+def cuda_parallel_threads(
+ function,
+ args,
+ workers=0,
+ star_args=False,
+ kw_args=False,
+ front_num=1,
+ Pool=TorchPool,
+ ordered_res=True,
+ **tqdm_kw,
+):
+ """
+ Parallel execution of a function using torch.multiprocessing with CUDA support.
+ This is the CUDA variant of the parallel_threads function.
+ """
+ # Set the start method for multiprocessing
+ set_start_method("spawn", force=True)
+
+ # Determine the number of workers
+ while workers <= 0:
+ workers += torch.multiprocessing.cpu_count()
+
+ # Convert args to an iterable
+ try:
+ n_args_parallel = len(args) - front_num
+ except TypeError:
+ n_args_parallel = None
+ args = iter(args)
+
+ # Sequential execution for the first few elements (useful for debugging)
+ front = []
+ while len(front) < front_num:
+ try:
+ a = next(args)
+ except StopIteration:
+ return front # End of the iterable
+ front.append(
+ function(*a) if star_args else function(**a) if kw_args else function(a)
+ )
+
+ # Parallel execution using torch.multiprocessing
+ out = []
+ with Pool(workers) as pool:
+ if star_args:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(starcall, [(function, a) for a in args])
+ elif kw_args:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(starstarcall, [(function, a) for a in args])
+ else:
+ map_func = pool.imap if ordered_res else pool.imap_unordered
+ futures = map_func(function, args)
+ # Track progress with tqdm
+ for f in tqdm(futures, total=n_args_parallel, **tqdm_kw):
+ out.append(f)
+ return front + out
+
+
+def parallel_processes(*args, **kwargs):
+ """Same as parallel_threads, with processes"""
+ import multiprocessing as mp
+
+ kwargs["Pool"] = mp.Pool
+ return parallel_threads(*args, **kwargs)
+
+
+def starcall(args):
+ """convenient wrapper for Process.Pool"""
+ function, args = args
+ return function(*args)
+
+
+def starstarcall(args):
+ """convenient wrapper for Process.Pool"""
+ function, args = args
+ return function(**args)
diff --git a/mapanything/utils/timing.py b/mapanything/utils/timing.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba9ee891d33420c51644f11f38bf169a1a610c27
--- /dev/null
+++ b/mapanything/utils/timing.py
@@ -0,0 +1,304 @@
+"""
+Utility functions for timing code blocks
+"""
+
+import time
+from contextlib import ContextDecorator
+
+import numpy as np
+
+
+class BlockTimeManager:
+ """
+ Manages a collection of timers and their formatting options.
+
+ This class serves as a central registry for Timer objects, allowing them to be
+ accessed by name and maintaining their formatting preferences.
+
+ Attributes:
+ timers (dict): Dictionary mapping timer names to Timer objects
+ timer_fmts (dict): Dictionary mapping timer names to their display formats
+ window_size (int): Default window size for calculating windowed averages
+ buf_size (int): Default buffer size for storing timing measurements
+ """
+
+ def __init__(self, window_size=10, buf_size=100000):
+ self.timers = dict()
+ self.timer_fmts = dict()
+ self.window_size = window_size
+ self.buf_size = buf_size
+
+
+btm = BlockTimeManager(window_size=100000)
+
+
+class Timer:
+ """
+ Core timing class that tracks execution times.
+
+ This class provides the fundamental timing functionality, storing timing measurements
+ and calculating various statistics.
+
+ Attributes:
+ name (str): Identifier for this timer
+ buf_size (int): Maximum number of timing measurements to store
+ window_size (int): Number of most recent measurements to use for windowed statistics
+ measures_arr (numpy.ndarray): Array storing start and end times of measurements
+ current_start (float or None): Start time of current measurement
+ current_end (float or None): End time of current measurement
+ """
+
+ def __init__(self, name, window_size, buf_size=100000):
+ self.name = name
+ self.buf_size = buf_size
+ self.window_size = window_size
+ self.init()
+
+ def init(self):
+ """Initialize or reset the timer's state."""
+ self.measures_arr = np.empty((0, 2)) # LIFO
+ self.current_start = None
+ self.current_end = None
+
+ def reset(self):
+ """Reset the timer to its initial state."""
+ self.init()
+
+ def tic(self):
+ """Start a new timing measurement."""
+ if self.current_start is not None:
+ # another tic executed before a toc
+ self.toc()
+ self.current_start = time.perf_counter()
+
+ def toc(self):
+ """End the current timing measurement."""
+ self.current_end = time.perf_counter()
+ self._add_current_measure()
+
+ def _add_current_measure(self):
+ """Add the current timing measurement to the measurements array."""
+ self.measures_arr = np.concatenate(
+ [
+ np.array([[self.current_start, self.current_end]]),
+ self.measures_arr[: self.buf_size],
+ ]
+ )
+ self.current_start = None
+ self.current_end = None
+
+ @property
+ def avg(self) -> float:
+ """Calculate the average execution time across all measurements."""
+ return np.mean(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def wavg(self) -> float:
+ """Calculate the windowed average execution time using the most recent measurements."""
+ return np.mean(
+ self.measures_arr[: self.window_size, 1]
+ - self.measures_arr[: self.window_size, 0]
+ )
+
+ @property
+ def max(self) -> float:
+ """Return the maximum execution time."""
+ return np.max(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def min(self) -> float:
+ """Return the minimum execution time."""
+ return np.min(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def total(self) -> float:
+ """Return the total execution time across all measurements."""
+ return np.sum(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def latest(self) -> float:
+ """Return the most recent execution time."""
+ return self.measures_arr[0, 1] - self.measures_arr[0, 0]
+
+ @property
+ def median(self) -> float:
+ """Return the median execution time."""
+ return np.median(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+ @property
+ def var(self) -> float:
+ """Return the variance of execution times."""
+ return np.var(self.measures_arr[:, 1] - self.measures_arr[:, 0])
+
+
+class BlockTimer(ContextDecorator):
+ """
+ A context manager and decorator for timing code blocks.
+
+ This class provides a convenient interface for timing code execution, either as a
+ context manager (with statement) or as a decorator. It uses the Timer class for
+ the actual timing functionality.
+
+ Attributes:
+ name (str): Identifier for this timer
+ fmt (str or None): Format string for displaying timing information
+ timer (Timer): The underlying Timer object
+ num_calls (int): Number of times this timer has been called
+ """
+
+ @staticmethod
+ def timers():
+ """Return a list of all registered timer names."""
+ return list(btm.timers.keys())
+
+ def __init__(self, name, fmt=None, window_size=100):
+ self.name = name
+ if name in btm.timers:
+ self.timer = btm.timers[name]
+ # restore format
+ self.fmt = fmt if fmt is not None else btm.timer_fmts[name]
+ else:
+ self.timer = Timer(name, btm.window_size, btm.buf_size)
+ btm.timers[name] = self.timer
+ btm.timer_fmts[name] = fmt
+ self.timer.window_size = window_size
+ self._default_fmt = "[{name}] num: {num} latest: {latest:.4f} --wind_avg: {wavg:.4f} -- avg: {avg:.4f} --var: {var:.4f} -- total: {total:.4f}"
+ if fmt == "default":
+ self.fmt = self._default_fmt
+ # extend here for new formats
+ else:
+ self.fmt = None
+
+ self.num_calls = 0
+
+ def __enter__(self) -> "Timer":
+ """Start timing when entering a context."""
+ self.tic()
+ return self
+
+ def __exit__(self, *args):
+ """End timing when exiting a context and optionally display results."""
+ self.toc()
+ if self.fmt is not None:
+ print(str(self))
+
+ def __str__(self) -> str:
+ """Return a string representation of the timer."""
+ return self.display()
+
+ def reset(self):
+ """Reset the timer and call counter."""
+ self.timer.reset()
+ self.num_calls = 0
+
+ def display(self, fmt=None):
+ """
+ Format and return timing information.
+
+ Args:
+ fmt (str, optional): Format string to use. If None, uses the timer's format.
+
+ Returns:
+ str: Formatted timing information
+ """
+ if fmt is None:
+ if self.fmt is not None:
+ fmt = self.fmt
+ else:
+ fmt = self._default_fmt
+ return fmt.format(
+ name=self.name,
+ num=self.num_calls,
+ latest=self.latest,
+ wavg=self.wavg,
+ avg=self.avg,
+ var=self.var,
+ total=self.total,
+ )
+
+ def tic(self):
+ """Start a new timing measurement and increment the call counter."""
+ self.timer.tic()
+ self.num_calls += 1
+
+ def toc(self, display=False):
+ """
+ End the current timing measurement.
+
+ Args:
+ display (bool): Whether to return a formatted display string
+
+ Returns:
+ str or None: Formatted timing information if display is True
+ """
+ self.timer.toc()
+ if display:
+ return self.display()
+
+ @property
+ def latest(self) -> float:
+ """Return the most recent execution time."""
+ return self.timer.latest
+
+ @property
+ def avg(self) -> float:
+ """Return the average execution time."""
+ return self.timer.avg
+
+ @property
+ def wavg(self) -> float:
+ """Return the windowed average execution time."""
+ return self.timer.wavg
+
+ @property
+ def max(self) -> float:
+ """Return the maximum execution time."""
+ return self.timer.max
+
+ @property
+ def min(self) -> float:
+ """Return the minimum execution time."""
+ return self.timer.min
+
+ @property
+ def total(self) -> float:
+ """Return the total execution time."""
+ return self.timer.total
+
+ @property
+ def median(self) -> float:
+ """Return the median execution time."""
+ return self.timer.median
+
+ @property
+ def var(self) -> float:
+ """Return the variance of execution times."""
+ return self.timer.var
+
+
+if __name__ == "__main__":
+
+ @BlockTimer("fct", "default")
+ def fct(bobo):
+ time.sleep(0.5)
+
+ fct(2)
+
+ for i in range(10):
+ with BlockTimer("affe", "default"):
+ time.sleep(0.1)
+ for i in range(1000):
+ with BlockTimer("test", None):
+ time.sleep(0.001)
+
+ # BlockTimer("test").display = f"""avg: {BlockTimer("test").avg} total: {BlockTimer("test").total}"""
+ # print(str(BlockTimer("test")))
+
+ print(BlockTimer("test"))
+ BlockTimer("test").tic()
+ BlockTimer("t2", "default").tic()
+ time.sleep(0.4)
+ print(BlockTimer("t2").toc(True))
+
+ time.sleep(0.4)
+ print(BlockTimer("test").toc(True))
diff --git a/mapanything/utils/train_tools.py b/mapanything/utils/train_tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f95d207f541a4d4de61f9111014d0927f91cd08
--- /dev/null
+++ b/mapanything/utils/train_tools.py
@@ -0,0 +1,978 @@
+"""
+Utility functions for training deep learning models, particularly focused on distributed training,
+metric logging, and gradient handling.
+
+This module provides tools for:
+- Tracking and logging metrics during training
+- Setting up distributed training environments
+- Handling gradient scaling and normalization
+- Managing learning rates and parameter groups
+- Saving and loading model checkpoints
+
+References: CroCo (https://github.com/naver/croco)
+"""
+
+import builtins
+import datetime
+import json
+import math
+import os
+import time
+from collections import defaultdict, deque
+from pathlib import Path
+
+import torch
+import torch.distributed as dist
+from torch import inf
+
+
+class SmoothedValue(object):
+ """
+ Track a series of values and provide access to smoothed values over a
+ window or the global series average.
+ """
+
+ def __init__(self, window_size=20, fmt=None):
+ if fmt is None:
+ fmt = "{median:.4f} ({global_avg:.4f})"
+ self.deque = deque(maxlen=window_size)
+ self.total = 0.0
+ self.count = 0
+ self.fmt = fmt
+
+ def update(self, value, n=1):
+ self.deque.append(value)
+ self.count += n
+ self.total += value * n
+
+ def synchronize_between_processes(self):
+ """
+ Warning: does not synchronize the deque!
+ """
+ if not is_dist_avail_and_initialized():
+ return
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
+ dist.barrier()
+ dist.all_reduce(t)
+ t = t.tolist()
+ self.count = int(t[0])
+ self.total = t[1]
+
+ @property
+ def median(self):
+ d = torch.tensor(list(self.deque))
+ return d.median().item()
+
+ @property
+ def avg(self):
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
+ return d.mean().item()
+
+ @property
+ def global_avg(self):
+ return self.total / self.count
+
+ @property
+ def max(self):
+ return max(self.deque)
+
+ @property
+ def value(self):
+ return self.deque[-1]
+
+ def __str__(self):
+ return self.fmt.format(
+ median=self.median,
+ avg=self.avg,
+ global_avg=self.global_avg,
+ max=self.max,
+ value=self.value,
+ )
+
+
+class MetricLogger(object):
+ """
+ Logger for tracking and displaying training metrics.
+
+ This class maintains a collection of metrics during training, provides
+ methods to update them, and formats them for display. It also handles
+ synchronization of metrics across processes in distributed training.
+ """
+
+ def __init__(self, delimiter="\t", print_per_view_stats=False):
+ """
+ Initialize the MetricLogger.
+
+ Args:
+ delimiter (str, optional): Delimiter for formatting output. Defaults to "\t".
+ print_per_view_stats (bool, optional): Whether to print per-view statistics. Defaults to False.
+ """
+ self.meters = defaultdict(SmoothedValue)
+ self.delimiter = delimiter
+ self.print_per_view_stats = print_per_view_stats
+
+ def update(self, **kwargs):
+ """
+ Update metrics with new values.
+
+ Args:
+ **kwargs: Key-value pairs where keys are metric names and values are metric values
+ Values can be tensors or numbers
+
+ Raises:
+ AssertionError: If a value is not a float or int after conversion from tensor
+ """
+ for k, v in kwargs.items():
+ if v is None:
+ continue
+ if isinstance(v, torch.Tensor):
+ v = v.item()
+ assert isinstance(v, (float, int))
+ self.meters[k].update(v)
+
+ def __getattr__(self, attr):
+ """
+ Get a meter by attribute name.
+
+ This allows accessing meters as attributes of the logger.
+
+ Args:
+ attr (str): Name of the attribute to get
+
+ Returns:
+ SmoothedValue: The meter corresponding to the attribute name
+
+ Raises:
+ AttributeError: If the attribute doesn't exist as a meter or regular attribute
+ """
+ if attr in self.meters:
+ return self.meters[attr]
+ if attr in self.__dict__:
+ return self.__dict__[attr]
+ raise AttributeError(
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
+ )
+
+ def __str__(self):
+ """
+ Format all metrics as a string.
+
+ Returns:
+ str: Formatted string containing all metrics
+ """
+ loss_str = []
+ for name, meter in self.meters.items():
+ # Skip printing per-view stats if not enabled
+ if not self.print_per_view_stats and "view" in name:
+ continue
+ loss_str.append("{}: {}".format(name, str(meter)))
+ return self.delimiter.join(loss_str)
+
+ def synchronize_between_processes(self):
+ """
+ Synchronize metrics across processes in distributed training.
+
+ This method calls synchronize_between_processes on each meter to
+ ensure consistent values across all processes.
+ """
+ for meter in self.meters.values():
+ meter.synchronize_between_processes()
+
+ def add_meter(self, name, meter):
+ """
+ Add a custom meter to the logger.
+
+ Args:
+ name (str): Name of the meter
+ meter (SmoothedValue): The meter to add
+ """
+ self.meters[name] = meter
+
+ def log_every(self, iterable, print_freq, header=None, max_iter=None):
+ """
+ Log metrics at regular intervals while iterating.
+
+ This method wraps an iterable and logs metrics every print_freq iterations.
+ It also tracks iteration time, data loading time, and memory usage.
+
+ Args:
+ iterable: Iterable to iterate over (typically a data loader)
+ print_freq (int): How often to log metrics (in iterations)
+ header (str, optional): Header string to print before metrics. Defaults to None.
+ max_iter (int, optional): Maximum number of iterations. Defaults to None.
+
+ Yields:
+ object: Items from the original iterable
+ """
+ i = 0
+ if not header:
+ header = ""
+ start_time = time.time()
+ end = time.time()
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
+ data_time = SmoothedValue(fmt="{avg:.4f}")
+ len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable)
+ space_fmt = ":" + str(len(str(len_iterable))) + "d"
+ log_msg = [
+ header,
+ "[{0" + space_fmt + "}/{1}]",
+ "eta: {eta}",
+ "{meters}",
+ "time: {time}",
+ "data: {data}",
+ ]
+ if torch.cuda.is_available():
+ log_msg.append("max mem: {memory:.0f}")
+ log_msg = self.delimiter.join(log_msg)
+ MB = 1024.0 * 1024.0
+ for it, obj in enumerate(iterable):
+ data_time.update(time.time() - end)
+ yield obj
+ iter_time.update(time.time() - end)
+ if i % print_freq == 0 or i == len_iterable - 1:
+ eta_seconds = iter_time.global_avg * (len_iterable - i)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+ if torch.cuda.is_available():
+ print(
+ log_msg.format(
+ i,
+ len_iterable,
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ memory=torch.cuda.max_memory_allocated() / MB,
+ )
+ )
+ else:
+ print(
+ log_msg.format(
+ i,
+ len_iterable,
+ eta=eta_string,
+ meters=str(self),
+ time=str(iter_time),
+ data=str(data_time),
+ )
+ )
+ i += 1
+ end = time.time()
+ if max_iter and it >= max_iter:
+ break
+ total_time = time.time() - start_time
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+ print(
+ "{} Total time: {} ({:.4f} s / it)".format(
+ header, total_time_str, total_time / len_iterable
+ )
+ )
+
+
+def setup_for_distributed(is_master):
+ """
+ This function disables printing when not in master process.
+
+ It replaces the built-in print function with a custom version that only prints
+ when the current process is the master process or when explicitly forced.
+
+ Args:
+ is_master (bool): Whether the current process is the master process
+ """
+ builtin_print = builtins.print
+
+ def print(*args, **kwargs):
+ force = kwargs.pop("force", False)
+ # force = force or (get_world_size() > 8)
+ if is_master or force:
+ now = datetime.datetime.now().time()
+ builtin_print("[{}] ".format(now), end="") # print with time stamp
+ builtin_print(*args, **kwargs)
+
+ builtins.print = print
+
+
+def is_dist_avail_and_initialized():
+ """
+ Check if distributed training is available and initialized.
+
+ Returns:
+ bool: True if distributed training is available and initialized, False otherwise
+ """
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_world_size():
+ """
+ Get the number of processes in the distributed training group.
+
+ Returns:
+ int: Number of processes in the distributed group, or 1 if not using distributed training
+ """
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ """
+ Get the rank of the current process in the distributed training group.
+
+ Returns:
+ int: Rank of the current process, or 0 if not using distributed training
+ """
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ """
+ Check if the current process is the main process (rank 0).
+
+ Returns:
+ bool: True if the current process is the main process, False otherwise
+ """
+ return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+ """
+ Save a PyTorch object only on the master process.
+
+ This function is useful in distributed training to avoid multiple processes
+ trying to save the same file simultaneously.
+
+ Args:
+ *args: Positional arguments to pass to torch.save()
+ **kwargs: Keyword arguments to pass to torch.save()
+ """
+ if is_main_process():
+ torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+ """
+ Initialize distributed training mode.
+
+ This function sets up the distributed training environment based on environment
+ variables and command-line arguments. It initializes the process group,
+ sets the appropriate device, and configures printing for the distributed setup.
+
+ Args:
+ args: Arguments object containing distributed training configuration.
+ Expected to have attributes like dist_url, and will be modified
+ to include rank, world_size, gpu, and distributed flag.
+ """
+ nodist = args.nodist if hasattr(args, "nodist") else False
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ and not nodist:
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.gpu = int(os.environ["LOCAL_RANK"])
+ else:
+ print("Not using distributed mode")
+ setup_for_distributed(is_master=True) # hack
+ args.distributed = False
+ return
+
+ args.distributed = True
+
+ torch.cuda.set_device(args.gpu)
+ args.dist_backend = "nccl"
+ print(
+ "| distributed init (rank {}): {}, gpu {}".format(
+ args.rank, args.dist_url, args.gpu
+ ),
+ flush=True,
+ )
+ torch.distributed.init_process_group(
+ backend=args.dist_backend,
+ init_method=args.dist_url,
+ world_size=args.world_size,
+ rank=args.rank,
+ )
+ torch.distributed.barrier()
+ setup_for_distributed(args.rank == 0)
+
+
+class NativeScalerWithGradNormCount:
+ """
+ A gradient scaler that handles gradient scaling and norm computation for mixed precision training.
+
+ This class wraps PyTorch's GradScaler to provide additional functionality for gradient norm tracking
+ and clipping during mixed precision training.
+ """
+
+ state_dict_key = "amp_scaler"
+
+ def __init__(self, enabled=True):
+ """Initialize the scaler.
+
+ Args:
+ enabled (bool): Whether to enable gradient scaling. Default: True
+ """
+ self._scaler = torch.GradScaler("cuda", enabled=enabled)
+
+ def __call__(
+ self,
+ loss,
+ optimizer,
+ clip_grad=None,
+ parameters=None,
+ create_graph=False,
+ update_grad=True,
+ ):
+ """Scales loss and performs backward pass with optional gradient clipping.
+
+ Args:
+ loss: The loss to backpropagate
+ optimizer: The optimizer being used
+ clip_grad: Max norm for gradient clipping. None means no clipping
+ parameters: Model parameters or list of parameters for gradient norm computation
+ create_graph: Whether to create graph during backward pass
+ update_grad: Whether to update gradients
+
+ Returns:
+ norm: The gradient norm if computed, else None. Returns list of norms if parameters is a list.
+ """
+ self._scaler.scale(loss).backward(create_graph=create_graph)
+ if update_grad:
+ if clip_grad is not None:
+ assert parameters is not None
+ self._scaler.unscale_(
+ optimizer
+ ) # unscale the gradients of optimizer's assigned params in-place
+ if isinstance(parameters, (list, tuple)):
+ norm = [
+ torch.nn.utils.clip_grad_norm_(p, clip_grad) for p in parameters
+ ]
+ else:
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
+ else:
+ self._scaler.unscale_(optimizer)
+ norm = get_grad_norm_(parameters)
+ self._scaler.step(optimizer)
+ self._scaler.update()
+ else:
+ norm = None
+ return norm
+
+ def state_dict(self):
+ """Returns the state dict of the underlying scaler.
+
+ Returns:
+ dict: The state dict of the gradient scaler
+ """
+ return self._scaler.state_dict()
+
+ def load_state_dict(self, state_dict):
+ """Loads the state dict into the underlying scaler.
+
+ Args:
+ state_dict: The state dict to load
+ """
+ self._scaler.load_state_dict(state_dict)
+
+
+def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
+ """
+ Calculate the gradient norm of parameters.
+
+ This function computes the norm of gradients for a set of parameters. It can handle
+ both single parameter groups and multiple parameter groups (list/tuple of parameters).
+
+ Args:
+ parameters: A tensor or iterable of tensors or iterable of iterables of tensors
+ containing model parameters for which to compute gradient norms
+ norm_type (float): Type of norm to use (e.g., 2.0 for L2 norm, inf for infinity norm)
+
+ Returns:
+ torch.Tensor: The computed gradient norm. If parameters is a list/tuple of parameter
+ groups, returns a list of norms, one for each group.
+ """
+ if isinstance(parameters, (list, tuple)):
+ # If parameters is already a list/tuple, process each parameter group
+ all_norms = []
+ for params in parameters:
+ if isinstance(params, torch.Tensor):
+ params = [params]
+ params = [p for p in params if p.grad is not None]
+ if len(params) > 0:
+ device = params[0].grad.device
+ if norm_type == inf:
+ group_norm = max(
+ p.grad.detach().abs().max().to(device) for p in params
+ )
+ else:
+ group_norm = torch.norm(
+ torch.stack(
+ [
+ torch.norm(p.grad.detach(), norm_type).to(device)
+ for p in params
+ ]
+ ),
+ norm_type,
+ )
+ else:
+ group_norm = torch.tensor(0.0)
+ all_norms.append(group_norm)
+ return all_norms
+
+ # Original logic for single parameter group
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = [p for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(parameters) == 0:
+ return torch.tensor(0.0)
+ device = parameters[0].grad.device
+ if norm_type == inf:
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
+ else:
+ total_norm = torch.norm(
+ torch.stack(
+ [torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]
+ ),
+ norm_type,
+ )
+ return total_norm
+
+
+def save_model(
+ args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None
+):
+ """
+ Save model checkpoint to disk.
+
+ This function saves the model state, optimizer state, loss scaler state,
+ training arguments, current epoch, and optionally the best metric value so far.
+ The checkpoint is only saved on the master process in distributed training.
+
+ Args:
+ args: Arguments containing output directory information
+ epoch (int): Current training epoch
+ model_without_ddp (torch.nn.Module): Model without DistributedDataParallel wrapper
+ optimizer (torch.optim.Optimizer): Optimizer instance
+ loss_scaler: Gradient scaler for mixed precision training
+ fname (str, optional): Custom filename suffix. If None, uses the epoch number. Defaults to None.
+ best_so_far (float, optional): Best metric value achieved so far. Defaults to None.
+ """
+ output_dir = Path(args.output_dir)
+ if fname is None:
+ fname = str(epoch)
+ checkpoint_path = output_dir / ("checkpoint-%s.pth" % fname)
+ to_save = {
+ "model": model_without_ddp.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "scaler": loss_scaler.state_dict(),
+ "args": args,
+ "epoch": epoch,
+ }
+ if best_so_far is not None:
+ to_save["best_so_far"] = best_so_far
+ print(f">> Saving model to {checkpoint_path} ...")
+ save_on_master(to_save, checkpoint_path)
+
+
+def load_model(train_args, model_without_ddp, optimizer, loss_scaler):
+ """
+ Load model checkpoint from disk or URL.
+
+ This function loads a saved checkpoint, restoring the model state, optimizer state,
+ loss scaler state, and training epoch. It can load from a local file or a URL.
+
+ Args:
+ train_args: Training arguments containing resume information
+ model_without_ddp (torch.nn.Module): Model without DistributedDataParallel wrapper
+ optimizer (torch.optim.Optimizer): Optimizer instance
+ loss_scaler: Gradient scaler for mixed precision training
+
+ Returns:
+ float or None: Best metric value from the checkpoint if available, otherwise None
+ """
+ train_args.start_epoch = 0
+ best_so_far = None
+ if train_args.resume and train_args.resume_ckpt is not None:
+ if train_args.resume_ckpt.startswith("https"):
+ checkpoint = torch.hub.load_state_dict_from_url(
+ train_args.resume_ckpt, map_location="cpu", check_hash=True
+ )
+ else:
+ checkpoint = torch.load(
+ train_args.resume_ckpt, map_location="cpu", weights_only=False
+ )
+ print("Resume checkpoint %s" % train_args.resume_ckpt)
+ model_without_ddp.load_state_dict(checkpoint["model"], strict=False)
+ train_args.start_epoch = checkpoint["epoch"] + 1
+ optimizer.load_state_dict(checkpoint["optimizer"])
+ if "scaler" in checkpoint:
+ loss_scaler.load_state_dict(checkpoint["scaler"])
+ if "best_so_far" in checkpoint:
+ best_so_far = checkpoint["best_so_far"]
+ print(" & best_so_far={:g}".format(best_so_far))
+ else:
+ print("")
+ print(
+ "With optim & sched! start_epoch={:d}".format(train_args.start_epoch),
+ end="",
+ )
+ return best_so_far
+
+
+def all_reduce_mean(x):
+ """
+ Compute the mean of a value across all processes in distributed training.
+
+ This function takes a value, reduces it across all processes using all_reduce,
+ and returns the mean value.
+
+ Args:
+ x: The value to reduce (typically a scalar)
+
+ Returns:
+ float: The mean value across all processes
+ """
+ world_size = get_world_size()
+ if world_size > 1:
+ x_reduce = torch.tensor(x).cuda()
+ dist.all_reduce(x_reduce)
+ x_reduce /= world_size
+ return x_reduce.item()
+ else:
+ return x
+
+
+def _replace(text, src, tgt, rm=""):
+ """
+ Advanced string replacement utility.
+
+ Given a text:
+ - replace all elements in src by the corresponding element in tgt
+ - remove all elements in rm
+
+ Args:
+ text (str): The input text to modify
+ src (str): String of characters to replace
+ tgt (str): String of replacement characters (must be same length as src or length 1)
+ rm (str, optional): String of characters to remove. Defaults to "".
+
+ Returns:
+ str: The modified text after replacements and removals
+
+ Raises:
+ AssertionError: If src and tgt have different lengths (unless tgt has length 1)
+ """
+ if len(tgt) == 1:
+ tgt = tgt * len(src)
+ assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len"
+ for s, t in zip(src, tgt):
+ text = text.replace(s, t)
+ for c in rm:
+ text = text.replace(c, "")
+ return text
+
+
+def filename(obj):
+ """
+ Transform a Python object or command into a proper filename.
+
+ This function converts a Python object or command string into a valid filename
+ by replacing special characters and ensuring the filename is not too long.
+
+ Special replacements:
+ - \1 gets replaced by slash '/'
+ - \2 gets replaced by comma ','
+
+ Args:
+ obj: The Python object or string to convert to a filename
+
+ Returns:
+ str: A valid filename derived from the input object
+
+ Raises:
+ AssertionError: If any part of the resulting path is longer than 256 characters
+ """
+ if not isinstance(obj, str):
+ obj = repr(obj)
+ obj = str(obj).replace("()", "")
+ obj = _replace(obj, "_,(*/\1\2", "-__x%/,", rm=" )'\"")
+ assert all(len(s) < 256 for s in obj.split(os.sep)), (
+ "filename too long (>256 characters):\n" + obj
+ )
+ return obj
+
+
+def compute_effective_lrs(train_args):
+ """
+ Compute the effective learning rates based on batch size scaling.
+
+ This function calculates the effective learning rates for the main model and
+ any submodules based on the effective batch size (accounting for gradient accumulation
+ and distributed training) and the base learning rates.
+
+ Args:
+ train_args: Training arguments containing batch size, accumulation iterations,
+ learning rates, and submodule configurations
+
+ Returns:
+ train_args: Updated training arguments with computed effective learning rates
+ """
+
+ # Compute the effective batch size
+ eff_batch_size = train_args.batch_size * train_args.accum_iter * get_world_size()
+ print("Accumulate grad iterations: %d" % train_args.accum_iter)
+ print("Effective batch size: %d" % eff_batch_size)
+ # Compute the effective default learning rate
+ if train_args.lr is None: # only base_lr is specified
+ train_args.lr = train_args.blr * math.sqrt(
+ eff_batch_size / train_args.base_eff_batch_size
+ )
+ print(
+ f"Base default lr for effective batch size {eff_batch_size}: %.2e"
+ % (train_args.lr * math.sqrt(train_args.base_eff_batch_size / eff_batch_size))
+ )
+ print("Actual default lr: %.2e" % train_args.lr)
+ for submodule, config in train_args.submodule_configs.items():
+ if config.get("lr") is None: # only base_lr is specified
+ config["lr"] = config["blr"] * math.sqrt(
+ eff_batch_size / train_args.base_eff_batch_size
+ )
+ print(
+ f"Submodule {submodule} base lr for effective batch size {eff_batch_size}: %.2e"
+ % (
+ config["lr"]
+ * math.sqrt(train_args.base_eff_batch_size / eff_batch_size)
+ )
+ )
+ print(f"Submodule {submodule} actual lr: %.2e" % config["lr"])
+
+ return train_args
+
+
+def get_parameter_groups(
+ model,
+ lr,
+ weight_decay,
+ skip_list=[],
+ submodule_configs=None,
+ warn_not_in_submodule=False,
+):
+ """
+ Get parameter groups for optimizer with customized learning rates and weight decay.
+
+ This function organizes model parameters into groups for the optimizer, allowing
+ different learning rates and weight decay values for different parts of the model.
+ Parameters are grouped by:
+ 1. Whether they should have weight decay applied (bias terms and 1D tensors typically don't)
+ 2. Which submodule they belong to (if submodule_configs is provided)
+
+ Args:
+ model (torch.nn.Module): Model to get parameter groups for
+ lr (float): Default learning rate for parameters not in submodule_configs
+ weight_decay (float): Default weight decay for parameters not in submodule_configs
+ skip_list (list): List of parameter names to skip weight decay for
+ submodule_configs (dict, optional): Dictionary mapping submodule prefixes to configs
+ with 'lr' and 'weight_decay' keys
+ warn_not_in_submodule (bool, optional): Whether to warn if a parameter does not
+ belong to any submodule. Defaults to False.
+
+ Returns:
+ tuple: A tuple containing:
+ - parameter_group_vars (list): List of parameter groups for optimizer
+ - parameter_group_name_to_idx_map (dict): Mapping from submodule name to parameter group indices
+ - parameter_group_idx_to_name_map (dict): Mapping from parameter group index to submodule name
+ """
+
+ if submodule_configs is None:
+ submodule_configs = {}
+
+ parameter_group_names = {}
+ parameter_group_vars = {}
+ parameter_group_name_to_idx_map = {}
+ parameter_group_idx_to_name_map = {}
+ mapping_index = 0
+
+ for name, param in model.named_parameters():
+ # Skip frozen parameters
+ if not param.requires_grad:
+ continue
+
+ # Determine the submodule this parameter belongs to
+ submodule_name = None
+ for submodule, config in submodule_configs.items():
+ if name.startswith(submodule):
+ submodule_name = submodule
+ break
+
+ if submodule_name:
+ config = submodule_configs[submodule_name]
+ this_weight_decay = config.get("weight_decay", weight_decay)
+ this_lr = config.get("lr", lr)
+ # Freeze the parameters if lr is 0
+ if this_lr == 0:
+ param.requires_grad = False
+ continue
+ else:
+ this_weight_decay = weight_decay
+ this_lr = lr
+ if warn_not_in_submodule and submodule_configs is not None:
+ print(
+ f"Warning: Parameter {name} does not belong to any submodule in {submodule_configs.keys()}."
+ )
+
+ # Assign weight decay values
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
+ group_name = f"{submodule_name}_no_decay" if submodule_name else "no_decay"
+ this_weight_decay = 0.0
+ else:
+ group_name = f"{submodule_name}_decay" if submodule_name else "decay"
+
+ if group_name not in parameter_group_names:
+ parameter_group_names[group_name] = {
+ "weight_decay": this_weight_decay,
+ "lr": this_lr,
+ "params": [],
+ }
+ parameter_group_vars[group_name] = {
+ "weight_decay": this_weight_decay,
+ "lr": this_lr,
+ "params": [],
+ }
+ submodule_name_mapping = submodule_name if submodule_name else "default"
+ if submodule_name_mapping not in parameter_group_name_to_idx_map:
+ parameter_group_name_to_idx_map[submodule_name_mapping] = [
+ mapping_index
+ ]
+ else:
+ parameter_group_name_to_idx_map[submodule_name_mapping].append(
+ mapping_index
+ )
+ parameter_group_idx_to_name_map[mapping_index] = submodule_name_mapping
+ mapping_index += 1
+
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+
+ # Print the parameter groups
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+
+ return (
+ list(parameter_group_vars.values()),
+ parameter_group_name_to_idx_map,
+ parameter_group_idx_to_name_map,
+ )
+
+
+def adjust_learning_rate(
+ optimizer,
+ epoch,
+ train_args,
+ parameter_group_idx_to_name_map,
+ submodule_configs=None,
+):
+ """
+ Adjust the learning rate based on the schedule type and current epoch.
+
+ This function updates the learning rates for all parameter groups in the optimizer
+ according to the specified learning rate schedule. Different submodules can have
+ different learning rate schedules.
+
+ Currently supported schedule types:
+ - linear_warmup_half_cycle_cosine_decay: Linear warmup followed by cosine decay
+
+ Args:
+ optimizer (torch.optim.Optimizer): The optimizer to update
+ epoch (int): Current training epoch
+ train_args: Training arguments containing schedule type, warmup epochs, etc.
+ parameter_group_idx_to_name_map (dict): Mapping from parameter group index to submodule name
+ submodule_configs (dict, optional): Dictionary of submodule-specific configurations
+ for learning rate schedules
+
+ Raises:
+ ValueError: If an unsupported schedule type is specified
+ """
+
+ if submodule_configs is None:
+ submodule_configs = {}
+
+ for group_num, param_group in enumerate(optimizer.param_groups):
+ submodule_name = parameter_group_idx_to_name_map.get(group_num)
+
+ if submodule_name in submodule_configs:
+ config = submodule_configs[submodule_name]
+ lr = config.get("lr", train_args.lr)
+ warmup_epochs = config.get("warmup_epochs", train_args.warmup_epochs)
+ min_lr = config.get("min_lr", train_args.min_lr)
+ schedule_type = config.get("schedule_type", train_args.schedule_type)
+ else:
+ lr = train_args.lr
+ warmup_epochs = train_args.warmup_epochs
+ min_lr = train_args.min_lr
+ schedule_type = train_args.schedule_type
+
+ if schedule_type == "linear_warmup_half_cycle_cosine_decay":
+ if epoch < warmup_epochs:
+ lr = lr * epoch / warmup_epochs
+ else:
+ lr = min_lr + (lr - min_lr) * 0.5 * (
+ 1.0
+ + math.cos(
+ math.pi
+ * (epoch - warmup_epochs)
+ / (train_args.epochs - warmup_epochs)
+ )
+ )
+ else:
+ raise ValueError(f"Schedule type {schedule_type} not implemented")
+
+ param_group["lr"] = lr
+
+
+def debug_after_backward(
+ model,
+ check_missing_gradients=True,
+ check_gradient_mismatch=False,
+ target_size=(256, 256, 1, 1),
+ target_stride=(256, 1, 256, 256),
+):
+ """
+ Debugging function to check for gradient issues after backward pass.
+
+ This function performs two types of gradient debugging:
+ 1. Gradient mismatch: Checks for parameters with specific gradient shapes and strides
+ that might indicate incorrect gradient computation.
+ 2. Missing gradients: Identifies parameters that require gradients but didn't receive any.
+
+ Args:
+ model (torch.nn.Module): The model to check gradients for
+ check_missing_gradients (bool, optional): Whether to check for missing gradients. Defaults to True.
+ check_gradient_mismatch (bool, optional): Whether to check for gradient mismatches. Defaults to False.
+ target_size (tuple, optional): Target tensor size to check for gradient mismatch. Defaults to (256, 256, 1, 1).
+ target_stride (tuple, optional): Target tensor stride to check for gradient mismatch. Defaults to (256, 1, 256, 256).
+ """
+ # Debug for missing gradients
+ if check_missing_gradients:
+ missing_grad_params = []
+ for name, param in model.named_parameters():
+ if param.requires_grad and param.grad is None:
+ missing_grad_params.append(name)
+
+ if missing_grad_params:
+ print("Parameters requiring gradients but missing gradients:")
+ for name in missing_grad_params:
+ print(f" - {name}")
+ else:
+ print("All parameters requiring gradients received gradients!")
+
+ # Debug for gradient mismatch
+ if check_gradient_mismatch:
+ for name, param in model.named_parameters():
+ grad = param.grad
+ if grad is None:
+ continue
+ if grad.size() == target_size and grad.stride() == target_stride:
+ print(f"Found parameter with incorrect gradient: '{name}'")
+ print(f"Gradient shape: {grad.size()}, strides: {grad.stride()}")
diff --git a/mapanything/utils/viz.py b/mapanything/utils/viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..e738f08cd3ad700ac6cf01fd8f8a1d3361fce46c
--- /dev/null
+++ b/mapanything/utils/viz.py
@@ -0,0 +1,167 @@
+"""
+Utility functions for visualization
+"""
+
+from argparse import ArgumentParser, Namespace
+from distutils.util import strtobool
+
+import rerun as rr
+
+
+def log_data_to_rerun(image, depthmap, pose, intrinsics, base_name, mask=None):
+ """
+ Log camera and image data to Rerun visualization tool.
+
+ Parameters
+ ----------
+ image : numpy.ndarray
+ RGB image to be logged
+ depthmap : numpy.ndarray
+ Depth map corresponding to the image
+ pose : numpy.ndarray
+ 4x4 camera pose matrix with rotation (3x3) and translation (3x1)
+ intrinsics : numpy.ndarray
+ Camera intrinsic matrix
+ base_name : str
+ Base name for the logged entities in Rerun
+ mask : numpy.ndarray, optional
+ Optional segmentation mask for the depth image
+ """
+ # Log camera info and loaded data
+ height, width = image.shape[0], image.shape[1]
+ rr.log(
+ base_name,
+ rr.Transform3D(
+ translation=pose[:3, 3],
+ mat3x3=pose[:3, :3],
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole",
+ rr.Pinhole(
+ image_from_camera=intrinsics,
+ height=height,
+ width=width,
+ camera_xyz=rr.ViewCoordinates.RDF,
+ ),
+ )
+ rr.log(
+ f"{base_name}/pinhole/rgb",
+ rr.Image(image),
+ )
+ rr.log(
+ f"{base_name}/pinhole/depth",
+ rr.DepthImage(depthmap),
+ )
+ if mask is not None:
+ rr.log(
+ f"{base_name}/pinhole/depth_mask",
+ rr.SegmentationImage(mask),
+ )
+
+
+def str2bool(v):
+ return bool(strtobool(v))
+
+
+def script_add_rerun_args(parser: ArgumentParser) -> None:
+ """
+ Add common Rerun script arguments to `parser`.
+
+ Change Log from https://github.com/rerun-io/rerun/blob/29eb8954b08e59ff96943dc0677f46f7ea4ea734/rerun_py/rerun_sdk/rerun/script_helpers.py#L65:
+ - Added default portforwarding url for ease of use
+ - Update parser types
+
+ Parameters
+ ----------
+ parser : ArgumentParser
+ The parser to add arguments to.
+
+ Returns
+ -------
+ None
+ """
+ parser.add_argument(
+ "--headless",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=True,
+ help="Don't show GUI",
+ )
+ parser.add_argument(
+ "--connect",
+ dest="connect",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=True,
+ help="Connect to an external viewer",
+ )
+ parser.add_argument(
+ "--serve",
+ dest="serve",
+ type=str2bool,
+ nargs="?",
+ const=True,
+ default=False,
+ help="Serve a web viewer (WARNING: experimental feature)",
+ )
+ parser.add_argument(
+ "--url",
+ type=str,
+ default="rerun+http://127.0.0.1:2004/proxy",
+ help="Connect to this HTTP(S) URL",
+ )
+ parser.add_argument(
+ "--save", type=str, default=None, help="Save data to a .rrd file at this path"
+ )
+ parser.add_argument(
+ "-o",
+ "--stdout",
+ dest="stdout",
+ action="store_true",
+ help="Log data to standard output, to be piped into a Rerun Viewer",
+ )
+
+
+def init_rerun_args(
+ headless=True,
+ connect=True,
+ serve=False,
+ url="rerun+http://127.0.0.1:2004/proxy",
+ save=None,
+ stdout=False,
+) -> Namespace:
+ """
+ Initialize common Rerun script arguments.
+
+ Parameters
+ ----------
+ headless : bool, optional
+ Don't show GUI, by default True
+ connect : bool, optional
+ Connect to an external viewer, by default True
+ serve : bool, optional
+ Serve a web viewer (WARNING: experimental feature), by default False
+ url : str, optional
+ Connect to this HTTP(S) URL, by default rerun+http://127.0.0.1:2004/proxy
+ save : str, optional
+ Save data to a .rrd file at this path, by default None
+ stdout : bool, optional
+ Log data to standard output, to be piped into a Rerun Viewer, by default False
+
+ Returns
+ -------
+ Namespace
+ The parsed arguments.
+ """
+ rerun_args = Namespace()
+ rerun_args.headless = headless
+ rerun_args.connect = connect
+ rerun_args.serve = serve
+ rerun_args.url = url
+ rerun_args.save = save
+ rerun_args.stdout = stdout
+
+ return rerun_args
diff --git a/mapanything/utils/wai/__init__.py b/mapanything/utils/wai/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c22e86818502c8a34ae795b710cafd1c3157c8f4
--- /dev/null
+++ b/mapanything/utils/wai/__init__.py
@@ -0,0 +1,3 @@
+"""
+This utils module contains PORTAGE of wai-core scripts/methods for MapAnything.
+"""
diff --git a/mapanything/utils/wai/basic_dataset.py b/mapanything/utils/wai/basic_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e8dd622afb4f4cf87e12a93c82a5526a74522c0
--- /dev/null
+++ b/mapanything/utils/wai/basic_dataset.py
@@ -0,0 +1,131 @@
+from pathlib import Path
+from typing import Any
+
+import torch
+from box import Box
+
+from mapanything.utils.wai.core import get_frame_index, load_data, load_frame
+from mapanything.utils.wai.ops import stack
+from mapanything.utils.wai.scene_frame import get_scene_frame_names
+
+
+class BasicSceneframeDataset(torch.utils.data.Dataset):
+ """Basic wai dataset to iterative over frames of scenes"""
+
+ @staticmethod
+ def collate_fn(batch: list[dict[str, Any]]) -> dict[str, Any]:
+ return stack(batch)
+
+ def __init__(
+ self,
+ cfg: Box,
+ ):
+ """
+ Initialize the BasicSceneframeDataset.
+
+ Args:
+ cfg (Box): Configuration object containing dataset parameters including:
+ - root: Root directory containing scene data
+ - frame_modalities: List of modalities to load for each frame
+ - key_remap: Optional dictionary mapping original keys to new keys
+ """
+ super().__init__()
+ self.cfg = cfg
+ self.root = cfg.root
+ keyframes = cfg.get("use_keyframes", True)
+ self.scene_frame_names = get_scene_frame_names(cfg, keyframes=keyframes)
+ self.scene_frame_list = [
+ (scene_name, frame_name)
+ for scene_name, frame_names in self.scene_frame_names.items()
+ for frame_name in frame_names
+ ]
+ self._scene_cache = {}
+
+ def __len__(self):
+ """
+ Get the total number of scene-frame pairs in the dataset.
+
+ Returns:
+ int: The number of scene-frame pairs.
+ """
+ return len(self.scene_frame_list)
+
+ def _load_scene(self, scene_name: str) -> dict[str, Any]:
+ """
+ Load scene data for a given scene name.
+
+ Args:
+ scene_name (str): The name of the scene to load.
+
+ Returns:
+ dict: A dictionary containing scene data, including scene metadata.
+ """
+ # load scene data
+ scene_data = {}
+ scene_data["meta"] = load_data(
+ Path(
+ self.root,
+ scene_name,
+ self.cfg.get("scene_meta_path", "scene_meta.json"),
+ ),
+ "scene_meta",
+ )
+
+ return scene_data
+
+ def _load_scene_frame(
+ self, scene_name: str, frame_name: str | float
+ ) -> dict[str, Any]:
+ """
+ Load data for a specific frame from a specific scene.
+
+ This method loads scene data if not already cached, then loads the specified frame
+ from that scene with the modalities specified in the configuration.
+
+ Args:
+ scene_name (str): The name of the scene containing the frame.
+ frame_name (str or float): The name/timestamp of the frame to load.
+
+ Returns:
+ dict: A dictionary containing the loaded frame data with requested modalities.
+ """
+ scene_frame_data = {}
+ if not (scene_data := self._scene_cache.get(scene_name)):
+ scene_data = self._load_scene(scene_name)
+ # for now only cache the last scene
+ self._scene_cache = {}
+ self._scene_cache[scene_name] = scene_data
+
+ frame_idx = get_frame_index(scene_data["meta"], frame_name)
+
+ scene_frame_data["scene_name"] = scene_name
+ scene_frame_data["frame_name"] = frame_name
+ scene_frame_data["scene_path"] = str(Path(self.root, scene_name))
+ scene_frame_data["frame_idx"] = frame_idx
+ scene_frame_data.update(
+ load_frame(
+ Path(self.root, scene_name),
+ frame_name,
+ modalities=self.cfg.frame_modalities,
+ scene_meta=scene_data["meta"],
+ )
+ )
+ # Remap key names
+ for key, new_key in self.cfg.get("key_remap", {}).items():
+ if key in scene_frame_data:
+ scene_frame_data[new_key] = scene_frame_data.pop(key)
+
+ return scene_frame_data
+
+ def __getitem__(self, index: int) -> dict[str, Any]:
+ """
+ Get a specific scene-frame pair by index.
+
+ Args:
+ index (int): The index of the scene-frame pair to retrieve.
+
+ Returns:
+ dict: A dictionary containing the loaded frame data with requested modalities.
+ """
+ scene_frame = self._load_scene_frame(*self.scene_frame_list[index])
+ return scene_frame
diff --git a/mapanything/utils/wai/camera.py b/mapanything/utils/wai/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eb65ea9edfb6b14f2f4765e5570fbacd96c26b0
--- /dev/null
+++ b/mapanything/utils/wai/camera.py
@@ -0,0 +1,263 @@
+"""
+This utils script contains PORTAGE of wai-core camera methods for MapAnything.
+"""
+
+from typing import Any
+
+import numpy as np
+import torch
+from scipy.spatial.transform import Rotation, Slerp
+
+from mapanything.utils.wai.ops import get_dtype_device
+
+# constants regarding camera models
+PINHOLE_CAM_KEYS = ["fl_x", "fl_y", "cx", "cy", "h", "w"]
+DISTORTION_PARAM_KEYS = [
+ "k1",
+ "k2",
+ "k3",
+ "k4",
+ "p1",
+ "p2",
+] # order corresponds to the OpenCV convention
+CAMERA_KEYS = PINHOLE_CAM_KEYS + DISTORTION_PARAM_KEYS
+
+
+def interpolate_intrinsics(
+ frame1: dict[str, Any],
+ frame2: dict[str, Any],
+ alpha: float,
+) -> dict[str, Any]:
+ """
+ Interpolate camera intrinsics linearly.
+ Args:
+ frame1: The first frame dictionary.
+ frame2: The second frame dictionary.
+ alpha: Interpolation parameter. alpha = 0 for frame1, alpha = 1 for frame2.
+ Returns:
+ frame_inter: dictionary with new intrinsics.
+ """
+ frame_inter = {}
+ for key in CAMERA_KEYS:
+ if key in frame1 and key in frame2:
+ p1 = frame1[key]
+ p2 = frame2[key]
+ frame_inter[key] = (1 - alpha) * p1 + alpha * p2
+ return frame_inter
+
+
+def interpolate_extrinsics(
+ matrix1: list | np.ndarray | torch.Tensor,
+ matrix2: list | np.ndarray | torch.Tensor,
+ alpha: float,
+) -> list | np.ndarray | torch.Tensor:
+ """
+ Interpolate camera extrinsics 4x4 matrices using SLERP.
+ Args:
+ matrix1: The first matrix.
+ matrix2: The second matrix.
+ alpha: Interpolation parameter. alpha = 0 for matrix1, alpha = 1 for matrix2.
+ Returns:
+ matrix: 4x4 interpolated matrix, same type.
+ Raises:
+ ValueError: If different type.
+ """
+ if not isinstance(matrix1, type(matrix2)):
+ raise ValueError("Both matrices should have the same type.")
+
+ dtype, device = get_dtype_device(matrix1)
+ if isinstance(matrix1, list):
+ mtype = "list"
+ matrix1 = np.array(matrix1)
+ matrix2 = np.array(matrix2)
+ elif isinstance(matrix1, np.ndarray):
+ mtype = "numpy"
+ elif isinstance(matrix1, torch.Tensor):
+ mtype = "torch"
+ matrix1 = matrix1.numpy()
+ matrix2 = matrix2.numpy()
+ else:
+ raise ValueError(
+ "Only list, numpy array and torch tensors are supported as inputs."
+ )
+
+ R1 = matrix1[:3, :3]
+ t1 = matrix1[:3, 3]
+ R2 = matrix2[:3, :3]
+ t2 = matrix2[:3, 3]
+
+ # interpolate translation
+ t = (1 - alpha) * t1 + alpha * t2
+
+ # interpolate rotations with SLERP
+ R1_quat = Rotation.from_matrix(R1).as_quat()
+ R2_quat = Rotation.from_matrix(R2).as_quat()
+ rotation_slerp = Slerp([0, 1], Rotation(np.stack([R1_quat, R2_quat])))
+ R = rotation_slerp(alpha).as_matrix()
+ matrix_inter = np.eye(4)
+
+ # combine together
+ matrix_inter[:3, :3] = R
+ matrix_inter[:3, 3] = t
+
+ if mtype == "list":
+ matrix_inter = matrix_inter.tolist()
+ elif mtype == "torch":
+ matrix_inter = torch.from_numpy(matrix_inter).to(dtype).to(device)
+ elif mtype == "numpy":
+ matrix_inter = matrix_inter.astype(dtype)
+
+ return matrix_inter
+
+
+def convert_camera_coeffs_to_pinhole_matrix(
+ scene_meta, frame, fmt="torch"
+) -> torch.Tensor | np.ndarray | list:
+ """
+ Convert camera intrinsics from NeRFStudio format to a 3x3 intrinsics matrix.
+
+ Args:
+ scene_meta: Scene metadata containing camera parameters
+ frame: Frame-specific camera parameters that override scene_meta
+
+ Returns:
+ torch.Tensor: 3x3 camera intrinsics matrix
+
+ Raises:
+ ValueError: If camera model is not PINHOLE or if distortion coefficients are present
+ """
+ # Check if camera model is supported
+ camera_model = frame.get("camera_model", scene_meta.get("camera_model"))
+ if camera_model != "PINHOLE":
+ raise ValueError("Only PINHOLE camera model supported")
+
+ # Check for unsupported distortion coefficients
+ if any(
+ (frame.get(coeff, 0) != 0) or (scene_meta.get(coeff, 0) != 0)
+ for coeff in DISTORTION_PARAM_KEYS
+ ):
+ raise ValueError(
+ "Pinhole camera does not support radial/tangential distortion -> Undistort first"
+ )
+
+ # Extract camera intrinsic parameters
+ camera_coeffs = {}
+ for coeff in ["fl_x", "fl_y", "cx", "cy"]:
+ camera_coeffs[coeff] = frame.get(coeff, scene_meta.get(coeff))
+ if camera_coeffs[coeff] is None:
+ raise ValueError(f"Missing required camera parameter: {coeff}")
+
+ # Create intrinsics matrix
+ intrinsics = [
+ [camera_coeffs["fl_x"], 0.0, camera_coeffs["cx"]],
+ [0.0, camera_coeffs["fl_y"], camera_coeffs["cy"]],
+ [0.0, 0.0, 1.0],
+ ]
+ if fmt == "torch":
+ intrinsics = torch.tensor(intrinsics)
+ elif fmt == "np":
+ intrinsics = np.array(intrinsics)
+
+ return intrinsics
+
+
+def rotate_pinhole_90degcw(
+ W: int, H: int, fx: float, fy: float, cx: float, cy: float
+) -> tuple[int, int, float, float, float, float]:
+ """Rotates the intrinsics of a pinhole camera model by 90 degrees clockwise."""
+ W_new = H
+ H_new = W
+ fx_new = fy
+ fy_new = fx
+ cy_new = cx
+ cx_new = H - 1 - cy
+ return W_new, H_new, fx_new, fy_new, cx_new, cy_new
+
+
+def _gl_cv_cmat() -> np.ndarray:
+ cmat = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
+ return cmat
+
+
+def _apply_transformation(
+ c2ws: torch.Tensor | np.ndarray, cmat: np.ndarray
+) -> torch.Tensor | np.ndarray:
+ """
+ Convert camera poses using a provided conversion matrix.
+
+ Args:
+ c2ws (torch.Tensor or np.ndarray): Camera poses (batch_size, 4, 4) or (4, 4)
+ cmat (torch.Tensor or np.ndarray): Conversion matrix (4, 4)
+
+ Returns:
+ torch.Tensor or np.ndarray: Transformed camera poses (batch_size, 4, 4) or (4, 4)
+ """
+ if isinstance(c2ws, torch.Tensor):
+ # Clone the input tensor to avoid modifying it in-place
+ c2ws_transformed = c2ws.clone()
+ # Apply the conversion matrix to the rotation part of the camera poses
+ if len(c2ws.shape) == 3:
+ c2ws_transformed[:, :3, :3] = c2ws_transformed[
+ :, :3, :3
+ ] @ torch.from_numpy(cmat[:3, :3]).to(c2ws).unsqueeze(0)
+ else:
+ c2ws_transformed[:3, :3] = c2ws_transformed[:3, :3] @ torch.from_numpy(
+ cmat[:3, :3]
+ ).to(c2ws)
+
+ elif isinstance(c2ws, np.ndarray):
+ # Clone the input array to avoid modifying it in-place
+ c2ws_transformed = c2ws.copy()
+ if len(c2ws.shape) == 3: # batched
+ # Apply the conversion matrix to the rotation part of the camera poses
+ c2ws_transformed[:, :3, :3] = np.einsum(
+ "ijk,lk->ijl", c2ws_transformed[:, :3, :3], cmat[:3, :3]
+ )
+ else: # single 4x4 matrix
+ # Apply the conversion matrix to the rotation part of the camera pose
+ c2ws_transformed[:3, :3] = np.dot(c2ws_transformed[:3, :3], cmat[:3, :3])
+
+ else:
+ raise ValueError("Input data type not supported.")
+
+ return c2ws_transformed
+
+
+def gl2cv(
+ c2ws: torch.Tensor | np.ndarray,
+ return_cmat: bool = False,
+) -> torch.Tensor | np.ndarray | tuple[torch.Tensor | np.ndarray, np.ndarray]:
+ """
+ Convert camera poses from OpenGL to OpenCV coordinate system.
+
+ Args:
+ c2ws (torch.Tensor or np.ndarray): Camera poses (batch_size, 4, 4) or (4, 4)
+ return_cmat (bool): If True, return the conversion matrix along with the transformed poses
+
+ Returns:
+ torch.Tensor or np.ndarray: Transformed camera poses (batch_size, 4, 4) or (4, 4)
+ np.ndarray (optional): Conversion matrix if return_cmat is True
+ """
+ cmat = _gl_cv_cmat()
+ if return_cmat:
+ return _apply_transformation(c2ws, cmat), cmat
+ return _apply_transformation(c2ws, cmat)
+
+
+def intrinsics_to_fov(
+ fx: torch.Tensor, fy: torch.Tensor, h: torch.Tensor, w: torch.Tensor
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Compute the horizontal and vertical fields of view in radians from camera intrinsics.
+
+ Args:
+ fx (torch.Tensor): focal x
+ fy (torch.Tensor): focal y
+ h (torch.Tensor): Image height(s) with shape (B,).
+ w (torch.Tensor): Image width(s) with shape (B,).
+
+ Returns:
+ tuple[torch.Tensor, torch.Tensor]: A tuple containing the horizontal and vertical fields
+ of view in radians, both with shape (N,).
+ """
+ return 2 * torch.atan((w / 2) / fx), 2 * torch.atan((h / 2) / fy)
diff --git a/mapanything/utils/wai/colormaps/colors_fps_5k.npz b/mapanything/utils/wai/colormaps/colors_fps_5k.npz
new file mode 100644
index 0000000000000000000000000000000000000000..7f259f39eb571a3c0848cc6a82d28a384e2254e6
--- /dev/null
+++ b/mapanything/utils/wai/colormaps/colors_fps_5k.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:fae94fe5fb565ff40d1c556ae2640d00fc068e732cb4af5bb64eef034790e07c
+size 9478
diff --git a/mapanything/utils/wai/core.py b/mapanything/utils/wai/core.py
new file mode 100644
index 0000000000000000000000000000000000000000..56e22269cfd304c2d26e2b50e1b7754489dc3ba5
--- /dev/null
+++ b/mapanything/utils/wai/core.py
@@ -0,0 +1,492 @@
+"""
+This utils script contains PORTAGE of wai-core core methods for MapAnything.
+"""
+
+import logging
+import re
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import torch
+
+from mapanything.utils.wai.camera import (
+ CAMERA_KEYS,
+ convert_camera_coeffs_to_pinhole_matrix,
+ interpolate_extrinsics,
+ interpolate_intrinsics,
+)
+from mapanything.utils.wai.io import _get_method, _load_scene_meta
+from mapanything.utils.wai.ops import crop
+
+logger = logging.getLogger(__name__)
+
+WAI_COLORMAP_PATH = Path(__file__).parent / "colormaps"
+
+
+def load_data(fname: str | Path, format_type: str | None = None, **kwargs) -> Any:
+ """
+ Loads data from a file using the appropriate method based on the file format.
+
+ Args:
+ fname (str or Path): The filename or path to load data from.
+ format_type (str, optional): The format type of the data. If None, it will be inferred from the file extension if possible.
+ Supported formats include: 'readable', 'scalar', 'image', 'binary', 'depth', 'normals',
+ 'numpy', 'ptz', 'mmap', 'scene_meta', 'labeled_image', 'mesh', 'labeled_mesh', 'caption', "latents".
+ **kwargs: Additional keyword arguments to pass to the loading method.
+
+ Returns:
+ The loaded data in the format returned by the specific loading method.
+
+ Raises:
+ ValueError: If the format cannot be inferred from the file extension.
+ NotImplementedError: If the specified format is not supported.
+ FileExistsError: If the file does not exist.
+ """
+ load_method = _get_method(fname, format_type, load=True)
+ return load_method(fname, **kwargs)
+
+
+def store_data(
+ fname: str | Path,
+ data: Any,
+ format_type: str | None = None,
+ **kwargs,
+) -> Any:
+ """
+ Stores data to a file using the appropriate method based on the file format.
+
+ Args:
+ fname (str or Path): The filename or path to store data to.
+ data: The data to be stored.
+ format_type (str, optional): The format type of the data. If None, it will be inferred from the file extension.
+ **kwargs: Additional keyword arguments to pass to the storing method.
+
+ Returns:
+ The result of the storing method, which may vary depending on the method used.
+ """
+ store_method = _get_method(fname, format_type, load=False)
+ Path(fname).parent.mkdir(parents=True, exist_ok=True)
+ return store_method(fname, data, **kwargs)
+
+
+def get_frame(
+ scene_meta: dict[str, Any],
+ frame_key: int | str | float,
+) -> dict[str, Any]:
+ """
+ Get a frame from scene_meta based on name or index.
+
+ Args:
+ scene_meta: Dictionary containing scene metadata
+ frame_key: Either a string (frame name) or integer (frame index) or float (video timestamp)
+
+ Returns:
+ The frame data (dict)
+ """
+ frame_idx = get_frame_index(scene_meta, frame_key)
+ if isinstance(frame_idx, int):
+ frame = scene_meta["frames"][frame_idx]
+ frame["_is_interpolated"] = False
+ else:
+ frame = {}
+ frame["frame_name"] = frame_key
+ left = int(frame_idx) # it's floor operation
+ assert left >= 0 and left < (len(scene_meta["frames"]) - 1), "Wrong index"
+ frame_left = scene_meta["frames"][left]
+ frame_right = scene_meta["frames"][left + 1]
+ # Interpolate intrinsics and extrinsics
+ frame["transform_matrix"] = interpolate_extrinsics(
+ frame_left["transform_matrix"],
+ frame_right["transform_matrix"],
+ frame_idx - left,
+ )
+ frame.update(
+ interpolate_intrinsics(
+ frame_left,
+ frame_right,
+ frame_idx - left,
+ )
+ )
+ frame["_is_interpolated"] = True
+ return frame
+
+
+def get_intrinsics(
+ scene_meta,
+ frame_key,
+ fmt: str = "torch",
+) -> torch.Tensor | np.ndarray | list:
+ frame = get_frame(scene_meta, frame_key)
+ return convert_camera_coeffs_to_pinhole_matrix(scene_meta, frame, fmt=fmt)
+
+
+def get_extrinsics(
+ scene_meta,
+ frame_key,
+ fmt: str = "torch",
+) -> torch.Tensor | np.ndarray | list | None:
+ frame = get_frame(scene_meta, frame_key)
+ if "transform_matrix" in frame:
+ if fmt == "torch":
+ return torch.tensor(frame["transform_matrix"]).reshape(4, 4).float()
+ elif fmt == "np":
+ return np.array(frame["transform_matrix"]).reshape(4, 4)
+ return frame["transform_matrix"]
+ else:
+ # TODO: should not happen if we enable interpolation
+ return None
+
+
+def get_frame_index(
+ scene_meta: dict[str, Any],
+ frame_key: int | str | float,
+ frame_index_threshold_sec: float = 1e-4,
+ distance_threshold_sec: float = 2.0,
+) -> int | float:
+ """
+ Returns the frame index from scene_meta based on name (str) or index (int) or sub-frame index (float).
+
+ Args:
+ scene_meta: Dictionary containing scene metadata
+ frame_key: Either a string (frame name) or integer (frame index) or float (sub-frame index)
+ frame_index_threshold_sec: A threshold for nearest neighbor clipping for indexes (in seconds).
+ Default is 1e-4, which is 10000 fps.
+ distance_th: A threshold for maximum distance between interpolated frames (in seconds).
+
+ Returns:
+ Frame index (int)
+
+ Raises:
+ ValueError: If frame_key is not a string or integer or float
+ """
+ if isinstance(frame_key, str):
+ try:
+ return scene_meta["frame_names"][frame_key]
+ except KeyError as err:
+ error_message = (
+ f"Frame name not found: {frame_key} - "
+ f"Please verify scene_meta.json of scene: {scene_meta['dataset_name']}/{scene_meta['scene_name']}"
+ )
+ logger.error(error_message)
+ raise KeyError(error_message) from err
+
+ if isinstance(frame_key, int):
+ return frame_key
+
+ if isinstance(frame_key, float):
+ # If exact hit
+ if frame_key in scene_meta["frame_names"]:
+ return scene_meta["frame_names"][frame_key]
+
+ frame_names = sorted(list(scene_meta["frame_names"].keys()))
+ distances = np.array([frm - frame_key for frm in frame_names])
+ left = int(np.nonzero(distances <= 0)[0][-1])
+ right = left + 1
+
+ # The last frame or rounding errors
+ if (
+ left == distances.shape[0] - 1
+ or abs(distances[left]) < frame_index_threshold_sec
+ ):
+ return scene_meta["frame_names"][frame_names[int(left)]]
+ if abs(distances[right]) < frame_index_threshold_sec:
+ return scene_meta["frame_names"][frame_names[int(right)]]
+
+ interpolation_distance = distances[right] - distances[left]
+ if interpolation_distance > distance_threshold_sec:
+ raise ValueError(
+ f"Frame interpolation is forbidden for distances larger than {distance_threshold_sec}."
+ )
+ alpha = -distances[left] / interpolation_distance
+
+ return scene_meta["frame_names"][frame_names[int(left)]] + alpha
+
+ raise ValueError(f"Frame key type not supported: {frame_key} ({type(frame_key)}).")
+
+
+def load_modality_data(
+ scene_root: Path | str,
+ results: dict[str, Any],
+ modality_dict: dict[str, Any],
+ modality: str,
+ frame: dict[str, Any] | None = None,
+ fmt: str = "torch",
+) -> dict[str, Any]:
+ """
+ Processes a modality by loading data from a specified path and updating the results dictionary.
+ This function extracts the format and path from the given modality dictionary, loads the data
+ from the specified path, and updates the results dictionary with the loaded data.
+
+ Args:
+ scene_root (str or Path): The root directory of the scene where the data is located.
+ results (dict): A dictionary to store the loaded modality data and optional frame path.
+ modality_dict (dict): A dictionary containing the modality information, including 'format'
+ and the path to the data.
+ modality (str): The key under which the loaded modality data will be stored in the results.
+ frame (dict, optional): A dictionary containing frame information. If provided, that means we are loading
+ frame modalities, otherwise it is scene modalities.
+
+ Returns:
+ dict: The updated results dictionary containing the loaded modality data.
+ """
+ modality_format = modality_dict["format"]
+
+ # The modality is stored as a video
+ if "video" in modality_format:
+ assert isinstance(frame["frame_name"], float), "frame_name should be float"
+ video_file = None
+ if "chunks" in modality_dict:
+ video_list = modality_dict["chunks"]
+ # Get the correct chunk of the video
+ for video_chunk in video_list:
+ if video_chunk["start"] <= frame["frame_name"] <= video_chunk["end"]:
+ video_file = video_chunk
+ break
+ else:
+ # There is only one video (no chunks)
+ video_file = modality_dict
+ if "start" not in video_file:
+ video_file["start"] = 0
+ if "end" not in video_file:
+ video_file["end"] = float("inf")
+ if not (video_file["start"] <= frame["frame_name"] <= video_file["end"]):
+ video_file = None
+
+ # This timestamp is not available in any of the chunks
+ if video_file is None:
+ frame_name = frame["frame_name"]
+ logger.warning(
+ f"Modality {modality} ({modality_format}) is not available at time {frame_name}"
+ )
+ return results
+
+ # Load the modality from the video
+ loaded_modality = load_data(
+ Path(scene_root, video_file["file"]),
+ modality_format,
+ frame_key=frame["frame_name"] - video_file["start"],
+ )
+
+ if "bbox" in video_file:
+ loaded_modality = crop(loaded_modality, video_file["bbox"])
+
+ if loaded_modality is not None:
+ results[modality] = loaded_modality
+
+ if frame:
+ results[f"{modality}_fname"] = video_file["file"]
+ else:
+ modality_path = [v for k, v in modality_dict.items() if k != "format"][0]
+ if frame:
+ if modality_path in frame:
+ fname = frame[modality_path]
+ else:
+ fname = None
+ else:
+ fname = modality_path
+ if fname is not None:
+ loaded_modality = load_data(
+ Path(scene_root, fname),
+ modality_format,
+ frame_key=frame["frame_name"] if frame else None,
+ fmt=fmt,
+ )
+ results[modality] = loaded_modality
+ if frame:
+ results[f"{modality}_fname"] = frame[modality_path]
+ return results
+
+
+def load_modality(
+ scene_root: Path | str,
+ modality_meta: dict[str, Any],
+ modality: str,
+ frame: dict[str, Any] | None = None,
+ fmt: str = "torch",
+) -> dict[str, Any]:
+ """
+ Loads modality data based on the provided metadata and updates the results dictionary.
+ This function navigates through the modality metadata to find the specified modality,
+ then loads the data for each modality found.
+
+ Args:
+ scene_root (str or Path): The root directory of the scene where the data is located.
+ modality_meta (dict): A nested dictionary containing metadata for various modalities.
+ modality (str): A string representing the path to the desired modality within the metadata,
+ using '/' as a separator for nested keys.
+ frame (dict, optional): A dictionary containing frame information. If provided, we are operating
+ on frame modalities, otherwise it is scene modalities.
+
+ Returns:
+ dict: A dictionary containing the loaded modality data.
+ """
+ results = {}
+ # support for nested modalities like "pred_depth/metric3dv2"
+ modality_keys = modality.split("/")
+ current_modality = modality_meta
+ for key in modality_keys:
+ try:
+ current_modality = current_modality[key]
+ except KeyError as err:
+ error_message = (
+ f"Modality '{err.args[0]}' not found in modalities metadata. "
+ f"Please verify the scene_meta.json and the provided modalities in {scene_root}."
+ )
+ logger.error(error_message)
+ raise KeyError(error_message) from err
+ if "format" in current_modality:
+ results = load_modality_data(
+ scene_root, results, current_modality, modality, frame, fmt=fmt
+ )
+ else:
+ # nested modality, return last by default
+ logger.warning("Nested modality, returning last by default")
+ key = next(reversed(current_modality.keys()))
+ results = load_modality_data(
+ scene_root, results, current_modality[key], modality, frame, fmt=fmt
+ )
+ return results
+
+
+def load_frame(
+ scene_root: Path | str,
+ frame_key: int | str | float,
+ modalities: str | list[str] | None = None,
+ scene_meta: dict[str, Any] | None = None,
+ load_intrinsics: bool = True,
+ load_extrinsics: bool = True,
+ fmt: str = "torch",
+ interpolate: bool = False,
+) -> dict[str, Any]:
+ """
+ Load a single frame from a scene with specified modalities.
+
+ Args:
+ scene_root (str or Path): The root directory of the scene where the data is located.
+ frame_key (int or str or float): Either a string (frame name) or integer (frame index) or float (video timestamp).
+ modalities (str or list[str], optional): The modality or list of modalities to load.
+ If None, only basic frame information is loaded.
+ scene_meta (dict, optional): Dictionary containing scene metadata. If None, it will be loaded
+ from scene_meta.json in the scene_root.
+ interpolate (bool, optional): Allow interpolating frames?
+
+ Returns:
+ dict: A dictionary containing the loaded frame data with the requested modalities.
+ """
+ scene_root = Path(scene_root)
+ if scene_meta is None:
+ scene_meta = _load_scene_meta(scene_root / "scene_meta.json")
+ frame = get_frame(scene_meta, frame_key)
+ # compact, standarized frame representation
+ wai_frame = {}
+ if load_extrinsics:
+ extrinsics = get_extrinsics(
+ scene_meta,
+ frame_key,
+ fmt=fmt,
+ )
+ if extrinsics is not None:
+ wai_frame["extrinsics"] = extrinsics
+ if load_intrinsics:
+ camera_model = frame.get("camera_model", scene_meta.get("camera_model"))
+ wai_frame["camera_model"] = camera_model
+ if camera_model == "PINHOLE":
+ wai_frame["intrinsics"] = get_intrinsics(scene_meta, frame_key, fmt=fmt)
+ elif camera_model in ["OPENCV", "OPENCV_FISHEYE"]:
+ # optional per-frame intrinsics
+ for camera_key in CAMERA_KEYS:
+ if camera_key in frame:
+ wai_frame[camera_key] = float(frame[camera_key])
+ elif camera_key in scene_meta:
+ wai_frame[camera_key] = float(scene_meta[camera_key])
+ else:
+ error_message = (
+ f"Camera model not supported: {camera_model} - "
+ f"Please verify scene_meta.json of scene: {scene_meta['dataset_name']}/{scene_meta['scene_name']}"
+ )
+ logger.error(error_message)
+ raise NotImplementedError(error_message)
+ wai_frame["w"] = frame.get("w", scene_meta["w"] if "w" in scene_meta else None)
+ wai_frame["h"] = frame.get("h", scene_meta["h"] if "h" in scene_meta else None)
+ wai_frame["frame_name"] = frame["frame_name"]
+ wai_frame["frame_idx"] = get_frame_index(scene_meta, frame_key)
+ wai_frame["_is_interpolated"] = frame["_is_interpolated"]
+
+ if modalities is not None:
+ if isinstance(modalities, str):
+ modalities = [modalities]
+ for modality in modalities:
+ # Handle regex patterns in modality
+ if any(char in modality for char in ".|*+?()[]{}^$\\"):
+ # This is a regex pattern
+ pattern = re.compile(modality)
+ matching_modalities = [
+ m for m in scene_meta["frame_modalities"] if pattern.match(m)
+ ]
+ if not matching_modalities:
+ raise ValueError(
+ f"No modalities match the pattern: {modality} in scene: {scene_root}"
+ )
+ # Use the first matching modality
+ modality = matching_modalities[0]
+ current_modalities = load_modality(
+ scene_root, scene_meta["frame_modalities"], modality, frame, fmt=fmt
+ )
+ wai_frame.update(current_modalities)
+
+ return wai_frame
+
+
+def set_frame(
+ scene_meta: dict[str, Any],
+ frame_key: int | str,
+ new_frame: dict[str, Any],
+ sort: bool = False,
+) -> dict[str, Any]:
+ """
+ Replace a frame in scene_meta with a new frame.
+
+ Args:
+ scene_meta: Dictionary containing scene metadata.
+ frame_key: Either a string (frame name) or integer (frame index).
+ new_frame: New frame data to replace the existing frame.
+ sort: If True, sort the keys in the new_frame dictionary.
+
+ Returns:
+ Updated scene_meta dictionary.
+ """
+ frame_idx = get_frame_index(scene_meta, frame_key)
+ if isinstance(frame_idx, float):
+ raise ValueError(
+ f"Setting frame for sub-frame frame_key is not supported: {frame_key} ({type(frame_key)})."
+ )
+ if sort:
+ new_frame = {k: new_frame[k] for k in sorted(new_frame)}
+ scene_meta["frames"][frame_idx] = new_frame
+ return scene_meta
+
+
+def nest_modality(
+ frame_modalities: dict[str, Any],
+ modality_name: str,
+) -> dict[str, Any]:
+ """
+ Converts a flat modality structure into a nested one based on the modality name.
+
+ Args:
+ frame_modalities (dict): Dictionary containing frame modalities.
+ modality_name (str): The name of the modality to nest.
+
+ Returns:
+ dict: A dictionary with the nested modality structure.
+ """
+ frame_modality = {}
+ if modality_name in frame_modalities:
+ frame_modality = frame_modalities[modality_name]
+ if "frame_key" in frame_modality:
+ # required for backwards compatibility
+ # converting non-nested format into nested one based on name
+ modality_name = frame_modality["frame_key"].split("_")[0]
+ frame_modality = {modality_name: frame_modality}
+ return frame_modality
diff --git a/mapanything/utils/wai/intersection_check.py b/mapanything/utils/wai/intersection_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..c368f32da3f2f4606a7dcbce5be69e0af62e84ad
--- /dev/null
+++ b/mapanything/utils/wai/intersection_check.py
@@ -0,0 +1,462 @@
+import torch
+from einops import rearrange, repeat
+from tqdm import tqdm
+
+
+def create_frustum_from_intrinsics(
+ intrinsics: torch.Tensor,
+ near: torch.Tensor | float,
+ far: torch.Tensor | float,
+) -> torch.Tensor:
+ r"""
+ Create a frustum from camera intrinsics.
+
+ Args:
+ intrinsics (torch.Tensor): Bx3x3 Intrinsics of cameras.
+ near (torch.Tensor or float): [B] Near plane distance.
+ far (torch.Tensor or float): [B] Far plane distance.
+
+ Returns:
+ frustum (torch.Tensor): Bx8x3 batch of frustum points following the order:
+ 5 ---------- 4
+ |\ /|
+ 6 \ / 7
+ \ 1 ---- 0 /
+ \| |/
+ 2 ---- 3
+ """
+
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
+ cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
+
+ # Calculate the offsets at the near plane
+ near_x = near * (cx / fx)
+ near_y = near * (cy / fy)
+ far_x = far * (cx / fx)
+ far_y = far * (cy / fy)
+
+ # Define frustum vertices in camera space
+ near_plane = torch.stack(
+ [
+ torch.stack([near_x, near_y, near * torch.ones_like(near_x)], dim=-1),
+ torch.stack([-near_x, near_y, near * torch.ones_like(near_x)], dim=-1),
+ torch.stack([-near_x, -near_y, near * torch.ones_like(near_x)], dim=-1),
+ torch.stack([near_x, -near_y, near * torch.ones_like(near_x)], dim=-1),
+ ],
+ dim=1,
+ )
+
+ far_plane = torch.stack(
+ [
+ torch.stack([far_x, far_y, far * torch.ones_like(far_x)], dim=-1),
+ torch.stack([-far_x, far_y, far * torch.ones_like(far_x)], dim=-1),
+ torch.stack([-far_x, -far_y, far * torch.ones_like(far_x)], dim=-1),
+ torch.stack([far_x, -far_y, far * torch.ones_like(far_x)], dim=-1),
+ ],
+ dim=1,
+ )
+
+ return torch.cat([near_plane, far_plane], dim=1)
+
+
+def _frustum_to_triangles(frustum: torch.Tensor) -> torch.Tensor:
+ """
+ Convert frustum to triangles.
+
+ Args:
+ frustums (torch.Tensor): Bx8 batch of frustum points.
+
+ Returns:
+ frustum_triangles (torch.Tensor): Bx3x3 batch of frustum triangles.
+ """
+
+ triangle_inds = torch.tensor(
+ [
+ [0, 1, 2],
+ [0, 2, 3],
+ [0, 3, 7],
+ [0, 7, 4],
+ [1, 2, 6],
+ [1, 6, 5],
+ [1, 4, 5],
+ [1, 0, 4],
+ [2, 6, 7],
+ [2, 3, 7],
+ [6, 7, 4],
+ [6, 5, 4],
+ ]
+ )
+ frustum_triangles = frustum[:, triangle_inds]
+ return frustum_triangles
+
+
+def segment_triangle_intersection_check(
+ start_points: torch.Tensor,
+ end_points: torch.Tensor,
+ triangles: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Check if segments (lines with starting and end point) intersect triangles in 3D using the
+ Moller-Trumbore algorithm.
+
+ Args:
+ start_points (torch.Tensor): Bx3 Starting points of the segment.
+ end_points (torch.Tensor): Bx3 End points of the segment.
+ triangles (torch.Tensor): Bx3x3 Vertices of the triangles.
+
+ Returns:
+ intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its
+ corresponding triangle.
+ """
+ vertex0 = triangles[:, 0, :]
+ vertex1 = triangles[:, 1, :]
+ vertex2 = triangles[:, 2, :]
+ edge1 = vertex1 - vertex0
+ edge2 = vertex2 - vertex0
+ ray_vectors = end_points - start_points
+ max_lengths = torch.norm(ray_vectors, dim=1)
+ ray_vectors = ray_vectors / max_lengths[:, None]
+ h = torch.cross(ray_vectors, edge2, dim=1)
+ a = (edge1 * h).sum(dim=1)
+
+ epsilon = 1e-6
+ mask = torch.abs(a) > epsilon
+ f = torch.zeros_like(a)
+ f[mask] = 1.0 / a[mask]
+
+ s = start_points - vertex0
+ u = f * (s * h).sum(dim=1)
+ q = torch.cross(s, edge1, dim=1)
+ v = f * (ray_vectors * q).sum(dim=1)
+
+ t = f * (edge2 * q).sum(dim=1)
+
+ # Check conditions
+ intersects = (
+ (u >= 0)
+ & (u <= 1)
+ & (v >= 0)
+ & (u + v <= 1)
+ & (t >= epsilon)
+ & (t <= max_lengths)
+ )
+
+ return intersects
+
+
+def triangle_intersection_check(
+ triangles1: torch.Tensor,
+ triangles2: torch.Tensor,
+) -> torch.Tensor:
+ """
+ Check if two triangles intersect.
+
+ Args:
+ triangles1 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles.
+ triangles2 (torch.Tensor): Bx3x3 Vertices of the first batch of triangles.
+
+ Returns:
+ triangle_intersection (torch.Tensor): B Boolean tensor indicating if triangles intersect.
+ """
+ n = triangles1.shape[1]
+ start_points1 = rearrange(triangles1, "B N C -> (B N) C")
+ end_points1 = rearrange(
+ triangles1[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C"
+ )
+
+ start_points2 = rearrange(triangles2, "B N C -> (B N) C")
+ end_points2 = rearrange(
+ triangles2[:, torch.arange(1, n + 1) % n], "B N C -> (B N) C"
+ )
+ intersection_1_2 = segment_triangle_intersection_check(
+ start_points1, end_points1, repeat(triangles2, "B N C -> (B N2) N C", N2=3)
+ )
+ intersection_2_1 = segment_triangle_intersection_check(
+ start_points2, end_points2, repeat(triangles1, "B N C -> (B N2) N C", N2=3)
+ )
+ triangle_intersection = torch.any(
+ rearrange(intersection_1_2, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n),
+ dim=1,
+ ) | torch.any(
+ rearrange(intersection_2_1, "(B N N2) -> B (N N2)", B=triangles1.shape[0], N=n),
+ dim=1,
+ )
+ return triangle_intersection
+
+
+def frustum_intersection_check(
+ frustums: torch.Tensor,
+ check_inside: bool = True,
+ chunk_size: int = 500,
+ device: str | None = None,
+) -> torch.Tensor:
+ """
+ Check if any pair of the frustums intersect with each other.
+
+ Args:
+ frustums (torch.Tensor): Bx8 batch of frustum points.
+ check_inside (bool): If True, also checks if one frustum is inside another.
+ Defaults to True.
+ chunk_size (Optional[int]): Number of chunks to split the computation into.
+ Defaults to 500.
+ device (Optional[str]): Device to store exhuastive frustum intersection matrix on.
+ Defaults to None.
+
+ Returns:
+ frustum_intersection (torch.Tensor): BxB tensor of Booleans indicating if any pair
+ of frustums intersect with each other.
+ """
+ B = frustums.shape[0]
+ if device is None:
+ device = frustums.device
+ frustum_triangles = _frustum_to_triangles(frustums)
+ T = frustum_triangles.shape[1]
+
+ # Perform frustum in frustum check if required
+ if check_inside:
+ frustum_intersection = frustums_in_frustum_check(
+ frustums=frustums, chunk_size=chunk_size, device=device
+ )
+ else:
+ frustum_intersection = torch.zeros((B, B), dtype=torch.bool, device=device)
+
+ # Check triangle intersections in chunks
+ for i in tqdm(range(0, B, chunk_size), desc="Checking triangle intersections"):
+ i_end = min(i + chunk_size, B)
+ chunk_i_size = i_end - i
+
+ for j in range(0, B, chunk_size):
+ j_end = min(j + chunk_size, B)
+ chunk_j_size = j_end - j
+
+ # Process all triangle pairs between the two chunks in a vectorized way
+ triangles_i = frustum_triangles[i:i_end] # [chunk_i, T, 3, 3]
+ triangles_j = frustum_triangles[j:j_end] # [chunk_j, T, 3, 3]
+
+ # Reshape to process all triangle pairs at once
+ tri_i = triangles_i.reshape(chunk_i_size * T, 3, 3)
+ tri_j = triangles_j.reshape(chunk_j_size * T, 3, 3)
+
+ # Expand for all pairs - explicitly specify dimensions instead of using ...
+ tri_i_exp = repeat(tri_i, "bt i j -> (bt bj_t) i j", bj_t=chunk_j_size * T)
+ tri_j_exp = repeat(tri_j, "bt i j -> (bi_t bt) i j", bi_t=chunk_i_size * T)
+
+ # Check intersection
+ batch_intersect = triangle_intersection_check(tri_i_exp, tri_j_exp)
+
+ # Reshape and check if any triangle pair intersects
+ batch_intersect = batch_intersect.reshape(chunk_i_size, T, chunk_j_size, T)
+ batch_intersect = batch_intersect.any(dim=(1, 3))
+
+ # Update result
+ frustum_intersection[i:i_end, j:j_end] |= batch_intersect.to(device)
+
+ return frustum_intersection
+
+
+def ray_triangle_intersection_check(
+ ray_origins: torch.Tensor,
+ ray_vectors: torch.Tensor,
+ triangles: torch.Tensor,
+ max_lengths: torch.Tensor | None = None,
+) -> torch.Tensor:
+ """
+ Check if rays intersect triangles in 3D using the Moller-Trumbore algorithm, considering the
+ finite length of rays.
+
+ Args:
+ ray_origins (torch.Tensor): Bx3 Origins of the rays.
+ ray_vectors (torch.Tensor): Bx3 Direction vectors of the rays.
+ triangles (torch.Tensor): Bx3x3 Vertices of the triangles.
+ max_lengths Optional[torch.Tensor]: B Maximum lengths of the rays.
+
+ Returns:
+ intersects (torch.Tensor): B Boolean tensor indicating if each ray intersects its
+ corresponding triangle.
+ """
+ vertex0 = triangles[:, 0, :]
+ vertex1 = triangles[:, 1, :]
+ vertex2 = triangles[:, 2, :]
+ edge1 = vertex1 - vertex0
+ edge2 = vertex2 - vertex0
+ h = torch.cross(ray_vectors, edge2, dim=1)
+ a = (edge1 * h).sum(dim=1)
+
+ epsilon = 1e-6
+ mask = torch.abs(a) > epsilon
+ f = torch.zeros_like(a)
+ f[mask] = 1.0 / a[mask]
+
+ s = ray_origins - vertex0
+ u = f * (s * h).sum(dim=1)
+ q = torch.cross(s, edge1, dim=1)
+ v = f * (ray_vectors * q).sum(dim=1)
+
+ t = f * (edge2 * q).sum(dim=1)
+
+ # Check conditions
+ intersects = (u >= 0) & (u <= 1) & (v >= 0) & (u + v <= 1) & (t >= epsilon)
+ if max_lengths is not None:
+ intersects &= t <= max_lengths
+
+ return intersects
+
+
+#### Checks for frustums
+def _frustum_to_planes(frustums: torch.Tensor) -> torch.Tensor:
+ r"""
+ Converts frustum parameters to plane representation.
+
+ Args:
+ frustums (torch.Tensor): Bx8 batch of frustum points following the order:
+ 5 ---------- 4
+ |\ /|
+ 6 \ / 7
+ \ 1 ---- 0 /
+ \| |/
+ 2 ---- 3
+
+ Returns:
+ planes (torch.Tensor): Bx6x4 where 6 represents the six frustum planes and
+ 4 represents plane parameters [a, b, c, d].
+ """
+ planes = []
+ for inds in [[0, 1, 3], [1, 6, 2], [0, 3, 7], [2, 6, 3], [0, 5, 1], [6, 5, 4]]:
+ normal = torch.cross(
+ frustums[:, inds[1]] - frustums[:, inds[0]],
+ frustums[:, inds[2]] - frustums[:, inds[0]],
+ dim=1,
+ )
+ normal = normal / torch.norm(normal, dim=1, keepdim=True)
+ d = -torch.sum(normal * frustums[:, inds[0]], dim=1, keepdim=True)
+ planes.append(torch.cat([normal, d], -1))
+ return torch.stack(planes, 1)
+
+
+def points_in_frustum_check(
+ frustums: torch.Tensor,
+ points: torch.Tensor,
+ chunk_size: int | None = None,
+ device: str | None = None,
+):
+ """
+ Check if points are inside frustums.
+
+ Args:
+ frustums (torch.Tensor): Bx8 batch of frustum points.
+ points (torch.Tensor): BxNx3 batch of points.
+ chunk_size (Optional[int]): Number of chunks to split the computation into. Defaults to None.
+ device (Optional[str]): Device to perfrom computation on. Defaults to None.
+
+ Returns:
+ inside (torch.Tensor): BxN batch of Booleans indicating if points are inside frustums.
+ """
+ if device is None:
+ device = frustums.device
+
+ if chunk_size is not None:
+ # Split computation into chunks to avoid OOM errors for large batch sizes
+ point_plane_direction = []
+ for chunk_idx in range(0, frustums.shape[0], chunk_size):
+ chunk_frustum_planes = _frustum_to_planes(
+ frustums[chunk_idx : chunk_idx + chunk_size]
+ )
+ # Bx8x4 tensor of plane parameters [a, b, c, d]
+ chunk_points = points[chunk_idx : chunk_idx + chunk_size]
+ chunk_point_plane_direction = torch.einsum(
+ "bij,bnj->bni", (chunk_frustum_planes[:, :, :-1], chunk_points)
+ ) + repeat(
+ chunk_frustum_planes[:, :, -1], "B P -> B N P", N=chunk_points.shape[1]
+ ) # BxMxN tensor
+ point_plane_direction.append(chunk_point_plane_direction.to(device))
+ point_plane_direction = torch.cat(point_plane_direction)
+ else:
+ # Convert frustums to planes
+ frustum_planes = _frustum_to_planes(
+ frustums
+ ) # Bx8x4 tensor of plane parameters [a, b, c, d]
+ # Compute dot product between each point and each plane
+ point_plane_direction = torch.einsum(
+ "bij,bnj->bni", (frustum_planes[:, :, :-1], points)
+ ) + repeat(frustum_planes[:, :, -1], "B P -> B N P", N=points.shape[1]).to(
+ device
+ ) # BxMxN tensor
+
+ inside = (point_plane_direction >= 0).all(-1)
+ return inside
+
+
+def frustums_in_frustum_check(
+ frustums: torch.Tensor,
+ chunk_size: int,
+ device: str | None = None,
+ use_double_chunking: bool = True,
+):
+ """
+ Check if frustums are contained within other frustums.
+
+ Args:
+ frustums (torch.Tensor): Bx8 batch of frustum points.
+ chunk_size (Optional[int]): Number of chunks to split the computation into.
+ Defaults to None.
+ device (Optional[str]): Device to store exhuastive frustum containment matrix on.
+ Defaults to None.
+ use_double_chunking (bool): If True, use double chunking to avoid OOM errors.
+ Defaults to True.
+
+ Returns:
+ frustum_contained (torch.Tensor): BxB batch of Booleans indiciating if frustums are inside
+ other frustums.
+ """
+ B = frustums.shape[0]
+ if device is None:
+ device = frustums.device
+
+ if use_double_chunking:
+ frustum_contained = torch.zeros((B, B), dtype=torch.bool, device=device)
+ # Check if frustums are containing each other by processing in chunks
+ for i in tqdm(range(0, B, chunk_size), desc="Checking frustum containment"):
+ i_end = min(i + chunk_size, B)
+ chunk_i_size = i_end - i
+
+ for j in range(0, B, chunk_size):
+ j_end = min(j + chunk_size, B)
+ chunk_j_size = j_end - j
+
+ # Process a chunk of frustums against another chunk
+ frustums_i = frustums[i:i_end]
+ frustums_j_vertices = frustums[
+ j:j_end, :1
+ ] # Just need one vertex to check containment
+
+ # Perform points in frustum check
+ contained = rearrange(
+ points_in_frustum_check(
+ repeat(frustums_i, "B ... -> (B B2) ...", B2=chunk_j_size),
+ repeat(
+ frustums_j_vertices, "B ... -> (B2 B) ...", B2=chunk_i_size
+ ),
+ )[:, 0],
+ "(B B2) -> B B2",
+ B=chunk_i_size,
+ ).to(device)
+
+ # Map results back to the full matrix
+ frustum_contained[i:i_end, j:j_end] |= contained
+ frustum_contained[j:j_end, i:i_end] |= contained.transpose(
+ 0, 1
+ ) # Symmetric relation
+ else:
+ # Perform points in frustum check with a single chunked loop
+ frustum_contained = rearrange(
+ points_in_frustum_check(
+ repeat(frustums, "B ... -> (B B2) ...", B2=B),
+ repeat(frustums[:, :1], "B ... -> (B2 B) ...", B2=B),
+ chunk_size=chunk_size,
+ )[:, 0],
+ "(B B2) -> B B2",
+ B=B,
+ ).to(device)
+ frustum_contained = frustum_contained | frustum_contained.T
+
+ return frustum_contained
diff --git a/mapanything/utils/wai/io.py b/mapanything/utils/wai/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d3359a6d6a92c50a755112d89f2fcfb7bf6dfe
--- /dev/null
+++ b/mapanything/utils/wai/io.py
@@ -0,0 +1,1373 @@
+"""
+This utils script contains PORTAGE of wai-core io methods for MapAnything.
+"""
+
+import gzip
+import io
+import json
+import logging
+import os
+from datetime import datetime
+from pathlib import Path
+from typing import Any, Callable, cast, IO, Literal, overload
+
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import cv2
+import numpy as np
+import torch
+import trimesh
+import yaml
+from PIL import Image, PngImagePlugin
+from plyfile import PlyData, PlyElement
+from safetensors.torch import load_file as load_sft, save_file as save_sft
+from torchvision.io import decode_image
+from yaml import CLoader
+
+from mapanything.utils.wai.ops import (
+ to_numpy,
+)
+from mapanything.utils.wai.semantics import (
+ apply_id_to_color_mapping,
+ INVALID_ID,
+ load_semantic_color_mapping,
+)
+
+# Try to use orjson for faster JSON processing
+try:
+ import orjson
+except ImportError:
+ orjson = None
+
+logger = logging.getLogger(__name__)
+
+
+@overload
+def _load_readable(
+ fname: Path | str, load_as_string: Literal[True], **kwargs
+) -> str: ...
+@overload
+def _load_readable(
+ fname: Path | str, load_as_string: Literal[False] = False, **kwargs
+) -> dict: ...
+
+
+def _load_readable(
+ fname: Path | str,
+ load_as_string: bool = False,
+ **kwargs,
+) -> Any | str:
+ """
+ Loads data from a human-readable file and will try to parse JSON or YAML files as a dict, list,
+ int, float, str, bool, or None object. Can optionally return the file contents as a string.
+
+ Args:
+ fname (str or Path): The filename to load data from.
+ load_as_string (bool, optional): Whether to return the loaded data as a string.
+ Defaults to False.
+
+ Returns:
+ The loaded data, which can be any type of object that can be represented in JSON or YAML.
+
+ Raises:
+ NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml).
+ """
+ if load_as_string:
+ return _load_readable_string(fname, **kwargs)
+ else:
+ return _load_readable_structured(fname, **kwargs)
+
+
+def _load_readable_structured(
+ fname: Path | str,
+ **kwargs,
+) -> Any:
+ """
+ Loads data from a human-readable file and will try to parse JSON or YAML files as a dict, list,
+ int, float, str, bool, or None object.
+
+ Args:
+ fname (str or Path): The filename to load data from.
+
+ Returns:
+ The loaded data, which can be any type of object that can be represented in JSON or YAML.
+
+ Raises:
+ NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml).
+ """
+ fname = Path(fname)
+ if not fname.exists():
+ raise FileNotFoundError(f"File does not exist: {fname}")
+
+ if fname.suffix == ".json":
+ # Use binary mode for JSON files
+ with open(fname, mode="rb") as f:
+ # Use orjson if available, otherwise use standard JSON
+ if orjson:
+ return orjson.loads(f.read())
+ return json.load(f)
+
+ if fname.suffix in [".yaml", ".yml"]:
+ # Use text mode with UTF-8 encoding for YAML files
+ with open(fname, mode="r", encoding="utf-8") as f:
+ return yaml.load(f, Loader=CLoader)
+
+ raise NotImplementedError(f"Readable format not supported: {fname.suffix}")
+
+
+def _load_readable_string(
+ fname: Path | str,
+ **kwargs,
+) -> str:
+ """
+ Loads data from a human-readable file as a string.
+
+ Args:
+ fname (str or Path): The filename to load data from.
+
+ Returns:
+ The file's contents, as a string.
+ """
+ fname = Path(fname)
+ if not fname.exists():
+ raise FileNotFoundError(f"File does not exist: {fname}")
+
+ with open(fname, mode="r", encoding="utf-8") as f:
+ contents = f.read()
+
+ return contents
+
+
+def _store_readable(
+ fname: Path | str,
+ data: Any,
+ **kwargs,
+) -> int:
+ """
+ Stores data in a human-readable file (JSON or YAML).
+
+ Args:
+ fname (str or Path): The filename to store data in.
+ data: The data to store, which can be any type of object that can be represented in JSON or YAML.
+
+ Returns:
+ The number of bytes written to the file.
+
+ Raises:
+ NotImplementedError: If the file suffix is not supported (i.e., not .json, .yaml, or .yml).
+ """
+ fname = Path(fname)
+
+ # Create parent directory if it doesn't exist
+ os.makedirs(fname.parent, exist_ok=True)
+
+ if fname.suffix == ".json":
+ if orjson:
+ # Define the operation for orjson
+ with open(fname, mode="wb") as f:
+ return f.write(orjson.dumps(data, option=orjson.OPT_INDENT_2))
+ else:
+ # Define the operation for standard json
+ with open(fname, mode="w", encoding="utf-8") as f:
+ json.dump(data, f, indent=2)
+ return f.tell()
+
+ elif fname.suffix in [".yaml", ".yml"]:
+ # Define the operation for YAML files
+ with open(fname, mode="w", encoding="utf-8") as f:
+ yaml.dump(data, f)
+ return f.tell()
+ else:
+ raise NotImplementedError(f"Writable format not supported: {fname.suffix}")
+
+
+def get_processing_state(scene_root: Path | str) -> dict:
+ """
+ Retrieves the processing state of a scene.
+
+ Args:
+ scene_root (Path or str): The root directory of the scene.
+
+ Returns:
+ dict: A dictionary containing the processing state of the scene.
+ If no processing log exists, or reading it fails, an empty
+ dictionary is returned.
+ """
+ process_log_path = Path(scene_root) / "_process_log.json"
+
+ try:
+ return _load_readable_structured(process_log_path)
+ except FileNotFoundError:
+ logger.debug(f"Log file not found, returning empty dict: {process_log_path}")
+ return {}
+ except Exception:
+ logger.error(
+ f"Could not parse, returning empty dict: {process_log_path}", exc_info=True
+ )
+ return {}
+
+
+def _write_exr(
+ fname: str | Path,
+ data: np.ndarray | torch.Tensor,
+ params: list | None = None,
+ **kwargs,
+) -> bool:
+ """
+ Writes an image as an EXR file using OpenCV.
+
+ Args:
+ fname (str or Path): The filename to save the image to.
+ data (numpy.ndarray, torch.Tensor): The image data to save. Must be a 2D or 3D array.
+ params (list, optional): A list of parameters to pass to OpenCV's imwrite function.
+ Defaults to None, which uses 32-bit with zip compression.
+
+ Returns:
+ bool: True if the image was saved successfully, False otherwise.
+
+ Raises:
+ ValueError: If the input data has less than two or more than three dimensions.
+
+ Notes:
+ Only 32-bit float (CV_32F) images can be saved.
+ For comparison of different compression methods, see P1732924327.
+ """
+ if Path(fname).suffix != ".exr":
+ raise ValueError(
+ f"Only filenames with suffix .exr allowed but received: {fname}"
+ )
+
+ ## Note: only 32-bit float (CV_32F) images can be saved
+ data_np = to_numpy(data, dtype=np.float32)
+ if (data_np.ndim > 3) or (data_np.ndim < 2):
+ raise ValueError(
+ f"Image needs to contain two or three dims but received: {data_np.shape}"
+ )
+
+ return cv2.imwrite(str(fname), data_np, params if params else [])
+
+
+@overload
+def _read_exr(fname: str | Path, fmt: Literal["np"], **kwargs) -> np.ndarray: ...
+@overload
+def _read_exr(fname: str | Path, fmt: Literal["PIL"], **kwargs) -> Image.Image: ...
+@overload
+def _read_exr(
+ fname: str | Path, fmt: Literal["torch"] = "torch", **kwargs
+) -> torch.Tensor: ...
+
+
+def _read_exr(
+ fname: str | Path, fmt: Literal["np", "PIL", "torch"] = "torch", **kwargs
+) -> np.ndarray | torch.Tensor | Image.Image:
+ """
+ Reads an EXR image file using OpenCV.
+
+ Args:
+ fname (str or Path): The filename of the EXR image to read.
+ fmt (str): The format of the output data. Can be one of:
+ - "torch": Returns a PyTorch tensor.
+ - "np": Returns a NumPy array.
+ - "PIL": Returns a PIL Image object.
+ Defaults to "torch".
+
+ Returns:
+ The EXR image data in the specified output format.
+
+ Raises:
+ NotImplementedError: If the specified output format is not supported.
+ ValueError: If data shape is not supported, e.g. multi-channel PIL float images.
+
+ Notes:
+ The EXR image is read in its original format, without any conversion or rescaling.
+ """
+ data = cv2.imread(str(fname), cv2.IMREAD_UNCHANGED)
+ if data is None:
+ raise FileNotFoundError(f"Failed to read EXR file: {fname}")
+ if fmt == "torch":
+ # Convert to PyTorch tensor with float32 dtype
+ data = torch.from_numpy(data).float()
+ elif fmt == "np":
+ # Convert to NumPy array with float32 dtype
+ data = np.array(data, dtype=np.float32)
+ elif fmt == "PIL":
+ if data.ndim != 2:
+ raise ValueError("PIL does not support multi-channel EXR images")
+
+ # Convert to PIL Image object
+ data = Image.fromarray(data)
+ else:
+ raise NotImplementedError(f"fmt not supported: {fmt}")
+ return data
+
+
+@overload
+def _load_image(
+ fname: str | Path,
+ fmt: Literal["np"],
+ resize: tuple[int, int] | None = None,
+ **kwargs,
+) -> np.ndarray: ...
+@overload
+def _load_image(
+ fname: str | Path,
+ fmt: Literal["pil"],
+ resize: tuple[int, int] | None = None,
+ **kwargs,
+) -> Image.Image: ...
+@overload
+def _load_image(
+ fname: str | Path,
+ fmt: Literal["torch"] = "torch",
+ resize: tuple[int, int] | None = None,
+ **kwargs,
+) -> torch.Tensor: ...
+
+
+def _load_image(
+ fname: str | Path,
+ fmt: Literal["np", "pil", "torch"] = "torch",
+ resize: tuple[int, int] | None = None,
+ **kwargs,
+) -> np.ndarray | torch.Tensor | Image.Image:
+ """
+ Loads an image from a file.
+
+ Args:
+ fname (str or Path): The filename to load the image from.
+ fmt (str): The format of the output data. Can be one of:
+ - "torch": Returns a PyTorch tensor with shape (C, H, W).
+ - "np": Returns a NumPy array with shape (H, W, C).
+ - "pil": Returns a PIL Image object.
+ Defaults to "torch".
+ resize (tuple, optional): A tuple of two integers representing the desired width and height of the image.
+ If None, the image is not resized. Defaults to None.
+
+ Returns:
+ The loaded image in the specified output format.
+
+ Raises:
+ NotImplementedError: If the specified output format is not supported.
+
+ Notes:
+ This function loads non-binary images in RGB mode and normalizes pixel values to the range [0, 1].
+ """
+
+ # Fastest way to load into torch tensor
+ if resize is None and fmt == "torch":
+ return decode_image(str(fname)).float() / 255.0
+
+ # Load using PIL
+ with open(fname, "rb") as f:
+ pil_image = Image.open(f)
+ pil_image.load()
+
+ if pil_image.mode not in ["RGB", "RGBA"]:
+ raise OSError(
+ f"Expected a RGB or RGBA image in {fname}, but instead found an image with mode {pil_image.mode}"
+ )
+
+ if resize is not None:
+ pil_image = pil_image.resize(resize)
+
+ if fmt == "torch":
+ return (
+ torch.from_numpy(np.array(pil_image)).permute(2, 0, 1).float() / 255.0
+ )
+ elif fmt == "np":
+ return np.array(pil_image, dtype=np.float32) / 255.0
+ elif fmt == "pil":
+ return pil_image
+ else:
+ raise NotImplementedError(f"Image format not supported: {fmt}")
+
+
+def _store_image(
+ fname: str | Path, img_data: np.ndarray | torch.Tensor | Image.Image, **kwargs
+) -> None:
+ """
+ Stores an image in a file.
+
+ Args:
+ fname (str or Path): The filename to store the image in.
+ img_data (numpy.ndarray, torch.tensor or PIL.Image.Image): The image data to store.
+
+ Notes (for numpy.ndarray or torch.tensor inputs):
+ This function assumes that the input image data is in the range [0, 1], and has shape
+ (H, W, C), or (C, H, W) for PyTorch tensors, with C being 3 or 4.
+ It converts the image data to uint8 format and saves it as a compressed image file.
+ """
+ if isinstance(img_data, torch.Tensor):
+ if img_data.ndim != 3:
+ raise ValueError(f"Tensor needs to be 3D but received: {img_data.shape=}")
+
+ if img_data.shape[0] in [3, 4]:
+ # Convert to HWC format expected by pillow `Image.save` below
+ img_data = img_data.permute(1, 2, 0)
+
+ img_data = img_data.contiguous()
+
+ if isinstance(img_data, (np.ndarray, torch.Tensor)):
+ if img_data.shape[-1] not in [3, 4]:
+ raise ValueError(
+ f"Image must have 3 or 4 channels, but received: {img_data.shape=}"
+ )
+
+ img_data_np = to_numpy(img_data, dtype=np.float32)
+ img_data = Image.fromarray((255 * img_data_np).round().astype(np.uint8))
+
+ with open(fname, "wb") as f:
+ pil_kwargs = {
+ # Make PNGs faster to save using minimal compression
+ "optimize": False,
+ "compress_level": 1,
+ # Higher JPEG image quality
+ "quality": "high",
+ }
+ pil_kwargs.update(kwargs)
+ img_data.save(cast(IO[bytes], f), **pil_kwargs)
+
+
+def _load_binary_mask(
+ fname: str | Path,
+ fmt: str = "torch",
+ resize: tuple[int, int] | None = None,
+ **kwargs,
+) -> np.ndarray | torch.Tensor | Image.Image:
+ """
+ Loads a binary image from a file.
+
+ Args:
+ fname (str or Path): The filename to load the binary image from.
+ fmt (str): The format of the output data. Can be one of:
+ - "torch": Returns a PyTorch Boolean tensor with shape H x W.
+ - "np": Returns a NumPy Boolean array with shape H x W.
+ - "pil": Returns a PIL Image object.
+ Defaults to "torch".
+ resize (tuple, optional): A tuple of two integers representing the desired width and height of the binary image.
+ If None, the image is not resized. Defaults to None.
+
+ Returns:
+ The loaded binary image in the specified output format.
+
+ Raises:
+ NotImplementedError: If the specified output format is not supported.
+ """
+ if fmt not in ["pil", "np", "torch"]:
+ raise NotImplementedError(f"Image format not supported: {fmt}")
+
+ with open(fname, "rb") as f:
+ pil_image = Image.open(f)
+ pil_image.load()
+
+ if pil_image.mode == "L":
+ pil_image = pil_image.convert("1")
+
+ elif pil_image.mode != "1":
+ raise OSError(
+ f"Expected a binary or grayscale image in {fname}, but instead found an image with mode {pil_image.mode}"
+ )
+
+ if resize is not None:
+ pil_image = pil_image.resize(resize)
+
+ if fmt == "pil":
+ return pil_image
+
+ mask = np.array(pil_image, copy=True)
+ return mask if fmt == "np" else torch.from_numpy(mask)
+
+
+def _store_binary_mask(
+ fname: str | Path, img_data: np.ndarray | torch.Tensor | Image.Image, **kwargs
+) -> None:
+ """
+ Stores a binary image in a compressed image file.
+
+ Args:
+ fname (str or Path): The filename to store the binary image in.
+ img_data (numpy.ndarray, torch.tensor or PIL.Image.Image): The binary image data to store.
+ """
+ if isinstance(img_data, Image.Image):
+ if img_data.mode not in ["1", "L"]:
+ raise RuntimeError(
+ f'Expected a PIL image with mode "1" or "L", but instead got a PIL image with mode {img_data.mode}'
+ )
+ elif isinstance(img_data, np.ndarray) or isinstance(img_data, torch.Tensor):
+ if len(img_data.squeeze().shape) != 2:
+ raise RuntimeError(
+ f"Expected a PyTorch tensor or NumPy array with shape (H, W, 1), (1, H, W) or (H, W), but the shape is {img_data.shape}"
+ )
+ img_data = img_data.squeeze()
+ else:
+ raise NotImplementedError(f"Input format not supported: {type(img_data)}")
+
+ if not isinstance(img_data, Image.Image):
+ img_data = to_numpy(img_data, dtype=bool)
+ img_data = Image.fromarray(img_data)
+
+ img_data = img_data.convert("1")
+ with open(fname, "wb") as f:
+ img_data.save(f, compress_level=1, optimize=False)
+
+
+def _load_sft(
+ fname: str | Path,
+ fmt: str = "torch",
+ **kwargs,
+) -> torch.Tensor:
+ """
+ Loads a tensor from a safetensor file.
+
+ Args:
+ fname (str | Path): The filename of the safetensor file to load.
+ fmt (str, optional): The format of the output data. Currently only "torch" is supported.
+ **kwargs: Additional keyword arguments (unused).
+
+ Returns:
+ torch.Tensor: The loaded tensor.
+
+ Raises:
+ AssertionError: If the file extension is not .sft or if fmt is not "torch".
+ """
+ assert Path(fname).suffix == ".sft", "Only .sft (safetensor) is supported"
+ assert fmt == "torch", "Only torch format is supported for latent"
+ out = load_sft(str(fname))
+ return out["latent"]
+
+
+def _store_sft(fname: str | Path, data: torch.Tensor, **kwargs) -> None:
+ """
+ Stores a tensor to a safetensor file.
+
+ Args:
+ fname (str | Path): The filename to store the latent in.
+ data (torch.Tensor): The latent tensor to store.
+ **kwargs: Additional keyword arguments (unused).
+
+ Raises:
+ AssertionError: If the file extension is not .sft or if data is not a torch.Tensor.
+ """
+ assert Path(fname).suffix == ".sft", "Only .sft (safetensor) is supported"
+ assert isinstance(data, torch.Tensor)
+ save_sft(tensors={"latent": data}, filename=str(fname))
+
+
+def _store_depth(fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs) -> bool:
+ """
+ Stores a depth map in an EXR file.
+
+ Args:
+ fname (str or Path): The filename to save the depth map to.
+ data (numpy.ndarray, torch.tensor): The depth map to save.
+
+ Returns:
+ bool: True if the depth map was saved successfully, False otherwise.
+
+ Raises:
+ ValueError: If the input data does not have two dimensions after removing singleton dimensions.
+ """
+ data_np = to_numpy(data, dtype=np.float32)
+ data_np = data_np.squeeze() # remove all 1-dim entries
+ if data_np.ndim != 2:
+ raise ValueError(f"Depth image needs to be 2d, but received: {data_np.shape}")
+
+ if "params" in kwargs:
+ params = kwargs["params"]
+ else:
+ # use 16-bit with zip compression for depth maps
+ params = [
+ cv2.IMWRITE_EXR_TYPE,
+ cv2.IMWRITE_EXR_TYPE_HALF,
+ cv2.IMWRITE_EXR_COMPRESSION,
+ cv2.IMWRITE_EXR_COMPRESSION_ZIP,
+ ]
+
+ return _write_exr(fname, data_np, params=params)
+
+
+def _load_depth(
+ fname: str | Path, fmt: str = "torch", **kwargs
+) -> np.ndarray | torch.Tensor | Image.Image:
+ """
+ Loads a depth image from an EXR file.
+
+ Args:
+ fname (str or Path): The filename of the EXR file to load.
+ fmt (str): The format of the output data. Can be one of:
+ - "torch": Returns a PyTorch tensor.
+ - "np": Returns a NumPy array.
+ - "PIL": Returns a PIL Image object.
+ Defaults to "torch".
+
+ Returns:
+ The loaded depth image in the specified output format.
+
+ Raises:
+ ValueError: If the loaded depth image does not have two dimensions.
+
+ Notes:
+ This function assumes that the EXR file contains a single-channel depth image.
+ """
+ data = _read_exr(fname, fmt)
+ if (fmt != "PIL") and (data.ndim != 2):
+ raise ValueError(f"Depth image needs to be 2D, but loaded: {data.shape}")
+ return data
+
+
+def _store_normals(
+ fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs
+) -> bool:
+ """
+ Stores a normals image in an EXR file.
+
+ Args:
+ fname (str or Path): The filename to save the normals image to.
+ data (numpy.ndarray): The normals image data to save. Will be converted to a 32-bit float array.
+
+ Returns:
+ bool: True if the normals image was saved successfully, False otherwise.
+
+ Raises:
+ ValueError: If the input data has more than three dimensions after removing singleton dimensions.
+ ValueError: If the input data does not have exactly three channels.
+ ValueError: If the input data is not normalized (i.e., maximum absolute value exceeds 1).
+
+ Notes:
+ This function assumes that the input data is in HWC (height, width, channels) format.
+ If the input data is in CHW (channels, height, width) format, it will be automatically transposed to HWC.
+ """
+ data_np = to_numpy(data, dtype=np.float32)
+ data_np = data_np.squeeze() # remove all singleton dimensions
+
+ if data_np.ndim != 3:
+ raise ValueError(
+ f"Normals image needs to be 3-dim but received: {data_np.shape}"
+ )
+
+ if (data_np.shape[0] == 3) and (data_np.shape[2] != 3):
+ # ensure HWC format
+ data_np = data_np.transpose(1, 2, 0)
+
+ if data_np.shape[2] != 3:
+ raise ValueError(
+ f"Normals image needs have 3 channels but received: {data_np.shape}"
+ )
+
+ # We want to check that the norm values are either 1 (valid) or 0 (invalid values are 0s)
+ norm = np.linalg.norm(data_np, axis=-1)
+ is_one = np.isclose(norm, 1.0, atol=1e-3)
+ is_zero = np.isclose(norm, 0.0)
+ if not np.all([is_one | is_zero]):
+ raise ValueError("Normals image must be normalized")
+
+ return _write_exr(fname, data_np)
+
+
+def _load_normals(
+ fname: str | Path, fmt: str = "torch", **kwargs
+) -> np.ndarray | torch.Tensor | Image.Image:
+ """
+ Loads a normals image from an EXR file.
+
+ Args:
+ fname (str or Path): The filename of the EXR file to load.
+ fmt (str): The format of the output data. Can be one of:
+ - "torch": Returns a PyTorch tensor.
+ - "np": Returns a NumPy array.
+ - "PIL": Returns a PIL Image object.
+ Defaults to "torch".
+
+ Returns:
+ The loaded normals image in the specified output format.
+
+ Raises:
+ Warning: If the loaded normals image has more than two dimensions.
+
+ Notes:
+ This function assumes that the EXR file contains a 3-channel normals image.
+ """
+ data = _read_exr(fname, fmt)
+
+ if data.ndim != 3:
+ raise ValueError(f"Normals image needs to be 3-dim but received: {data.shape}")
+
+ if data.shape[2] != 3:
+ raise ValueError(
+ f"Normals image needs have 3 channels but received: {data.shape}"
+ )
+
+ return data
+
+
+def _load_numpy(fname: str | Path, allow_pickle: bool = False, **kwargs) -> np.ndarray:
+ """
+ Loads a NumPy array from a file.
+
+ Args:
+ fname (str or Path): The filename to load the NumPy array from.
+ allow_pickle (bool, optional): Whether to allow pickled objects in the NumPy file.
+ Defaults to False.
+
+ Returns:
+ numpy.ndarray: The loaded NumPy array.
+
+ Raises:
+ NotImplementedError: If the file suffix is not supported (i.e., not .npy or .npz).
+
+ Notes:
+ This function supports loading NumPy arrays from .npy and .npz files.
+ For .npz files, it assumes that the array is stored under the key "arr_0".
+ """
+ fname = Path(fname)
+ with open(fname, "rb") as fid:
+ if fname.suffix == ".npy":
+ return np.load(fid, allow_pickle=allow_pickle)
+ elif fname.suffix == ".npz":
+ return np.load(fid, allow_pickle=allow_pickle).get("arr_0")
+ else:
+ raise NotImplementedError(f"Numpy format not supported: {fname.suffix}")
+
+
+def _store_numpy(fname: str | Path, data: np.ndarray, **kwargs) -> None:
+ """
+ Stores a NumPy array in a file.
+
+ Args:
+ fname (str or Path): The filename to store the NumPy array in.
+ data (numpy.ndarray): The NumPy array to store.
+
+ Raises:
+ NotImplementedError: If the file suffix is not supported (i.e., not .npy or .npz).
+
+ Notes:
+ This function supports storing NumPy arrays in .npy and .npz files.
+ For .npz files, it uses compression to reduce the file size.
+ """
+ fname = Path(fname)
+ with open(fname, "wb") as fid:
+ if fname.suffix == ".npy":
+ np.save(fid, data)
+ elif fname.suffix == ".npz":
+ np.savez_compressed(fid, arr_0=data)
+ else:
+ raise NotImplementedError(f"Numpy format not supported: {fname.suffix}")
+
+
+def _load_ptz(fname: str | Path, **kwargs) -> torch.Tensor:
+ """
+ Loads a PyTorch tensor from a PTZ file.
+
+ Args:
+ fname (str or Path): The filename to load the tensor from.
+
+ Returns:
+ torch.Tensor: The loaded PyTorch tensor.
+
+ Notes:
+ This function assumes that the PTZ file contains a PyTorch tensor saved using `torch.save`.
+ If the tensor was saved in a different format, this function may fail.
+ """
+ with open(fname, "rb") as fid:
+ data = gzip.decompress(fid.read())
+ ## Note: if the following line fails, save PyTorch tensors in PTZ instead of NumPy
+ return torch.load(io.BytesIO(data), map_location="cpu", weights_only=True)
+
+
+def _store_ptz(fname: str | Path, data: torch.Tensor, **kwargs) -> None:
+ """
+ Stores a PyTorch tensor in a PTZ file.
+
+ Args:
+ fname (str or Path): The filename to store the tensor in.
+ data (torch.Tensor): The PyTorch tensor to store.
+
+ Notes:
+ This function saves the tensor using `torch.save` and compresses it using gzip.
+ """
+ with open(fname, "wb") as fid:
+ with gzip.open(fid, "wb") as gfid:
+ torch.save(data, gfid)
+
+
+def _store_mmap(fname: str | Path, data: np.ndarray | torch.Tensor, **kwargs) -> str:
+ """
+ Stores matrix-shaped data in a memory-mapped file.
+
+ Args:
+ fname (str or Path): The filename to store the data in.
+ data (numpy.ndarray): The matrix-shaped data to store.
+
+ Returns:
+ str: The name of the stored memory-mapped file.
+
+ Notes:
+ This function stores the data in a .npy file with a modified filename that includes the shape of the data.
+ The data is converted to float32 format before storing.
+ """
+ fname = Path(fname)
+ # add dimensions to the file name for loading
+ data_np = to_numpy(data, dtype=np.float32)
+ shape_string = "x".join([str(dim) for dim in data_np.shape])
+ mmap_name = f"{fname.stem}--{shape_string}.npy"
+ with open(fname.parent / mmap_name, "wb") as fid:
+ np.save(fid, data_np)
+ return mmap_name
+
+
+def _load_mmap(fname: str | Path, **kwargs) -> np.memmap:
+ """
+ Loads matrix-shaped data from a memory-mapped file.
+
+ Args:
+ fname (str or Path): The filename of the memory-mapped file to load.
+
+ Returns:
+ numpy.memmap: A memory-mapped array containing the loaded data.
+
+ Notes:
+ This function assumes that the filename contains the shape of the data, separated by 'x' or ','.
+ It uses this information to create a memory-mapped array with the correct shape.
+ """
+ shape_string = Path(Path(fname).name.split("--")[1]).stem
+ shape = [int(dim) for dim in shape_string.replace(",", "x").split("x")]
+ with open(fname, "rb") as fid:
+ return np.memmap(fid, dtype=np.float32, mode="r", shape=shape, offset=128)
+
+
+def _store_scene_meta(fname: Path | str, scene_meta: dict[str, Any], **kwargs) -> None:
+ """
+ Stores scene metadata in a readable file.
+
+ Args:
+ fname (str or Path): The filename to store the scene metadata in.
+ scene_meta (dict): The scene metadata to store.
+
+ Notes:
+ This function updates the "last_modified" field of the scene metadata to the current date and time before storing it.
+ It also removes the "frame_names" field from the scene metadata, as it is not necessary to store this information.
+ Creates a backup of the existing file before overwriting it.
+ """
+ # update the modified date
+ scene_meta["last_modified"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
+ if "frame_names" in scene_meta:
+ del scene_meta["frame_names"]
+
+ # create/overwrite backup
+ fname_path = Path(fname)
+ if fname_path.exists():
+ backup_fname = fname_path.parent / f"_{fname_path.stem}_backup.json"
+ if backup_fname.exists():
+ backup_fname.unlink()
+ fname_path.rename(backup_fname)
+
+ _store_readable(fname, scene_meta)
+
+
+def _load_scene_meta(fname: Path | str, **kwargs) -> dict[str, Any]:
+ """
+ Loads scene metadata from a readable file.
+
+ Args:
+ fname (str or Path): The filename to load the scene metadata from.
+
+ Returns:
+ dict: The loaded scene metadata, including an additional "frame_names" field that maps frame names to their indices.
+
+ Notes:
+ This function creates the "frame_names" field in the scene metadata for efficient lookup of frame indices by name.
+ """
+ scene_meta = _load_readable_structured(fname)
+ # create the frame_name -> frame_idx for efficiency
+ scene_meta["frame_names"] = {
+ frame["frame_name"]: frame_idx
+ for frame_idx, frame in enumerate(scene_meta["frames"])
+ }
+ return scene_meta
+
+
+def _load_labeled_image(
+ fname: str | Path,
+ fmt: str = "torch",
+ resize: tuple[int, int] | None = None,
+ **kwargs,
+) -> np.ndarray | torch.Tensor | Image.Image:
+ """
+ Loads a labeled image from a PNG file.
+
+ Args:
+ fname (str or Path): The filename to load the image from.
+ fmt (str): The format of the output data. Can be one of:
+ - "torch": Returns a PyTorch int32 tensor with shape (H, W).
+ - "np": Returns a NumPy int32 array with shape (H, W).
+ - "pil": Returns a PIL Image object.
+ Defaults to "torch".
+ resize (tuple, optional): A tuple of two integers representing the desired width and height of the image.
+ If None, the image is not resized. Defaults to None.
+
+ Returns:
+ The loaded image in the specified output format.
+
+ Raises:
+ NotImplementedError: If the specified output format is not supported.
+ RuntimeError: If the 'id_to_color_mapping' is missing in the PNG metadata.
+
+ Notes:
+ The function expects the PNG file to contain metadata with a key 'id_to_color_mapping',
+ which maps from label ids to tuples of RGB values.
+ """
+ with open(fname, "rb") as f:
+ pil_image = Image.open(f)
+ pil_image.load()
+ if pil_image.mode != "RGB":
+ raise OSError(
+ f"Expected a RGB image in {fname}, but instead found an image with mode {pil_image.mode}"
+ )
+
+ # Load id to RGB mapping
+ color_palette_json = pil_image.info.get("id_to_color_mapping", None)
+ if color_palette_json is None:
+ raise RuntimeError("'id_to_color_mapping' is missing in the PNG metadata.")
+ color_palette = json.loads(color_palette_json)
+ color_to_id_mapping = {
+ tuple(color): int(id) for id, color in color_palette.items()
+ }
+
+ if resize is not None:
+ pil_image = pil_image.resize(resize, Image.NEAREST)
+
+ if fmt == "pil":
+ return pil_image
+
+ # Reverse the color mapping: map from RGB colors to ids
+ img_data = np.array(pil_image)
+
+ # Create a lookup table for fast mapping
+ max_color_value = 256 # Assuming 8-bit per channel
+ lookup_table = np.full(
+ (max_color_value, max_color_value, max_color_value),
+ INVALID_ID,
+ dtype=np.int32,
+ )
+ for color, index in color_to_id_mapping.items():
+ lookup_table[color] = index
+ # Map colors to ids using the lookup table
+ img_data = lookup_table[img_data[..., 0], img_data[..., 1], img_data[..., 2]]
+
+ if fmt == "np":
+ return img_data
+ elif fmt == "torch":
+ return torch.from_numpy(img_data)
+ else:
+ raise NotImplementedError(f"Image format not supported: {fmt}")
+
+
+def _store_labeled_image(
+ fname: str | Path,
+ img_data: np.ndarray | torch.Tensor | Image.Image,
+ semantic_color_mapping: np.ndarray | None = None,
+ **kwargs,
+) -> None:
+ """
+ Stores a labeled image as a uint8 RGB PNG file.
+
+ Args:
+ fname (str or Path): The filename to store the image in.
+ img_data (numpy.ndarray, torch.Tensor or PIL.Image.Image): The per-pixel label ids to store.
+ semantic_color_mapping (np.ndarray): Optional, preloaded NumPy array of semantic colors.
+
+ Raises:
+ ValueError: If the file suffix is not supported (i.e., not .png).
+ RuntimeError: If the type of the image data is different from uint16, int16 or int32.
+
+ Notes:
+ The function takes an image with per-pixel label ids and converts it into an RGB image
+ using a specified mapping from label ids to RGB colors. The resulting image is saved as
+ a PNG file, with the mapping stored as metadata.
+ """
+ if Path(fname).suffix != ".png":
+ raise ValueError(
+ f"Only filenames with suffix .png allowed but received: {fname}"
+ )
+
+ if isinstance(img_data, Image.Image) and img_data.mode != "I;16":
+ raise RuntimeError(
+ f"The provided image does not seem to be a labeled image. The provided PIL image has mode {img_data.mode}."
+ )
+
+ if isinstance(img_data, np.ndarray) and img_data.dtype not in [
+ np.uint16,
+ np.int16,
+ np.int32,
+ ]:
+ raise RuntimeError(
+ f"The provided NumPy array has type {img_data.dtype} but the expected type is np.uint16, np.int16 or np.int32."
+ )
+
+ if isinstance(img_data, torch.Tensor):
+ if img_data.dtype not in [torch.uint16, torch.int16, torch.int32]:
+ raise RuntimeError(
+ f"The provided PyTorch tensor has type {img_data.dtype} but the expected type is torch.uint16, torch.int16 or torch.int32."
+ )
+ img_data = img_data.numpy()
+
+ if semantic_color_mapping is None:
+ # Mapping from ids to colors not provided, load it now
+ semantic_color_mapping = load_semantic_color_mapping()
+
+ img_data, color_palette = apply_id_to_color_mapping(
+ img_data, semantic_color_mapping
+ )
+ pil_image = Image.fromarray(img_data, "RGB")
+
+ # Create a PngInfo object to store metadata
+ meta = PngImagePlugin.PngInfo()
+ meta.add_text("id_to_color_mapping", json.dumps(color_palette))
+
+ pil_image.save(fname, pnginfo=meta)
+
+
+def _load_generic_mesh(mesh_path: str | Path, **kwargs) -> trimesh.Trimesh:
+ """Load mesh with the trimesh library.
+
+ Args:
+ mesh_path (str): Path to the mesh file
+
+ Returns:
+ The trimesh object from trimesh.load().
+
+ Raises:
+ ValueError: If the file format is not supported.
+ """
+
+ # needed to load big texture files
+ Image.MAX_IMAGE_PIXELS = None
+
+ # load mesh with trimesh
+ mesh_data = trimesh.load(mesh_path, process=False)
+
+ return mesh_data
+
+
+def _store_generic_mesh(
+ file_path: str | Path, mesh_data: dict | trimesh.Trimesh, **kwargs
+) -> None:
+ """
+ Dummy function for storing generic mesh data.
+
+ Args:
+ file_path (str): The filename to store the mesh in.
+ mesh_data (dict): Dictionary containing mesh data.
+ **kwargs: Additional keyword arguments.
+
+ Raises:
+ NotImplementedError: This function is not implemented yet.
+ """
+ raise NotImplementedError("Storing generic meshes is not implemented yet.")
+
+
+def _load_labeled_mesh(
+ file_path: str | Path,
+ fmt: str = "torch",
+ palette: str = "rgb",
+ **kwargs,
+) -> dict | trimesh.Trimesh:
+ """
+ Loads a mesh from a labeled mesh file (PLY binary format).
+
+ Args:
+ file_path (str): The path to the labeled mesh file (.ply).
+ fmt (str): Output format of the mesh data. Can be one of:
+ - "torch": Returns a dict of PyTorch tensors containing mesh data.
+ - "np": Returns a dict of NumPy arrays containing mesh data.
+ - "trimesh": Returns a trimesh mesh object.
+ Defaults to "torch".
+ palette (str): Output color of the trimesh mesh data. Can be one of:
+ - "rgb": Colors the mesh with original rgb colors
+ - "semantic_class": Colors the mesh with semantic class colors
+ - "instance": Colors the mesh with semantic instance colors
+ Applied only when fmt is "trimesh".
+
+ Returns:
+ The loaded mesh in the specified output format.
+
+ Raises:
+ NotImplementedError: If the specified output format is not supported.
+
+ Notes:
+ This function reads a binary PLY file with vertex position, color, and optional
+ semantic class and instance IDs. The faces are stored as lists of vertex indices.
+ """
+ # load data (NOTE: define known_list_len to enable faster read)
+ ply_data = PlyData.read(file_path, known_list_len={"face": {"vertex_indices": 3}})
+
+ # get vertices
+ vertex_data = ply_data["vertex"].data
+ vertices = np.column_stack(
+ (vertex_data["x"], vertex_data["y"], vertex_data["z"])
+ ).astype(np.float32)
+
+ # initialize output data
+ mesh_data = {}
+ mesh_data["is_labeled_mesh"] = True
+ mesh_data["vertices"] = vertices
+
+ # get faces if available
+ if "face" in ply_data:
+ faces = np.asarray(ply_data["face"].data["vertex_indices"]).astype(np.int32)
+ mesh_data["faces"] = faces
+
+ # get rgb colors if available
+ if all(color in vertex_data.dtype.names for color in ["red", "green", "blue"]):
+ vertices_color = np.column_stack(
+ (vertex_data["red"], vertex_data["green"], vertex_data["blue"])
+ ).astype(np.uint8)
+ mesh_data["vertices_color"] = vertices_color
+
+ # get vertices class and instance if available
+ if "semantic_class_id" in vertex_data.dtype.names:
+ vertices_class = vertex_data["semantic_class_id"].astype(np.int32)
+ mesh_data["vertices_semantic_class_id"] = vertices_class
+
+ if "instance_id" in vertex_data.dtype.names:
+ vertices_instance = vertex_data["instance_id"].astype(np.int32)
+ mesh_data["vertices_instance_id"] = vertices_instance
+
+ # get class colors if available
+ if all(
+ color in vertex_data.dtype.names
+ for color in [
+ "semantic_class_red",
+ "semantic_class_green",
+ "semantic_class_blue",
+ ]
+ ):
+ vertices_semantic_class_color = np.column_stack(
+ (
+ vertex_data["semantic_class_red"],
+ vertex_data["semantic_class_green"],
+ vertex_data["semantic_class_blue"],
+ )
+ ).astype(np.uint8)
+ mesh_data["vertices_semantic_class_color"] = vertices_semantic_class_color
+
+ # get instance colors if available
+ if all(
+ color in vertex_data.dtype.names
+ for color in ["instance_red", "instance_green", "instance_blue"]
+ ):
+ vertices_instance_color = np.column_stack(
+ (
+ vertex_data["instance_red"],
+ vertex_data["instance_green"],
+ vertex_data["instance_blue"],
+ )
+ ).astype(np.uint8)
+ mesh_data["vertices_instance_color"] = vertices_instance_color
+
+ # convert data into output format (if needed)
+ if fmt == "np":
+ return mesh_data
+ elif fmt == "torch":
+ return {k: torch.tensor(v) for k, v in mesh_data.items()}
+ elif fmt == "trimesh":
+ trimesh_mesh = trimesh.Trimesh(
+ vertices=mesh_data["vertices"], faces=mesh_data["faces"]
+ )
+ # color the mesh according to the palette
+ if palette == "rgb":
+ # original rgb colors
+ if "vertices_color" in mesh_data:
+ trimesh_mesh.visual.vertex_colors = mesh_data["vertices_color"]
+ else:
+ raise ValueError(
+ f"Palette {palette} could not be applied. Missing vertices_color in mesh data."
+ )
+ elif palette == "semantic_class":
+ # semantic class colors
+ if "vertices_semantic_class_color" in mesh_data:
+ trimesh_mesh.visual.vertex_colors = mesh_data[
+ "vertices_semantic_class_color"
+ ]
+ else:
+ raise ValueError(
+ f"Palette {palette} could not be applied. Missing vertices_semantic_class_color in mesh data."
+ )
+ elif palette == "instance":
+ # semantic instance colors
+ if "vertices_instance_color" in mesh_data:
+ trimesh_mesh.visual.vertex_colors = mesh_data["vertices_instance_color"]
+ else:
+ raise ValueError(
+ f"Palette {palette} could not be applied. Missing vertices_instance_color in mesh data."
+ )
+ else:
+ raise ValueError(f"Invalid palette: {palette}.")
+ return trimesh_mesh
+ else:
+ raise NotImplementedError(f"Labeled mesh format not supported: {fmt}")
+
+
+def _store_labeled_mesh(file_path: str | Path, mesh_data: dict, **kwargs) -> None:
+ """
+ Stores a mesh in WAI format (PLY binary format).
+
+ Args:
+ file_path (str): The filename to store the mesh in.
+ mesh_data (dict): Dictionary containing mesh data with keys:
+ - 'vertices' (numpy.ndarray): Array of vertex coordinates with shape (N, 3).
+ - 'faces' (numpy.ndarray, optional): Array of face indices.
+ - 'vertices_color' (numpy.ndarray, optional): Array of vertex colors with shape (N, 3).
+ - 'vertices_semantic_class_id' (numpy.ndarray, optional): Array of semantic classes for each vertex with shape (N).
+ - 'vertices_instance_id' (numpy.ndarray, optional): Array of instance IDs for each vertex with shape (N).
+ - 'vertices_semantic_class_color' (numpy.ndarray, optional): Array of vertex semantic class colors with shape (N, 3).
+ - 'vertices_instance_color' (numpy.ndarray, optional): Array of vertex instance colors with shape (N, 3).
+
+ Notes:
+ This function writes a binary PLY file with vertex position, color, and optional
+ semantic class and instance IDs. The faces are stored as lists of vertex indices.
+ """
+ # Validate input data
+ if "vertices" not in mesh_data:
+ raise ValueError("Mesh data must contain 'vertices'")
+
+ # create vertex data with properties
+ vertex_dtype = [("x", "f4"), ("y", "f4"), ("z", "f4")]
+ if "vertices_color" in mesh_data:
+ vertex_dtype.extend([("red", "u1"), ("green", "u1"), ("blue", "u1")])
+ if "vertices_semantic_class_id" in mesh_data:
+ vertex_dtype.append(("semantic_class_id", "i4"))
+ if "vertices_instance_id" in mesh_data:
+ vertex_dtype.append(("instance_id", "i4"))
+ if "vertices_semantic_class_color" in mesh_data:
+ vertex_dtype.extend(
+ [
+ ("semantic_class_red", "u1"),
+ ("semantic_class_green", "u1"),
+ ("semantic_class_blue", "u1"),
+ ]
+ )
+ if "vertices_instance_color" in mesh_data:
+ vertex_dtype.extend(
+ [("instance_red", "u1"), ("instance_green", "u1"), ("instance_blue", "u1")]
+ )
+ vertex_count = len(mesh_data["vertices"])
+ vertex_data = np.zeros(vertex_count, dtype=vertex_dtype)
+
+ # vertex positions
+ vertex_data["x"] = mesh_data["vertices"][:, 0]
+ vertex_data["y"] = mesh_data["vertices"][:, 1]
+ vertex_data["z"] = mesh_data["vertices"][:, 2]
+
+ # vertex colors
+ if "vertices_color" in mesh_data:
+ vertex_data["red"] = mesh_data["vertices_color"][:, 0]
+ vertex_data["green"] = mesh_data["vertices_color"][:, 1]
+ vertex_data["blue"] = mesh_data["vertices_color"][:, 2]
+
+ # vertex class
+ if "vertices_semantic_class_id" in mesh_data:
+ vertex_data["semantic_class_id"] = mesh_data["vertices_semantic_class_id"]
+
+ # vertex instance
+ if "vertices_instance_id" in mesh_data:
+ vertex_data["instance_id"] = mesh_data["vertices_instance_id"]
+
+ # vertex class colors
+ if "vertices_semantic_class_color" in mesh_data:
+ vertex_data["semantic_class_red"] = mesh_data["vertices_semantic_class_color"][
+ :, 0
+ ]
+ vertex_data["semantic_class_green"] = mesh_data[
+ "vertices_semantic_class_color"
+ ][:, 1]
+ vertex_data["semantic_class_blue"] = mesh_data["vertices_semantic_class_color"][
+ :, 2
+ ]
+
+ # vertex instance colors
+ if "vertices_instance_color" in mesh_data:
+ vertex_data["instance_red"] = mesh_data["vertices_instance_color"][:, 0]
+ vertex_data["instance_green"] = mesh_data["vertices_instance_color"][:, 1]
+ vertex_data["instance_blue"] = mesh_data["vertices_instance_color"][:, 2]
+
+ # initialize data to save
+ vertex_element = PlyElement.describe(vertex_data, "vertex")
+ data_to_save = [vertex_element]
+
+ # faces data
+ if "faces" in mesh_data:
+ face_dtype = [("vertex_indices", "i4", (3,))]
+ face_data = np.zeros(len(mesh_data["faces"]), dtype=face_dtype)
+ face_data["vertex_indices"] = mesh_data["faces"]
+ face_element = PlyElement.describe(face_data, "face")
+ data_to_save.append(face_element)
+
+ # Create and write a binary PLY file
+ ply_data = PlyData(data_to_save, text=False)
+ ply_data.write(file_path)
+
+
+def _get_method(
+ fname: Path | str, format_type: str | None = None, load: bool = True
+) -> Callable:
+ """
+ Returns a method for loading or storing data in a specific format.
+
+ Args:
+ fname (str or Path): The filename to load or store data from/to.
+ format_type (str, optional): The format of the data. If None, it will be inferred from the file extension.
+ Defaults to None.
+ load (bool, optional): Whether to return a method for loading or storing data.
+ Defaults to True.
+
+ Returns:
+ callable: A method for loading or storing data in the specified format.
+
+ Raises:
+ ValueError: If the format cannot be inferred from the file extension.
+ NotImplementedError: If the specified format is not supported.
+
+ Notes:
+ This function supports various formats, including readable files (JSON, YAML), images, NumPy arrays,
+ PyTorch tensors, memory-mapped files, and scene metadata.
+ """
+ fname = Path(fname)
+ if format_type is None:
+ # use default formats
+ if fname.suffix in [".json", ".yaml", ".yml"]:
+ format_type = "readable"
+ elif fname.suffix in [".jpg", ".jpeg", ".png", ".webp"]:
+ format_type = "image"
+ elif fname.suffix in [".npy", ".npz"]:
+ format_type = "numpy"
+ elif fname.suffix == ".ptz":
+ format_type = "ptz"
+ elif fname.suffix == ".sft":
+ format_type = "sft"
+ elif fname.suffix == ".exr":
+ format_type = "scalar"
+ elif fname.suffix in [".glb", ".obj", ".ply"]:
+ format_type = "mesh"
+ else:
+ raise ValueError(f"Cannot infer format for {fname}")
+ methods = {
+ "readable": (_load_readable, _store_readable),
+ "scalar": (_read_exr, _write_exr),
+ "image": (_load_image, _store_image),
+ "binary": (_load_binary_mask, _store_binary_mask),
+ "latent": (_load_sft, _store_sft),
+ "depth": (_load_depth, _store_depth),
+ "normals": (_load_normals, _store_normals),
+ "numpy": (_load_numpy, _store_numpy),
+ "ptz": (_load_ptz, _store_ptz),
+ "sft": (_load_sft, _store_sft),
+ "mmap": (_load_mmap, _store_mmap),
+ "scene_meta": (_load_scene_meta, _store_scene_meta),
+ "labeled_image": (_load_labeled_image, _store_labeled_image),
+ "mesh": (_load_generic_mesh, _store_generic_mesh),
+ "labeled_mesh": (_load_labeled_mesh, _store_labeled_mesh),
+ }
+ try:
+ return methods[format_type][0 if load else 1]
+ except KeyError as e:
+ raise NotImplementedError(f"Format not supported: {format_type}") from e
diff --git a/mapanything/utils/wai/m_ops.py b/mapanything/utils/wai/m_ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..8aaf8de57122b52b8c87baadfa31353b55400912
--- /dev/null
+++ b/mapanything/utils/wai/m_ops.py
@@ -0,0 +1,346 @@
+import numpy as np
+import torch
+
+
+def m_dot(
+ transform: torch.Tensor,
+ points: torch.Tensor | list,
+ maintain_shape: bool = False,
+) -> torch.Tensor | list:
+ """
+ Apply batch matrix multiplication between transform matrices and points.
+
+ Args:
+ transform: Batch of transformation matrices [..., 3/4, 3/4]
+ points: Batch of points [..., N, 3] or a list of points
+ maintain_shape: If True, preserves the original shape of points
+
+ Returns:
+ Transformed points with shape [..., N, 3] or a list of transformed points
+ """
+ if isinstance(points, list):
+ return [m_dot(t, p, maintain_shape) for t, p in zip(transform, points)]
+
+ # Store original shape and flatten batch dimensions
+ orig_shape = points.shape
+ batch_dims = points.shape[:-3]
+
+ # Reshape to standard batch format
+ transform_flat = transform.reshape(-1, transform.shape[-2], transform.shape[-1])
+ points_flat = points.reshape(transform_flat.shape[0], -1, points.shape[-1])
+
+ # Apply transformation
+ pts = torch.bmm(
+ transform_flat[:, :3, :3],
+ points_flat[..., :3].permute(0, 2, 1).to(transform_flat.dtype),
+ ).permute(0, 2, 1)
+
+ if transform.shape[-1] == 4:
+ pts = pts + transform_flat[:, :3, 3].unsqueeze(1)
+
+ # Restore original shape
+ if maintain_shape:
+ return pts.reshape(orig_shape)
+ else:
+ return pts.reshape(*batch_dims, -1, 3)
+
+
+def m_unproject(
+ depth: torch.Tensor,
+ intrinsic: torch.Tensor,
+ cam2world: torch.Tensor = None,
+ img_grid: torch.Tensor = None,
+ valid: torch.Tensor = None,
+ H: int | None = None,
+ W: int | None = None,
+ img_feats: torch.Tensor = None,
+ maintain_shape: bool = False,
+) -> torch.Tensor:
+ """
+ Unproject 2D image points with depth values to 3D points in camera or world space.
+
+ Args:
+ depth: Depth values, either a tensor of shape ...xHxW or a float value
+ intrinsic: Camera intrinsic matrix of shape ...x3x3
+ cam2world: Optional camera-to-world transformation matrix of shape ...x4x4
+ img_grid: Optional pre-computed image grid. If None, will be created
+ valid: Optional mask for valid depth values or minimum depth threshold
+ H: Image height (required if depth is a scalar)
+ W: Image width (required if depth is a scalar)
+ img_feats: Optional image features to append to 3D points
+ maintain_shape: If True, preserves the original shape of points
+
+ Returns:
+ 3D points in camera or world space, with optional features appended
+ """
+ # Get device and shape information from intrinsic matrix
+ device = intrinsic.device
+ pre_shape = intrinsic.shape[:-2] # Batch dimensions
+
+ # Validate inputs
+ if isinstance(depth, (int, float)) and H is None:
+ raise ValueError("H must be provided if depth is a scalar")
+
+ # Determine image dimensions from depth if not provided
+ if isinstance(depth, torch.Tensor) and H is None:
+ H, W = depth.shape[-2:]
+
+ # Create image grid if not provided
+ if img_grid is None:
+ # Create coordinate grid with shape HxWx3 (last dimension is homogeneous)
+ img_grid = _create_image_grid(H, W, device)
+ # Add homogeneous coordinate
+ img_grid = torch.cat([img_grid, torch.ones_like(img_grid[..., :1])], -1)
+
+ # Expand img_grid to match batch dimensions of intrinsic
+ if img_grid.dim() <= intrinsic.dim():
+ img_grid = img_grid.unsqueeze(0)
+ img_grid = img_grid.expand(*pre_shape, *img_grid.shape[-3:])
+
+ # Handle valid mask or minimum depth threshold
+ depth_mask = None
+ if valid is not None:
+ if isinstance(valid, float):
+ # Create mask for minimum depth value
+ depth_mask = depth > valid
+ elif isinstance(valid, torch.Tensor):
+ depth_mask = valid
+
+ # Apply mask to image grid and other inputs
+ img_grid = masking(img_grid, depth_mask, dim=intrinsic.dim())
+ if not isinstance(depth, (int, float)):
+ depth = masking(depth, depth_mask, dim=intrinsic.dim() - 1)
+ if img_feats is not None:
+ img_feats = masking(img_feats, depth_mask, dim=intrinsic.dim() - 1)
+
+ # Unproject 2D points to 3D camera space
+ cam_pts: torch.Tensor = m_dot(
+ m_inverse_intrinsics(intrinsic),
+ img_grid[..., [1, 0, 2]],
+ maintain_shape=True,
+ )
+ # Scale by depth values
+ cam_pts = mult(cam_pts, depth.unsqueeze(-1))
+
+ # Transform to world space if cam2world is provided
+ if cam2world is not None:
+ cam_pts = m_dot(cam2world, cam_pts, maintain_shape=True)
+
+ # Append image features if provided
+ if img_feats is not None:
+ if isinstance(cam_pts, list):
+ if isinstance(cam_pts[0], list):
+ # Handle nested list case
+ result = []
+ for batch_idx, batch in enumerate(cam_pts):
+ batch_result = []
+ for view_idx, view in enumerate(batch):
+ batch_result.append(
+ torch.cat([view, img_feats[batch_idx][view_idx]], -1)
+ )
+ result.append(batch_result)
+ cam_pts = result
+ else:
+ # Handle single list case
+ cam_pts = [
+ torch.cat([pts, feats], -1)
+ for pts, feats in zip(cam_pts, img_feats)
+ ]
+ else:
+ # Handle tensor case
+ cam_pts = torch.cat([cam_pts, img_feats], -1)
+
+ if maintain_shape:
+ return cam_pts
+
+ # Flatten last dimension
+ return cam_pts.reshape(*pre_shape, -1, 3)
+
+
+def m_project(
+ world_pts: torch.Tensor,
+ intrinsic: torch.Tensor,
+ world2cam: torch.Tensor | None = None,
+ maintain_shape: bool = False,
+) -> torch.Tensor:
+ """
+ Project 3D world points to 2D image coordinates.
+
+ Args:
+ world_pts: 3D points in world coordinates
+ intrinsic: Camera intrinsic matrix
+ world2cam: Optional transformation from world to camera coordinates
+ maintain_shape: If True, preserves the original shape of points
+
+ Returns:
+ Image points with coordinates in img_y,img_x,z order
+ """
+ # Transform points from world to camera space if world2cam is provided
+ cam_pts: torch.Tensor = world_pts
+ if world2cam is not None:
+ cam_pts = m_dot(world2cam, world_pts, maintain_shape=maintain_shape)
+
+ # Get shapes to properly expand intrinsics
+ shared_dims = intrinsic.shape[:-2]
+ extra_dims = cam_pts.shape[len(shared_dims) : -1]
+
+ # Expand intrinsics to match cam_pts shape
+ expanded_intrinsic = intrinsic.view(*shared_dims, *([1] * len(extra_dims)), 3, 3)
+ expanded_intrinsic = expanded_intrinsic.expand(*shared_dims, *extra_dims, 3, 3)
+
+ # Project points from camera space to image space
+ depth_abs = cam_pts[..., 2].abs().clamp(min=1e-5)
+ return torch.stack(
+ [
+ expanded_intrinsic[..., 1, 1] * cam_pts[..., 1] / depth_abs
+ + expanded_intrinsic[..., 1, 2],
+ expanded_intrinsic[..., 0, 0] * cam_pts[..., 0] / depth_abs
+ + expanded_intrinsic[..., 0, 2],
+ cam_pts[..., 2],
+ ],
+ -1,
+ )
+
+
+def in_image(
+ image_pts: torch.Tensor | list,
+ H: int,
+ W: int,
+ min_depth: float = 0.0,
+) -> torch.Tensor | list:
+ """
+ Check if image points are within the image boundaries.
+
+ Args:
+ image_pts: Image points in pixel coordinates
+ H: Image height
+ W: Image width
+ min_depth: Minimum valid depth
+
+ Returns:
+ Boolean mask indicating which points are within the image
+ """
+ is_list = isinstance(image_pts, list)
+ if is_list:
+ return [in_image(pts, H, W, min_depth=min_depth) for pts in image_pts]
+
+ in_image_mask = (
+ torch.all(image_pts >= 0, -1)
+ & (image_pts[..., 0] < H)
+ & (image_pts[..., 1] < W)
+ )
+ if (min_depth is not None) and image_pts.shape[-1] == 3:
+ in_image_mask &= image_pts[..., 2] > min_depth
+ return in_image_mask
+
+
+def _create_image_grid(H: int, W: int, device: torch.device) -> torch.Tensor:
+ """
+ Create a coordinate grid for image pixels.
+
+ Args:
+ H: Image height
+ W: Image width
+ device: Computation device
+
+ Returns:
+ Image grid with shape HxWx3 (last dimension is homogeneous)
+ """
+ y_coords = torch.arange(H, device=device)
+ x_coords = torch.arange(W, device=device)
+
+ # Use meshgrid with indexing="ij" for correct orientation
+ y_grid, x_grid = torch.meshgrid(y_coords, x_coords, indexing="ij")
+
+ # Stack coordinates and add homogeneous coordinate
+ img_grid = torch.stack([y_grid, x_grid, torch.ones_like(y_grid)], dim=-1)
+
+ return img_grid
+
+
+def masking(
+ X: torch.Tensor | list,
+ mask: torch.Tensor | list,
+ dim: int = 3,
+) -> torch.Tensor | list:
+ """
+ Apply a Boolean mask to tensor or list elements.
+ Handles nested structures by recursively applying the mask.
+
+ Args:
+ X: Input tensor or list to be masked
+ mask: Boolean mask to apply
+ dim: Dimension threshold for recursive processing
+
+ Returns:
+ Masked tensor or list with the same structure as input
+ """
+ if isinstance(X, list) or (isinstance(X, torch.Tensor) and X.dim() >= dim):
+ return [masking(x, m, dim) for x, m in zip(X, mask)]
+ return X[mask]
+
+
+def m_inverse_intrinsics(intrinsics: torch.Tensor) -> torch.Tensor:
+ """
+ Compute the inverse of camera intrinsics matrices analytically.
+ This is much faster than using torch.inverse() for intrinsics matrices.
+
+ The intrinsics matrix has the form:
+ K = [fx s cx]
+ [0 fy cy]
+ [0 0 1]
+
+ And its inverse is:
+ K^-1 = [1/fx -s/(fx*fy) (s*cy-cx*fy)/(fx*fy)]
+ [0 1/fy -cy/fy ]
+ [0 0 1 ]
+
+ Args:
+ intrinsics: Camera intrinsics matrices of shape [..., 3, 3]
+
+ Returns:
+ Inverse intrinsics matrices of shape [..., 3, 3]
+ """
+ # Extract the components of the intrinsics matrix
+ fx = intrinsics[..., 0, 0]
+ s = intrinsics[..., 0, 1] # skew, usually 0
+ cx = intrinsics[..., 0, 2]
+ fy = intrinsics[..., 1, 1]
+ cy = intrinsics[..., 1, 2]
+
+ # Create output tensor with same shape and device
+ inv_intrinsics = torch.zeros_like(intrinsics)
+
+ # Compute the inverse analytically
+ inv_intrinsics[..., 0, 0] = 1.0 / fx
+ inv_intrinsics[..., 0, 1] = -s / (fx * fy)
+ inv_intrinsics[..., 0, 2] = (s * cy - cx * fy) / (fx * fy)
+ inv_intrinsics[..., 1, 1] = 1.0 / fy
+ inv_intrinsics[..., 1, 2] = -cy / fy
+ inv_intrinsics[..., 2, 2] = 1.0
+
+ return inv_intrinsics
+
+
+def mult(
+ A: torch.Tensor | np.ndarray | list | float | int,
+ B: torch.Tensor | np.ndarray | list | float | int,
+) -> torch.Tensor | np.ndarray | list | float | int:
+ """
+ Multiply two objects with support for lists, tensors, arrays, and scalars.
+ Handles nested structures by recursively applying multiplication.
+
+ Args:
+ A: First operand (tensor, array, list, or scalar)
+ B: Second operand (tensor, array, list, or scalar)
+
+ Returns:
+ Result of multiplication with the same structure as inputs
+ """
+ if isinstance(A, list) and isinstance(B, (int, float)):
+ return [mult(a, B) for a in A]
+ if isinstance(B, list) and isinstance(A, (int, float)):
+ return [mult(A, b) for b in B]
+ if isinstance(A, list) and isinstance(B, list):
+ return [mult(a, b) for a, b in zip(A, B)]
+ return A * B
diff --git a/mapanything/utils/wai/ops.py b/mapanything/utils/wai/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d0b615eafd1bf0cb7e7f3286a7574bfe72c9a02
--- /dev/null
+++ b/mapanything/utils/wai/ops.py
@@ -0,0 +1,368 @@
+"""
+This utils script contains PORTAGE of wai-core ops methods for MapAnything.
+"""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+
+
+def to_numpy(
+ data: torch.Tensor | np.ndarray | int | float,
+ dtype: np.dtype | str | type = np.float32,
+) -> np.ndarray:
+ """
+ Convert data to a NumPy array with the specified dtype (default: float32).
+
+ This function handles conversion from NumPy arrays and PyTorch tensors to a NumPy array.
+
+ Args:
+ data: Input data (torch.Tensor, np.ndarray, or scalar)
+ dtype: Target data type (NumPy dtype, str, or type). Default: np.float32.
+
+ Returns:
+ Converted data as NumPy array with specified dtype.
+ """
+ # Set default dtype if not defined
+ assert dtype is not None, "dtype cannot be None"
+ dtype = np.dtype(dtype)
+
+ # Handle torch.Tensor
+ if isinstance(data, torch.Tensor):
+ return data.detach().cpu().numpy().astype(dtype)
+
+ # Handle numpy.ndarray
+ if isinstance(data, np.ndarray):
+ return data.astype(dtype)
+
+ # Handle scalar values
+ if isinstance(data, (int, float)):
+ return np.array(data, dtype=dtype)
+
+ raise NotImplementedError(f"Unsupported data type: {type(data)}")
+
+
+def get_dtype_device(
+ data: torch.Tensor | np.ndarray | dict | list,
+) -> tuple[torch.dtype | np.dtype | None, torch.device | str | type | None]:
+ """
+ Determine the data type and device of the input data.
+
+ This function recursively inspects the input data and determines its data type
+ and device. It handles PyTorch tensors, NumPy arrays, dictionaries, and lists.
+
+ Args:
+ data: Input data (torch.Tensor, np.ndarray, dict, list, or other)
+
+ Returns:
+ tuple: (dtype, device) where:
+ - dtype: The data type (torch.dtype or np.dtype)
+ - device: The device (torch.device, 'cpu', 'cuda:X', or np.ndarray)
+
+ Raises:
+ ValueError: If tensors in a dictionary are on different CUDA devices
+ """
+ if isinstance(data, torch.Tensor):
+ return data.dtype, data.device
+
+ if isinstance(data, np.ndarray):
+ return data.dtype, np.ndarray
+
+ if isinstance(data, dict):
+ dtypes = {get_dtype_device(v)[0] for v in data.values()}
+ devices = {get_dtype_device(v)[1] for v in data.values()}
+ cuda_devices = {device for device in devices if str(device).startswith("cuda")}
+ cpu_devices = {device for device in devices if str(device).startswith("cpu")}
+ if (len(cuda_devices) > 0) or (len(cpu_devices) > 0):
+ # torch.tensor
+ dtype = torch.float
+ if all(dtype == torch.half for dtype in dtypes):
+ dtype = torch.half
+ device = None
+ if len(cuda_devices) > 1:
+ raise ValueError("All tensors must be on the same device")
+ if len(cuda_devices) == 1:
+ device = list(cuda_devices)[0]
+ if (device is None) and (len(cpu_devices) == 1):
+ device = list(cpu_devices)[0]
+ else:
+ dtype = np.float32
+ # Fix typo in numpy float16 check
+ if all(dtype == np.float16 for dtype in dtypes):
+ dtype = np.float16
+ device = np.ndarray
+
+ elif isinstance(data, list):
+ if not data: # Handle empty list case
+ return None, None
+ dtype, device = get_dtype_device(data[0])
+
+ else:
+ return np.float32, np.ndarray
+
+ return dtype, device
+
+
+def crop(
+ data: np.ndarray | torch.Tensor | Image.Image,
+ bbox: tuple[int, int, int, int] | tuple[int, int],
+) -> np.ndarray | torch.Tensor | Image.Image:
+ """
+ Crop data of different formats (numpy arrays, PyTorch tensors, PIL Images) to a target size.
+
+ Args:
+ data: Input data to resize (numpy.ndarray, torch.Tensor, or PIL.Image.Image)
+ size: Target size as tuple (offset_height, offset_width, height, width) or tuple (height, width)
+
+ Returns:
+ Cropped data in the same format as the input
+ """
+ if len(bbox) == 4:
+ offset_height, offset_width, target_height, target_width = bbox
+ elif len(bbox) == 2:
+ target_height, target_width = bbox
+ offset_height, offset_width = 0, 0
+ else:
+ raise ValueError(f"Unsupported size length {len(bbox)}.")
+
+ end_height = offset_height + target_height
+ end_width = offset_width + target_width
+
+ if any([sz < 0 for sz in bbox]):
+ raise ValueError("Bounding box can't have negative values.")
+
+ if isinstance(data, np.ndarray):
+ if (
+ max(offset_height, end_height) > data.shape[0]
+ or max(offset_width, end_width) > data.shape[1]
+ ):
+ raise ValueError("Invalid bounding box.")
+ cropped_data = data[offset_height:end_height, offset_width:end_width, ...]
+ return cropped_data
+
+ # Handle PIL images
+ elif isinstance(data, Image.Image):
+ if (
+ max(offset_height, end_height) > data.size[1]
+ or max(offset_width, end_width) > data.size[0]
+ ):
+ raise ValueError("Invalid bounding box.")
+ return data.crop((offset_width, offset_height, end_width, end_height))
+
+ # Handle PyTorch tensors
+ elif isinstance(data, torch.Tensor):
+ if data.is_nested:
+ # special handling for nested tensors
+ return torch.stack([crop(nested_tensor, bbox) for nested_tensor in data])
+ if (
+ max(offset_height, end_height) > data.shape[-2]
+ or max(offset_width, end_width) > data.shape[-1]
+ ):
+ raise ValueError("Invalid bounding box.")
+ cropped_data = data[..., offset_height:end_height, offset_width:end_width]
+ return cropped_data
+ else:
+ raise TypeError(f"Unsupported data type '{type(data)}'.")
+
+
+def stack(
+ data: list[
+ dict[str, torch.Tensor | np.ndarray]
+ | list[torch.Tensor | np.ndarray]
+ | tuple[torch.Tensor | np.ndarray]
+ ],
+) -> dict[str, torch.Tensor | np.ndarray] | list[torch.Tensor | np.ndarray]:
+ """
+ Stack a list of dictionaries into a single dictionary with stacked values.
+ Or when given a list of sublists, stack the sublists using torch or numpy stack
+ if the items are of equal size, or nested tensors if the items are PyTorch tensors
+ of different size.
+
+ This utility function is similar to PyTorch's collate function, but specifically
+ designed for stacking dictionaries of numpy arrays or PyTorch tensors.
+
+ Args:
+ data (list): A list of dictionaries with the same keys, where values are
+ either numpy arrays or PyTorch tensors.
+ OR
+ A list of sublist, where the values of sublists are PyTorch tensors
+ or np arrrays.
+
+ Returns:
+ dict: A dictionary with the same keys as input dictionaries, but with values
+ stacked along a new first dimension.
+ OR
+ list: If the input was a list with sublists, it returns a list with a stacked
+ output for each original input sublist.
+
+ Raises:
+ ValueError: If dictionaries in the list have inconsistent keys.
+ NotImplementedError: If input is not a list or contains non-dictionary elements.
+ """
+ if not isinstance(data, list):
+ raise NotImplementedError(f"Stack: Data type not supported: {data}")
+
+ if len(data) == 0:
+ return data
+
+ if all(isinstance(entry, dict) for entry in data):
+ stacked_data = {}
+ keys = list(data[0].keys())
+ if any(set(entry.keys()) != set(keys) for entry in data):
+ raise ValueError("Data not consistent for stacking")
+
+ for key in keys:
+ stacked_data[key] = []
+ for entry in data:
+ stacked_data[key].append(entry[key])
+
+ # stack it according to data format
+ if all(isinstance(v, np.ndarray) for v in stacked_data[key]):
+ stacked_data[key] = np.stack(stacked_data[key])
+ elif all(isinstance(v, torch.Tensor) for v in stacked_data[key]):
+ # Check if all tensors have the same shape
+ first_shape = stacked_data[key][0].shape
+ if all(tensor.shape == first_shape for tensor in stacked_data[key]):
+ stacked_data[key] = torch.stack(stacked_data[key])
+ else:
+ # Use nested tensors if shapes are not consistent
+ stacked_data[key] = torch.nested.nested_tensor(stacked_data[key])
+ return stacked_data
+
+ if all(isinstance(entry, list) for entry in data):
+ # new stacked data will be a list with all of the sublist
+ stacked_data = []
+ for sublist in data:
+ # stack it according to data format
+ if all(isinstance(v, np.ndarray) for v in sublist):
+ stacked_data.append(np.stack(sublist))
+ elif all(isinstance(v, torch.Tensor) for v in sublist):
+ # Check if all tensors have the same shape
+ first_shape = sublist[0].shape
+ if all(tensor.shape == first_shape for tensor in sublist):
+ stacked_data.append(torch.stack(sublist))
+ else:
+ # Use nested tensors if shapes are not consistent
+ stacked_data.append(torch.nested.nested_tensor(sublist))
+ return stacked_data
+
+ raise NotImplementedError(f"Stack: Data type not supported: {data}")
+
+
+def resize(
+ data: np.ndarray | torch.Tensor | Image.Image,
+ size: tuple[int, int] | int | None = None,
+ scale: float | None = None,
+ modality_format: str | None = None,
+) -> np.ndarray | torch.Tensor | Image.Image:
+ """
+ Resize data of different formats (numpy arrays, PyTorch tensors, PIL Images) to a target size.
+
+ Args:
+ data: Input data to resize (numpy.ndarray, torch.Tensor, or PIL.Image.Image)
+ size: Target size as tuple (height, width) or single int for long-side scaling
+ scale: Scale factor to apply to the original dimensions
+ modality_format: Type of data being resized ('depth', 'normals', or None)
+ Affects interpolation method used
+
+ Returns:
+ Resized data in the same format as the input
+
+ Raises:
+ ValueError: If neither size nor scale is provided, or if both are provided
+ TypeError: If data is not a supported type
+ """
+ # Validate input parameters
+ if size is not None and scale is not None:
+ raise ValueError("Only one of size or scale should be provided.")
+
+ # Calculate size from scale if needed
+ if size is None:
+ if scale is None:
+ raise ValueError("Either size or scale must be provided.")
+
+ size = (1, 1)
+ if isinstance(data, (np.ndarray, torch.Tensor)):
+ size = (int(data.shape[-2] * scale), int(data.shape[-1] * scale))
+ elif isinstance(data, Image.Image):
+ size = (int(data.size[1] * scale), int(data.size[0] * scale))
+ else:
+ raise TypeError(f"Unsupported data type '{type(data)}'.")
+
+ # Handle long-side scaling when size is a single integer
+ elif isinstance(size, int):
+ long_side = size
+ if isinstance(data, (np.ndarray, torch.Tensor)):
+ if isinstance(data, torch.Tensor) and data.is_nested:
+ raise ValueError(
+ "Long-side scaling not support for nested tensors, use fixed size instead."
+ )
+ h, w = data.shape[-2], data.shape[-1]
+ elif isinstance(data, Image.Image):
+ w, h = data.size
+ else:
+ raise TypeError(f"Unsupported data type '{type(data)}'.")
+ if h > w:
+ size = (long_side, int(w * long_side / h))
+ else:
+ size = (int(h * long_side / w), long_side)
+
+ target_height, target_width = size
+
+ # Set interpolation method based on modality
+ if modality_format in ["depth", "normals"]:
+ interpolation = Image.Resampling.NEAREST
+ torch_interpolation = "nearest"
+ else:
+ interpolation = Image.Resampling.LANCZOS
+ torch_interpolation = "bilinear"
+
+ # Handle numpy arrays
+ if isinstance(data, np.ndarray):
+ pil_image = Image.fromarray(data)
+ resized_image = pil_image.resize((target_width, target_height), interpolation)
+ return np.array(resized_image)
+
+ # Handle PIL images
+ elif isinstance(data, Image.Image):
+ return data.resize((target_width, target_height), interpolation)
+
+ # Handle PyTorch tensors
+ elif isinstance(data, torch.Tensor):
+ if data.is_nested:
+ # special handling for nested tensors
+ return torch.stack(
+ [
+ resize(nested_tensor, size, scale, modality_format)
+ for nested_tensor in data
+ ]
+ )
+ original_dim = data.ndim
+ if original_dim == 2: # (H, W)
+ data = data.unsqueeze(0).unsqueeze(0) # Add channel and batch dimensions
+ elif original_dim == 3: # (C/B, H W)
+ if modality_format == "depth":
+ data = data.unsqueeze(1) # channel batch dimension
+ else:
+ data = data.unsqueeze(0) # Add batch dimension
+ resized_tensor = F.interpolate(
+ data,
+ size=(target_height, target_width),
+ mode=torch_interpolation,
+ align_corners=False if torch_interpolation != "nearest" else None,
+ )
+ if original_dim == 2:
+ return resized_tensor.squeeze(0).squeeze(
+ 0
+ ) # Remove batch and channel dimensions
+ elif original_dim == 3:
+ if modality_format == "depth":
+ return resized_tensor.squeeze(1) # Remove channel dimension
+
+ return resized_tensor.squeeze(0) # Remove batch dimension
+ else:
+ return resized_tensor
+
+ else:
+ raise TypeError(f"Unsupported data type '{type(data)}'.")
diff --git a/mapanything/utils/wai/scene_frame.py b/mapanything/utils/wai/scene_frame.py
new file mode 100644
index 0000000000000000000000000000000000000000..a801cfc3743369980bb16bafbada20bdfbe2dabb
--- /dev/null
+++ b/mapanything/utils/wai/scene_frame.py
@@ -0,0 +1,431 @@
+import logging
+import os
+import random
+import re
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+
+from mapanything.utils.wai.io import (
+ _load_readable,
+ _load_scene_meta,
+ get_processing_state,
+)
+
+logger = logging.getLogger(__name__)
+
+
+def get_scene_frame_names(
+ cfg: dict | object,
+ root: Path | str | None = None,
+ scene_frames_fn: str | None = None,
+ keyframes: bool = True,
+) -> dict[str, list[str | float]] | None:
+ """
+ Retrieve scene frame names based on configuration and optional parameters.
+
+ This function determines the scene frame names by resolving the scene frame file
+ and applying any necessary filters based on the provided configuration.
+
+ Args:
+ cfg: Configuration object containing settings and parameters.
+ root: Optional root directory path. If not provided, it will be fetched from cfg.
+ scene_frames_fn: Optional scene frames file name. If not provided, it will be fetched from cfg.
+ keyframes: Optional, used only for a video. If True (default), return only keyframes (with camera poses).
+
+ Returns:
+ A dictionary mapping scene names to their respective frame names.
+ """
+ scene_frames_fn = (
+ cfg.get("scene_frames_fn") if scene_frames_fn is None else scene_frames_fn
+ )
+ scene_frame_names = None
+ if scene_frames_fn is not None:
+ # load scene_frames based on scene_frame file
+ scene_frame_names = _resolve_scene_frames_fn(scene_frames_fn)
+
+ scene_names = get_scene_names(
+ cfg,
+ root=root,
+ scene_names=(
+ list(scene_frame_names.keys()) if scene_frame_names is not None else None
+ ),
+ )
+ scene_frame_names = _resolve_scene_frame_names(
+ cfg,
+ scene_names,
+ root=root,
+ scene_frame_names=scene_frame_names,
+ keyframes=keyframes,
+ )
+ return scene_frame_names
+
+
+def get_scene_names(
+ cfg: dict | object,
+ root: Path | str | None = None,
+ scene_names: list[str] | None = None,
+ shuffle: bool = False,
+) -> list[str]:
+ """
+ Retrieve scene names based on the provided configuration and optional parameters.
+
+ This function determines the scene names by checking the root directory for subdirectories
+ and applying any necessary filters based on the provided configuration.
+
+ Args:
+ cfg: Configuration object containing settings and parameters.
+ root: Optional root directory path. If not provided, it will be fetched from cfg.
+ scene_names: Optional list of scene names. If not provided, it will be determined from the root directory.
+ shuffle: Optional bool. Default to False. If True, it will return the list of scene names in random order.
+
+ Returns:
+ A list of scene names after applying any filters specified in the configuration.
+ """
+ root = cfg.get("root") if root is None else root
+ if root is not None:
+ # Check if the root exists
+ if not Path(root).exists():
+ raise IOError(f"Root directory does not exist: {root}")
+
+ # Check if the root is a directory
+ if not Path(root).is_dir():
+ raise IOError(f"Root directory is not a directory: {root}")
+
+ if scene_names is None:
+ scene_filters = cfg.get("scene_filters")
+ if (
+ scene_filters
+ and len(scene_filters) == 1
+ and isinstance(scene_filters[0], list)
+ and all(isinstance(entry, str) for entry in scene_filters[0])
+ ):
+ # Shortcut the scene_names if the scene_filters is only a list of scene names
+ scene_names = scene_filters[0]
+ else:
+ # List all subdirectories in the root as scenes
+ scene_names = sorted(
+ [entry.name for entry in os.scandir(root) if entry.is_dir()]
+ )
+ # Filter scenes based on scene_filters
+ scene_names = _filter_scenes(root, scene_names, cfg.get("scene_filters"))
+
+ # shuffle the list if needed (in place)
+ if shuffle:
+ random.shuffle(scene_names)
+
+ return scene_names
+
+
+def _filter_scenes(
+ root: Path | str,
+ scene_names: list[str],
+ scene_filters: tuple | list | None,
+) -> list[str]:
+ if scene_filters is None:
+ return scene_names
+
+ if not isinstance(scene_filters, (tuple, list)):
+ raise ValueError("scene_filters must be a list or tuple")
+
+ for scene_filter in scene_filters:
+ if scene_filter in [None, "all"]:
+ pass
+
+ elif isinstance(scene_filter, (tuple, list)):
+ if len(scene_filter) == 0:
+ raise ValueError("scene_filter cannot be empty")
+
+ elif all(isinstance(x, int) for x in scene_filter):
+ if len(scene_filter) == 2:
+ # start/end index
+ scene_names = scene_names[scene_filter[0] : scene_filter[1]]
+ elif len(scene_filter) == 3:
+ # start/end/step
+ scene_names = scene_names[
+ scene_filter[0] : scene_filter[1] : scene_filter[2]
+ ]
+ else:
+ # omegaconf conversion issue (converts strings to integers whenever possible)
+ if str(scene_filter[0]) in scene_names:
+ scene_names = [str(s) for s in scene_filter]
+ else:
+ raise ValueError(
+ "scene_filter format [start_idx, end_idx] or [start_idx, end_idx, step_size] or [scene_name1, scene_name2, ...]"
+ )
+
+ elif all(isinstance(x, str) for x in scene_filter):
+ # explicit scene names
+ if set(scene_filter).issubset(set(scene_names)):
+ scene_names = list(scene_filter)
+ else:
+ logger.warning(
+ f"Scene(s) not available: {set(scene_filter) - set(scene_names)}"
+ )
+ scene_names = list(set(scene_names) & set(scene_filter))
+ else:
+ raise TypeError(
+ f"Scene filter type not supported: {type(scene_filter)}"
+ )
+
+ elif isinstance(scene_filter, dict):
+ # reserved key words
+ if modality := scene_filter.get("exists"):
+ scene_names = [
+ scene_name
+ for scene_name in scene_names
+ if Path(root, scene_name, modality).exists()
+ ]
+
+ elif modality := scene_filter.get("exists_not"):
+ scene_names = [
+ scene_name
+ for scene_name in scene_names
+ if not Path(root, scene_name, modality).exists()
+ ]
+
+ elif process_filter := scene_filter.get("process_state"):
+ # filter for where has
+ (process_key, process_state) = process_filter
+ filtered_scene_names = []
+ for scene_name in scene_names:
+ # load processing state and check for
+ processing_state = get_processing_state(Path(root, scene_name))
+ if "*" in process_key: # regex matching
+ for process_name in processing_state:
+ if re.match(process_key, process_name):
+ process_key = process_name
+ break
+ if process_key not in processing_state:
+ continue
+ if processing_state[process_key]["state"] == process_state:
+ filtered_scene_names.append(scene_name)
+ scene_names = filtered_scene_names
+
+ elif process_filter := scene_filter.get("process_state_not"):
+ # filter for where does not have
+ (process_key, process_state) = process_filter
+ filtered_scene_names = []
+ for scene_name in scene_names:
+ # load processing state and check for
+ try:
+ processing_state = get_processing_state(Path(root, scene_name))
+ except Exception:
+ filtered_scene_names.append(scene_name)
+ continue
+ if "*" in process_key: # regex matching
+ for process_name in processing_state:
+ if re.match(process_key, process_name):
+ process_key = process_name
+ break
+ if (process_key not in processing_state) or (
+ processing_state[process_key]["state"] != process_state
+ ):
+ filtered_scene_names.append(scene_name)
+ scene_names = filtered_scene_names
+
+ else:
+ raise ValueError(f"Scene filter not supported: {scene_filter}")
+
+ elif isinstance(scene_filter, str):
+ # regex
+ scene_names = [
+ scene_name
+ for scene_name in scene_names
+ if re.fullmatch(scene_filter, scene_name)
+ ]
+ else:
+ raise ValueError(f"Scene filter not supported: {scene_filter}")
+
+ return scene_names
+
+
+def _resolve_scene_frames_fn(scene_frames_fn: str) -> dict[str, list[str] | None]:
+ # support for file list in forms of lists or dicts
+ # containing scene_names [-> frames]
+ scene_frames_list = _load_readable(scene_frames_fn)
+ scene_frame_names = {}
+
+ # TODO: The following code seems unreachable as scene_frames_list is always a dict
+ if isinstance(scene_frames_list, (list, tuple)):
+ for entry in scene_frames_list:
+ if isinstance(entry, (tuple, list)):
+ if (
+ (len(entry) != 2)
+ or (not isinstance(entry[0], str))
+ or (not isinstance(entry[1], list))
+ ):
+ raise NotImplementedError(
+ "Only supports lists of [, [frame_names]]"
+ )
+ scene_frame_names[entry[0]] = entry[1]
+ elif isinstance(entry, str):
+ scene_frame_names[entry] = None
+ elif isinstance(entry, dict):
+ # scene_name -> frames
+ raise NotImplementedError("Dict entry not supported yet")
+ else:
+ raise IOError(f"File list contains an entry of wrong format: {entry}")
+
+ elif isinstance(scene_frames_list, dict):
+ # scene_name -> frames
+ for scene_name, frame in scene_frames_list.items():
+ if isinstance(frame, (tuple, list)):
+ scene_frame_names[scene_name] = frame
+ elif isinstance(frame, dict):
+ if "frame_names" in frame:
+ scene_frame_names[scene_name] = frame["frame_names"]
+ else:
+ raise IOError(f"Scene frames format not supported: {frame}")
+ elif frame is None:
+ scene_frame_names[scene_name] = frame
+ else:
+ raise IOError(f"Scene frames format not supported: {frame}")
+
+ else:
+ raise IOError(f"Scene frames format not supported: {scene_frames_list}")
+
+ return scene_frame_names
+
+
+def _resolve_scene_frame_names(
+ cfg: dict | object,
+ scene_names: list[str],
+ root: Path | str | None = None,
+ scene_frame_names: dict[str, list[str | float] | None] | None = None,
+ keyframes: bool = True,
+) -> dict[str, list[str]]:
+ root = cfg.get("root") if root is None else root
+ if scene_frame_names is not None:
+ # restrict to the additional scene-level prefiltering
+ scene_frame_names = {
+ scene_name: scene_frame_names[scene_name] for scene_name in scene_names
+ }
+ # dict already loaded, apply additional filters
+ for scene_name, frame_names in scene_frame_names.items():
+ if frame_names is None:
+ scene_meta = _load_scene_meta(
+ Path(
+ root, scene_name, cfg.get("scene_meta_path", "scene_meta.json")
+ )
+ )
+ frame_names = [frame["frame_name"] for frame in scene_meta["frames"]]
+ # TODO: add some logic for video keyframes
+
+ scene_frame_names[scene_name] = _filter_frame_names(
+ root, frame_names, scene_name, cfg.get("frame_filters")
+ )
+ else:
+ scene_frame_names = {}
+ for scene_name in scene_names:
+ scene_meta = _load_scene_meta(
+ Path(root, scene_name, cfg.get("scene_meta_path", "scene_meta.json"))
+ )
+ if not keyframes:
+ frame_names = get_video_frames(scene_meta)
+ if frame_names is None:
+ keyframes = True
+ if keyframes:
+ frame_names = [frame["frame_name"] for frame in scene_meta["frames"]]
+ frame_names = _filter_frame_names(
+ root, frame_names, scene_name, cfg.get("frame_filters")
+ )
+ scene_frame_names[scene_name] = frame_names
+ return scene_frame_names
+
+
+def _filter_frame_names(
+ root: Path | str,
+ frame_names: list[str],
+ scene_name: str,
+ frame_filters: list | tuple | None,
+) -> list[str]:
+ if frame_filters is None:
+ return frame_names
+
+ if not isinstance(frame_filters, (tuple, list)):
+ raise ValueError("frame_filters must be a list or tuple")
+
+ for frame_filter in frame_filters:
+ if frame_filter in [None, "all"]:
+ pass
+
+ elif isinstance(frame_filter, (tuple, list)):
+ if len(frame_filter) == 0:
+ raise ValueError("frame_filter cannot be empty")
+
+ if isinstance(frame_filter[0], int):
+ if len(frame_filter) == 2:
+ # start/end index
+ frame_names = frame_names[frame_filter[0] : frame_filter[1]]
+
+ elif len(frame_filter) == 3:
+ # start/end/step
+ frame_names = frame_names[
+ frame_filter[0] : frame_filter[1] : frame_filter[2]
+ ]
+
+ else:
+ raise ValueError(
+ "frame_filter format [start_idx, end_idx] or [start_idx, end_idx,step_size]"
+ )
+ else:
+ raise TypeError(
+ f"frame_filter[0] type not supported: {type(frame_filter[0])}"
+ )
+
+ elif isinstance(frame_filter, str):
+ # reserved key words
+ if match := re.match("exists: (.+)", frame_filter):
+ modality = match.group(1)
+ frame_names = [
+ frame_name
+ for frame_name in frame_names
+ if any(Path(root, scene_name, modality).glob(f"{frame_name}.*"))
+ ]
+
+ elif match := re.match("!exists: (.+)", frame_filter):
+ modality = match.group(1)
+ frame_names = [
+ frame_name
+ for frame_name in frame_names
+ if not any(Path(root, scene_name, modality).glob(f"{frame_name}.*"))
+ ]
+
+ else: # general regex
+ frame_names = [
+ frame_name
+ for frame_name in frame_names
+ if re.match(frame_filter, frame_name)
+ ]
+
+ else:
+ raise ValueError(f"frame_filter type not supported: {type(frame_filter)}")
+
+ return frame_names
+
+
+def get_video_frames(scene_meta: dict[str, Any]):
+ """
+ Return names of video frames.
+ Args:
+ scene_meta: dictionary with scene_meat data.
+
+ Returns:
+ A list of video frame names.
+ """
+ image_modality = [mod for mod in scene_meta["frame_modalities"] if "image" in mod]
+ if len(image_modality) > 0:
+ image_modality = scene_meta["frame_modalities"][image_modality[0]]
+ if "chunks" in image_modality:
+ file_list = image_modality["chunks"]
+ else:
+ file_list = [image_modality]
+ frame_names = []
+ for chunk in file_list:
+ start, end, fps = chunk["start"], chunk["end"], chunk["fps"]
+ chunk_frame_names = np.arange(start, end, 1.0 / fps).tolist()
+ frame_names += chunk_frame_names
+ return frame_names
+ return None
diff --git a/mapanything/utils/wai/semantics.py b/mapanything/utils/wai/semantics.py
new file mode 100644
index 0000000000000000000000000000000000000000..324b28ed0ede3f2bc1a3c7b50e015af41f7157a4
--- /dev/null
+++ b/mapanything/utils/wai/semantics.py
@@ -0,0 +1,40 @@
+"""
+This utils script contains PORTAGE of wai-core semantics methods for MapAnything.
+"""
+
+import numpy as np
+from PIL import Image
+
+INVALID_ID = 0
+INVALID_COLOR = (0, 0, 0)
+
+
+def load_semantic_color_mapping(filename: str = "colors_fps_5k.npz") -> np.ndarray:
+ """Loads a precomputed colormap."""
+ from mapanything.utils.wai.core import WAI_COLORMAP_PATH
+
+ return np.load(WAI_COLORMAP_PATH / filename).get("arr_0")
+
+
+def apply_id_to_color_mapping(
+ data_id: np.ndarray | Image.Image,
+ semantic_color_mapping: np.ndarray,
+) -> tuple[np.ndarray, dict[int, tuple[int, int, int]]]:
+ """Maps semantic class/instance IDs to RGB colors."""
+ if isinstance(data_id, Image.Image):
+ data_id = np.array(data_id)
+
+ max_color_id = semantic_color_mapping.shape[0] - 1
+ max_data_id = data_id.max()
+ if max_data_id > max_color_id:
+ raise ValueError("The provided color palette does not have enough colors!")
+
+ # Create palette containing the id->color mappings of the input data IDs
+ unique_indices = np.unique(data_id).tolist()
+ color_palette = {
+ index: semantic_color_mapping[index, :].tolist() for index in unique_indices
+ }
+
+ data_colors = semantic_color_mapping[data_id]
+
+ return data_colors, color_palette
diff --git a/mapanything/utils/warnings.py b/mapanything/utils/warnings.py
new file mode 100644
index 0000000000000000000000000000000000000000..8422416bac8ba5893f6a50f2b32125f4f9ab65bb
--- /dev/null
+++ b/mapanything/utils/warnings.py
@@ -0,0 +1,41 @@
+"""
+Wrapper utilities for warnings.
+"""
+
+import warnings
+from functools import wraps
+
+
+def suppress_traceback(fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ try:
+ return fn(*args, **kwargs)
+ except Exception as e:
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
+ raise
+
+ return wrapper
+
+
+class no_warnings:
+ def __init__(self, action: str = "ignore", **kwargs):
+ self.action = action
+ self.filter_kwargs = kwargs
+
+ def __call__(self, fn):
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ with warnings.catch_warnings():
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+ return fn(*args, **kwargs)
+
+ return wrapper
+
+ def __enter__(self):
+ self.warnings_manager = warnings.catch_warnings()
+ self.warnings_manager.__enter__()
+ warnings.simplefilter(self.action, **self.filter_kwargs)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..90e37a3ef4741cff67e074ea9b8047fe64d5cc47
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+--extra-index-url https://download.pytorch.org/whl/cu113
+torch
+torchvision
+torchaudio
+gradio
+huggingface-hub
+numpy
+opencv-python-headless
+Pillow
+matplotlib
+scikit-learn
+scipy
+spaces
+hydra-core
+omegaconf
+trimesh
+einops
+requests
+psutil
+pillow-heif
+tqdm
+safetensors
+uniception==0.1.4
\ No newline at end of file