Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| Parallel tool execution optimization for ALS Research Agent | |
| This module replaces sequential tool execution with parallel execution | |
| to reduce response time by ~60-70% for multi-tool queries. | |
| """ | |
| import asyncio | |
| from typing import List, Dict, Tuple, Any | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| async def execute_single_tool( | |
| tool_call: Dict, | |
| call_mcp_tool_func, | |
| index: int | |
| ) -> Tuple[int, str, Dict]: | |
| """ | |
| Execute a single tool call asynchronously. | |
| Returns (index, progress_text, result_dict) to maintain order. | |
| """ | |
| tool_name = tool_call["name"] | |
| tool_args = tool_call["input"] | |
| # Show search info in progress text | |
| tool_display = tool_name.replace('__', ' → ') | |
| search_info = "" | |
| if "query" in tool_args: | |
| search_info = f" `{tool_args['query'][:50]}{'...' if len(tool_args['query']) > 50 else ''}`" | |
| elif "condition" in tool_args: | |
| search_info = f" `{tool_args['condition'][:50]}{'...' if len(tool_args['condition']) > 50 else ''}`" | |
| try: | |
| # Call MCP tool | |
| start_time = asyncio.get_event_loop().time() | |
| tool_result = await call_mcp_tool_func(tool_name, tool_args) | |
| elapsed = asyncio.get_event_loop().time() - start_time | |
| logger.info(f"Tool {tool_name} completed in {elapsed:.2f}s") | |
| # Check for zero results to provide clear indicators | |
| has_results = True | |
| results_count = 0 | |
| if isinstance(tool_result, str): | |
| result_lower = tool_result.lower() | |
| # Check for specific result counts | |
| import re | |
| count_matches = re.findall(r'found (\d+) (?:papers?|trials?|preprints?|results?)', result_lower) | |
| if count_matches: | |
| results_count = int(count_matches[0]) | |
| # Check for no results | |
| if any(phrase in result_lower for phrase in [ | |
| "no results found", "0 results", "no papers found", | |
| "no trials found", "no preprints found", "not found", | |
| "zero results", "no matches" | |
| ]) or results_count == 0: | |
| has_results = False | |
| # Create clear success/failure indicator | |
| if has_results: | |
| if results_count > 0: | |
| progress_text = f"\n✅ **Found {results_count} results:** {tool_display}{search_info}" | |
| else: | |
| progress_text = f"\n✅ **Success:** {tool_display}{search_info}" | |
| else: | |
| progress_text = f"\n⚠️ **No results:** {tool_display}{search_info} - will try alternatives" | |
| # Add timing for long operations | |
| if elapsed > 5: | |
| progress_text += f" (took {elapsed:.1f}s)" | |
| # Check for zero results to enable self-correction | |
| if not has_results: | |
| # Add self-correction hint to the result | |
| tool_result += "\n\n**SELF-CORRECTION HINT:** No results found with this query. Consider:\n" | |
| tool_result += "1. Broadening search terms (remove qualifiers)\n" | |
| tool_result += "2. Using alternative terminology or synonyms\n" | |
| tool_result += "3. Searching related concepts\n" | |
| tool_result += "4. Checking for typos in search terms" | |
| result_dict = { | |
| "type": "tool_result", | |
| "tool_use_id": tool_call["id"], | |
| "content": tool_result | |
| } | |
| return index, progress_text, result_dict | |
| except Exception as e: | |
| logger.error(f"Error executing tool {tool_name}: {e}") | |
| # Clear failure indicator for errors | |
| progress_text = f"\n❌ **Failed:** {tool_display}{search_info} - {str(e)[:50]}" | |
| error_result = { | |
| "type": "tool_result", | |
| "tool_use_id": tool_call["id"], | |
| "content": f"Error executing tool: {str(e)}" | |
| } | |
| return index, progress_text, error_result | |
| async def execute_tool_calls_parallel( | |
| tool_calls: List[Dict], | |
| call_mcp_tool_func | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Execute tool calls in parallel and collect results. | |
| Maintains the original order of tool calls in results. | |
| Returns: (progress_text, tool_results_content) | |
| """ | |
| if not tool_calls: | |
| return "", [] | |
| # Track execution time for progress reporting | |
| start_time = asyncio.get_event_loop().time() | |
| # Log parallel execution | |
| logger.info(f"Executing {len(tool_calls)} tools in parallel") | |
| # Create tasks for parallel execution | |
| tasks = [ | |
| execute_single_tool(tool_call, call_mcp_tool_func, i) | |
| for i, tool_call in enumerate(tool_calls) | |
| ] | |
| # Execute all tasks in parallel | |
| results = await asyncio.gather(*tasks, return_exceptions=True) | |
| # Sort results by index to maintain original order | |
| sorted_results = sorted( | |
| [r for r in results if not isinstance(r, Exception)], | |
| key=lambda x: x[0] | |
| ) | |
| # Combine results with progress summary | |
| completed_count = len(sorted_results) | |
| total_count = len(tool_calls) | |
| # Create progress summary with timing info | |
| elapsed_time = asyncio.get_event_loop().time() - start_time | |
| if elapsed_time > 5: | |
| timing_info = f" in {elapsed_time:.1f}s" | |
| else: | |
| timing_info = "" | |
| progress_text = f"\n📊 **Search Progress:** Completed {completed_count}/{total_count} searches{timing_info}\n" | |
| tool_results_content = [] | |
| for index, prog_text, result_dict in sorted_results: | |
| progress_text += prog_text | |
| tool_results_content.append(result_dict) | |
| # Handle any exceptions | |
| for i, result in enumerate(results): | |
| if isinstance(result, Exception): | |
| logger.error(f"Task {i} failed with exception: {result}") | |
| # Add error result for failed tasks | |
| if i < len(tool_calls): | |
| tool_results_content.insert(i, { | |
| "type": "tool_result", | |
| "tool_use_id": tool_calls[i]["id"], | |
| "content": f"Tool execution failed: {str(result)}" | |
| }) | |
| return progress_text, tool_results_content | |
| # Backward compatibility wrapper | |
| async def execute_tool_calls_optimized( | |
| tool_calls: List[Dict], | |
| call_mcp_tool_func, | |
| parallel: bool = True | |
| ) -> Tuple[str, List[Dict]]: | |
| """ | |
| Execute tool calls with optional parallel execution. | |
| Args: | |
| tool_calls: List of tool calls to execute | |
| call_mcp_tool_func: Function to call MCP tools | |
| parallel: If True, execute tools in parallel; if False, execute sequentially | |
| Returns: (progress_text, tool_results_content) | |
| """ | |
| if parallel and len(tool_calls) > 1: | |
| # Use parallel execution for multiple tools | |
| return await execute_tool_calls_parallel(tool_calls, call_mcp_tool_func) | |
| else: | |
| # Fall back to sequential execution (import from original) | |
| from refactored_helpers import execute_tool_calls | |
| return await execute_tool_calls(tool_calls, call_mcp_tool_func) | |
| def estimate_time_savings(num_tools: int, avg_tool_time: float = 3.5) -> Dict[str, float]: | |
| """ | |
| Estimate time savings from parallel execution. | |
| Args: | |
| num_tools: Number of tools to execute | |
| avg_tool_time: Average time per tool in seconds | |
| Returns: Dictionary with timing estimates | |
| """ | |
| sequential_time = num_tools * avg_tool_time | |
| # Parallel time is roughly the time of the slowest tool plus overhead | |
| parallel_time = avg_tool_time + 0.5 # 0.5s overhead for coordination | |
| savings = sequential_time - parallel_time | |
| savings_percent = (savings / sequential_time) * 100 if sequential_time > 0 else 0 | |
| return { | |
| "sequential_time": sequential_time, | |
| "parallel_time": parallel_time, | |
| "time_saved": savings, | |
| "savings_percent": savings_percent | |
| } | |
| # Test the optimization | |
| if __name__ == "__main__": | |
| # Test time savings estimation | |
| for n in [2, 3, 4, 5]: | |
| estimates = estimate_time_savings(n) | |
| print(f"\n{n} tools:") | |
| print(f" Sequential: {estimates['sequential_time']:.1f}s") | |
| print(f" Parallel: {estimates['parallel_time']:.1f}s") | |
| print(f" Savings: {estimates['time_saved']:.1f}s ({estimates['savings_percent']:.0f}%)") |