""" SHARP Gradio Demo - Standard Native Layout - Fixed: Added @spaces.GPU for ZeroGPU compatibility (Fixes 'dummy' output) - Fixed: Download Button visibility logic """ from __future__ import annotations import warnings import json from pathlib import Path from typing import Final import gradio as gr # --- 1. Import Spaces for ZeroGPU Support --- try: import spaces except ImportError: # Fallback for local testing if spaces is not installed class spaces: @staticmethod def GPU(func): return func # Suppress internal warnings warnings.filterwarnings("ignore", category=FutureWarning, module="torch.distributed") # Ensure model_utils is present in your directory from model_utils import TrajectoryType, predict_and_maybe_render_gpu # ----------------------------------------------------------------------------- # Paths & Config # ----------------------------------------------------------------------------- APP_DIR: Final[Path] = Path(__file__).resolve().parent OUTPUTS_DIR: Final[Path] = APP_DIR / "outputs" ASSETS_DIR: Final[Path] = APP_DIR / "assets" EXAMPLES_DIR: Final[Path] = ASSETS_DIR / "examples" IMAGE_EXTS: Final[tuple[str, ...]] = (".png", ".jpg", ".jpeg", ".webp") # ----------------------------------------------------------------------------- # SEO # ----------------------------------------------------------------------------- SEO_HEAD = """ """ # ----------------------------------------------------------------------------- # Helpers # ----------------------------------------------------------------------------- def _ensure_dir(path: Path) -> Path: path.mkdir(parents=True, exist_ok=True) return path def get_example_files() -> list[list[str]]: """Discover images in assets/examples for the UI.""" _ensure_dir(EXAMPLES_DIR) # Check manifest.json first manifest_path = EXAMPLES_DIR / "manifest.json" if manifest_path.exists(): try: data = json.loads(manifest_path.read_text(encoding="utf-8")) examples = [] for entry in data: if "image" in entry: img_path = EXAMPLES_DIR / entry["image"] if img_path.exists(): examples.append([str(img_path)]) if examples: return examples except Exception as e: print(f"Manifest error: {e}") # Fallback: simple file scan examples = [] for ext in IMAGE_EXTS: for img in sorted(EXAMPLES_DIR.glob(f"*{ext}")): examples.append([str(img)]) return examples # --- 2. Apply @spaces.GPU Decorator --- @spaces.GPU(duration=120) def run_sharp( image_path: str | None, trajectory_type: str, output_long_side: int, num_frames: int, fps: int, render_video: bool, progress=gr.Progress() ) -> tuple[str | None, dict, str]: """ Main Inference Function Decorated with @spaces.GPU to ensure it runs on the GPU node. """ if not image_path: raise gr.Error("Please upload an image first.") # Validate inputs out_long_side_val = None if int(output_long_side) <= 0 else int(output_long_side) # Convert trajectory string to Enum safely traj_key = trajectory_type.upper() if hasattr(TrajectoryType, traj_key): traj_enum = TrajectoryType[traj_key] else: traj_enum = trajectory_type try: progress(0.1, desc="Initializing SHARP model on GPU...") # Call the backend model video_path, ply_path = predict_and_maybe_render_gpu( image_path, trajectory_type=traj_enum, num_frames=int(num_frames), fps=int(fps), output_long_side=out_long_side_val, render_video=bool(render_video), ) # Prepare outputs status_msg = f"### ✅ Success\nGenerated: `{ply_path.name}`" video_result = str(video_path) if video_path else None if video_path: status_msg += f"\nVideo: `{video_path.name}`" # Explicitly update the Download Button download_btn_update = gr.DownloadButton( value=str(ply_path), visible=True, label=f"Download {ply_path.name}" ) return ( video_result, download_btn_update, status_msg ) except Exception as e: # If it fails, we return None for video, hide button, and show error return ( None, gr.DownloadButton(visible=False), f"### ❌ Error\n{str(e)}" ) # ----------------------------------------------------------------------------- # UI Construction # ----------------------------------------------------------------------------- def build_demo() -> gr.Blocks: theme = gr.themes.Default() with gr.Blocks(theme=theme, head=SEO_HEAD, title="SHARP 3D Generator") as demo: with gr.Row(): with gr.Column(scale=1): gr.Markdown("# SHARP: Single-Image 3D Generator\nConvert any static image into a 3D Gaussian Splat scene instantly.") # --- Main Layout (Strict Two Columns) --- with gr.Row(equal_height=False): # --- LEFT COLUMN: Input & Controls --- with gr.Column(scale=1): image_in = gr.Image( label="Input Image", type="filepath", sources=["upload", "clipboard"], interactive=True ) # Configs with gr.Group(): with gr.Row(): trajectory = gr.Dropdown( label="Camera Movement", choices=["swipe", "shake", "rotate", "rotate_forward"], value="rotate_forward", scale=2 ) output_res = gr.Dropdown( label="Output Resolution", choices=[("Original", 0), ("512px", 512), ("1024px", 1024)], value=0, scale=1 ) with gr.Row(): frames = gr.Slider(label="Frames", minimum=24, maximum=120, step=1, value=60) fps_in = gr.Slider(label="FPS", minimum=8, maximum=60, step=1, value=30) render_toggle = gr.Checkbox(label="Render Video Preview", value=True) run_btn = gr.Button("🚀 Generate 3D Scene", variant="primary", size="lg") # Examples example_files = get_example_files() if example_files: gr.Examples( examples=example_files, inputs=[image_in], label="Examples", run_on_click=False, cache_examples=False ) # --- RIGHT COLUMN: Output --- with gr.Column(scale=1): video_out = gr.Video( label="3D Preview", autoplay=True, elem_id="output-video", interactive=False ) with gr.Group(): status_md = gr.Markdown("Ready to generate.") # Button starts hidden ply_download = gr.DownloadButton( label="Download .PLY File", variant="secondary", visible=False ) # --- Logic Binding --- run_btn.click( fn=run_sharp, inputs=[image_in, trajectory, output_res, frames, fps_in, render_toggle], outputs=[video_out, ply_download, status_md], concurrency_limit=1 ) return demo # ----------------------------------------------------------------------------- # Entry Point # ----------------------------------------------------------------------------- _ensure_dir(OUTPUTS_DIR) if __name__ == "__main__": demo = build_demo() demo.queue().launch( allowed_paths=[str(ASSETS_DIR)], ssr_mode=False )