Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| import tempfile | |
| import time | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Optional | |
| from uuid import uuid4 | |
| from fastapi import BackgroundTasks, FastAPI, File, Form, HTTPException, UploadFile | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse, HTMLResponse, JSONResponse | |
| import uvicorn | |
| from inference import run_inference | |
| from mission_planner import MissionPlan, get_mission_plan | |
| logging.basicConfig(level=logging.INFO) | |
| app = FastAPI(title="Video Processing Backend") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["x-mission-summary"], | |
| ) | |
| def _save_upload_to_tmp(upload: UploadFile) -> str: | |
| suffix = Path(upload.filename or "upload.mp4").suffix or ".mp4" | |
| fd, path = tempfile.mkstemp(prefix="input_", suffix=suffix, dir="/tmp") | |
| os.close(fd) | |
| with open(path, "wb") as buffer: | |
| data = upload.file.read() | |
| buffer.write(data) | |
| return path | |
| def _safe_delete(path: str) -> None: | |
| try: | |
| os.remove(path) | |
| except FileNotFoundError: | |
| return | |
| except Exception: | |
| logging.exception("Failed to remove temporary file: %s", path) | |
| def _schedule_cleanup(background_tasks: BackgroundTasks, path: str) -> None: | |
| def _cleanup(target: str = path) -> None: | |
| _safe_delete(target) | |
| background_tasks.add_task(_cleanup) | |
| class CachedMission: | |
| prompt: str | |
| detector: Optional[str] | |
| latitude: float | |
| longitude: float | |
| plan: MissionPlan | |
| created_at: float | |
| MISSION_CACHE: dict[str, CachedMission] = {} | |
| MISSION_CACHE_TTL_SECONDS = 3600.0 | |
| def _prune_mission_cache() -> None: | |
| now = time.time() | |
| expired = [ | |
| key | |
| for key, entry in MISSION_CACHE.items() | |
| if now - entry.created_at > MISSION_CACHE_TTL_SECONDS | |
| ] | |
| for key in expired: | |
| MISSION_CACHE.pop(key, None) | |
| def _store_mission_plan( | |
| prompt: str, | |
| detector: Optional[str], | |
| plan: MissionPlan, | |
| latitude: float, | |
| longitude: float, | |
| ) -> str: | |
| _prune_mission_cache() | |
| mission_id = uuid4().hex | |
| MISSION_CACHE[mission_id] = CachedMission( | |
| prompt=prompt, | |
| detector=detector, | |
| latitude=latitude, | |
| longitude=longitude, | |
| plan=plan, | |
| created_at=time.time(), | |
| ) | |
| return mission_id | |
| def _get_cached_mission(mission_id: str) -> CachedMission: | |
| _prune_mission_cache() | |
| entry = MISSION_CACHE.get(mission_id) | |
| if entry is None: | |
| raise HTTPException(status_code=404, detail="Mission prompt not found. Please set it again.") | |
| return entry | |
| def _require_coordinates( | |
| latitude: Optional[float], | |
| longitude: Optional[float], | |
| ) -> tuple[float, float]: | |
| if latitude is None or longitude is None: | |
| raise HTTPException(status_code=400, detail="Mission location coordinates are required.") | |
| return float(latitude), float(longitude) | |
| def _resolve_mission_plan( | |
| prompt: Optional[str], | |
| mission_id: Optional[str], | |
| latitude: Optional[float], | |
| longitude: Optional[float], | |
| ) -> tuple[MissionPlan, str]: | |
| if mission_id: | |
| cached = _get_cached_mission(mission_id) | |
| return cached.plan, cached.prompt | |
| normalized_prompt = (prompt or "").strip() | |
| if not normalized_prompt: | |
| raise HTTPException(status_code=400, detail="Mission prompt is required.") | |
| _require_coordinates(latitude, longitude) | |
| plan = get_mission_plan(normalized_prompt) | |
| return plan, normalized_prompt | |
| def _validate_inputs( | |
| video: UploadFile | None, | |
| prompt: str | None, | |
| mission_id: Optional[str], | |
| latitude: Optional[float], | |
| longitude: Optional[float], | |
| ) -> None: | |
| if video is None: | |
| raise HTTPException(status_code=400, detail="Video file is required.") | |
| if mission_id: | |
| return | |
| if not prompt: | |
| raise HTTPException(status_code=400, detail="Mission prompt is required.") | |
| _require_coordinates(latitude, longitude) | |
| def _location_only_prompt(latitude: float, longitude: float) -> str: | |
| return ( | |
| "Threat reconnaissance mission. " | |
| f"Identify and prioritize potential hostile or suspicious object classes around latitude {latitude:.4f}, " | |
| f"longitude {longitude:.4f}. Consider common threats for this environment when selecting classes." | |
| ) | |
| async def location_context( | |
| latitude: float = Form(...), | |
| longitude: float = Form(...), | |
| ): | |
| prompt = _location_only_prompt(latitude, longitude) | |
| try: | |
| plan = get_mission_plan(prompt) | |
| except Exception as exc: | |
| logging.exception("Location-only planning failed.") | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| return { | |
| "mission_plan": plan.to_dict(), | |
| "location": {"latitude": latitude, "longitude": longitude}, | |
| "prompt_used": prompt, | |
| } | |
| async def mission_plan_endpoint( | |
| prompt: str = Form(...), | |
| detector: Optional[str] = Form(None), | |
| latitude: float = Form(...), | |
| longitude: float = Form(...), | |
| ): | |
| normalized_prompt = (prompt or "").strip() | |
| if not normalized_prompt: | |
| raise HTTPException(status_code=400, detail="Prompt is required.") | |
| try: | |
| plan = get_mission_plan(normalized_prompt) | |
| except Exception as exc: | |
| logging.exception("Mission planning failed.") | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| mission_id = _store_mission_plan(normalized_prompt, detector, plan, latitude, longitude) | |
| return { | |
| "mission_id": mission_id, | |
| "mission_plan": plan.to_dict(), | |
| "location": {"latitude": latitude, "longitude": longitude}, | |
| } | |
| async def process_video( | |
| background_tasks: BackgroundTasks, | |
| video: UploadFile = File(...), | |
| prompt: Optional[str] = Form(None), | |
| mission_id: Optional[str] = Form(None), | |
| detector: Optional[str] = Form(None), | |
| latitude: Optional[float] = Form(None), | |
| longitude: Optional[float] = Form(None), | |
| ): | |
| _validate_inputs(video, prompt, mission_id, latitude, longitude) | |
| try: | |
| input_path = _save_upload_to_tmp(video) | |
| except Exception: | |
| logging.exception("Failed to save uploaded file.") | |
| raise HTTPException(status_code=500, detail="Failed to save uploaded video.") | |
| finally: | |
| await video.close() | |
| fd, output_path = tempfile.mkstemp(prefix="output_", suffix=".mp4", dir="/tmp") | |
| os.close(fd) | |
| mission_plan, mission_prompt = _resolve_mission_plan(prompt, mission_id, latitude, longitude) | |
| try: | |
| output_path, _, _ = run_inference( | |
| input_path, | |
| output_path, | |
| mission_prompt, | |
| detector_name=detector, | |
| generate_summary=False, | |
| mission_plan=mission_plan, | |
| ) | |
| except ValueError as exc: | |
| logging.exception("Video decoding failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| except Exception as exc: | |
| logging.exception("Inference failed.") | |
| _safe_delete(input_path) | |
| _safe_delete(output_path) | |
| return JSONResponse(status_code=500, content={"error": str(exc)}) | |
| _schedule_cleanup(background_tasks, input_path) | |
| _schedule_cleanup(background_tasks, output_path) | |
| response = FileResponse( | |
| path=output_path, | |
| media_type="video/mp4", | |
| filename="processed.mp4", | |
| ) | |
| return response | |
| async def mission_summary( | |
| video: UploadFile = File(...), | |
| prompt: Optional[str] = Form(None), | |
| mission_id: Optional[str] = Form(None), | |
| detector: Optional[str] = Form(None), | |
| latitude: Optional[float] = Form(None), | |
| longitude: Optional[float] = Form(None), | |
| ): | |
| _validate_inputs(video, prompt, mission_id, latitude, longitude) | |
| mission_plan, mission_prompt = _resolve_mission_plan(prompt, mission_id, latitude, longitude) | |
| try: | |
| input_path = _save_upload_to_tmp(video) | |
| except Exception: | |
| logging.exception("Failed to save uploaded file.") | |
| raise HTTPException(status_code=500, detail="Failed to save uploaded video.") | |
| finally: | |
| await video.close() | |
| try: | |
| _, _, mission_summary = run_inference( | |
| input_path, | |
| output_video_path=None, | |
| mission_prompt=mission_prompt, | |
| detector_name=detector, | |
| write_output_video=False, | |
| generate_summary=True, | |
| mission_plan=mission_plan, | |
| ) | |
| except ValueError as exc: | |
| logging.exception("Video decoding failed.") | |
| _safe_delete(input_path) | |
| raise HTTPException(status_code=500, detail=str(exc)) | |
| except Exception as exc: | |
| logging.exception("Summary generation failed.") | |
| _safe_delete(input_path) | |
| return JSONResponse(status_code=500, content={"error": str(exc)}) | |
| _safe_delete(input_path) | |
| return { | |
| "mission_plan": mission_plan.to_dict(), | |
| "mission_summary": mission_summary or "", | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) | |
| async def demo_page() -> str: | |
| demo_path = Path(__file__).with_name("demo.html") | |
| try: | |
| return demo_path.read_text(encoding="utf-8") | |
| except FileNotFoundError: | |
| return "<h1>Demo page missing</h1>" | |