from typing import List, Dict, Any import asyncio from concurrent.futures import ThreadPoolExecutor from tqdm import tqdm from schemas.data_models import EvaluationRequest, EvaluationSummary, APIProvider, MetricType from .graph_builder import EvaluationGraphBuilder from config import settings class EvaluationAgent: def __init__(self): self.graph_builder = None async def evaluate_async(self, request: EvaluationRequest) -> EvaluationSummary: """Evaluate questions asynchronously using LangGraph""" start_time = asyncio.get_event_loop().time() if len(request.questions) != len(request.ground_truths): raise ValueError("Questions and ground truths must have same length") if request.model_responses and len(request.questions) != len(request.model_responses): raise ValueError("Questions and model responses must have same length") # Initialize graph builder with API provider self.graph_builder = EvaluationGraphBuilder( model_name=request.judge_model, api_provider=request.api_provider.value ) # Build evaluation graph graph = self.graph_builder.build_graph() # Process evaluations results = [] with ThreadPoolExecutor(max_workers=request.max_concurrent) as executor: futures = [] for i in range(len(request.questions)): state = { "question": request.questions[i], "ground_truth": request.ground_truths[i], "model_response": request.model_responses[i] if request.model_responses else "", "metrics": [m.value for m in request.metrics] } # Add context if available and if context metrics are requested context_metrics = ["context_precision", "context_recall"] if any(m in context_metrics for m in [metric.value for metric in request.metrics]) and hasattr(request, 'contexts') and request.contexts: state["context"] = request.contexts[i] if i < len(request.contexts) else "No context provided." future = executor.submit( self._run_evaluation, graph, state ) futures.append(future) # Process with progress bar for future in tqdm(futures, desc="Evaluating responses"): try: result = future.result() results.append(result["final_result"]) except Exception as e: print(f"Evaluation failed: {e}") # Add a failed result with default values failed_result = { "question": state["question"], "ground_truth": state["ground_truth"], "model_response": state["model_response"], "metrics": {m.value: 0 for m in request.metrics}, "explanations": {m.value: f"Evaluation failed: {str(e)}" for m in request.metrics}, "processing_time": 0, "overall_score": 0 } results.append(failed_result) # Calculate summary avg_scores = self._calculate_average_scores(results, request.metrics) overall_score = self._calculate_overall_score(results) return EvaluationSummary( total_questions=len(request.questions), average_scores=avg_scores, individual_results=results, total_processing_time=asyncio.get_event_loop().time() - start_time, model_used=request.judge_model, api_provider=request.api_provider.value, overall_score=overall_score ) def _run_evaluation(self, graph, state): """Run evaluation synchronously (for ThreadPoolExecutor)""" return graph.invoke(state) def _calculate_average_scores(self, results: List[Any], metrics: List[MetricType]) -> Dict[MetricType, float]: """Calculate average scores across all results""" avg_scores = {} for metric in metrics: scores = [result.metrics.get(metric.value, 0) for result in results] avg_scores[metric] = sum(scores) / len(scores) if scores else 0 return avg_scores def _calculate_overall_score(self, results: List[Any]) -> float: """Calculate overall score across all results""" if not results: return 0 overall_scores = [result.overall_score for result in results if hasattr(result, 'overall_score')] return sum(overall_scores) / len(overall_scores) if overall_scores else 0