File size: 19,468 Bytes
c08309c
f8693d4
1d67f6d
fff7b35
9313878
4d79a4c
 
fff7b35
 
 
1455c2a
ca7be06
9313878
2dcb1f2
fff7b35
2960fc5
ca7be06
1d67f6d
 
 
9313878
d33d284
 
 
 
 
 
 
 
 
 
 
 
 
2e3627a
6e0d909
1d67f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57a11ff
f8693d4
fff7b35
 
6e0d909
 
 
 
 
 
 
 
fff7b35
6e0d909
aa593cf
9313878
 
 
2960fc5
6e0d909
 
aa593cf
9313878
 
2960fc5
6e0d909
 
aa593cf
9313878
6e0d909
 
 
 
2960fc5
6e0d909
 
aa593cf
6e0d909
 
 
 
 
aa593cf
6e0d909
 
 
 
 
 
fff7b35
 
2960fc5
57a11ff
fff7b35
 
2960fc5
fff7b35
 
 
2960fc5
f8693d4
2960fc5
9313878
f8693d4
4d79a4c
762edf7
 
 
 
 
 
 
 
 
4d79a4c
 
 
 
 
762edf7
6e0d909
4d79a4c
f8693d4
762edf7
 
 
 
 
 
 
 
 
 
6e0d909
762edf7
 
6e0d909
762edf7
 
 
 
 
 
 
9313878
762edf7
 
 
 
 
4d79a4c
 
 
 
 
 
 
 
 
 
 
762edf7
2960fc5
 
 
 
 
 
2dcb1f2
4d79a4c
 
 
 
 
 
9313878
e364ce0
 
1d67f6d
e364ce0
 
 
762edf7
 
e364ce0
 
 
762edf7
 
e364ce0
 
 
 
4d79a4c
762edf7
c08309c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
762edf7
 
 
 
 
4d79a4c
 
 
 
 
 
 
e364ce0
 
 
 
 
 
 
762edf7
 
c8b2c9e
e364ce0
 
762edf7
 
 
 
e364ce0
 
 
 
 
 
 
 
dfb43ff
762edf7
 
e364ce0
2dcb1f2
 
 
 
 
 
e364ce0
 
 
d33d284
 
762edf7
 
2e3627a
e364ce0
 
 
 
 
 
 
1801360
 
e364ce0
 
 
d33d284
 
e364ce0
 
 
 
 
 
 
 
d33d284
e364ce0
 
 
 
 
 
c8b2c9e
4d79a4c
 
 
dfb43ff
4d79a4c
 
 
 
e364ce0
 
2960fc5
 
 
 
 
 
9313878
2960fc5
762edf7
d824005
 
9313878
87a36dc
9313878
1d67f6d
87a36dc
e56fce6
fff7b35
9313878
 
 
4d79a4c
 
 
 
 
 
 
 
 
 
 
87a36dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
# agent_logic.py (Fixed: OpenAI and Nebius Initialization + Robust Unpacking)
import asyncio
from typing import AsyncGenerator, Dict, Optional, List, Tuple
import json
import os
import uuid
import datetime
import google.generativeai as genai
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
import re 
from personas import PERSONAS_DATA 
import config
from utils import load_prompt
from mcp_servers import AgentCalibrator, BusinessSolutionEvaluator, get_llm_response
from self_correction import SelfCorrector

CLASSIFIER_SYSTEM_PROMPT = load_prompt(config.PROMPT_FILES["classifier"])
HOMOGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_homogeneous"])
HETEROGENEOUS_MANAGER_PROMPT = load_prompt(config.PROMPT_FILES["manager_heterogeneous"])

