Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, Response, WebSocket, WebSocketDisconnect | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.concurrency import run_in_threadpool | |
| from gradio_client import Client | |
| import gradio as gr | |
| import uvicorn | |
| import httpx | |
| import websockets | |
| import asyncio | |
| from urllib.parse import urljoin, urlparse, unquote | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| class HttpClient: | |
| def __init__(self): | |
| # Configure the HTTP client with appropriate timeouts | |
| self.client = httpx.AsyncClient( | |
| timeout=httpx.Timeout(30.0), | |
| follow_redirects=False | |
| ) | |
| async def forward_request(self, request: Request, target_url: str): | |
| """ | |
| Forward an incoming request to a target URL | |
| """ | |
| try: | |
| # Extract method, headers, and body from the incoming request | |
| method = request.method | |
| headers = dict(request.headers) | |
| # Remove headers that shouldn't be forwarded | |
| headers.pop("host", None) | |
| headers.pop("connection", None) | |
| # Get the request body | |
| body = await request.body() | |
| logger.info(f"Forwarding {method} request to {target_url}") | |
| # Forward the request to the target URL | |
| response = await self.client.request( | |
| method=method, | |
| url=target_url, | |
| headers=headers, | |
| content=body | |
| ) | |
| # Handle the response from the target server | |
| response_headers = dict(response.headers) | |
| # Remove headers that shouldn't be forwarded from the response | |
| response_headers.pop("connection", None) | |
| response_headers.pop("transfer-encoding", None) | |
| return Response( | |
| content=response.content, | |
| status_code=response.status_code, | |
| headers=response_headers | |
| ) | |
| except httpx.TimeoutException: | |
| logger.error(f"Timeout error while forwarding request to {target_url}") | |
| return Response( | |
| content="Request timeout error", | |
| status_code=504 | |
| ) | |
| except httpx.NetworkError as e: | |
| logger.error(f"Network error while forwarding request: {str(e)}") | |
| return Response( | |
| content=f"Network error: {str(e)}", | |
| status_code=502 | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error forwarding request: {str(e)}") | |
| return Response( | |
| content=f"Request error: {str(e)}", | |
| status_code=500 | |
| ) | |
| async def close(self): | |
| await self.client.aclose() | |
| # Initialize the HTTP client | |
| http_client = HttpClient() | |
| async def read_root(): | |
| with open("index.html") as f: | |
| return f.read() | |
| async def gradio_client(repo_id: str, api_name: str, request: Request): | |
| client = Client(repo_id) | |
| data = await request.json() | |
| result = await run_in_threadpool(client.predict, *data["args"], api_name=f"/{api_name}") | |
| return result | |
| async def web_client(request: Request, path: str): | |
| """ | |
| Main web client endpoint that forwards all requests to the target URL | |
| specified in the path or the 'X-Target-Url' header | |
| """ | |
| # Prioritize URL in path if it starts with http:// or https:// | |
| if path.lower().startswith("http://") or path.lower().startswith("https://"): | |
| target_url = path | |
| else: | |
| # Get the target URL from the header | |
| target_url = request.headers.get("X-Target-Url") | |
| # If we have a target URL from header and a path, combine them | |
| if target_url and path: | |
| # Validate the target URL from header | |
| try: | |
| parsed_url = urlparse(target_url) | |
| if not parsed_url.scheme or not parsed_url.netloc: | |
| return Response( | |
| content="Invalid X-Target-Url header", | |
| status_code=400 | |
| ) | |
| except Exception: | |
| return Response( | |
| content="Invalid X-Target-Url header", | |
| status_code=400 | |
| ) | |
| # Join the target URL with the path properly | |
| target_url = urljoin(target_url.rstrip('/') + '/', path.lstrip('/')) | |
| if not target_url: | |
| return Response( | |
| content="Missing X-Target-Url header or URL in path", | |
| status_code=400 | |
| ) | |
| # Validate the target URL | |
| try: | |
| parsed_url = urlparse(target_url) | |
| if not parsed_url.scheme or not parsed_url.netloc: | |
| return Response( | |
| content="Invalid X-Target-Url header or URL in path", | |
| status_code=400 | |
| ) | |
| except Exception: | |
| return Response( | |
| content="Invalid X-Target-Url header or URL in path", | |
| status_code=400 | |
| ) | |
| # Forward the request | |
| return await http_client.forward_request(request, target_url) | |
| async def websocket_client(websocket: WebSocket, path: str): | |
| """ | |
| WebSocket endpoint that forwards WebSocket connections to the target URL | |
| specified in the 'X-Target-Url' header or in the path | |
| """ | |
| # Get the target URL from the header or path | |
| target_url = websocket.headers.get("X-Target-Url") | |
| # If no header, use path as target URL if it's a valid WebSocket URL | |
| if not target_url: | |
| # Handle URL-encoded paths | |
| decoded_path = path | |
| if path and '%' in path: | |
| # URL decode the path | |
| decoded_path = unquote(path) | |
| if decoded_path and (decoded_path.lower().startswith("ws://") or decoded_path.lower().startswith("wss://")): | |
| target_url = decoded_path | |
| else: | |
| await websocket.close(code=1008, reason="Missing X-Target-Url header or invalid URL in path") | |
| return | |
| # Validate the target URL | |
| try: | |
| parsed_url = urlparse(target_url) | |
| if not parsed_url.scheme or not parsed_url.netloc: | |
| await websocket.close(code=1008, reason="Invalid target URL") | |
| return | |
| except Exception: | |
| await websocket.close(code=1008, reason="Invalid target URL") | |
| return | |
| # Accept the WebSocket connection | |
| await websocket.accept() | |
| # Convert HTTP/HTTPS URL to WebSocket URL | |
| if target_url.lower().startswith("https://"): | |
| ws_target_url = "wss://" + target_url[8:] | |
| elif target_url.lower().startswith("http://"): | |
| ws_target_url = "ws://" + target_url[7:] | |
| else: | |
| ws_target_url = target_url | |
| # Add path if provided (but only if it's not already a complete URL) | |
| if path and not (path.lower().startswith("ws://") or path.lower().startswith("wss://")): | |
| # Join the target URL with the path properly | |
| ws_target_url = urljoin(ws_target_url.rstrip('/') + '/', path.lstrip('/')) | |
| try: | |
| # Connect to the target WebSocket server | |
| async with websockets.connect(ws_target_url) as target_ws: | |
| # Forward messages between client and target server | |
| async def forward_client_to_target(): | |
| try: | |
| while True: | |
| data = await websocket.receive_text() | |
| await target_ws.send(data) | |
| except WebSocketDisconnect: | |
| pass | |
| async def forward_target_to_client(): | |
| try: | |
| while True: | |
| data = await target_ws.recv() | |
| await websocket.send_text(data) | |
| except websockets.ConnectionClosed: | |
| pass | |
| # Run both forwarding tasks concurrently | |
| await asyncio.gather( | |
| forward_client_to_target(), | |
| forward_target_to_client(), | |
| return_exceptions=True | |
| ) | |
| except websockets.InvalidURI: | |
| await websocket.close(code=1008, reason="Invalid WebSocket URL") | |
| except websockets.InvalidHandshake: | |
| await websocket.close(code=1008, reason="WebSocket handshake failed") | |
| except Exception as e: | |
| logger.error(f"Error in WebSocket connection: {str(e)}") | |
| await websocket.close(code=1011, reason="Internal server error") | |
| finally: | |
| try: | |
| await websocket.close() | |
| except: | |
| pass | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return {"status": "ok"} | |
| async def shutdown_event(): | |
| await http_client.close() | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |