MudabbirAI / agent_logic.py
youssefleb's picture
Update agent_logic.py
aa593cf verified
# agent_logic.py (Fixed: OpenAI and Nebius Initialization + Robust Unpacking)
import asyncio
from typing import AsyncGenerator, Dict, Optional, List, Tuple
import json
import os
import uuid
import datetime
import google.generativeai as genai
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
import re
from personas import PERSONAS_DATA
import config
from utils import load_prompt
from mcp_servers import AgentCalibrator, BusinessSolutionEvaluator, get_llm_response
from self_correction import SelfCorrector
CLASSIFIER_SYSTEM_PROMPT = load_prompt(config.PROMPT_FILES["classifier"])
HOMOGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_homogeneous"])
HETEROGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_heterogeneous"])
METRIC_MAPPING = {
"novelty": "Novelty",
"usefulness": "Usefulness_Feasibility",
"feasibility": "Usefulness_Feasibility",
"usefulness_feasibility": "Usefulness_Feasibility",
"usefulness/feasibility": "Usefulness_Feasibility",
"flexibility": "Flexibility",
"elaboration": "Elaboration",
"cultural_appropriateness": "Cultural_Appropriateness",
"cultural_sensitivity": "Cultural_Appropriateness",
"cultural appropriateness": "Cultural_Appropriateness",
"cultural appropriateness/sensitivity": "Cultural_Appropriateness"
}
# --- HELPER CLASSES ---
class Baseline_Single_Agent:
def __init__(self, api_clients: dict):
self.gemini_client = api_clients.get("Gemini")
async def solve(self, problem: str, persona_prompt: str) -> Tuple[str, dict]:
if not self.gemini_client: raise ValueError("Single_Agent requires a Google/Gemini client.")
return await get_llm_response("Gemini", self.gemini_client, persona_prompt, problem)
class Baseline_Static_Homogeneous:
def __init__(self, api_clients: dict):
self.api_clients = {name: client for name, client in api_clients.items() if client}
self.gemini_client = api_clients.get("Gemini")
async def solve(self, problem: str, persona_prompt: str) -> Tuple[str, List[dict]]:
if not self.gemini_client: raise ValueError("Homogeneous_Team requires a Google/Gemini client.")
system_prompt = persona_prompt
user_prompt = f"As an expert Implementer, generate a detailed plan for this problem: {problem}"
tasks = [get_llm_response(llm, client, system_prompt, user_prompt) for llm, client in self.api_clients.items()]
results = await asyncio.gather(*tasks)
responses = [r[0] for r in results]
usages = [r[1] for r in results]
manager_system_prompt = HOMOGENEOUS_MANAGER_PROMPT
reports_str = "\n\n".join(f"Report from Team Member {i+1}:\n{resp}" for i, resp in enumerate(responses))
manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these reports into one final, comprehensive solution."
final_text, final_usage = await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
usages.append(final_usage)
return final_text, usages
class Baseline_Static_Heterogeneous:
def __init__(self, api_clients: dict):
self.api_clients = api_clients
self.gemini_client = api_clients.get("Gemini")
async def solve(self, problem: str, team_plan: dict) -> Tuple[str, List[dict]]:
if not self.gemini_client: raise ValueError("Heterogeneous_Team requires a Google/Gemini client.")
tasks = []
for role, config_data in team_plan.items():
llm_name = config_data["llm"]
persona_key = config_data["persona"]
client = self.api_clients.get(llm_name)
if not client:
llm_name = "Gemini"
client = self.gemini_client
system_prompt = PERSONAS_DATA[persona_key]["description"]
user_prompt = f"As the team's '{role}', provide your unique perspective on how to solve this problem: {problem}"
tasks.append(get_llm_response(llm_name, client, system_prompt, user_prompt))
results = await asyncio.gather(*tasks)
responses = [r[0] for r in results]
usages = [r[1] for r in results]
manager_system_prompt = HETEROGENEOUS_MANAGER_PROMPT
reports_str = "\n\n".join(f"Report from {team_plan[role]['llm']} (as {role}):\n{resp}" for (role, resp) in zip(team_plan.keys(), responses))
manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these specialist reports into one final, comprehensive solution."
final_text, final_usage = await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
usages.append(final_usage)
return final_text, usages
# --- MAIN AGENT ---
class StrategicSelectorAgent:
def __init__(self, api_keys: Dict[str, Optional[str]]):
self.api_keys = api_keys
# Initialize potential clients including new providers
self.api_clients = {
"Gemini": None,
"Anthropic": None,
"SambaNova": None,
"OpenAI": None,
"Nebius": None
}
# --- INIT GEMINI ---
if api_keys.get("google") and api_keys["google"].strip():
try:
genai.configure(api_key=api_keys["google"])
self.api_clients["Gemini"] = genai.GenerativeModel(config.MODELS["Gemini"]["default"])
except Exception as e: print(f"Warning: Gemini init failed: {e}")
# --- INIT ANTHROPIC ---
if api_keys.get("anthropic") and api_keys["anthropic"].strip():
try:
self.api_clients["Anthropic"] = AsyncAnthropic(api_key=api_keys["anthropic"])
except Exception as e: print(f"Warning: Anthropic init failed: {e}")
# --- INIT SAMBANOVA ---
if api_keys.get("sambanova") and api_keys["sambanova"].strip():
try:
self.api_clients["SambaNova"] = AsyncOpenAI(
api_key=api_keys["sambanova"],
base_url=os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1")
)
except Exception as e: print(f"Warning: SambaNova init failed: {e}")
# --- INIT OPENAI (NEW) ---
if api_keys.get("openai") and api_keys["openai"].strip():
try:
self.api_clients["OpenAI"] = AsyncOpenAI(api_key=api_keys["openai"])
except Exception as e: print(f"Warning: OpenAI init failed: {e}")
# --- INIT NEBIUS (NEW) ---
if api_keys.get("nebius") and api_keys["nebius"].strip():
try:
self.api_clients["Nebius"] = AsyncOpenAI(
api_key=api_keys["nebius"],
base_url="https://api.studio.nebius.ai/v1/"
)
except Exception as e: print(f"Warning: Nebius init failed: {e}")
if not self.api_clients["Gemini"]:
raise ValueError("Google API Key is required.")
self.evaluator = BusinessSolutionEvaluator(self.api_clients["Gemini"])
self.calibrator = AgentCalibrator(self.api_clients, self.evaluator)
self.corrector = SelfCorrector(threshold=3.0)
self.single_agent = Baseline_Single_Agent(self.api_clients)
self.homo_team = Baseline_Static_Homogeneous(self.api_clients)
self.hetero_team = Baseline_Static_Heterogeneous(self.api_clients)
self.current_team_plan = None
if "ERROR:" in CLASSIFIER_SYSTEM_PROMPT: raise FileNotFoundError(CLASSIFIER_SYSTEM_PROMPT)
async def solve(self, problem: str) -> AsyncGenerator[str, None]:
run_id = str(uuid.uuid4())[:8]
# Initialize Financial Tracking
financial_report = {
"calibration_cost": 0.0,
"generation_cost": 0.0,
"total_cost": 0.0,
"usage_breakdown": []
}
debug_log = {
"run_id": run_id,
"timestamp": datetime.datetime.now().isoformat(),
"problem": problem,
"classification": "",
"trace": [],
"financial_report": financial_report
}
# Helper to add usage and calculate cost
def add_usage(usage_list):
if isinstance(usage_list, dict): usage_list = [usage_list]
current_step_cost = 0.0
for u in usage_list:
financial_report["usage_breakdown"].append(u)
# Lookup pricing
model_name = u.get("model", "Gemini")
# Safely get pricing with default fallbacks
pricing = config.PRICING.get(model_name, {"input": 0, "output": 0})
# Calculate Cost
cost = (u.get("input", 0) / 1_000_000 * pricing["input"]) + \
(u.get("output", 0) / 1_000_000 * pricing["output"])
financial_report["total_cost"] += cost
current_step_cost += cost
return current_step_cost
try:
yield "Classifying problem archetype (live)..."
classification, cls_usage = await get_llm_response("Gemini", self.api_clients["Gemini"], CLASSIFIER_SYSTEM_PROMPT, problem)
classification = classification.strip().replace("\"", "")
yield f"Diagnosis: {classification}"
add_usage(cls_usage)
debug_log["classification"] = classification
if "Error generating response" in classification:
yield "Classifier failed. Defaulting to Single Agent."
classification = "Direct_Procedure"
solution_draft = ""
v_fitness_json = {}
scores = {}
# --- MAIN LOOP ---
for i in range(2):
current_problem = problem
if i > 0:
yield f"--- (Loop {i}) Score is too low. Initiating Self-Correction... ---"
correction_prompt_text = self.corrector.get_correction_plan(v_fitness_json)
yield f"Diagnosis: {correction_prompt_text.splitlines()[3].strip()}"
current_problem = f"{problem}\n\n{correction_prompt_text}"
debug_log["trace"].append({
"step_type": "correction_plan",
"loop_index": i,
"prompt": correction_prompt_text
})
# --- DEPLOY ---
default_persona = PERSONAS_DATA[config.DEFAULT_PERSONA_KEY]["description"]
current_usages = []
if classification == "Direct_Procedure" or classification == "Holistic_Abstract_Reasoning":
if i == 0: yield "Deploying: Baseline Single Agent (Simplicity Hypothesis)..."
solution_draft, u = await self.single_agent.solve(current_problem, default_persona)
current_usages.append(u)
elif classification == "Local_Geometric_Procedural":
if i == 0: yield "Deploying: Static Homogeneous Team (Expert Anomaly)..."
solution_draft, u_list = await self.homo_team.solve(current_problem, default_persona)
current_usages.extend(u_list)
elif classification == "Cognitive_Labyrinth":
if i == 0:
yield "Deploying: Static Heterogeneous Team (Cognitive Diversity)..."
# --- UNPACK 4 VALUES FROM CALIBRATOR ---
# Safely call and unpack, providing debugging if it fails
calib_result = await self.calibrator.calibrate_team(current_problem)
if calib_result is None:
raise ValueError("CRITICAL ERROR: calibrate_team returned None. Please verify mcp_servers.py on the server.")
if len(calib_result) != 4:
# Fallback logic if server has old mcp_servers.py (e.g. 2 or 3 values)
if len(calib_result) == 2:
team_plan, calibration_errors = calib_result
calib_details, calib_usage = [], [] # Defaults
elif len(calib_result) == 3:
team_plan, calibration_errors, calib_details = calib_result
calib_usage = [] # Default
else:
raise ValueError(f"Calibrator returned {len(calib_result)} values, expected 4.")
else:
team_plan, calibration_errors, calib_details, calib_usage = calib_result
# Track Calibration Cost explicitly
calib_step_cost = add_usage(calib_usage)
financial_report["calibration_cost"] += calib_step_cost
debug_log["trace"].append({
"step_type": "calibration",
"details": calib_details,
"errors": calibration_errors,
"selected_plan": team_plan
})
if calibration_errors:
yield "--- CALIBRATION WARNINGS ---"
for err in calibration_errors: yield err
yield "-----------------------------"
yield f"Calibration complete. Best Team: {json.dumps({k: v['llm'] for k, v in team_plan.items()})}"
self.current_team_plan = team_plan
solution_draft, u_list = await self.hetero_team.solve(current_problem, self.current_team_plan)
current_usages.extend(u_list)
else:
if i == 0: yield f"Diagnosis '{classification}' is unknown. Defaulting to Single Agent."
solution_draft, u = await self.single_agent.solve(current_problem, default_persona)
current_usages.append(u)
add_usage(current_usages)
if "Error generating response" in solution_draft:
raise Exception(f"The specialist team failed to generate a solution. Error: {solution_draft}")
yield f"Draft solution received: '{solution_draft[:60]}...'"
# --- EVALUATE ---
yield "Evaluating draft (live)..."
v_fitness_json, eval_usage = await self.evaluator.evaluate(current_problem, solution_draft)
add_usage(eval_usage)
if isinstance(v_fitness_json, list):
if len(v_fitness_json) > 0 and isinstance(v_fitness_json[0], dict):
v_fitness_json = v_fitness_json[0]
else:
v_fitness_json = {}
normalized_fitness = {}
if isinstance(v_fitness_json, dict):
for k, v in v_fitness_json.items():
canonical_key = None
clean_k = k.lower().strip()
if clean_k in METRIC_MAPPING: canonical_key = METRIC_MAPPING[clean_k]
if not canonical_key: continue
if isinstance(v, dict):
score_value = v.get('score')
justification_value = v.get('justification', str(v))
elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
score_value = v[0].get('score')
justification_value = v[0].get('justification', str(v[0]))
else:
score_value = v
justification_value = "Score extracted directly."
if isinstance(score_value, str):
try:
match = re.search(r'\d+', score_value)
score_value = int(match.group()) if match else 0
except:
score_value = 0
try:
score_value = int(score_value)
except (ValueError, TypeError):
score_value = 0
normalized_fitness[canonical_key] = {'score': score_value, 'justification': justification_value}
else:
normalized_fitness = {k: {'score': 0, 'justification': "Invalid JSON structure"} for k in ["Novelty", "Usefulness_Feasibility", "Flexibility", "Elaboration", "Cultural_Appropriateness"]}
v_fitness_json = normalized_fitness
scores = {k: v.get('score', 0) for k, v in v_fitness_json.items()}
yield f"Evaluation Score: {scores}"
debug_log["trace"].append({
"step_type": "attempt",
"loop_index": i,
"draft": solution_draft,
"scores": scores,
"full_evaluation": v_fitness_json
})
if scores.get('Novelty', 0) <= 1:
yield f"⚠️ Low Score Detected. Reason: {v_fitness_json.get('Novelty', {}).get('justification', 'Unknown')}"
if self.corrector.is_good_enough(scores):
yield "--- Solution approved by self-corrector. ---"
break
elif i == 1:
yield "--- Max correction loops reached. Accepting best effort. ---"
# --- FINALIZE ---
financial_report["generation_cost"] = financial_report["total_cost"] - financial_report["calibration_cost"]
debug_log["financial_report"] = financial_report
await asyncio.sleep(0.5)
solution_draft_json_safe = json.dumps(solution_draft)
debug_log_json_safe = json.dumps(debug_log)
yield f"FINAL: {{\"text\": {solution_draft_json_safe}, \"audio\": null, \"log\": {debug_log_json_safe}}}"
except Exception as e:
error_msg = f"An error occurred in the agent's solve loop: {e}"
print(error_msg)
debug_log["error"] = str(e)
yield error_msg
finally:
try:
os.makedirs("logs", exist_ok=True)
log_path = f"logs/run_{run_id}.json"
with open(log_path, "w", encoding="utf-8") as f:
json.dump(debug_log, f, indent=2)
print(f"Detailed execution log saved to {log_path}")
except Exception as log_err:
print(f"Failed to save log: {log_err}")