METRIC_MAPPING = {
    "novelty": "Novelty",
    "usefulness": "Usefulness_Feasibility",
    "feasibility": "Usefulness_Feasibility",
    "usefulness_feasibility": "Usefulness_Feasibility",
    "usefulness/feasibility": "Usefulness_Feasibility",
    "flexibility": "Flexibility",
    "elaboration": "Elaboration",
    "cultural_appropriateness": "Cultural_Appropriateness",
    "cultural_sensitivity": "Cultural_Appropriateness",
    "cultural appropriateness": "Cultural_Appropriateness",
    "cultural appropriateness/sensitivity": "Cultural_Appropriateness"
}

# --- HELPER CLASSES ---

class Baseline_Single_Agent:
    def __init__(self, api_clients: dict):
        self.gemini_client = api_clients.get("Gemini")
    async def solve(self, problem: str, persona_prompt: str) -> Tuple[str, dict]:
        if not self.gemini_client: raise ValueError("Single_Agent requires a Google/Gemini client.")
        return await get_llm_response("Gemini", self.gemini_client, persona_prompt, problem)

class Baseline_Static_Homogeneous:
    def __init__(self, api_clients: dict):
        self.api_clients = {name: client for name, client in api_clients.items() if client}
        self.gemini_client = api_clients.get("Gemini")
    
    async def solve(self, problem: str, persona_prompt: str) -> Tuple[str, List[dict]]:
        if not self.gemini_client: raise ValueError("Homogeneous_Team requires a Google/Gemini client.")
        system_prompt = persona_prompt
        user_prompt = f"As an expert Implementer, generate a detailed plan for this problem: {problem}"
        
        tasks = [get_llm_response(llm, client, system_prompt, user_prompt) for llm, client in self.api_clients.items()]
        results = await asyncio.gather(*tasks)
        
        responses = [r[0] for r in results]
        usages = [r[1] for r in results]
        
        manager_system_prompt = HOMOGENEOUS_MANAGER_PROMPT
        reports_str = "\n\n".join(f"Report from Team Member {i+1}:\n{resp}" for i, resp in enumerate(responses))
        manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these reports into one final, comprehensive solution."
        
        final_text, final_usage = await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
        usages.append(final_usage)
        
        return final_text, usages

class Baseline_Static_Heterogeneous:
    def __init__(self, api_clients: dict):
        self.api_clients = api_clients
        self.gemini_client = api_clients.get("Gemini")
    
    async def solve(self, problem: str, team_plan: dict) -> Tuple[str, List[dict]]:
        if not self.gemini_client: raise ValueError("Heterogeneous_Team requires a Google/Gemini client.")
        tasks = []
        for role, config_data in team_plan.items():
            llm_name = config_data["llm"]
            persona_key = config_data["persona"]
            client = self.api_clients.get(llm_name)
            if not client:
                llm_name = "Gemini"
                client = self.gemini_client
            system_prompt = PERSONAS_DATA[persona_key]["description"]
            user_prompt = f"As the team's '{role}', provide your unique perspective on how to solve this problem: {problem}"
            tasks.append(get_llm_response(llm_name, client, system_prompt, user_prompt))
        
        results = await asyncio.gather(*tasks)
        responses = [r[0] for r in results]
        usages = [r[1] for r in results]
        
        manager_system_prompt = HETEROGENEOUS_MANAGER_PROMPT
        reports_str = "\n\n".join(f"Report from {team_plan[role]['llm']} (as {role}):\n{resp}" for (role, resp) in zip(team_plan.keys(), responses))
        
        manager_user_prompt = f"Original Problem: {problem}\n\n{reports_str}\n\nPlease synthesize these specialist reports into one final, comprehensive solution."
        
        final_text, final_usage = await get_llm_response("Gemini", self.gemini_client, manager_system_prompt, manager_user_prompt)
        usages.append(final_usage)
        
        return final_text, usages

# --- MAIN AGENT ---

