youssefleb commited on
Commit
762edf7
·
verified ·
1 Parent(s): a4f27db

Update agent_logic.py

Browse files
Files changed (1) hide show
  1. agent_logic.py +87 -46
agent_logic.py CHANGED
@@ -1,4 +1,4 @@
1
- # agent_logic.py (Milestone 5 - FINAL & ROBUST + LOGGING + NATURAL TEXT + ALLOWLIST FILTER)
2
  import asyncio
3
  from typing import AsyncGenerator, Dict, Optional
4
  import json
@@ -12,7 +12,6 @@ import re
12
  from personas import PERSONAS_DATA
13
  import config
14
  from utils import load_prompt
15
- # Removed extract_json_str as we no longer need to parse the solution
16
  from mcp_servers import AgentCalibrator, BusinessSolutionEvaluator, get_llm_response
17
  from self_correction import SelfCorrector
18
 
@@ -20,9 +19,6 @@ CLASSIFIER_SYSTEM_PROMPT = load_prompt(config.PROMPT_FILES["classifier"])
20
  HOMOGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_homogeneous"])
21
  HETEROGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_heterogeneous"])
22
 
23
- # --- METRIC BOUNCER (Allowlist) ---
24
- # We map any variation of the key to the canonical internal name.
25
- # If a key isn't in here, it gets dropped.
26
  METRIC_MAPPING = {
27
  "novelty": "Novelty",
28
  "usefulness": "Usefulness_Feasibility",
@@ -42,6 +38,7 @@ class Baseline_Single_Agent:
42
  self.gemini_client = api_clients.get("Gemini")
43
  async def solve(self, problem: str, persona_prompt: str):
44
  if not self.gemini_client: raise ValueError("Single_Agent requires a Google/Gemini client.")
 
45
  return await get_llm_response("Gemini", self.gemini_client, persona_prompt, problem)
46
 
47
  class Baseline_Static_Homogeneous:
@@ -55,14 +52,19 @@ class Baseline_Static_Homogeneous:
55
  user_prompt = f"As an expert Implementer, generate a detailed plan for this problem: {problem}"
56
 
57
  tasks = [get_llm_response(llm, client, system_prompt, user_prompt) for llm, client in self.api_clients.items()]
58
- responses = await asyncio.gather(*tasks)
 
 
 
59
 
60
  manager_system_prompt = HOMOGENEOUS_MANAGER_PROMPT
61
  reports_str = "\n\n".join(f"Report from Team Member {i+1}:\n{resp}" for i, resp in enumerate(responses))
62
-
63
  manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these reports into one final, comprehensive solution."
64
 
65
- return await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
 
 
 
66
 
67
  class Baseline_Static_Heterogeneous:
68
  def __init__(self, api_clients: dict):
@@ -83,14 +85,19 @@ class Baseline_Static_Heterogeneous:
83
  user_prompt = f"As the team's '{role}', provide your unique perspective on how to solve this problem: {problem}"
84
  tasks.append(get_llm_response(llm_name, client, system_prompt, user_prompt))
85
 
86
- responses = await asyncio.gather(*tasks)
 
 
87
 
88
  manager_system_prompt = HETEROGENEOUS_MANAGER_PROMPT
89
  reports_str = "\n\n".join(f"Report from {team_plan[role]['llm']} (as {role}):\n{resp}" for (role, resp) in zip(team_plan.keys(), responses))
90
 
91
  manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these specialist reports into one final, comprehensive solution."
92
 
93
- return await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
 
 
 
94
 
95
  class StrategicSelectorAgent:
96
  def __init__(self, api_keys: Dict[str, Optional[str]]):
@@ -124,30 +131,56 @@ class StrategicSelectorAgent:
124
 
125
  if "ERROR:" in CLASSIFIER_SYSTEM_PROMPT: raise FileNotFoundError(CLASSIFIER_SYSTEM_PROMPT)
126
 
127
- async def _classify_problem(self, problem: str) -> AsyncGenerator[str, None]:
128
- yield "Classifying problem archetype (live)..."
129
- classification = await get_llm_response("Gemini", self.api_clients["Gemini"], CLASSIFIER_SYSTEM_PROMPT, problem)
130
- classification = classification.strip().replace("\"", "")
131
- yield f"Diagnosis: {classification}"
132
 
133
  async def solve(self, problem: str) -> AsyncGenerator[str, None]:
134
- # --- 1. Initialize Logging ---
135
  run_id = str(uuid.uuid4())[:8]
 
 
 
 
 
 
 
 
 
136
  debug_log = {
137
  "run_id": run_id,
138
  "timestamp": datetime.datetime.now().isoformat(),
139
  "problem": problem,
140
  "classification": "",
141
- "trace": []
 
142
  }
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  try:
145
- classification_generator = self._classify_problem(problem)
146
- classification = ""
147
- async for status_update in classification_generator:
148
- yield status_update
149
- if "Diagnosis: " in status_update:
150
- classification = status_update.split(": ")[-1]
151
 
152
  debug_log["classification"] = classification
153
 
@@ -159,7 +192,7 @@ class StrategicSelectorAgent:
159
  v_fitness_json = {}
160
  scores = {}
161
 
162
- # --- MAIN LOOP (Self-Correction) ---
163
  for i in range(2):
164
  current_problem = problem
165
  if i > 0:
@@ -176,20 +209,30 @@ class StrategicSelectorAgent:
176
 
177
  # --- DEPLOY ---
178
  default_persona = PERSONAS_DATA[config.DEFAULT_PERSONA_KEY]["description"]
 
179
 
180
  if classification == "Direct_Procedure" or classification == "Holistic_Abstract_Reasoning":
181
  if i == 0: yield "Deploying: Baseline Single Agent (Simplicity Hypothesis)..."
182
- solution_draft = await self.single_agent.solve(current_problem, default_persona)
 
183
 
184
  elif classification == "Local_Geometric_Procedural":
185
  if i == 0: yield "Deploying: Static Homogeneous Team (Expert Anomaly)..."
186
- solution_draft = await self.homo_team.solve(current_problem, default_persona)
 
187
 
188
  elif classification == "Cognitive_Labyrinth":
189
  if i == 0:
190
  yield "Deploying: Static Heterogeneous Team (Cognitive Diversity)..."
191
- team_plan, calibration_errors, calib_details = await self.calibrator.calibrate_team(current_problem)
192
 
 
 
 
 
 
 
 
 
193
  debug_log["trace"].append({
194
  "step_type": "calibration",
195
  "details": calib_details,
@@ -204,11 +247,16 @@ class StrategicSelectorAgent:
204
  yield f"Calibration complete. Best Team: {json.dumps({k: v['llm'] for k, v in team_plan.items()})}"
205
  self.current_team_plan = team_plan
206
 
207
- solution_draft = await self.hetero_team.solve(current_problem, self.current_team_plan)
 
208
 
209
  else:
210
  if i == 0: yield f"Diagnosis '{classification}' is unknown. Defaulting to Single Agent."
211
- solution_draft = await self.single_agent.solve(current_problem, default_persona)
 
 
 
 
212
 
213
  if "Error generating response" in solution_draft:
214
  raise Exception(f"The specialist team failed to generate a solution. Error: {solution_draft}")
@@ -218,32 +266,25 @@ class StrategicSelectorAgent:
218
  # --- EVALUATE ---
219
  yield "Evaluating draft (live)..."
220
 
221
- v_fitness_json = await self.evaluator.evaluate(current_problem, solution_draft)
 
222
 
223
- # --- Safety Check for List ---
224
  if isinstance(v_fitness_json, list):
225
  if len(v_fitness_json) > 0 and isinstance(v_fitness_json[0], dict):
226
  v_fitness_json = v_fitness_json[0]
227
  else:
228
  v_fitness_json = {}
229
 
230
- # --- ROBUST NORMALIZATION WITH ALLOWLIST FILTER ---
231
  normalized_fitness = {}
232
  if isinstance(v_fitness_json, dict):
233
  for k, v in v_fitness_json.items():
234
- # 1. Map fuzzy keys to canonical keys
235
  canonical_key = None
236
  clean_k = k.lower().strip()
237
-
238
- # Check exact match or known variation
239
- if clean_k in METRIC_MAPPING:
240
- canonical_key = METRIC_MAPPING[clean_k]
241
-
242
- # If we couldn't map it to a valid metric, SKIP IT.
243
- if not canonical_key:
244
- continue
245
 
246
- # 2. Extract Score Value
247
  if isinstance(v, dict):
248
  score_value = v.get('score')
249
  justification_value = v.get('justification', str(v))
@@ -251,14 +292,11 @@ class StrategicSelectorAgent:
251
  score_value = v[0].get('score')
252
  justification_value = v[0].get('justification', str(v[0]))
253
  else:
254
- # Flat value case
255
  score_value = v
256
  justification_value = "Score extracted directly."
257
 
258
- # 3. Clean Score (handle "4/5" strings)
259
  if isinstance(score_value, str):
260
  try:
261
- # Looks for the first number in the string
262
  match = re.search(r'\d+', score_value)
263
  score_value = int(match.group()) if match else 0
264
  except:
@@ -271,7 +309,6 @@ class StrategicSelectorAgent:
271
 
272
  normalized_fitness[canonical_key] = {'score': score_value, 'justification': justification_value}
273
  else:
274
- # Fallback for total failure
275
  normalized_fitness = {k: {'score': 0, 'justification': "Invalid JSON structure"} for k in ["Novelty", "Usefulness_Feasibility", "Flexibility", "Elaboration", "Cultural_Appropriateness"]}
276
 
277
  v_fitness_json = normalized_fitness
@@ -296,6 +333,10 @@ class StrategicSelectorAgent:
296
  yield "--- Max correction loops reached. Accepting best effort. ---"
297
 
298
  # --- FINALIZE ---
 
 
 
 
299
  await asyncio.sleep(0.5)
300
  yield "Milestone 5 Complete. Self-Correction loop is live."
301
 
 
1
+ # agent_logic.py (Milestone 5 - FINAL & ROBUST + LOGGING + COST TRACKING)
2
  import asyncio
3
  from typing import AsyncGenerator, Dict, Optional
4
  import json
 
12
  from personas import PERSONAS_DATA
13
  import config
14
  from utils import load_prompt
 
15
  from mcp_servers import AgentCalibrator, BusinessSolutionEvaluator, get_llm_response
16
  from self_correction import SelfCorrector
17
 
 
19
  HOMOGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_homogeneous"])
20
  HETEROGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_heterogeneous"])
21
 
 
 
 
22
  METRIC_MAPPING = {
23
  "novelty": "Novelty",
24
  "usefulness": "Usefulness_Feasibility",
 
38
  self.gemini_client = api_clients.get("Gemini")
39
  async def solve(self, problem: str, persona_prompt: str):
40
  if not self.gemini_client: raise ValueError("Single_Agent requires a Google/Gemini client.")
41
+ # Returns (text, usage)
42
  return await get_llm_response("Gemini", self.gemini_client, persona_prompt, problem)
43
 
44
  class Baseline_Static_Homogeneous:
 
52
  user_prompt = f"As an expert Implementer, generate a detailed plan for this problem: {problem}"
53
 
54
  tasks = [get_llm_response(llm, client, system_prompt, user_prompt) for llm, client in self.api_clients.items()]
55
+ results = await asyncio.gather(*tasks)
56
+
57
+ responses = [r[0] for r in results]
58
+ usages = [r[1] for r in results]
59
 
60
  manager_system_prompt = HOMOGENEOUS_MANAGER_PROMPT
61
  reports_str = "\n\n".join(f"Report from Team Member {i+1}:\n{resp}" for i, resp in enumerate(responses))
 
62
  manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these reports into one final, comprehensive solution."
63
 
64
+ final_text, final_usage = await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
65
+ usages.append(final_usage)
66
+
67
+ return final_text, usages
68
 
69
  class Baseline_Static_Heterogeneous:
70
  def __init__(self, api_clients: dict):
 
85
  user_prompt = f"As the team's '{role}', provide your unique perspective on how to solve this problem: {problem}"
86
  tasks.append(get_llm_response(llm_name, client, system_prompt, user_prompt))
87
 
88
+ results = await asyncio.gather(*tasks)
89
+ responses = [r[0] for r in results]
90
+ usages = [r[1] for r in results]
91
 
92
  manager_system_prompt = HETEROGENEOUS_MANAGER_PROMPT
93
  reports_str = "\n\n".join(f"Report from {team_plan[role]['llm']} (as {role}):\n{resp}" for (role, resp) in zip(team_plan.keys(), responses))
94
 
95
  manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these specialist reports into one final, comprehensive solution."
96
 
97
+ final_text, final_usage = await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
98
+ usages.append(final_usage)
99
+
100
+ return final_text, usages
101
 
102
  class StrategicSelectorAgent:
103
  def __init__(self, api_keys: Dict[str, Optional[str]]):
 
131
 
132
  if "ERROR:" in CLASSIFIER_SYSTEM_PROMPT: raise FileNotFoundError(CLASSIFIER_SYSTEM_PROMPT)
133
 
134
+ # Removed unused _classify_problem generator to cleaner code structure in solve()
 
 
 
 
135
 
136
  async def solve(self, problem: str) -> AsyncGenerator[str, None]:
 
137
  run_id = str(uuid.uuid4())[:8]
138
+
139
+ # Initialize Financial Tracking
140
+ financial_report = {
141
+ "calibration_cost": 0.0,
142
+ "generation_cost": 0.0,
143
+ "total_cost": 0.0,
144
+ "usage_breakdown": []
145
+ }
146
+
147
  debug_log = {
148
  "run_id": run_id,
149
  "timestamp": datetime.datetime.now().isoformat(),
150
  "problem": problem,
151
  "classification": "",
152
+ "trace": [],
153
+ "financial_report": financial_report
154
  }
155
 
156
+ # Helper to add usage and calculate cost
157
+ def add_usage(usage_list):
158
+ if isinstance(usage_list, dict): usage_list = [usage_list]
159
+
160
+ current_step_cost = 0.0
161
+ for u in usage_list:
162
+ financial_report["usage_breakdown"].append(u)
163
+
164
+ # Lookup pricing
165
+ model_name = u.get("model", "Gemini")
166
+ # Default to 0 if model not found in config
167
+ pricing = config.PRICING.get(model_name, {"input": 0, "output": 0})
168
+
169
+ # Calculate Cost: (Tokens / 1M) * Price
170
+ cost = (u.get("input", 0) / 1_000_000 * pricing["input"]) + \
171
+ (u.get("output", 0) / 1_000_000 * pricing["output"])
172
+
173
+ financial_report["total_cost"] += cost
174
+ current_step_cost += cost
175
+ return current_step_cost
176
+
177
  try:
178
+ yield "Classifying problem archetype (live)..."
179
+ # Get classification and its usage
180
+ classification, cls_usage = await get_llm_response("Gemini", self.api_clients["Gemini"], CLASSIFIER_SYSTEM_PROMPT, problem)
181
+ classification = classification.strip().replace("\"", "")
182
+ yield f"Diagnosis: {classification}"
183
+ add_usage(cls_usage)
184
 
185
  debug_log["classification"] = classification
186
 
 
192
  v_fitness_json = {}
193
  scores = {}
194
 
195
+ # --- MAIN LOOP ---
196
  for i in range(2):
197
  current_problem = problem
198
  if i > 0:
 
209
 
210
  # --- DEPLOY ---
211
  default_persona = PERSONAS_DATA[config.DEFAULT_PERSONA_KEY]["description"]
212
+ current_usages = [] # Track usage for this specific generation step
213
 
214
  if classification == "Direct_Procedure" or classification == "Holistic_Abstract_Reasoning":
215
  if i == 0: yield "Deploying: Baseline Single Agent (Simplicity Hypothesis)..."
216
+ solution_draft, u = await self.single_agent.solve(current_problem, default_persona)
217
+ current_usages.append(u)
218
 
219
  elif classification == "Local_Geometric_Procedural":
220
  if i == 0: yield "Deploying: Static Homogeneous Team (Expert Anomaly)..."
221
+ solution_draft, u_list = await self.homo_team.solve(current_problem, default_persona)
222
+ current_usages.extend(u_list)
223
 
224
  elif classification == "Cognitive_Labyrinth":
225
  if i == 0:
226
  yield "Deploying: Static Heterogeneous Team (Cognitive Diversity)..."
 
227
 
228
+ # --- UNPACK 4 VALUES FROM CALIBRATOR ---
229
+ # (Plan, Errors, Details, UsageStats)
230
+ team_plan, calibration_errors, calib_details, calib_usage = await self.calibrator.calibrate_team(current_problem)
231
+
232
+ # Track Calibration Cost explicitly
233
+ calib_step_cost = add_usage(calib_usage)
234
+ financial_report["calibration_cost"] += calib_step_cost
235
+
236
  debug_log["trace"].append({
237
  "step_type": "calibration",
238
  "details": calib_details,
 
247
  yield f"Calibration complete. Best Team: {json.dumps({k: v['llm'] for k, v in team_plan.items()})}"
248
  self.current_team_plan = team_plan
249
 
250
+ solution_draft, u_list = await self.hetero_team.solve(current_problem, self.current_team_plan)
251
+ current_usages.extend(u_list)
252
 
253
  else:
254
  if i == 0: yield f"Diagnosis '{classification}' is unknown. Defaulting to Single Agent."
255
+ solution_draft, u = await self.single_agent.solve(current_problem, default_persona)
256
+ current_usages.append(u)
257
+
258
+ # Add generation usage to total
259
+ add_usage(current_usages)
260
 
261
  if "Error generating response" in solution_draft:
262
  raise Exception(f"The specialist team failed to generate a solution. Error: {solution_draft}")
 
266
  # --- EVALUATE ---
267
  yield "Evaluating draft (live)..."
268
 
269
+ v_fitness_json, eval_usage = await self.evaluator.evaluate(current_problem, solution_draft)
270
+ add_usage(eval_usage)
271
 
272
+ # Safety Check
273
  if isinstance(v_fitness_json, list):
274
  if len(v_fitness_json) > 0 and isinstance(v_fitness_json[0], dict):
275
  v_fitness_json = v_fitness_json[0]
276
  else:
277
  v_fitness_json = {}
278
 
279
+ # Normalization with Allowlist
280
  normalized_fitness = {}
281
  if isinstance(v_fitness_json, dict):
282
  for k, v in v_fitness_json.items():
 
283
  canonical_key = None
284
  clean_k = k.lower().strip()
285
+ if clean_k in METRIC_MAPPING: canonical_key = METRIC_MAPPING[clean_k]
286
+ if not canonical_key: continue
 
 
 
 
 
 
287
 
 
288
  if isinstance(v, dict):
289
  score_value = v.get('score')
290
  justification_value = v.get('justification', str(v))
 
292
  score_value = v[0].get('score')
293
  justification_value = v[0].get('justification', str(v[0]))
294
  else:
 
295
  score_value = v
296
  justification_value = "Score extracted directly."
297
 
 
298
  if isinstance(score_value, str):
299
  try:
 
300
  match = re.search(r'\d+', score_value)
301
  score_value = int(match.group()) if match else 0
302
  except:
 
309
 
310
  normalized_fitness[canonical_key] = {'score': score_value, 'justification': justification_value}
311
  else:
 
312
  normalized_fitness = {k: {'score': 0, 'justification': "Invalid JSON structure"} for k in ["Novelty", "Usefulness_Feasibility", "Flexibility", "Elaboration", "Cultural_Appropriateness"]}
313
 
314
  v_fitness_json = normalized_fitness
 
333
  yield "--- Max correction loops reached. Accepting best effort. ---"
334
 
335
  # --- FINALIZE ---
336
+ # Calculate Generation Cost (Total - Calibration)
337
+ # This captures initial generation + any re-generations + evaluations
338
+ financial_report["generation_cost"] = financial_report["total_cost"] - financial_report["calibration_cost"]
339
+
340
  await asyncio.sleep(0.5)
341
  yield "Milestone 5 Complete. Self-Correction loop is live."
342