""" FleetMind MCP Authentication Proxy Captures API keys from initial SSE connections and injects them into tool requests. This proxy sits between MCP clients and the FastMCP server, solving the multi-tenant authentication problem by: 1. Capturing api_key from initial /sse?api_key=xxx connection 2. Storing api_key mapped to session_id 3. Injecting api_key into subsequent /messages/?session_id=xxx requests Architecture: MCP Client -> Proxy (port 7860) -> FastMCP (port 7861) """ import asyncio import logging import os from aiohttp import web, ClientSession, ClientTimeout from aiohttp.client_exceptions import ClientConnectionResetError from urllib.parse import urlencode, parse_qs, urlparse, urlunparse import sys # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[logging.StreamHandler()] ) logger = logging.getLogger(__name__) # Proxy configuration # On HuggingFace, PORT env var is set to 7860 PROXY_PORT = int(os.getenv("PORT", 7860)) # Public-facing port FASTMCP_PORT = 7861 # Internal FastMCP server port (fixed) FASTMCP_HOST = "localhost" # Session storage: session_id -> api_key session_api_keys = {} async def proxy_handler(request): """ Main proxy handler - forwards all requests to FastMCP server. Captures API keys from SSE connections and injects them into tool calls. """ path = request.path query_params = dict(request.query) # Extract API key if present (initial SSE connection) api_key = query_params.get('api_key') session_id = query_params.get('session_id') # STEP 1: Capture API key from initial SSE connection if api_key and path == '/sse': logger.info(f"[AUTH] Captured API key from SSE connection: {api_key[:20]}...") # Store temporarily - will be linked to session when we see it session_api_keys['_pending_api_key'] = api_key # STEP 2: Link session_id to API key (from /messages requests) if session_id and path.startswith('/messages'): # Check if we have a stored API key for this session if session_id not in session_api_keys: # Link this session to the pending API key if '_pending_api_key' in session_api_keys: api_key_to_store = session_api_keys['_pending_api_key'] session_api_keys[session_id] = api_key_to_store logger.info(f"[AUTH] Linked session {session_id[:12]}... to API key") # STEP 3: Inject API key into request for FastMCP stored_api_key = session_api_keys.get(session_id) if stored_api_key: query_params['api_key'] = stored_api_key logger.debug(f"[AUTH] Injected API key into request for session {session_id[:12]}...") # Build target URL for FastMCP server query_string = urlencode(query_params) if query_params else "" target_url = f"http://{FASTMCP_HOST}:{FASTMCP_PORT}{path}" if query_string: target_url += f"?{query_string}" # Forward request to FastMCP # For SSE connections: total=None disables overall timeout (keeps connection alive) # Still use socket timeouts for safety (sock_connect, sock_read) async with ClientSession( timeout=ClientTimeout( total=None, # No total timeout for long-lived SSE connections sock_connect=30, # 30 seconds for initial connection sock_read=300 # 5 minutes for individual socket reads ) ) as session: try: # Copy headers headers = dict(request.headers) # Remove host header to avoid conflicts headers.pop('Host', None) # Forward request based on method if request.method == 'GET': async with session.get(target_url, headers=headers) as resp: # For SSE, stream the response if 'text/event-stream' in resp.content_type: # Create streaming response for SSE response = web.StreamResponse( status=resp.status, reason=resp.reason, headers=dict(resp.headers) ) await response.prepare(request) # Background task to send keep-alive pings (prevents timeout) async def send_keepalive(): try: while True: await asyncio.sleep(30) # Send ping every 30 seconds await response.write(b":\n\n") # SSE comment (ignored by client) except asyncio.CancelledError: pass keepalive_task = asyncio.create_task(send_keepalive()) try: # Stream chunks from FastMCP to client async for chunk in resp.content.iter_any(): await response.write(chunk) await response.write_eof() finally: # Cancel keep-alive task when streaming completes keepalive_task.cancel() try: await keepalive_task except asyncio.CancelledError: pass return response else: # For regular responses, read entire body body = await resp.read() resp_headers = dict(resp.headers) return web.Response( body=body, status=resp.status, headers=resp_headers ) elif request.method == 'POST': body = await request.read() async with session.post(target_url, data=body, headers=headers) as resp: resp_body = await resp.read() # Don't pass content_type separately - it's already in headers resp_headers = dict(resp.headers) return web.Response( body=resp_body, status=resp.status, headers=resp_headers ) else: # Forward other methods async with session.request( request.method, target_url, data=await request.read(), headers=headers ) as resp: body = await resp.read() return web.Response( body=body, status=resp.status, headers=dict(resp.headers) ) except (ClientConnectionResetError, ConnectionResetError) as e: # Client disconnected - this is normal for SSE connections # Log at DEBUG level to reduce noise logger.debug(f"[SSE] Client disconnected: {e}") return web.Response(text="Client disconnected", status=499) except Exception as e: import traceback error_details = traceback.format_exc() logger.error(f"[ERROR] Proxy error: {type(e).__name__}: {e}") logger.error(f"[ERROR] Traceback:\n{error_details}") return web.Response( text=f"Proxy error: {type(e).__name__}: {str(e)}", status=502 ) async def health_check(request): """Health check endpoint""" return web.Response(text="FleetMind Proxy OK", status=200) def create_app(): """Create and configure the proxy application""" app = web.Application() # Health check endpoint app.router.add_get('/health', health_check) # Proxy all other requests app.router.add_route('*', '/{path:.*}', proxy_handler) return app async def main(): """Start the proxy server""" print("\n" + "=" * 70) print("FleetMind MCP Authentication Proxy") print("=" * 70) print(f"Proxy listening on: http://0.0.0.0:{PROXY_PORT}") print(f"Forwarding to FastMCP: http://{FASTMCP_HOST}:{FASTMCP_PORT}") print("=" * 70) print("[OK] Multi-tenant authentication enabled") print("[OK] API keys captured from SSE connections") print("[OK] Sessions automatically linked to API keys") print("=" * 70 + "\n") app = create_app() runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, '0.0.0.0', PROXY_PORT) await site.start() logger.info(f"[OK] Proxy server started on port {PROXY_PORT}") logger.info(f"[OK] Forwarding to FastMCP on {FASTMCP_HOST}:{FASTMCP_PORT}") # Keep running try: await asyncio.Event().wait() except KeyboardInterrupt: logger.info("Shutting down proxy server...") await runner.cleanup() if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: print("\nProxy server stopped.") sys.exit(0)