class StrategicSelectorAgent:
    def __init__(self, api_keys: Dict[str, Optional[str]]):
        self.api_keys = api_keys
        # Initialize potential clients including new providers
        self.api_clients = { 
            "Gemini": None, 
            "Anthropic": None, 
            "SambaNova": None,
            "OpenAI": None,
            "Nebius": None
        }
        
        # --- INIT GEMINI ---
        if api_keys.get("google") and api_keys["google"].strip():
            try:
                genai.configure(api_key=api_keys["google"])
                self.api_clients["Gemini"] = genai.GenerativeModel(config.MODELS["Gemini"]["default"])
            except Exception as e: print(f"Warning: Gemini init failed: {e}")
            
        # --- INIT ANTHROPIC ---
        if api_keys.get("anthropic") and api_keys["anthropic"].strip():
            try:
                self.api_clients["Anthropic"] = AsyncAnthropic(api_key=api_keys["anthropic"])
            except Exception as e: print(f"Warning: Anthropic init failed: {e}")
            
        # --- INIT SAMBANOVA ---
        if api_keys.get("sambanova") and api_keys["sambanova"].strip():
            try:
                self.api_clients["SambaNova"] = AsyncOpenAI(
                    api_key=api_keys["sambanova"], 
                    base_url=os.getenv("SAMBANOVA_BASE_URL", "https://api.sambanova.ai/v1")
                )
            except Exception as e: print(f"Warning: SambaNova init failed: {e}")

        # --- INIT OPENAI (NEW) ---
        if api_keys.get("openai") and api_keys["openai"].strip():
            try:
                self.api_clients["OpenAI"] = AsyncOpenAI(api_key=api_keys["openai"])
            except Exception as e: print(f"Warning: OpenAI init failed: {e}")

        # --- INIT NEBIUS (NEW) ---
        if api_keys.get("nebius") and api_keys["nebius"].strip():
            try:
                self.api_clients["Nebius"] = AsyncOpenAI(
                    api_key=api_keys["nebius"],
                    base_url="https://api.studio.nebius.ai/v1/"
                )
            except Exception as e: print(f"Warning: Nebius init failed: {e}")
        
        if not self.api_clients["Gemini"]:
            raise ValueError("Google API Key is required.")

        self.evaluator = BusinessSolutionEvaluator(self.api_clients["Gemini"])
        self.calibrator = AgentCalibrator(self.api_clients, self.evaluator)
        self.corrector = SelfCorrector(threshold=3.0)
        self.single_agent = Baseline_Single_Agent(self.api_clients)
        self.homo_team = Baseline_Static_Homogeneous(self.api_clients)
        self.hetero_team = Baseline_Static_Heterogeneous(self.api_clients)
        self.current_team_plan = None

        if "ERROR:" in CLASSIFIER_SYSTEM_PROMPT: raise FileNotFoundError(CLASSIFIER_SYSTEM_PROMPT)

    async def solve(self, problem: str) -> AsyncGenerator[str, None]:
        run_id = str(uuid.uuid4())[:8]
        
        # Initialize Financial Tracking
        financial_report = {
            "calibration_cost": 0.0,
            "generation_cost": 0.0,
            "total_cost": 0.0,
            "usage_breakdown": [] 
        }
        
        debug_log = {
            "run_id": run_id,
            "timestamp": datetime.datetime.now().isoformat(),
            "problem": problem,
            "classification": "",
            "trace": [],
            "financial_report": financial_report 
        }

        # Helper to add usage and calculate cost
        def add_usage(usage_list):
            if isinstance(usage_list, dict): usage_list = [usage_list]
            
            current_step_cost = 0.0
            for u in usage_list:
                financial_report["usage_breakdown"].append(u)
                
                # Lookup pricing
                model_name = u.get("model", "Gemini")
                # Safely get pricing with default fallbacks
                pricing = config.PRICING.get(model_name, {"input": 0, "output": 0})
                
                # Calculate Cost
                cost = (u.get("input", 0) / 1_000_000 * pricing["input"]) + \
                       (u.get("output", 0) / 1_000_000 * pricing["output"])
                
                financial_report["total_cost"] += cost
                current_step_cost += cost
            return current_step_cost

        try:
            yield "Classifying problem archetype (live)..."
            classification, cls_usage = await get_llm_response("Gemini", self.api_clients["Gemini"], CLASSIFIER_SYSTEM_PROMPT, problem)
            classification = classification.strip().replace("\"", "") 
            yield f"Diagnosis: {classification}"
            add_usage(cls_usage)
            
            debug_log["classification"] = classification

            if "Error generating response" in classification:
                yield "Classifier failed. Defaulting to Single Agent."
                classification = "Direct_Procedure"

            solution_draft = "" 
            v_fitness_json = {}
            scores = {}
            
            # --- MAIN LOOP ---
            for i in range(2): 
                current_problem = problem
                if i > 0:
                    yield f"--- (Loop {i}) Score is too low. Initiating Self-Correction... ---"
                    correction_prompt_text = self.corrector.get_correction_plan(v_fitness_json)
                    yield f"Diagnosis: {correction_prompt_text.splitlines()[3].strip()}"
                    current_problem = f"{problem}\n\n{correction_prompt_text}"
                    
                    debug_log["trace"].append({
                        "step_type": "correction_plan",
                        "loop_index": i,
                        "prompt": correction_prompt_text
                    })
                
                # --- DEPLOY ---
                default_persona = PERSONAS_DATA[config.DEFAULT_PERSONA_KEY]["description"]
                current_usages = []

                if classification == "Direct_Procedure" or classification == "Holistic_Abstract_Reasoning":
                    if i == 0: yield "Deploying: Baseline Single Agent (Simplicity Hypothesis)..."
                    solution_draft, u = await self.single_agent.solve(current_problem, default_persona)
                    current_usages.append(u)
                    
                elif classification == "Local_Geometric_Procedural":
                    if i == 0: yield "Deploying: Static Homogeneous Team (Expert Anomaly)..."
                    solution_draft, u_list = await self.homo_team.solve(current_problem, default_persona)
                    current_usages.extend(u_list)
                    
                elif classification == "Cognitive_Labyrinth":
                    if i == 0:
                        yield "Deploying: Static Heterogeneous Team (Cognitive Diversity)..."
                        
                        # --- UNPACK 4 VALUES FROM CALIBRATOR ---
                        # Safely call and unpack, providing debugging if it fails
                        calib_result = await self.calibrator.calibrate_team(current_problem)
                        
                        if calib_result is None:
                             raise ValueError("CRITICAL ERROR: calibrate_team returned None. Please verify mcp_servers.py on the server.")
                        
                        if len(calib_result) != 4:
                             # Fallback logic if server has old mcp_servers.py (e.g. 2 or 3 values)
                             if len(calib_result) == 2:
                                 team_plan, calibration_errors = calib_result
                                 calib_details, calib_usage = [], [] # Defaults
                             elif len(calib_result) == 3:
                                 team_plan, calibration_errors, calib_details = calib_result
                                 calib_usage = [] # Default
                             else:
                                 raise ValueError(f"Calibrator returned {len(calib_result)} values, expected 4.")
                        else:
                             team_plan, calibration_errors, calib_details, calib_usage = calib_result
                        
                        # Track Calibration Cost explicitly
                        calib_step_cost = add_usage(calib_usage)
                        financial_report["calibration_cost"] += calib_step_cost

                        debug_log["trace"].append({
                            "step_type": "calibration",
                            "details": calib_details,
                            "errors": calibration_errors,
                            "selected_plan": team_plan
                        })

                        if calibration_errors:
                            yield "--- CALIBRATION WARNINGS ---"
                            for err in calibration_errors: yield err
                            yield "-----------------------------"
                        yield f"Calibration complete. Best Team: {json.dumps({k: v['llm'] for k, v in team_plan.items()})}"
                        self.current_team_plan = team_plan
                    
                    solution_draft, u_list = await self.hetero_team.solve(current_problem, self.current_team_plan)
                    current_usages.extend(u_list)
                
                else:
                    if i == 0: yield f"Diagnosis '{classification}' is unknown. Defaulting to Single Agent."
                    solution_draft, u = await self.single_agent.solve(current_problem, default_persona)
                    current_usages.append(u)

                add_usage(current_usages)

                if "Error generating response" in solution_draft:
                    raise Exception(f"The specialist team failed to generate a solution. Error: {solution_draft}")

                yield f"Draft solution received: '{solution_draft[:60]}...'"

                # --- EVALUATE ---
                yield "Evaluating draft (live)..."
                
                v_fitness_json, eval_usage = await self.evaluator.evaluate(current_problem, solution_draft)
                add_usage(eval_usage)
                
                if isinstance(v_fitness_json, list):
                    if len(v_fitness_json) > 0 and isinstance(v_fitness_json[0], dict):
                        v_fitness_json = v_fitness_json[0]
                    else:
                        v_fitness_json = {}

                normalized_fitness = {}
                if isinstance(v_fitness_json, dict):
                    for k, v in v_fitness_json.items():
                        canonical_key = None
                        clean_k = k.lower().strip()
                        if clean_k in METRIC_MAPPING: canonical_key = METRIC_MAPPING[clean_k]
                        if not canonical_key: continue

                        if isinstance(v, dict):
                            score_value = v.get('score')
                            justification_value = v.get('justification', str(v))
                        elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict):
                            score_value = v[0].get('score')
                            justification_value = v[0].get('justification', str(v[0]))
                        else:
                            score_value = v 
                            justification_value = "Score extracted directly."
                        
                        if isinstance(score_value, str):
                            try:
                                match = re.search(r'\d+', score_value)
                                score_value = int(match.group()) if match else 0
                            except:
                                score_value = 0
                        
                        try:
                            score_value = int(score_value)
                        except (ValueError, TypeError):
                            score_value = 0

                        normalized_fitness[canonical_key] = {'score': score_value, 'justification': justification_value}
                else:
                     normalized_fitness = {k: {'score': 0, 'justification': "Invalid JSON structure"} for k in ["Novelty", "Usefulness_Feasibility", "Flexibility", "Elaboration", "Cultural_Appropriateness"]}
                
                v_fitness_json = normalized_fitness
                scores = {k: v.get('score', 0) for k, v in v_fitness_json.items()}
                yield f"Evaluation Score: {scores}"
                
                debug_log["trace"].append({
                    "step_type": "attempt",
                    "loop_index": i,
                    "draft": solution_draft,
                    "scores": scores,
                    "full_evaluation": v_fitness_json
                })

                if scores.get('Novelty', 0) <= 1:
                     yield f"⚠️ Low Score Detected. Reason: {v_fitness_json.get('Novelty', {}).get('justification', 'Unknown')}"

                if self.corrector.is_good_enough(scores):
                    yield "--- Solution approved by self-corrector. ---"
                    break
                elif i == 1:
                    yield "--- Max correction loops reached. Accepting best effort. ---"
            
            # --- FINALIZE ---
            financial_report["generation_cost"] = financial_report["total_cost"] - financial_report["calibration_cost"]
            debug_log["financial_report"] = financial_report

            await asyncio.sleep(0.5)
            
            solution_draft_json_safe = json.dumps(solution_draft)
            debug_log_json_safe = json.dumps(debug_log)
            
            yield f"FINAL: {{\"text\": {solution_draft_json_safe}, \"audio\": null, \"log\": {debug_log_json_safe}}}"
        
        except Exception as e:
            error_msg = f"An error occurred in the agent's solve loop: {e}"
            print(error_msg)
            debug_log["error"] = str(e)
            yield error_msg
        
        finally:
            try:
                os.makedirs("logs", exist_ok=True)
                log_path = f"logs/run_{run_id}.json"
                with open(log_path, "w", encoding="utf-8") as f:
                    json.dump(debug_log, f, indent=2)
                print(f"Detailed execution log saved to {log_path}")
            except Exception as log_err:
                print(f"Failed to save log: {log_err}")