gagndeep commited on
Commit
671c57a
·
1 Parent(s): 81e21b3
Files changed (1) hide show
  1. model_utils.py +569 -231
model_utils.py CHANGED
@@ -1,274 +1,612 @@
1
- """
2
- SHARP Gradio Demo (Fixed)
3
- - Standard Two-Column Layout
4
- - Robust Error Handling
5
- - Glitch-free Examples (Load-only)
 
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
9
 
10
- import warnings
11
- import json
 
 
 
 
12
  from pathlib import Path
13
- from typing import Final
14
- import gradio as gr
15
 
16
- # Suppress internal warnings
17
- warnings.filterwarnings("ignore", category=FutureWarning, module="torch.distributed")
18
 
19
- # Ensure model_utils is present
20
- # We wrap this import to prevent app crash if model_utils is missing during UI dev
21
  try:
22
- from model_utils import TrajectoryType, predict_and_maybe_render_gpu
23
- except ImportError:
24
- # Dummy mocks for testing/building UI without backend
25
- class TrajectoryType:
26
- pass
27
- def predict_and_maybe_render_gpu(*args, **kwargs):
28
- return None, Path("dummy.ply")
29
 
30
- # -----------------------------------------------------------------------------
31
- # Paths & Config
32
- # -----------------------------------------------------------------------------
 
 
 
33
 
34
- APP_DIR: Final[Path] = Path(__file__).resolve().parent
35
- OUTPUTS_DIR: Final[Path] = APP_DIR / "outputs"
36
- ASSETS_DIR: Final[Path] = APP_DIR / "assets"
37
- EXAMPLES_DIR: Final[Path] = ASSETS_DIR / "examples"
 
 
38
 
39
- IMAGE_EXTS: Final[tuple[str, ...]] = (".png", ".jpg", ".jpeg", ".webp")
40
 
41
  # -----------------------------------------------------------------------------
42
  # Helpers
43
  # -----------------------------------------------------------------------------
44
 
 
 
 
 
 
45
  def _ensure_dir(path: Path) -> Path:
46
  path.mkdir(parents=True, exist_ok=True)
47
  return path
48
 
49
- def get_example_files() -> list[list[str]]:
50
- """Discover images in assets/examples for the UI."""
51
- _ensure_dir(EXAMPLES_DIR)
52
-
53
- # Check manifest.json first
54
- manifest_path = EXAMPLES_DIR / "manifest.json"
55
- if manifest_path.exists():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  try:
57
- data = json.loads(manifest_path.read_text(encoding="utf-8"))
58
- examples = []
59
- for entry in data:
60
- if "image" in entry:
61
- img_path = EXAMPLES_DIR / entry["image"]
62
- if img_path.exists():
63
- examples.append([str(img_path)])
64
- if examples:
65
- return examples
66
- except Exception as e:
67
- print(f"Manifest error: {e}")
68
-
69
- # Fallback: simple file scan
70
- examples = []
71
- for ext in IMAGE_EXTS:
72
- for img in sorted(EXAMPLES_DIR.glob(f"*{ext}")):
73
- examples.append([str(img)])
74
- return examples
75
-
76
- def run_sharp(
77
- image_path: str | None,
78
- trajectory_preset: str,
79
- output_long_side: int | float | None,
80
- num_frames: int | float,
81
- fps: int | float,
82
- render_video: bool,
83
- progress=gr.Progress()
84
- ) -> tuple[str | None, str | None, str]:
85
- """
86
- Main Inference Function
87
- """
88
- if not image_path:
89
- raise gr.Error("Please upload an image first.")
90
-
91
- # 1. Safe Integer Conversion (Handle None or Float inputs from sliders)
92
- try:
93
- out_long_side_val = int(output_long_side) if output_long_side and int(output_long_side) > 0 else None
94
- n_frames = int(num_frames)
95
- fps_val = int(fps)
96
- except (TypeError, ValueError):
97
- # Fallbacks if UI sends weird data
98
- out_long_side_val = None
99
- n_frames = 60
100
- fps_val = 30
101
-
102
- # 2. Safe Trajectory Mapping
103
- # Map UI friendly names to internal keys
104
- traj_map = {
105
- "Orbit (Standard)": "rotate",
106
- "Orbit (Forward)": "rotate_forward",
107
- "Swipe Left": "swipe",
108
- "Shake": "shake",
109
- "Zoom In": "zoom",
110
- "Dolly": "dolly"
111
- }
112
-
113
- internal_name = traj_map.get(trajectory_preset, "rotate")
114
-
115
- # Try to find the Enum member safely
116
- traj_enum = internal_name # Default to string if Enum logic fails
117
- try:
118
- if hasattr(TrajectoryType, internal_name.upper()):
119
- traj_enum = getattr(TrajectoryType, internal_name.upper())
120
- elif hasattr(TrajectoryType, internal_name):
121
- traj_enum = getattr(TrajectoryType, internal_name)
122
- except Exception:
123
- print(f"Warning: Could not resolve TrajectoryType.{internal_name}, passing string '{internal_name}'")
124
- traj_enum = internal_name
125
-
126
- # 3. Run Inference
127
- try:
128
- progress(0.1, desc="Initializing model...")
129
-
130
- video_path, ply_path = predict_and_maybe_render_gpu(
131
- image_path,
132
- trajectory_type=traj_enum,
133
- num_frames=n_frames,
134
- fps=fps_val,
135
- output_long_side=out_long_side_val,
136
- render_video=bool(render_video),
137
- )
138
 
139
- status_msg = f"✅ **Success**\n\nPLY: `{ply_path.name}`"
140
- if video_path:
141
- status_msg += f"\nVideo: `{video_path.name}`"
142
-
143
- return (
144
- str(video_path) if video_path else None,
145
- str(ply_path),
146
- status_msg
147
- )
148
 
149
- except Exception as e:
150
- # Catch all errors to prevent UI crash
151
- raise gr.Error(f"Generation failed: {str(e)}")
 
 
 
 
 
 
 
152
 
153
  # -----------------------------------------------------------------------------
154
- # UI Construction
155
  # -----------------------------------------------------------------------------
156
 
157
- def build_demo() -> gr.Blocks:
158
- theme = gr.themes.Default()
159
-
160
- css = """
161
- .container { max-width: 1200px; margin: auto; }
162
- #header { text-align: center; margin-bottom: 20px; }
163
- """
164
-
165
- with gr.Blocks(theme=theme, css=css, title="SHARP 3D") as demo:
166
-
167
- # --- Header ---
168
- with gr.Column(elem_id="header"):
169
- gr.Markdown("# SHARP: Single-Image 3D Generator")
170
- gr.Markdown("Convert any static image into a 3D Gaussian Splat scene instantly.")
171
-
172
- # --- Main Two-Column Layout ---
173
- with gr.Row(equal_height=False):
174
-
175
- # --- LEFT COLUMN: Input & Controls ---
176
- with gr.Column():
177
- image_in = gr.Image(
178
- label="Input Image",
179
- type="filepath",
180
- sources=["upload", "clipboard"],
181
- height=350
182
- )
183
 
184
- # Controls are visible (no accordion)
185
- with gr.Group():
186
- gr.Markdown("### 🎥 Settings")
187
- trajectory_preset = gr.Dropdown(
188
- label="Camera Movement",
189
- choices=[
190
- "Orbit (Standard)",
191
- "Orbit (Forward)",
192
- "Swipe Left",
193
- "Shake",
194
- "Zoom In",
195
- "Dolly"
196
- ],
197
- value="Orbit (Forward)",
198
- interactive=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  )
200
-
201
- output_res = gr.Dropdown(
202
- label="Output Resolution",
203
- choices=[("Original", 0), ("512px", 512), ("1024px", 1024)],
204
- value=0,
205
- interactive=True
 
 
 
 
 
 
 
 
 
 
 
206
  )
 
 
207
 
208
- # Advanced (Collapsible)
209
- with gr.Accordion("Advanced Options", open=False):
210
- frames = gr.Slider(label="Frames", minimum=24, maximum=120, step=1, value=60)
211
- fps_in = gr.Slider(label="FPS", minimum=8, maximum=60, step=1, value=30)
212
- render_toggle = gr.Checkbox(label="Render Video Preview", value=True)
 
 
 
213
 
214
- run_btn = gr.Button("🚀 Generate 3D Scene", variant="primary", size="lg")
 
 
 
 
 
215
 
216
- # --- RIGHT COLUMN: Output ---
217
- with gr.Column():
218
- video_out = gr.Video(
219
- label="3D Preview",
220
- autoplay=True,
221
- height=350
222
- )
223
-
224
- with gr.Group():
225
- ply_download = gr.DownloadButton(
226
- label="Download .PLY File",
227
- variant="secondary",
228
- visible=True
 
 
 
 
 
 
 
229
  )
230
- status_md = gr.Markdown("Waiting for input...")
231
-
232
- # --- Footer: Examples ---
233
- gr.Markdown("### 📝 Examples")
234
- example_files = get_example_files()
235
-
236
- if example_files:
237
- gr.Examples(
238
- examples=example_files,
239
- inputs=[image_in],
240
- # CRITICAL FIX: We do NOT set fn=run_sharp here.
241
- # This ensures clicking an example ONLY fills the image input.
242
- # The user must click "Generate" to run (prevents the 'None' arguments crash).
243
- label="Click an image to load it:"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  )
245
-
246
- # --- Event Binding ---
247
- run_btn.click(
248
- fn=run_sharp,
249
- inputs=[
250
- image_in,
251
- trajectory_preset,
252
- output_res,
253
- frames,
254
- fps_in,
255
- render_toggle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
  ],
257
- outputs=[video_out, ply_download, status_md],
258
- concurrency_limit=1
 
 
 
 
 
 
 
259
  )
260
 
261
- return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  # -----------------------------------------------------------------------------
264
- # Entry Point
265
  # -----------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
- _ensure_dir(OUTPUTS_DIR)
268
 
269
- if __name__ == "__main__":
270
- demo = build_demo()
271
- demo.queue().launch(
272
- allowed_paths=[str(ASSETS_DIR)],
273
- ssr_mode=False
274
- )
 
1
+ """SHARP inference + optional CUDA video rendering utilities.
2
+
3
+ Design goals:
4
+ - Reuse SHARP's own predict/render pipeline (no subprocess calls).
5
+ - Be robust on Hugging Face Spaces + ZeroGPU.
6
+ - Cache model weights and predictor construction across requests.
7
+
8
+ Public API (used by the Gradio app):
9
+ - TrajectoryType
10
+ - predict_and_maybe_render_gpu(...)
11
  """
12
 
13
  from __future__ import annotations
14
 
15
+ import os
16
+ import threading
17
+ import time
18
+ import uuid
19
+ from contextlib import contextmanager
20
+ from dataclasses import dataclass
21
  from pathlib import Path
22
+ from typing import Final, Literal
 
23
 
24
+ import torch
 
25
 
 
 
26
  try:
27
+ import spaces
28
+ except Exception: # pragma: no cover
29
+ spaces = None # type: ignore[assignment]
 
 
 
 
30
 
31
+ try:
32
+ # Prefer HF cache / Hub downloads (works with Spaces `preload_from_hub`).
33
+ from huggingface_hub import hf_hub_download, try_to_load_from_cache
34
+ except Exception: # pragma: no cover
35
+ hf_hub_download = None # type: ignore[assignment]
36
+ try_to_load_from_cache = None # type: ignore[assignment]
37
 
38
+ from sharp.cli.predict import DEFAULT_MODEL_URL, predict_image
39
+ from sharp.cli.render import render_gaussians as sharp_render_gaussians
40
+ from sharp.models import PredictorParams, create_predictor
41
+ from sharp.utils import camera, io
42
+ from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply
43
+ from sharp.utils.gsplat import GSplatRenderer
44
 
45
+ TrajectoryType = Literal["swipe", "shake", "rotate", "rotate_forward"]
46
 
47
  # -----------------------------------------------------------------------------
48
  # Helpers
49
  # -----------------------------------------------------------------------------
50
 
51
+
52
+ def _now_ms() -> int:
53
+ return int(time.time() * 1000)
54
+
55
+
56
  def _ensure_dir(path: Path) -> Path:
57
  path.mkdir(parents=True, exist_ok=True)
58
  return path
59
 
60
+
61
+ def _make_even(x: int) -> int:
62
+ return x if x % 2 == 0 else x + 1
63
+
64
+
65
+ def _select_device(preference: str = "auto") -> torch.device:
66
+ """Select the best available device for inference (CPU/CUDA/MPS)."""
67
+ if preference not in {"auto", "cpu", "cuda", "mps"}:
68
+ raise ValueError("device preference must be one of: auto|cpu|cuda|mps")
69
+
70
+ if preference == "cpu":
71
+ return torch.device("cpu")
72
+ if preference == "cuda":
73
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
74
+ if preference == "mps":
75
+ return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
76
+
77
+ # auto
78
+ if torch.cuda.is_available():
79
+ return torch.device("cuda")
80
+ if torch.backends.mps.is_available():
81
+ return torch.device("mps")
82
+ return torch.device("cpu")
83
+
84
+
85
+ # -----------------------------------------------------------------------------
86
+ # Prediction outputs
87
+ # -----------------------------------------------------------------------------
88
+
89
+
90
+ @dataclass(frozen=True, slots=True)
91
+ class PredictionOutputs:
92
+ """Outputs of SHARP inference (plus derived metadata for rendering)."""
93
+
94
+ ply_path: Path
95
+ gaussians: Gaussians3D
96
+ metadata_for_render: SceneMetaData
97
+ input_resolution_hw: tuple[int, int]
98
+ focal_length_px: float
99
+
100
+
101
+ # -----------------------------------------------------------------------------
102
+ # Patch SHARP VideoWriter to properly close the optional depth writer
103
+ # -----------------------------------------------------------------------------
104
+
105
+
106
+ class _PatchedVideoWriter(io.VideoWriter):
107
+ """Ensure depth writer is closed so files can be safely cleaned up."""
108
+
109
+ def __init__(
110
+ self, output_path: Path, fps: float = 30.0, render_depth: bool = True
111
+ ) -> None:
112
+ super().__init__(output_path, fps=fps, render_depth=render_depth)
113
+ # Ensure attribute exists for downstream code paths.
114
+ if not hasattr(self, "depth_writer"):
115
+ self.depth_writer = None # type: ignore[attribute-defined-outside-init]
116
+
117
+ def close(self):
118
+ super().close()
119
+ depth_writer = getattr(self, "depth_writer", None)
120
  try:
121
+ if depth_writer is not None:
122
+ depth_writer.close()
123
+ except Exception:
124
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
 
 
 
 
 
 
 
 
 
126
 
127
+ @contextmanager
128
+ def _patched_sharp_videowriter():
129
+ """Temporarily patch `sharp.utils.io.VideoWriter` used by `sharp.cli.render`."""
130
+ original = io.VideoWriter
131
+ io.VideoWriter = _PatchedVideoWriter # type: ignore[assignment]
132
+ try:
133
+ yield
134
+ finally:
135
+ io.VideoWriter = original # type: ignore[assignment]
136
+
137
 
138
  # -----------------------------------------------------------------------------
139
+ # Model wrapper
140
  # -----------------------------------------------------------------------------
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ class ModelWrapper:
144
+ """Cached SHARP model wrapper for Gradio/Spaces."""
145
+
146
+ def __init__(
147
+ self,
148
+ *,
149
+ outputs_dir: str | Path = "outputs",
150
+ checkpoint_url: str = DEFAULT_MODEL_URL,
151
+ checkpoint_path: str | Path | None = None,
152
+ device_preference: str = "auto",
153
+ keep_model_on_device: bool | None = None,
154
+ hf_repo_id: str | None = None,
155
+ hf_filename: str | None = None,
156
+ hf_revision: str | None = None,
157
+ ) -> None:
158
+ self.outputs_dir = _ensure_dir(Path(outputs_dir))
159
+ self.checkpoint_url = checkpoint_url
160
+
161
+ env_ckpt = os.getenv("SHARP_CHECKPOINT_PATH") or os.getenv("SHARP_CHECKPOINT")
162
+ if checkpoint_path:
163
+ self.checkpoint_path = Path(checkpoint_path)
164
+ elif env_ckpt:
165
+ self.checkpoint_path = Path(env_ckpt)
166
+ else:
167
+ self.checkpoint_path = None
168
+
169
+ # Optional Hugging Face Hub fallback (useful when direct CDN download fails).
170
+ self.hf_repo_id = hf_repo_id or os.getenv("SHARP_HF_REPO_ID", "apple/Sharp")
171
+ self.hf_filename = hf_filename or os.getenv(
172
+ "SHARP_HF_FILENAME", "sharp_2572gikvuh.pt"
173
+ )
174
+ self.hf_revision = hf_revision or os.getenv("SHARP_HF_REVISION") or None
175
+
176
+ self.device_preference = device_preference
177
+
178
+ # For ZeroGPU, it's safer to not keep large tensors on CUDA across calls.
179
+ if keep_model_on_device is None:
180
+ keep_env = (
181
+ os.getenv("SHARP_KEEP_MODEL_ON_DEVICE")
182
+ )
183
+ self.keep_model_on_device = keep_env == "1"
184
+ else:
185
+ self.keep_model_on_device = keep_model_on_device
186
+
187
+ self._lock = threading.RLock()
188
+ self._predictor: torch.nn.Module | None = None
189
+ self._predictor_device: torch.device | None = None
190
+ self._state_dict: dict | None = None
191
+
192
+ def has_cuda(self) -> bool:
193
+ return torch.cuda.is_available()
194
+
195
+ def _load_state_dict(self) -> dict:
196
+ with self._lock:
197
+ if self._state_dict is not None:
198
+ return self._state_dict
199
+
200
+ # 1) Explicit local checkpoint path
201
+ if self.checkpoint_path is not None:
202
+ try:
203
+ self._state_dict = torch.load(
204
+ self.checkpoint_path,
205
+ weights_only=True,
206
+ map_location="cpu",
207
  )
208
+ return self._state_dict
209
+ except Exception as e:
210
+ raise RuntimeError(
211
+ "Failed to load SHARP checkpoint from local path.\n\n"
212
+ f"Path:\n {self.checkpoint_path}\n\n"
213
+ f"Original error:\n {type(e).__name__}: {e}"
214
+ ) from e
215
+
216
+ # 2) HF cache (no-network): best match for Spaces `preload_from_hub`.
217
+ hf_cache_error: Exception | None = None
218
+ if try_to_load_from_cache is not None:
219
+ try:
220
+ cached = try_to_load_from_cache(
221
+ repo_id=self.hf_repo_id,
222
+ filename=self.hf_filename,
223
+ revision=self.hf_revision,
224
+ repo_type="model",
225
  )
226
+ except TypeError:
227
+ cached = try_to_load_from_cache(self.hf_repo_id, self.hf_filename) # type: ignore[misc]
228
 
229
+ try:
230
+ if isinstance(cached, str) and Path(cached).exists():
231
+ self._state_dict = torch.load(
232
+ cached, weights_only=True, map_location="cpu"
233
+ )
234
+ return self._state_dict
235
+ except Exception as e:
236
+ hf_cache_error = e
237
 
238
+ # 3) HF Hub download (reuse cache when available; may download otherwise).
239
+ hf_error: Exception | None = None
240
+ if hf_hub_download is not None:
241
+ # Attempt "local only" mode if supported (avoids network).
242
+ try:
243
+ import inspect
244
 
245
+ if "local_files_only" in inspect.signature(hf_hub_download).parameters:
246
+ ckpt_path = hf_hub_download(
247
+ repo_id=self.hf_repo_id,
248
+ filename=self.hf_filename,
249
+ revision=self.hf_revision,
250
+ local_files_only=True,
251
+ )
252
+ if Path(ckpt_path).exists():
253
+ self._state_dict = torch.load(
254
+ ckpt_path, weights_only=True, map_location="cpu"
255
+ )
256
+ return self._state_dict
257
+ except Exception:
258
+ pass
259
+
260
+ try:
261
+ ckpt_path = hf_hub_download(
262
+ repo_id=self.hf_repo_id,
263
+ filename=self.hf_filename,
264
+ revision=self.hf_revision,
265
  )
266
+ self._state_dict = torch.load(
267
+ ckpt_path,
268
+ weights_only=True,
269
+ map_location="cpu",
270
+ )
271
+ return self._state_dict
272
+ except Exception as e:
273
+ hf_error = e
274
+
275
+ # 4) Default upstream CDN (torch hub cache). Last resort.
276
+ url_error: Exception | None = None
277
+ try:
278
+ self._state_dict = torch.hub.load_state_dict_from_url(
279
+ self.checkpoint_url,
280
+ progress=True,
281
+ map_location="cpu",
282
+ )
283
+ return self._state_dict
284
+ except Exception as e:
285
+ url_error = e
286
+
287
+ # If we got here: all options failed.
288
+ hint_lines = [
289
+ "Failed to load SHARP checkpoint.",
290
+ "",
291
+ "Tried (in order):",
292
+ f" 1) HF cache (preload_from_hub): repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}",
293
+ f" 2) HF Hub download: repo_id={self.hf_repo_id}, filename={self.hf_filename}, revision={self.hf_revision or 'None'}",
294
+ f" 3) URL (torch hub): {self.checkpoint_url}",
295
+ "",
296
+ "If network access is restricted, set a local checkpoint path:",
297
+ " - SHARP_CHECKPOINT_PATH=/path/to/sharp_2572gikvuh.pt",
298
+ "",
299
+ "Original errors:",
300
+ ]
301
+ if try_to_load_from_cache is None:
302
+ hint_lines.append(" HF cache: huggingface_hub not installed")
303
+ elif hf_cache_error is not None:
304
+ hint_lines.append(
305
+ f" HF cache: {type(hf_cache_error).__name__}: {hf_cache_error}"
306
+ )
307
+ else:
308
+ hint_lines.append(" HF cache: (not found in cache)")
309
+
310
+ if hf_hub_download is None:
311
+ hint_lines.append(" HF download: huggingface_hub not installed")
312
+ else:
313
+ hint_lines.append(f" HF download: {type(hf_error).__name__}: {hf_error}")
314
+
315
+ hint_lines.append(f" URL: {type(url_error).__name__}: {url_error}")
316
+
317
+ raise RuntimeError("\n".join(hint_lines))
318
+
319
+ def _get_predictor(self, device: torch.device) -> torch.nn.Module:
320
+ with self._lock:
321
+ if self._predictor is None:
322
+ state_dict = self._load_state_dict()
323
+ predictor = create_predictor(PredictorParams())
324
+ predictor.load_state_dict(state_dict)
325
+ predictor.eval()
326
+ self._predictor = predictor
327
+ self._predictor_device = torch.device("cpu")
328
+
329
+ assert self._predictor is not None
330
+ assert self._predictor_device is not None
331
+
332
+ if self._predictor_device != device:
333
+ self._predictor.to(device)
334
+ self._predictor_device = device
335
+
336
+ return self._predictor
337
+
338
+ def _maybe_move_model_back_to_cpu(self) -> None:
339
+ if self.keep_model_on_device:
340
+ return
341
+ with self._lock:
342
+ if self._predictor is not None and self._predictor_device is not None:
343
+ if self._predictor_device.type != "cpu":
344
+ self._predictor.to("cpu")
345
+ self._predictor_device = torch.device("cpu")
346
+ if torch.cuda.is_available():
347
+ torch.cuda.empty_cache()
348
+
349
+ def _make_output_stem(self, input_path: Path) -> str:
350
+ return f"{input_path.stem}-{_now_ms()}-{uuid.uuid4().hex[:8]}"
351
+
352
+ def predict_to_ply(self, image_path: str | Path) -> PredictionOutputs:
353
+ """Run SHARP inference and export a .ply file."""
354
+ image_path = Path(image_path)
355
+ if not image_path.exists():
356
+ raise FileNotFoundError(f"Image does not exist: {image_path}")
357
+
358
+ device = _select_device(self.device_preference)
359
+ predictor = self._get_predictor(device)
360
+
361
+ image_np, _, f_px = io.load_rgb(image_path)
362
+ height, width = image_np.shape[:2]
363
+
364
+ with torch.no_grad():
365
+ gaussians = predict_image(predictor, image_np, f_px, device)
366
+
367
+ stem = self._make_output_stem(image_path)
368
+ ply_path = self.outputs_dir / f"{stem}.ply"
369
+
370
+ # save_ply expects (height, width).
371
+ save_ply(gaussians, f_px, (height, width), ply_path)
372
+
373
+ # SceneMetaData expects (width, height) for resolution.
374
+ metadata_for_render = SceneMetaData(
375
+ focal_length_px=float(f_px),
376
+ resolution_px=(int(width), int(height)),
377
+ color_space="linearRGB",
378
+ )
379
+
380
+ self._maybe_move_model_back_to_cpu()
381
+
382
+ return PredictionOutputs(
383
+ ply_path=ply_path,
384
+ gaussians=gaussians,
385
+ metadata_for_render=metadata_for_render,
386
+ input_resolution_hw=(int(height), int(width)),
387
+ focal_length_px=float(f_px),
388
+ )
389
+
390
+ def _render_video_impl(
391
+ self,
392
+ *,
393
+ gaussians: Gaussians3D,
394
+ metadata: SceneMetaData,
395
+ output_path: Path,
396
+ trajectory_type: TrajectoryType,
397
+ num_frames: int,
398
+ fps: int,
399
+ output_long_side: int | None,
400
+ ) -> Path:
401
+ if not torch.cuda.is_available():
402
+ raise RuntimeError("Rendering requires CUDA (gsplat).")
403
+
404
+ if num_frames < 2:
405
+ raise ValueError("num_frames must be >= 2")
406
+ if fps < 1:
407
+ raise ValueError("fps must be >= 1")
408
+
409
+ # Keep aligned with upstream CLI pipeline where possible.
410
+ if output_long_side is None and int(fps) == 30:
411
+ params = camera.TrajectoryParams(
412
+ type=trajectory_type,
413
+ num_steps=int(num_frames),
414
+ num_repeats=1,
415
  )
416
+ with _patched_sharp_videowriter():
417
+ sharp_render_gaussians(
418
+ gaussians=gaussians,
419
+ metadata=metadata,
420
+ params=params,
421
+ output_path=output_path,
422
+ )
423
+ depth_path = output_path.with_suffix(".depth.mp4")
424
+ try:
425
+ if depth_path.exists():
426
+ depth_path.unlink()
427
+ except Exception:
428
+ pass
429
+ return output_path
430
+
431
+ # Adapted pipeline for custom output resolution / FPS.
432
+ src_w, src_h = metadata.resolution_px
433
+ src_f = float(metadata.focal_length_px)
434
+
435
+ if output_long_side is None:
436
+ out_w, out_h, out_f = src_w, src_h, src_f
437
+ else:
438
+ long_side = max(src_w, src_h)
439
+ scale = float(output_long_side) / float(long_side)
440
+ out_w = _make_even(max(2, int(round(src_w * scale))))
441
+ out_h = _make_even(max(2, int(round(src_h * scale))))
442
+ out_f = src_f * scale
443
+
444
+ traj_params = camera.TrajectoryParams(
445
+ type=trajectory_type,
446
+ num_steps=int(num_frames),
447
+ num_repeats=1,
448
+ )
449
+
450
+ device = torch.device("cuda")
451
+ gaussians_cuda = gaussians.to(device)
452
+
453
+ intrinsics = torch.tensor(
454
+ [
455
+ [out_f, 0.0, (out_w - 1) / 2.0, 0.0],
456
+ [0.0, out_f, (out_h - 1) / 2.0, 0.0],
457
+ [0.0, 0.0, 1.0, 0.0],
458
+ [0.0, 0.0, 0.0, 1.0],
459
  ],
460
+ device=device,
461
+ dtype=torch.float32,
462
+ )
463
+
464
+ cam_model = camera.create_camera_model(
465
+ gaussians_cuda,
466
+ intrinsics,
467
+ resolution_px=(out_w, out_h),
468
+ lookat_mode=traj_params.lookat_mode,
469
  )
470
 
471
+ trajectory = camera.create_eye_trajectory(
472
+ gaussians_cuda,
473
+ traj_params,
474
+ resolution_px=(out_w, out_h),
475
+ f_px=out_f,
476
+ )
477
+
478
+ renderer = GSplatRenderer(color_space=metadata.color_space)
479
+
480
+ # IMPORTANT: Keep render_depth=True (avoids upstream AttributeError).
481
+ video_writer = _PatchedVideoWriter(output_path, fps=float(fps), render_depth=True)
482
+
483
+ for eye_position in trajectory:
484
+ cam_info = cam_model.compute(eye_position)
485
+ rendering = renderer(
486
+ gaussians_cuda,
487
+ extrinsics=cam_info.extrinsics[None].to(device),
488
+ intrinsics=cam_info.intrinsics[None].to(device),
489
+ image_width=cam_info.width,
490
+ image_height=cam_info.height,
491
+ )
492
+ color = (rendering.color[0].permute(1, 2, 0) * 255.0).to(dtype=torch.uint8)
493
+ depth = rendering.depth[0]
494
+ video_writer.add_frame(color, depth)
495
+
496
+ video_writer.close()
497
+
498
+ depth_path = output_path.with_suffix(".depth.mp4")
499
+ try:
500
+ if depth_path.exists():
501
+ depth_path.unlink()
502
+ except Exception:
503
+ pass
504
+
505
+ return output_path
506
+
507
+ def render_video(
508
+ self,
509
+ *,
510
+ gaussians: Gaussians3D,
511
+ metadata: SceneMetaData,
512
+ output_stem: str,
513
+ trajectory_type: TrajectoryType = "rotate_forward",
514
+ num_frames: int = 60,
515
+ fps: int = 30,
516
+ output_long_side: int | None = None,
517
+ ) -> Path:
518
+ """Render a camera trajectory as an MP4 (CUDA-only)."""
519
+ output_path = self.outputs_dir / f"{output_stem}.mp4"
520
+ return self._render_video_impl(
521
+ gaussians=gaussians,
522
+ metadata=metadata,
523
+ output_path=output_path,
524
+ trajectory_type=trajectory_type,
525
+ num_frames=num_frames,
526
+ fps=fps,
527
+ output_long_side=output_long_side,
528
+ )
529
+
530
+ def predict_and_maybe_render(
531
+ self,
532
+ image_path: str | Path,
533
+ *,
534
+ trajectory_type: TrajectoryType,
535
+ num_frames: int,
536
+ fps: int,
537
+ output_long_side: int | None,
538
+ render_video: bool = True,
539
+ ) -> tuple[Path | None, Path]:
540
+ """One-shot helper for the UI: returns (video_path, ply_path)."""
541
+ pred = self.predict_to_ply(image_path)
542
+
543
+ if not render_video:
544
+ return None, pred.ply_path
545
+
546
+ if not torch.cuda.is_available():
547
+ return None, pred.ply_path
548
+
549
+ output_stem = pred.ply_path.with_suffix("").name
550
+ video_path = self.render_video(
551
+ gaussians=pred.gaussians,
552
+ metadata=pred.metadata_for_render,
553
+ output_stem=output_stem,
554
+ trajectory_type=trajectory_type,
555
+ num_frames=num_frames,
556
+ fps=fps,
557
+ output_long_side=output_long_side,
558
+ )
559
+ return video_path, pred.ply_path
560
+
561
 
562
  # -----------------------------------------------------------------------------
563
+ # ZeroGPU entrypoints
564
  # -----------------------------------------------------------------------------
565
+ #
566
+ # IMPORTANT: Do NOT decorate bound instance methods with `@spaces.GPU` on ZeroGPU.
567
+ # The wrapper uses multiprocessing queues and pickles args/kwargs. If `self` is
568
+ # included, Python will try to pickle the whole instance. ModelWrapper contains
569
+ # a threading.RLock (not pickleable) and the model itself should not be pickled.
570
+ #
571
+ # Expose module-level functions that accept only pickleable arguments and
572
+ # create/cache the ModelWrapper inside the GPU worker process.
573
+
574
+ DEFAULT_OUTPUTS_DIR: Final[Path] = _ensure_dir(Path(__file__).resolve().parent / "outputs")
575
+
576
+ _GLOBAL_MODEL: ModelWrapper | None = None
577
+ _GLOBAL_MODEL_INIT_LOCK: Final[threading.Lock] = threading.Lock()
578
+
579
+
580
+ def get_global_model(*, outputs_dir: str | Path = DEFAULT_OUTPUTS_DIR) -> ModelWrapper:
581
+ global _GLOBAL_MODEL
582
+ with _GLOBAL_MODEL_INIT_LOCK:
583
+ if _GLOBAL_MODEL is None:
584
+ _GLOBAL_MODEL = ModelWrapper(outputs_dir=outputs_dir)
585
+ return _GLOBAL_MODEL
586
+
587
+
588
+ def predict_and_maybe_render(
589
+ image_path: str | Path,
590
+ *,
591
+ trajectory_type: TrajectoryType,
592
+ num_frames: int,
593
+ fps: int,
594
+ output_long_side: int | None,
595
+ render_video: bool = True,
596
+ ) -> tuple[Path | None, Path]:
597
+ model = get_global_model()
598
+ return model.predict_and_maybe_render(
599
+ image_path,
600
+ trajectory_type=trajectory_type,
601
+ num_frames=num_frames,
602
+ fps=fps,
603
+ output_long_side=output_long_side,
604
+ render_video=render_video,
605
+ )
606
 
 
607
 
608
+ # Export the GPU-wrapped callable (or a no-op wrapper locally).
609
+ if spaces is not None:
610
+ predict_and_maybe_render_gpu = spaces.GPU(duration=180)(predict_and_maybe_render)
611
+ else: # pragma: no cover
612
+ predict_and_maybe_render_gpu = predict_and_maybe_render