Spaces:
Running
Running
File size: 6,311 Bytes
3e435ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
#!/usr/bin/env python3
"""
Helper functions to consolidate duplicate code in als_agent_app.py
Refactored to improve efficiency and reduce redundancy
"""
import asyncio
import httpx
import logging
import os
from typing import AsyncGenerator, List, Dict, Any, Optional, Tuple
from llm_client import UnifiedLLMClient
logger = logging.getLogger(__name__)
async def stream_with_retry(
client,
messages: List[Dict],
tools: List[Dict],
system_prompt: str,
max_retries: int = 2,
model: str = None,
max_tokens: int = 8192,
stream_name: str = "API call",
temperature: float = 0.7
) -> AsyncGenerator[Tuple[str, List[Dict], str], None]:
"""
Simplified wrapper that delegates to UnifiedLLMClient.
The client parameter can be:
- An Anthropic client (for backward compatibility)
- A UnifiedLLMClient instance
- None (will create a UnifiedLLMClient)
Yields: (response_text, tool_calls, provider_used) tuples
"""
# If client is None or is an Anthropic client, use UnifiedLLMClient
if client is None or not hasattr(client, 'stream'):
# Create or get a UnifiedLLMClient instance
llm_client = UnifiedLLMClient()
logger.info(f"Using {llm_client.get_provider_display_name()} for {stream_name}")
try:
# Use the unified client's stream method
async for text, tool_calls, provider in llm_client.stream(
messages=messages,
tools=tools,
system_prompt=system_prompt,
model=model,
max_tokens=max_tokens,
temperature=temperature
):
yield (text, tool_calls, provider)
finally:
# Clean up if we created the client
await llm_client.cleanup()
else:
# Client is already a UnifiedLLMClient
logger.info(f"Using provided {client.get_provider_display_name()} for {stream_name}")
async for text, tool_calls, provider in client.stream(
messages=messages,
tools=tools,
system_prompt=system_prompt,
model=model,
max_tokens=max_tokens,
temperature=temperature
):
yield (text, tool_calls, provider)
async def execute_tool_calls(
tool_calls: List[Dict],
call_mcp_tool_func
) -> Tuple[str, List[Dict]]:
"""
Execute tool calls and collect results.
Consolidates duplicate tool execution logic.
Now includes self-correction hints for zero results.
Returns: (progress_text, tool_results_content)
"""
progress_text = ""
tool_results_content = []
zero_result_tools = []
for tool_call in tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["input"]
# Single, clean execution marker with search info
tool_display = tool_name.replace('__', ' → ')
# Show key search parameters
search_info = ""
if "query" in tool_args:
search_info = f" `{tool_args['query'][:50]}{'...' if len(tool_args['query']) > 50 else ''}`"
elif "condition" in tool_args:
search_info = f" `{tool_args['condition'][:50]}{'...' if len(tool_args['condition']) > 50 else ''}`"
progress_text += f"\n🔧 **Searching:** {tool_display}{search_info}\n"
# Call MCP tool
tool_result = await call_mcp_tool_func(tool_name, tool_args)
# Check for zero results to enable self-correction
if isinstance(tool_result, str):
result_lower = tool_result.lower()
if any(phrase in result_lower for phrase in [
"no results found", "0 results", "no papers found",
"no trials found", "no preprints found", "not found",
"zero results", "no matches"
]):
zero_result_tools.append((tool_name, tool_args))
# Add self-correction hint to the result
tool_result += "\n\n**SELF-CORRECTION HINT:** No results found with this query. Consider:\n"
tool_result += "1. Broadening search terms (remove qualifiers)\n"
tool_result += "2. Using alternative terminology or synonyms\n"
tool_result += "3. Searching related concepts\n"
tool_result += "4. Checking for typos in search terms"
# Add to results array
tool_results_content.append({
"type": "tool_result",
"tool_use_id": tool_call["id"],
"content": tool_result
})
return progress_text, tool_results_content
def build_assistant_message(
text_content: str,
tool_calls: List[Dict],
strip_markers: List[str] = None
) -> List[Dict]:
"""
Build assistant message content with text and tool uses.
Consolidates duplicate message building logic.
Args:
text_content: Text content to include
tool_calls: List of tool calls to include
strip_markers: Optional list of text markers to strip from content
Returns: List of content blocks for assistant message
"""
assistant_content = []
# Process text content
if text_content and text_content.strip():
processed_text = text_content
# Strip any specified markers
if strip_markers:
for marker in strip_markers:
processed_text = processed_text.replace(marker, "")
processed_text = processed_text.strip()
if processed_text:
assistant_content.append({
"type": "text",
"text": processed_text
})
# Add tool uses
for tc in tool_calls:
assistant_content.append({
"type": "tool_use",
"id": tc["id"],
"name": tc["name"],
"input": tc["input"]
})
return assistant_content
def should_continue_iterations(
iteration_count: int,
max_iterations: int,
tool_calls: List[Dict]
) -> bool:
"""
Check if tool iterations should continue.
Centralizes iteration control logic.
"""
if not tool_calls:
return False
if iteration_count >= max_iterations:
logger.warning(f"Reached maximum tool iterations ({max_iterations})")
return False
return True |