Spaces:
Running
Running
| """ | |
| Custom MCP client using direct subprocess communication. | |
| This bypasses the buggy stdio_client from mcp.client.stdio. | |
| """ | |
| import asyncio | |
| import json | |
| import logging | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| logger = logging.getLogger(__name__) | |
| class MCPClient: | |
| """Custom MCP client using direct subprocess communication""" | |
| def __init__(self, server_script: str, server_name: str): | |
| self.server_script = server_script | |
| self.server_name = server_name | |
| self.process: Optional[subprocess.Popen] = None | |
| self.message_id = 0 | |
| self._initialized = False | |
| self.script_path = server_script # Store for potential restart | |
| async def start(self): | |
| """Start the MCP server subprocess""" | |
| logger.info(f"Starting MCP server: {self.server_name}") | |
| self.process = subprocess.Popen( | |
| [sys.executable, self.server_script], | |
| stdin=subprocess.PIPE, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| text=True, | |
| bufsize=1 # Line-buffered I/O to prevent 8KB truncation | |
| ) | |
| # Initialize the session | |
| await self._initialize() | |
| logger.info(f"Successfully started MCP server: {self.server_name}") | |
| async def _initialize(self): | |
| """Initialize the MCP session""" | |
| init_message = { | |
| "jsonrpc": "2.0", | |
| "id": self._next_id(), | |
| "method": "initialize", | |
| "params": { | |
| "protocolVersion": "2024-11-05", | |
| "capabilities": {}, | |
| "clientInfo": { | |
| "name": "als-research-agent", | |
| "version": "1.0.0" | |
| } | |
| } | |
| } | |
| response = await self._send_request(init_message) | |
| if "result" in response: | |
| self._initialized = True | |
| logger.info(f"Initialized {self.server_name}: {response['result'].get('serverInfo', {})}") | |
| else: | |
| raise Exception(f"Initialization failed: {response}") | |
| def _next_id(self) -> int: | |
| """Get next message ID""" | |
| self.message_id += 1 | |
| return self.message_id | |
| async def _send_request(self, message: Dict[str, Any]) -> Dict[str, Any]: | |
| """Send a JSON-RPC request and wait for response""" | |
| if not self.process: | |
| raise RuntimeError("Server not started") | |
| # Check if process is still alive | |
| if self.process.poll() is not None: | |
| # Process has terminated | |
| raise RuntimeError(f"Server {self.server_name} has terminated unexpectedly") | |
| # Send request | |
| request_json = json.dumps(message) + "\n" | |
| self.process.stdin.write(request_json) | |
| self.process.stdin.flush() | |
| # Read response with timeout | |
| try: | |
| response_line = await asyncio.wait_for( | |
| asyncio.to_thread(self.process.stdout.readline), | |
| timeout=60.0 # Extended timeout for LlamaIndex/RAG server initialization | |
| ) | |
| if not response_line: | |
| raise Exception("Server closed stdout") | |
| return json.loads(response_line) | |
| except asyncio.TimeoutError: | |
| raise Exception("Request timed out") | |
| async def list_tools(self) -> List[Dict[str, Any]]: | |
| """List available tools""" | |
| if not self._initialized: | |
| raise RuntimeError("Client not initialized") | |
| message = { | |
| "jsonrpc": "2.0", | |
| "id": self._next_id(), | |
| "method": "tools/list", | |
| "params": {} | |
| } | |
| response = await self._send_request(message) | |
| if "result" in response: | |
| return response["result"].get("tools", []) | |
| else: | |
| raise Exception(f"List tools failed: {response}") | |
| async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str: | |
| """Call a tool""" | |
| if not self._initialized: | |
| raise RuntimeError("Client not initialized") | |
| message = { | |
| "jsonrpc": "2.0", | |
| "id": self._next_id(), | |
| "method": "tools/call", | |
| "params": { | |
| "name": tool_name, | |
| "arguments": arguments | |
| } | |
| } | |
| response = await self._send_request(message) | |
| if "result" in response: | |
| # Extract result from response | |
| result = response["result"] | |
| # Handle different response formats | |
| if isinstance(result, dict): | |
| # New format with 'result' field | |
| if "result" in result: | |
| return result["result"] | |
| # Content array format | |
| elif "content" in result: | |
| content = result["content"] | |
| if isinstance(content, list) and len(content) > 0: | |
| return content[0].get("text", str(content)) | |
| return str(content) | |
| else: | |
| return str(result) | |
| else: | |
| return str(result) | |
| else: | |
| error = response.get("error", {}) | |
| raise Exception(f"Tool call failed: {error.get('message', response)}") | |
| async def close(self): | |
| """Close the MCP client and terminate server""" | |
| if self.process: | |
| logger.info(f"Closing MCP server: {self.server_name}") | |
| self.process.terminate() | |
| try: | |
| self.process.wait(timeout=5) | |
| except subprocess.TimeoutExpired: | |
| self.process.kill() | |
| self.process.wait() | |
| self.process = None | |
| self._initialized = False | |
| class MCPClientManager: | |
| """Manage multiple MCP clients""" | |
| def __init__(self): | |
| self.clients: Dict[str, MCPClient] = {} | |
| async def add_server(self, name: str, script_path: str): | |
| """Add and start an MCP server""" | |
| client = MCPClient(script_path, name) | |
| await client.start() | |
| self.clients[name] = client | |
| logger.info(f"Added MCP server: {name}") | |
| async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> str: | |
| """Call a tool on a specific server""" | |
| if server_name not in self.clients: | |
| raise ValueError(f"Server not found: {server_name}") | |
| return await self.clients[server_name].call_tool(tool_name, arguments) | |
| async def list_all_tools(self) -> Dict[str, List[Dict[str, Any]]]: | |
| """List tools from all servers, handling failures gracefully""" | |
| all_tools = {} | |
| failed_servers = [] | |
| for name, client in self.clients.items(): | |
| try: | |
| tools = await client.list_tools() | |
| for tool in tools: | |
| tool['server'] = name # Add server info to each tool | |
| all_tools[name] = tools | |
| except Exception as e: | |
| logger.error(f"Failed to list tools from server {name}: {e}") | |
| failed_servers.append(name) | |
| # Continue with other servers instead of failing entirely | |
| all_tools[name] = [] | |
| if failed_servers: | |
| logger.warning(f"Some servers failed to respond: {', '.join(failed_servers)}") | |
| # Try to restart failed servers | |
| for server_name in failed_servers: | |
| try: | |
| client = self.clients[server_name] | |
| script_path = client.script_path if hasattr(client, 'script_path') else None | |
| if script_path: | |
| logger.info(f"Attempting to restart {server_name} server...") | |
| await client.close() | |
| # Re-add the server (which will restart it) | |
| await self.add_server(server_name, script_path) | |
| # Try listing tools again after restart | |
| tools = await self.clients[server_name].list_tools() | |
| for tool in tools: | |
| tool['server'] = server_name | |
| all_tools[server_name] = tools | |
| logger.info(f"Successfully restarted {server_name} server") | |
| except Exception as restart_error: | |
| logger.error(f"Failed to restart {server_name}: {restart_error}") | |
| # Remove the failed server from clients to prevent further errors | |
| if server_name in self.clients: | |
| del self.clients[server_name] | |
| return all_tools | |
| async def close_all(self): | |
| """Close all MCP clients""" | |
| for client in self.clients.values(): | |
| await client.close() | |
| self.clients.clear() | |
| logger.info("All MCP servers closed") | |