Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -442,38 +442,50 @@ def run_agent(
|
|
| 442 |
# MAIN INTERFACE FUNCTION
|
| 443 |
# =============================================================================
|
| 444 |
|
| 445 |
-
def
|
| 446 |
-
"""
|
| 447 |
-
messages
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
def convert_messages_to_tuples(messages: List[Dict]) -> List[Tuple]:
|
| 458 |
-
"""Convert messages format back to Gradio tuple format"""
|
| 459 |
-
tuples = []
|
| 460 |
-
i = 0
|
| 461 |
-
while i < len(messages):
|
| 462 |
-
user_msg = ""
|
| 463 |
-
bot_msg = ""
|
| 464 |
-
|
| 465 |
-
if i < len(messages) and messages[i]["role"] == "user":
|
| 466 |
-
user_msg = messages[i]["content"]
|
| 467 |
-
i += 1
|
| 468 |
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
|
| 479 |
def get_second_opinion(
|
|
@@ -482,49 +494,46 @@ def get_second_opinion(
|
|
| 482 |
provider: str,
|
| 483 |
model: str,
|
| 484 |
api_key: str,
|
| 485 |
-
chatbot_history
|
| 486 |
temperature: float = 0.7,
|
| 487 |
max_tokens: int = 4000
|
| 488 |
-
) -> Tuple[str, List[
|
| 489 |
"""
|
| 490 |
Get a second opinion from the AI agent using MCP tools
|
| 491 |
|
| 492 |
Returns:
|
| 493 |
-
Tuple of (response,
|
| 494 |
"""
|
|
|
|
|
|
|
|
|
|
| 495 |
if not api_key:
|
| 496 |
env_key = LLM_PROVIDERS.get(provider, {}).get("env_key", "")
|
| 497 |
api_key = os.environ.get(env_key, "")
|
| 498 |
if not api_key:
|
| 499 |
error_msg = f"⚠️ API key required. Please enter your {provider} API key or set {env_key} in HuggingFace Spaces Settings."
|
| 500 |
-
return error_msg,
|
| 501 |
|
| 502 |
client, error = get_client(provider, api_key)
|
| 503 |
if error:
|
| 504 |
-
return f"⚠️ Error: {error}",
|
| 505 |
-
|
| 506 |
-
# Convert tuple history to messages format for agent
|
| 507 |
-
messages_history = convert_tuples_to_messages(chatbot_history)
|
| 508 |
|
| 509 |
# Run the agent
|
| 510 |
-
response,
|
| 511 |
user_input=user_input,
|
| 512 |
persona=persona,
|
| 513 |
client=client,
|
| 514 |
provider=provider,
|
| 515 |
model=model,
|
| 516 |
-
conversation_history=
|
| 517 |
temperature=temperature,
|
| 518 |
max_tokens=max_tokens
|
| 519 |
)
|
| 520 |
|
| 521 |
-
# Convert back to tuple format for Gradio
|
| 522 |
-
updated_tuples = convert_messages_to_tuples(updated_messages)
|
| 523 |
-
|
| 524 |
# Format tool log for display
|
| 525 |
tool_log_display = "\n".join(tool_log) if tool_log else "No tools called"
|
| 526 |
|
| 527 |
-
return response,
|
| 528 |
|
| 529 |
|
| 530 |
# =============================================================================
|
|
@@ -912,7 +921,7 @@ def create_interface():
|
|
| 912 |
def chat_interaction(user_msg, persona, provider, model, api_key,
|
| 913 |
history, temp, max_tok):
|
| 914 |
if not user_msg.strip():
|
| 915 |
-
return history, "", "No input provided"
|
| 916 |
|
| 917 |
response, updated_history, tool_log = get_second_opinion(
|
| 918 |
user_msg,
|
|
@@ -920,16 +929,17 @@ def create_interface():
|
|
| 920 |
provider,
|
| 921 |
model,
|
| 922 |
api_key,
|
| 923 |
-
history
|
| 924 |
temperature=temp,
|
| 925 |
max_tokens=max_tok
|
| 926 |
)
|
| 927 |
|
| 928 |
if response.startswith("⚠️"):
|
| 929 |
-
# Error occurred - append to history
|
| 930 |
-
error_history =
|
| 931 |
-
error_history.append(
|
| 932 |
-
|
|
|
|
| 933 |
|
| 934 |
return updated_history, "", tool_log
|
| 935 |
|
|
|
|
| 442 |
# MAIN INTERFACE FUNCTION
|
| 443 |
# =============================================================================
|
| 444 |
|
| 445 |
+
def normalize_history(history) -> List[Dict]:
|
| 446 |
+
"""
|
| 447 |
+
Convert history to messages format regardless of input format.
|
| 448 |
+
Handles: None, empty list, tuples format, messages format
|
| 449 |
+
"""
|
| 450 |
+
if history is None or history == []:
|
| 451 |
+
return []
|
| 452 |
+
|
| 453 |
+
# Check if it's already in messages format
|
| 454 |
+
if isinstance(history, list) and len(history) > 0:
|
| 455 |
+
first_item = history[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
+
# Messages format: [{"role": "...", "content": "..."}, ...]
|
| 458 |
+
if isinstance(first_item, dict) and "role" in first_item:
|
| 459 |
+
return [msg for msg in history if isinstance(msg, dict) and "role" in msg and "content" in msg]
|
| 460 |
|
| 461 |
+
# Tuple format: [("user msg", "assistant msg"), ...]
|
| 462 |
+
if isinstance(first_item, (tuple, list)) and len(first_item) == 2:
|
| 463 |
+
messages = []
|
| 464 |
+
for item in history:
|
| 465 |
+
if isinstance(item, (tuple, list)) and len(item) == 2:
|
| 466 |
+
user_msg, assistant_msg = item
|
| 467 |
+
if user_msg:
|
| 468 |
+
messages.append({"role": "user", "content": str(user_msg)})
|
| 469 |
+
if assistant_msg:
|
| 470 |
+
messages.append({"role": "assistant", "content": str(assistant_msg)})
|
| 471 |
+
return messages
|
| 472 |
+
|
| 473 |
+
return []
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def format_history_for_gradio(messages: List[Dict]) -> List[Dict]:
|
| 477 |
+
"""
|
| 478 |
+
Format messages for Gradio Chatbot (messages format).
|
| 479 |
+
Ensures all messages have proper structure.
|
| 480 |
+
"""
|
| 481 |
+
result = []
|
| 482 |
+
for msg in messages:
|
| 483 |
+
if isinstance(msg, dict) and "role" in msg and "content" in msg:
|
| 484 |
+
result.append({
|
| 485 |
+
"role": str(msg["role"]),
|
| 486 |
+
"content": str(msg["content"]) if msg["content"] else ""
|
| 487 |
+
})
|
| 488 |
+
return result
|
| 489 |
|
| 490 |
|
| 491 |
def get_second_opinion(
|
|
|
|
| 494 |
provider: str,
|
| 495 |
model: str,
|
| 496 |
api_key: str,
|
| 497 |
+
chatbot_history, # Can be any format
|
| 498 |
temperature: float = 0.7,
|
| 499 |
max_tokens: int = 4000
|
| 500 |
+
) -> Tuple[str, List[Dict], str]:
|
| 501 |
"""
|
| 502 |
Get a second opinion from the AI agent using MCP tools
|
| 503 |
|
| 504 |
Returns:
|
| 505 |
+
Tuple of (response, updated_history, tool_log_display)
|
| 506 |
"""
|
| 507 |
+
# Normalize history to messages format
|
| 508 |
+
history = normalize_history(chatbot_history)
|
| 509 |
+
|
| 510 |
if not api_key:
|
| 511 |
env_key = LLM_PROVIDERS.get(provider, {}).get("env_key", "")
|
| 512 |
api_key = os.environ.get(env_key, "")
|
| 513 |
if not api_key:
|
| 514 |
error_msg = f"⚠️ API key required. Please enter your {provider} API key or set {env_key} in HuggingFace Spaces Settings."
|
| 515 |
+
return error_msg, format_history_for_gradio(history), ""
|
| 516 |
|
| 517 |
client, error = get_client(provider, api_key)
|
| 518 |
if error:
|
| 519 |
+
return f"⚠️ Error: {error}", format_history_for_gradio(history), ""
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
# Run the agent
|
| 522 |
+
response, updated_history, tool_log = run_agent(
|
| 523 |
user_input=user_input,
|
| 524 |
persona=persona,
|
| 525 |
client=client,
|
| 526 |
provider=provider,
|
| 527 |
model=model,
|
| 528 |
+
conversation_history=history,
|
| 529 |
temperature=temperature,
|
| 530 |
max_tokens=max_tokens
|
| 531 |
)
|
| 532 |
|
|
|
|
|
|
|
|
|
|
| 533 |
# Format tool log for display
|
| 534 |
tool_log_display = "\n".join(tool_log) if tool_log else "No tools called"
|
| 535 |
|
| 536 |
+
return response, format_history_for_gradio(updated_history), tool_log_display
|
| 537 |
|
| 538 |
|
| 539 |
# =============================================================================
|
|
|
|
| 921 |
def chat_interaction(user_msg, persona, provider, model, api_key,
|
| 922 |
history, temp, max_tok):
|
| 923 |
if not user_msg.strip():
|
| 924 |
+
return format_history_for_gradio(normalize_history(history)), "", "No input provided"
|
| 925 |
|
| 926 |
response, updated_history, tool_log = get_second_opinion(
|
| 927 |
user_msg,
|
|
|
|
| 929 |
provider,
|
| 930 |
model,
|
| 931 |
api_key,
|
| 932 |
+
history,
|
| 933 |
temperature=temp,
|
| 934 |
max_tokens=max_tok
|
| 935 |
)
|
| 936 |
|
| 937 |
if response.startswith("⚠️"):
|
| 938 |
+
# Error occurred - append error to history
|
| 939 |
+
error_history = normalize_history(history)
|
| 940 |
+
error_history.append({"role": "user", "content": user_msg})
|
| 941 |
+
error_history.append({"role": "assistant", "content": response})
|
| 942 |
+
return format_history_for_gradio(error_history), "", tool_log or "Error occurred"
|
| 943 |
|
| 944 |
return updated_history, "", tool_log
|
| 945 |
|