Alovestocode commited on
Commit
add063a
·
verified ·
1 Parent(s): d283cc4

Add Gradio interface for ZeroGPU detection - fixes 'No @spaces.GPU function detected' warning

Browse files
Files changed (2) hide show
  1. app.py +31 -4
  2. requirements.txt +1 -0
app.py CHANGED
@@ -378,10 +378,37 @@ def interactive_ui() -> str:
378
  """
379
 
380
 
381
- app = fastapi_app
 
382
 
 
 
 
 
 
 
 
 
 
383
 
384
- if __name__ == "__main__": # pragma: no cover
385
- import uvicorn
 
 
 
 
 
 
 
 
 
 
 
386
 
387
- uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))
 
 
 
 
 
 
 
378
  """
379
 
380
 
381
+ # Gradio interface for ZeroGPU detection - ZeroGPU requires Gradio SDK
382
+ import gradio as gr
383
 
384
+ @spaces.GPU(duration=300)
385
+ def gradio_generate(
386
+ prompt: str,
387
+ max_new_tokens: int = MAX_NEW_TOKENS,
388
+ temperature: float = DEFAULT_TEMPERATURE,
389
+ top_p: float = DEFAULT_TOP_P,
390
+ ) -> str:
391
+ """Gradio interface function with GPU decorator for ZeroGPU detection."""
392
+ return _generate(prompt, max_new_tokens, temperature, top_p)
393
 
394
+ # Create Gradio interface - this ensures ZeroGPU detects the GPU function
395
+ gradio_interface = gr.Interface(
396
+ fn=gradio_generate,
397
+ inputs=[
398
+ gr.Textbox(label="Prompt", lines=5, placeholder="Enter your router prompt here..."),
399
+ gr.Slider(minimum=64, maximum=2048, value=MAX_NEW_TOKENS, step=16, label="Max New Tokens"),
400
+ gr.Slider(minimum=0.0, maximum=2.0, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature"),
401
+ gr.Slider(minimum=0.0, maximum=1.0, value=DEFAULT_TOP_P, step=0.05, label="Top-p"),
402
+ ],
403
+ outputs=gr.Textbox(label="Generated Response", lines=10),
404
+ title="Router Model API - ZeroGPU",
405
+ description=f"Model: {MODEL_ID} | Strategy: {ACTIVE_STRATEGY or 'pending'}",
406
+ )
407
 
408
+ # Mount FastAPI app within Gradio for API endpoints
409
+ # The main app must be Gradio for ZeroGPU detection, with FastAPI mounted for API routes
410
+ app = gr.mount_gradio_app(fastapi_app, gradio_interface, path="/")
411
+
412
+
413
+ if __name__ == "__main__": # pragma: no cover
414
+ app.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
requirements.txt CHANGED
@@ -6,3 +6,4 @@ torch>=2.1.0
6
  transformers>=4.40.0
7
  uvicorn>=0.22.0
8
  sentencepiece>=0.1.99
 
 
6
  transformers>=4.40.0
7
  uvicorn>=0.22.0
8
  sentencepiece>=0.1.99
9
+ gradio>=4.0.0