Alovestocode commited on
Commit
eb7b063
·
verified ·
1 Parent(s): d90f2a7

Fix: Use direct FastAPI decorators with JSONResponse to avoid Content-Length conflicts

Browse files
Files changed (1) hide show
  1. app.py +16 -14
app.py CHANGED
@@ -5,7 +5,7 @@ from functools import lru_cache
5
  from typing import List, Optional, Tuple
6
 
7
  import torch
8
- from fastapi import APIRouter, HTTPException
9
  from pydantic import BaseModel
10
 
11
  try:
@@ -483,24 +483,26 @@ with gr.Blocks(
483
  outputs=[prompt_input, output, status_display],
484
  )
485
 
486
- # Attach API routes directly onto Gradio's FastAPI instance using APIRouter
487
- api_router = APIRouter()
 
 
488
 
489
-
490
- @api_router.get("/health")
491
- def api_health() -> dict[str, str]:
492
  """API health check endpoint."""
493
  return healthcheck()
494
 
495
-
496
- @api_router.post("/v1/generate", response_model=GenerateResponse)
497
- def api_generate(payload: GeneratePayload) -> GenerateResponse:
498
  """API generate endpoint."""
499
- return generate_endpoint(payload)
500
-
501
-
502
- # Include the router in Gradio's FastAPI app
503
- gradio_app.app.include_router(api_router)
 
 
504
 
505
  # Call warm start
506
  warm_start()
 
5
  from typing import List, Optional, Tuple
6
 
7
  import torch
8
+ from fastapi import HTTPException
9
  from pydantic import BaseModel
10
 
11
  try:
 
483
  outputs=[prompt_input, output, status_display],
484
  )
485
 
486
+ # Add API routes directly to Gradio's FastAPI app
487
+ # These routes are added after Gradio Blocks context but before queue/launch
488
+ from fastapi.responses import JSONResponse
489
+ from fastapi import Request
490
 
491
+ @gradio_app.app.get("/health", response_class=JSONResponse)
492
+ def api_health():
 
493
  """API health check endpoint."""
494
  return healthcheck()
495
 
496
+ @gradio_app.app.post("/v1/generate", response_class=JSONResponse)
497
+ async def api_generate(request: Request):
 
498
  """API generate endpoint."""
499
+ try:
500
+ data = await request.json()
501
+ payload = GeneratePayload(**data)
502
+ result = generate_endpoint(payload)
503
+ return {"text": result.text}
504
+ except Exception as exc:
505
+ raise HTTPException(status_code=500, detail=str(exc))
506
 
507
  # Call warm start
508
  warm_start()