|
|
""" |
|
|
FleetMind MCP Authentication Proxy |
|
|
Captures API keys from initial SSE connections and injects them into tool requests. |
|
|
|
|
|
This proxy sits between MCP clients and the FastMCP server, solving the |
|
|
multi-tenant authentication problem by: |
|
|
1. Capturing api_key from initial /sse?api_key=xxx connection |
|
|
2. Storing api_key mapped to session_id |
|
|
3. Injecting api_key into subsequent /messages/?session_id=xxx requests |
|
|
|
|
|
Architecture: |
|
|
MCP Client -> Proxy (port 7860) -> FastMCP (port 7861) |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import logging |
|
|
import os |
|
|
from aiohttp import web, ClientSession, ClientTimeout |
|
|
from aiohttp.client_exceptions import ClientConnectionResetError |
|
|
from urllib.parse import urlencode, parse_qs, urlparse, urlunparse |
|
|
import sys |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[logging.StreamHandler()] |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
PROXY_PORT = int(os.getenv("PORT", 7860)) |
|
|
FASTMCP_PORT = 7861 |
|
|
FASTMCP_HOST = "localhost" |
|
|
|
|
|
|
|
|
session_api_keys = {} |
|
|
|
|
|
|
|
|
async def proxy_handler(request): |
|
|
""" |
|
|
Main proxy handler - forwards all requests to FastMCP server. |
|
|
Captures API keys from SSE connections and injects them into tool calls. |
|
|
""" |
|
|
path = request.path |
|
|
query_params = dict(request.query) |
|
|
|
|
|
|
|
|
api_key = query_params.get('api_key') |
|
|
session_id = query_params.get('session_id') |
|
|
|
|
|
|
|
|
if api_key and path == '/sse': |
|
|
logger.info(f"[AUTH] Captured API key from SSE connection: {api_key[:20]}...") |
|
|
|
|
|
session_api_keys['_pending_api_key'] = api_key |
|
|
|
|
|
|
|
|
if session_id and path.startswith('/messages'): |
|
|
|
|
|
if session_id not in session_api_keys: |
|
|
|
|
|
if '_pending_api_key' in session_api_keys: |
|
|
api_key_to_store = session_api_keys['_pending_api_key'] |
|
|
session_api_keys[session_id] = api_key_to_store |
|
|
logger.info(f"[AUTH] Linked session {session_id[:12]}... to API key") |
|
|
|
|
|
|
|
|
stored_api_key = session_api_keys.get(session_id) |
|
|
if stored_api_key: |
|
|
query_params['api_key'] = stored_api_key |
|
|
logger.debug(f"[AUTH] Injected API key into request for session {session_id[:12]}...") |
|
|
|
|
|
|
|
|
query_string = urlencode(query_params) if query_params else "" |
|
|
target_url = f"http://{FASTMCP_HOST}:{FASTMCP_PORT}{path}" |
|
|
if query_string: |
|
|
target_url += f"?{query_string}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async with ClientSession( |
|
|
timeout=ClientTimeout( |
|
|
total=None, |
|
|
sock_connect=30, |
|
|
sock_read=300 |
|
|
) |
|
|
) as session: |
|
|
try: |
|
|
|
|
|
headers = dict(request.headers) |
|
|
|
|
|
headers.pop('Host', None) |
|
|
|
|
|
|
|
|
if request.method == 'GET': |
|
|
async with session.get(target_url, headers=headers) as resp: |
|
|
|
|
|
if 'text/event-stream' in resp.content_type: |
|
|
|
|
|
response = web.StreamResponse( |
|
|
status=resp.status, |
|
|
reason=resp.reason, |
|
|
headers=dict(resp.headers) |
|
|
) |
|
|
await response.prepare(request) |
|
|
|
|
|
|
|
|
async def send_keepalive(): |
|
|
try: |
|
|
while True: |
|
|
await asyncio.sleep(30) |
|
|
await response.write(b":\n\n") |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
|
|
|
keepalive_task = asyncio.create_task(send_keepalive()) |
|
|
|
|
|
try: |
|
|
|
|
|
async for chunk in resp.content.iter_any(): |
|
|
await response.write(chunk) |
|
|
|
|
|
await response.write_eof() |
|
|
finally: |
|
|
|
|
|
keepalive_task.cancel() |
|
|
try: |
|
|
await keepalive_task |
|
|
except asyncio.CancelledError: |
|
|
pass |
|
|
|
|
|
return response |
|
|
else: |
|
|
|
|
|
body = await resp.read() |
|
|
resp_headers = dict(resp.headers) |
|
|
return web.Response( |
|
|
body=body, |
|
|
status=resp.status, |
|
|
headers=resp_headers |
|
|
) |
|
|
|
|
|
elif request.method == 'POST': |
|
|
body = await request.read() |
|
|
async with session.post(target_url, data=body, headers=headers) as resp: |
|
|
resp_body = await resp.read() |
|
|
|
|
|
resp_headers = dict(resp.headers) |
|
|
return web.Response( |
|
|
body=resp_body, |
|
|
status=resp.status, |
|
|
headers=resp_headers |
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
async with session.request( |
|
|
request.method, |
|
|
target_url, |
|
|
data=await request.read(), |
|
|
headers=headers |
|
|
) as resp: |
|
|
body = await resp.read() |
|
|
return web.Response( |
|
|
body=body, |
|
|
status=resp.status, |
|
|
headers=dict(resp.headers) |
|
|
) |
|
|
|
|
|
except (ClientConnectionResetError, ConnectionResetError) as e: |
|
|
|
|
|
|
|
|
logger.debug(f"[SSE] Client disconnected: {e}") |
|
|
return web.Response(text="Client disconnected", status=499) |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_details = traceback.format_exc() |
|
|
logger.error(f"[ERROR] Proxy error: {type(e).__name__}: {e}") |
|
|
logger.error(f"[ERROR] Traceback:\n{error_details}") |
|
|
return web.Response( |
|
|
text=f"Proxy error: {type(e).__name__}: {str(e)}", |
|
|
status=502 |
|
|
) |
|
|
|
|
|
|
|
|
async def health_check(request): |
|
|
"""Health check endpoint""" |
|
|
return web.Response(text="FleetMind Proxy OK", status=200) |
|
|
|
|
|
|
|
|
def create_app(): |
|
|
"""Create and configure the proxy application""" |
|
|
app = web.Application() |
|
|
|
|
|
|
|
|
app.router.add_get('/health', health_check) |
|
|
|
|
|
|
|
|
app.router.add_route('*', '/{path:.*}', proxy_handler) |
|
|
|
|
|
return app |
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Start the proxy server""" |
|
|
print("\n" + "=" * 70) |
|
|
print("FleetMind MCP Authentication Proxy") |
|
|
print("=" * 70) |
|
|
print(f"Proxy listening on: http://0.0.0.0:{PROXY_PORT}") |
|
|
print(f"Forwarding to FastMCP: http://{FASTMCP_HOST}:{FASTMCP_PORT}") |
|
|
print("=" * 70) |
|
|
print("[OK] Multi-tenant authentication enabled") |
|
|
print("[OK] API keys captured from SSE connections") |
|
|
print("[OK] Sessions automatically linked to API keys") |
|
|
print("=" * 70 + "\n") |
|
|
|
|
|
app = create_app() |
|
|
runner = web.AppRunner(app) |
|
|
await runner.setup() |
|
|
|
|
|
site = web.TCPSite(runner, '0.0.0.0', PROXY_PORT) |
|
|
await site.start() |
|
|
|
|
|
logger.info(f"[OK] Proxy server started on port {PROXY_PORT}") |
|
|
logger.info(f"[OK] Forwarding to FastMCP on {FASTMCP_HOST}:{FASTMCP_PORT}") |
|
|
|
|
|
|
|
|
try: |
|
|
await asyncio.Event().wait() |
|
|
except KeyboardInterrupt: |
|
|
logger.info("Shutting down proxy server...") |
|
|
await runner.cleanup() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
try: |
|
|
asyncio.run(main()) |
|
|
except KeyboardInterrupt: |
|
|
print("\nProxy server stopped.") |
|
|
sys.exit(0) |
|
|
|