Spaces:
Sleeping
Sleeping
| """ | |
| RAGAS-based Evaluator | |
| Uses RAGAS for comprehensive SQL evaluation metrics. | |
| """ | |
| import os | |
| import time | |
| import pandas as pd | |
| from typing import Dict, List, Any, Optional | |
| from dataclasses import dataclass | |
| import duckdb | |
| import sqlglot | |
| from ragas import evaluate | |
| from ragas.metrics import ( | |
| faithfulness, | |
| answer_relevancy, | |
| context_precision, | |
| context_recall | |
| ) | |
| from ragas.testset import TestsetGenerator | |
| from datasets import Dataset | |
| import numpy as np | |
| # HuggingFace LLM for RAGAS | |
| from ragas.llms import LangchainLLMWrapper | |
| from langchain_huggingface import HuggingFacePipeline | |
| from transformers import pipeline | |
| class EvaluationResult: | |
| """Result of a single evaluation.""" | |
| model_name: str | |
| dataset_name: str | |
| dialect: str | |
| case_id: str | |
| question: str | |
| reference_sql: str | |
| generated_sql: str | |
| correctness_exact: float | |
| result_match_f1: float | |
| exec_success: float | |
| latency_ms: float | |
| readability: float | |
| dialect_ok: float | |
| ragas_faithfulness: float | |
| ragas_relevancy: float | |
| ragas_precision: float | |
| ragas_recall: float | |
| composite_score: float | |
| class RAGASEvaluator: | |
| """RAGAS-based evaluator for SQL generation.""" | |
| def __init__(self): | |
| # Initialize HuggingFace LLM for RAGAS | |
| self.hf_llm = None | |
| self._setup_huggingface_llm() | |
| self.ragas_metrics = [ | |
| faithfulness, | |
| answer_relevancy, | |
| context_precision, | |
| context_recall | |
| ] | |
| def _setup_huggingface_llm(self): | |
| """Setup HuggingFace LLM for RAGAS evaluation.""" | |
| try: | |
| # Create a HuggingFace pipeline for evaluation | |
| # Use a lightweight model for evaluation tasks | |
| hf_pipeline = pipeline( | |
| "text-generation", | |
| model="microsoft/DialoGPT-small", | |
| max_new_tokens=256, | |
| temperature=0.1, | |
| do_sample=True, | |
| device=-1 # Use CPU for evaluation | |
| ) | |
| # Wrap the pipeline in LangChain | |
| langchain_llm = HuggingFacePipeline(pipeline=hf_pipeline) | |
| # Wrap LangChain LLM for RAGAS | |
| self.hf_llm = LangchainLLMWrapper(langchain_llm=langchain_llm) | |
| print("✅ HuggingFace LLM configured for RAGAS evaluation") | |
| except Exception as e: | |
| print(f"⚠️ Could not setup HuggingFace LLM for RAGAS: {e}") | |
| print(" RAGAS metrics will be skipped") | |
| self.hf_llm = None | |
| def evaluate_sql( | |
| self, | |
| model_name: str, | |
| dataset_name: str, | |
| dialect: str, | |
| case_id: str, | |
| question: str, | |
| reference_sql: str, | |
| generated_sql: str, | |
| schema: str, | |
| db_path: str | |
| ) -> EvaluationResult: | |
| """Evaluate a single SQL generation.""" | |
| start_time = time.time() | |
| # Basic metrics | |
| correctness_exact = self._calculate_exact_match(reference_sql, generated_sql) | |
| result_match_f1 = self._calculate_result_match_f1( | |
| reference_sql, generated_sql, db_path | |
| ) | |
| exec_success = self._calculate_execution_success(generated_sql, db_path) | |
| readability = self._calculate_readability(generated_sql) | |
| dialect_ok = self._calculate_dialect_compliance(generated_sql, dialect) | |
| # RAGAS metrics | |
| ragas_metrics = self._calculate_ragas_metrics( | |
| question, generated_sql, reference_sql, schema | |
| ) | |
| latency_ms = (time.time() - start_time) * 1000 | |
| # Composite score | |
| composite_score = self._calculate_composite_score( | |
| correctness_exact, result_match_f1, exec_success, | |
| latency_ms, readability, dialect_ok, ragas_metrics | |
| ) | |
| return EvaluationResult( | |
| model_name=model_name, | |
| dataset_name=dataset_name, | |
| dialect=dialect, | |
| case_id=case_id, | |
| question=question, | |
| reference_sql=reference_sql, | |
| generated_sql=generated_sql, | |
| correctness_exact=correctness_exact, | |
| result_match_f1=result_match_f1, | |
| exec_success=exec_success, | |
| latency_ms=latency_ms, | |
| readability=readability, | |
| dialect_ok=dialect_ok, | |
| ragas_faithfulness=ragas_metrics.get('faithfulness', 0.0), | |
| ragas_relevancy=ragas_metrics.get('answer_relevancy', 0.0), | |
| ragas_precision=ragas_metrics.get('context_precision', 0.0), | |
| ragas_recall=ragas_metrics.get('context_recall', 0.0), | |
| composite_score=composite_score | |
| ) | |
| def _calculate_exact_match(self, reference_sql: str, generated_sql: str) -> float: | |
| """Calculate exact match score.""" | |
| # Normalize SQL for comparison | |
| try: | |
| ref_normalized = sqlglot.parse_one(reference_sql).sql() | |
| gen_normalized = sqlglot.parse_one(generated_sql).sql() | |
| return 1.0 if ref_normalized.lower() == gen_normalized.lower() else 0.0 | |
| except: | |
| return 0.0 | |
| def _calculate_result_match_f1(self, reference_sql: str, generated_sql: str, db_path: str) -> float: | |
| """Calculate F1 score based on query results.""" | |
| try: | |
| # Execute both queries | |
| ref_results = self._execute_sql(reference_sql, db_path) | |
| gen_results = self._execute_sql(generated_sql, db_path) | |
| if ref_results is None or gen_results is None: | |
| return 0.0 | |
| # Convert to sets for comparison | |
| ref_set = set(str(row) for row in ref_results) | |
| gen_set = set(str(row) for row in gen_results) | |
| if not ref_set and not gen_set: | |
| return 1.0 | |
| if not ref_set or not gen_set: | |
| return 0.0 | |
| # Calculate F1 | |
| intersection = len(ref_set & gen_set) | |
| precision = intersection / len(gen_set) if gen_set else 0 | |
| recall = intersection / len(ref_set) if ref_set else 0 | |
| if precision + recall == 0: | |
| return 0.0 | |
| return 2 * (precision * recall) / (precision + recall) | |
| except Exception as e: | |
| print(f"⚠️ Error calculating result match F1: {e}") | |
| return 0.0 | |
| def _calculate_execution_success(self, sql: str, db_path: str) -> float: | |
| """Calculate execution success rate.""" | |
| try: | |
| result = self._execute_sql(sql, db_path) | |
| return 1.0 if result is not None else 0.0 | |
| except: | |
| return 0.0 | |
| def _calculate_readability(self, sql: str) -> float: | |
| """Calculate SQL readability score.""" | |
| try: | |
| # Simple readability metrics | |
| lines = sql.strip().split('\n') | |
| avg_line_length = sum(len(line) for line in lines) / len(lines) | |
| # Penalize very long lines and very short queries | |
| if avg_line_length > 100 or len(sql.strip()) < 20: | |
| return 0.5 | |
| elif avg_line_length > 80: | |
| return 0.7 | |
| else: | |
| return 1.0 | |
| except: | |
| return 0.5 | |
| def _calculate_dialect_compliance(self, sql: str, dialect: str) -> float: | |
| """Calculate dialect compliance score.""" | |
| try: | |
| # Parse and transpile to check dialect compliance | |
| parsed = sqlglot.parse_one(sql) | |
| transpiled = parsed.sql(dialect=dialect) | |
| # If transpilation succeeds without errors, it's compliant | |
| return 1.0 if transpiled else 0.0 | |
| except: | |
| return 0.0 | |
| def _calculate_ragas_metrics( | |
| self, | |
| question: str, | |
| generated_sql: str, | |
| reference_sql: str, | |
| schema: str | |
| ) -> Dict[str, float]: | |
| """Calculate RAGAS metrics using HuggingFace models.""" | |
| try: | |
| # Check if HuggingFace LLM is available | |
| if self.hf_llm is None: | |
| print("⚠️ No HuggingFace LLM configured - skipping RAGAS metrics") | |
| return { | |
| 'faithfulness': 0.0, | |
| 'answer_relevancy': 0.0, | |
| 'context_precision': 0.0, | |
| 'context_recall': 0.0 | |
| } | |
| # Check if OpenAI API key is available (still required by RAGAS) | |
| if not os.getenv("OPENAI_API_KEY"): | |
| print("⚠️ No OpenAI API key found - RAGAS still requires it for internal operations") | |
| return { | |
| 'faithfulness': 0.0, | |
| 'answer_relevancy': 0.0, | |
| 'context_precision': 0.0, | |
| 'context_recall': 0.0 | |
| } | |
| # Create dataset for RAGAS evaluation | |
| dataset = Dataset.from_dict({ | |
| "question": [question], | |
| "answer": [generated_sql], | |
| "contexts": [[schema]], | |
| "ground_truth": [reference_sql] | |
| }) | |
| # Configure metrics to use HuggingFace LLM | |
| # Create new metric instances with the HuggingFace LLM | |
| metrics_with_hf = [] | |
| for metric in self.ragas_metrics: | |
| # Create a new instance of the metric with the HuggingFace LLM | |
| if hasattr(metric, '__class__'): | |
| new_metric = metric.__class__() | |
| if hasattr(new_metric, 'llm'): | |
| new_metric.llm = self.hf_llm | |
| metrics_with_hf.append(new_metric) | |
| else: | |
| metrics_with_hf.append(metric) | |
| # Evaluate with RAGAS using HuggingFace LLM | |
| result = evaluate( | |
| dataset, | |
| metrics=metrics_with_hf | |
| ) | |
| return { | |
| 'faithfulness': result['faithfulness'][0] if 'faithfulness' in result else 0.0, | |
| 'answer_relevancy': result['answer_relevancy'][0] if 'answer_relevancy' in result else 0.0, | |
| 'context_precision': result['context_precision'][0] if 'context_precision' in result else 0.0, | |
| 'context_recall': result['context_recall'][0] if 'context_recall' in result else 0.0 | |
| } | |
| except Exception as e: | |
| print(f"⚠️ Error calculating RAGAS metrics with HuggingFace: {e}") | |
| return { | |
| 'faithfulness': 0.0, | |
| 'answer_relevancy': 0.0, | |
| 'context_precision': 0.0, | |
| 'context_recall': 0.0 | |
| } | |
| def _execute_sql(self, sql: str, db_path: str) -> Optional[List]: | |
| """Execute SQL query and return results.""" | |
| try: | |
| conn = duckdb.connect(db_path) | |
| result = conn.execute(sql).fetchall() | |
| conn.close() | |
| return result | |
| except Exception as e: | |
| print(f"⚠️ SQL execution error: {e}") | |
| return None | |
| def _calculate_composite_score( | |
| self, | |
| correctness_exact: float, | |
| result_match_f1: float, | |
| exec_success: float, | |
| latency_ms: float, | |
| readability: float, | |
| dialect_ok: float, | |
| ragas_metrics: Dict[str, float] | |
| ) -> float: | |
| """Calculate composite score with RAGAS metrics.""" | |
| # Weights for different metrics | |
| weights = { | |
| 'correctness_exact': 0.25, | |
| 'result_match_f1': 0.20, | |
| 'exec_success': 0.15, | |
| 'latency': 0.10, | |
| 'readability': 0.05, | |
| 'dialect_ok': 0.05, | |
| 'ragas_faithfulness': 0.10, | |
| 'ragas_relevancy': 0.10 | |
| } | |
| # Normalize latency (lower is better) | |
| latency_score = max(0, 1 - (latency_ms / 5000)) # 5 second max | |
| # Calculate weighted score | |
| score = ( | |
| weights['correctness_exact'] * correctness_exact + | |
| weights['result_match_f1'] * result_match_f1 + | |
| weights['exec_success'] * exec_success + | |
| weights['latency'] * latency_score + | |
| weights['readability'] * readability + | |
| weights['dialect_ok'] * dialect_ok + | |
| weights['ragas_faithfulness'] * ragas_metrics.get('faithfulness', 0.0) + | |
| weights['ragas_relevancy'] * ragas_metrics.get('answer_relevancy', 0.0) | |
| ) | |
| return min(1.0, max(0.0, score)) | |
| def evaluate_batch( | |
| self, | |
| evaluations: List[Dict[str, Any]] | |
| ) -> List[EvaluationResult]: | |
| """Evaluate a batch of SQL generations.""" | |
| results = [] | |
| for eval_data in evaluations: | |
| result = self.evaluate_sql( | |
| model_name=eval_data['model_name'], | |
| dataset_name=eval_data['dataset_name'], | |
| dialect=eval_data['dialect'], | |
| case_id=eval_data['case_id'], | |
| question=eval_data['question'], | |
| reference_sql=eval_data['reference_sql'], | |
| generated_sql=eval_data['generated_sql'], | |
| schema=eval_data['schema'], | |
| db_path=eval_data['db_path'] | |
| ) | |
| results.append(result) | |
| return results | |
| def save_results(self, results: List[EvaluationResult], filepath: str): | |
| """Save evaluation results to file.""" | |
| data = [] | |
| for result in results: | |
| data.append({ | |
| 'model_name': result.model_name, | |
| 'dataset_name': result.dataset_name, | |
| 'dialect': result.dialect, | |
| 'case_id': result.case_id, | |
| 'question': result.question, | |
| 'reference_sql': result.reference_sql, | |
| 'generated_sql': result.generated_sql, | |
| 'correctness_exact': result.correctness_exact, | |
| 'result_match_f1': result.result_match_f1, | |
| 'exec_success': result.exec_success, | |
| 'latency_ms': result.latency_ms, | |
| 'readability': result.readability, | |
| 'dialect_ok': result.dialect_ok, | |
| 'ragas_faithfulness': result.ragas_faithfulness, | |
| 'ragas_relevancy': result.ragas_relevancy, | |
| 'ragas_precision': result.ragas_precision, | |
| 'ragas_recall': result.ragas_recall, | |
| 'composite_score': result.composite_score | |
| }) | |
| df = pd.DataFrame(data) | |
| df.to_parquet(filepath, index=False) | |
| print(f"💾 Results saved to {filepath}") | |
| # Global instance | |
| ragas_evaluator = RAGASEvaluator() | |