#!/usr/bin/env python3 """ Helper functions to consolidate duplicate code in als_agent_app.py Refactored to improve efficiency and reduce redundancy """ import asyncio import httpx import logging import os from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple from llm_client import UnifiedLLMClient logger = logging.getLogger(__name__) async def stream_with_retry( client, messages: List[Dict], tools: List[Dict], system_prompt: str, max_retries: int = 2, model: str = None, max_tokens: int = 8192, stream_name: str = "API call", temperature: float = 0.7 ) -> AsyncGenerator[Tuple[str, List[Dict], str], None]: """ Simplified wrapper that delegates to UnifiedLLMClient. The client parameter can be: - An Anthropic client (for backward compatibility) - A UnifiedLLMClient instance - None (will create a UnifiedLLMClient) Yields: (response_text, tool_calls, provider_used) tuples """ # If client is None or is an Anthropic client, use UnifiedLLMClient if client is None or not hasattr(client, 'stream'): # Create or get a UnifiedLLMClient instance llm_client = UnifiedLLMClient() logger.info(f"Using {llm_client.get_provider_display_name()} for {stream_name}") try: # Use the unified client's stream method async for text, tool_calls, provider in llm_client.stream( messages=messages, tools=tools, system_prompt=system_prompt, model=model, max_tokens=max_tokens, temperature=temperature ): yield (text, tool_calls, provider) finally: # Clean up if we created the client await llm_client.cleanup() else: # Client is already a UnifiedLLMClient logger.info(f"Using provided {client.get_provider_display_name()} for {stream_name}") async for text, tool_calls, provider in client.stream( messages=messages, tools=tools, system_prompt=system_prompt, model=model, max_tokens=max_tokens, temperature=temperature ): yield (text, tool_calls, provider) async def execute_tool_calls( tool_calls: List[Dict], call_mcp_tool_func ) -> Tuple[str, List[Dict]]: """ Execute tool calls and collect results. Consolidates duplicate tool execution logic. Now includes self-correction hints for zero results. Returns: (progress_text, tool_results_content) """ progress_text = "" tool_results_content = [] zero_result_tools = [] for tool_call in tool_calls: tool_name = tool_call["name"] tool_args = tool_call["input"] # Single, clean execution marker with search info tool_display = tool_name.replace('__', ' → ') # Show key search parameters search_info = "" if "query" in tool_args: search_info = f" `{tool_args['query'][:50]}{'...' if len(tool_args['query']) > 50 else ''}`" elif "condition" in tool_args: search_info = f" `{tool_args['condition'][:50]}{'...' if len(tool_args['condition']) > 50 else ''}`" progress_text += f"\nšŸ”§ **Searching:** {tool_display}{search_info}\n" # Call MCP tool tool_result = await call_mcp_tool_func(tool_name, tool_args) # Check for zero results to enable self-correction if isinstance(tool_result, str): result_lower = tool_result.lower() if any(phrase in result_lower for phrase in [ "no results found", "0 results", "no papers found", "no trials found", "no preprints found", "not found", "zero results", "no matches" ]): zero_result_tools.append((tool_name, tool_args)) # Add self-correction hint to the result tool_result += "\n\n**SELF-CORRECTION HINT:** No results found with this query. Consider:\n" tool_result += "1. Broadening search terms (remove qualifiers)\n" tool_result += "2. Using alternative terminology or synonyms\n" tool_result += "3. Searching related concepts\n" tool_result += "4. Checking for typos in search terms" # Add to results array tool_results_content.append({ "type": "tool_result", "tool_use_id": tool_call["id"], "content": tool_result }) return progress_text, tool_results_content def build_assistant_message( text_content: str, tool_calls: List[Dict], strip_markers: List[str] = None ) -> List[Dict]: """ Build assistant message content with text and tool uses. Consolidates duplicate message building logic. Args: text_content: Text content to include tool_calls: List of tool calls to include strip_markers: Optional list of text markers to strip from content Returns: List of content blocks for assistant message """ assistant_content = [] # Process text content if text_content and text_content.strip(): processed_text = text_content # Strip any specified markers if strip_markers: for marker in strip_markers: processed_text = processed_text.replace(marker, "") processed_text = processed_text.strip() if processed_text: assistant_content.append({ "type": "text", "text": processed_text }) # Add tool uses for tc in tool_calls: assistant_content.append({ "type": "tool_use", "id": tc["id"], "name": tc["name"], "input": tc["input"] }) return assistant_content def should_continue_iterations( iteration_count: int, max_iterations: int, tool_calls: List[Dict] ) -> bool: """ Check if tool iterations should continue. Centralizes iteration control logic. """ if not tool_calls: return False if iteration_count >= max_iterations: logger.warning(f"Reached maximum tool iterations ({max_iterations})") return False return True