Spaces:
Sleeping
Sleeping
| """ | |
| Evaluator Module | |
| Handles dataset loading, SQL execution, and metrics computation. | |
| """ | |
| import os | |
| import time | |
| import yaml | |
| import duckdb | |
| import sqlglot | |
| import pandas as pd | |
| from typing import Dict, Any, List, Tuple, Optional | |
| from dataclasses import dataclass | |
| from models_registry import models_registry, model_interface | |
| from scoring import Metrics, scoring_engine | |
| class DatasetConfig: | |
| """Configuration for a dataset.""" | |
| name: str | |
| schema_path: str | |
| loader_path: str | |
| cases_path: str | |
| class CaseConfig: | |
| """Configuration for a test case.""" | |
| id: str | |
| question: str | |
| reference_sql: Dict[str, str] # dialect -> SQL | |
| difficulty: str | |
| description: str | |
| class DatasetManager: | |
| """Manages datasets and their configurations.""" | |
| def __init__(self, tasks_dir: str = "tasks"): | |
| self.tasks_dir = tasks_dir | |
| self.datasets = self._discover_datasets() | |
| def _discover_datasets(self) -> Dict[str, DatasetConfig]: | |
| """Discover available datasets in the tasks directory.""" | |
| datasets = {} | |
| if not os.path.exists(self.tasks_dir): | |
| return datasets | |
| # Look for datasets in the new multi-use-case structure | |
| for use_case in os.listdir(self.tasks_dir): | |
| use_case_path = os.path.join(self.tasks_dir, use_case) | |
| if os.path.isdir(use_case_path): | |
| # Look for datasets within each use case | |
| for dataset_name in os.listdir(use_case_path): | |
| dataset_path = os.path.join(use_case_path, dataset_name) | |
| if os.path.isdir(dataset_path): | |
| schema_path = os.path.join(dataset_path, "schema.sql") | |
| loader_path = os.path.join(dataset_path, "loader.py") | |
| cases_path = os.path.join(dataset_path, "cases.yaml") | |
| # Check requirements based on use case | |
| required_files = [] | |
| if use_case == "sql_generation": | |
| # SQL generation needs all three files | |
| required_files = [schema_path, loader_path, cases_path] | |
| elif use_case == "code_generation": | |
| # Code generation needs loader and cases | |
| required_files = [loader_path, cases_path] | |
| elif use_case == "documentation": | |
| # Documentation only needs cases | |
| required_files = [cases_path] | |
| else: | |
| # Default: require all files | |
| required_files = [schema_path, loader_path, cases_path] | |
| if all(os.path.exists(p) for p in required_files): | |
| # Use the full path as the dataset name | |
| full_name = f"{use_case}/{dataset_name}" | |
| datasets[full_name] = DatasetConfig( | |
| name=full_name, | |
| schema_path=schema_path if os.path.exists(schema_path) else None, | |
| loader_path=loader_path if os.path.exists(loader_path) else None, | |
| cases_path=cases_path | |
| ) | |
| return datasets | |
| def get_datasets(self) -> Dict[str, DatasetConfig]: | |
| """Get all available datasets.""" | |
| return self.datasets | |
| def get_dataset(self, name: str) -> Optional[DatasetConfig]: | |
| """Get a specific dataset by name.""" | |
| return self.datasets.get(name) | |
| def load_cases(self, dataset_name: str) -> List[CaseConfig]: | |
| """Load test cases for a dataset.""" | |
| dataset = self.get_dataset(dataset_name) | |
| if not dataset: | |
| raise ValueError(f"Dataset not found: {dataset_name}") | |
| with open(dataset.cases_path, 'r') as f: | |
| cases_data = yaml.safe_load(f) | |
| cases = [] | |
| for case_data in cases_data.get('cases', []): | |
| try: | |
| case = CaseConfig( | |
| id=case_data['id'], | |
| question=case_data['question'], | |
| reference_sql=case_data['reference_sql'], # Human-provided ground truth SQL | |
| difficulty=case_data.get('difficulty', 'medium'), | |
| description=case_data.get('description', '') | |
| ) | |
| cases.append(case) | |
| except KeyError as e: | |
| print(f"Missing key in case data: {e}") | |
| print(f"Available keys: {list(case_data.keys())}") | |
| raise | |
| return cases | |
| def create_database(self, dataset_name: str) -> str: | |
| """Create database for a dataset.""" | |
| dataset = self.get_dataset(dataset_name) | |
| if not dataset: | |
| raise ValueError(f"Dataset not found: {dataset_name}") | |
| # Import and run the loader | |
| loader_module_path = dataset.loader_path | |
| loader_dir = os.path.dirname(loader_module_path) | |
| loader_module_name = os.path.basename(loader_module_path).replace('.py', '') | |
| import sys | |
| sys.path.insert(0, loader_dir) | |
| try: | |
| loader_module = __import__(loader_module_name) | |
| db_path = loader_module.create_database() | |
| return db_path | |
| finally: | |
| sys.path.remove(loader_dir) | |
| class SQLExecutor: | |
| """Handles SQL execution and result comparison.""" | |
| def __init__(self): | |
| self.conn = None | |
| def connect(self, db_path: str): | |
| """Connect to a DuckDB database.""" | |
| self.conn = duckdb.connect(db_path) | |
| def disconnect(self): | |
| """Disconnect from the database.""" | |
| if self.conn: | |
| self.conn.close() | |
| self.conn = None | |
| def execute_sql(self, sql: str) -> Tuple[bool, Optional[pd.DataFrame], str]: | |
| """Execute SQL and return success status, result, and error message.""" | |
| if not self.conn: | |
| return False, None, "No database connection" | |
| try: | |
| result = self.conn.execute(sql).fetchdf() | |
| return True, result, "" | |
| except Exception as e: | |
| return False, None, str(e) | |
| def transpile_sql(self, sql: str, target_dialect: str) -> Tuple[bool, str, str]: | |
| """Transpile SQL to target dialect using sqlglot.""" | |
| try: | |
| # Parse the SQL | |
| parsed = sqlglot.parse_one(sql) | |
| # Transpile to target dialect | |
| transpiled = parsed.sql(dialect=target_dialect) | |
| return True, transpiled, "" | |
| except Exception as e: | |
| return False, sql, str(e) | |
| class MetricsComputer: | |
| """Computes evaluation metrics for SQL queries.""" | |
| def __init__(self): | |
| self.executor = SQLExecutor() | |
| def compute_result_match_f1(self, reference_df: pd.DataFrame, candidate_df: pd.DataFrame) -> float: | |
| """Compute F1 score for result matching.""" | |
| if reference_df is None or candidate_df is None: | |
| return 0.0 | |
| # Convert to sets of tuples for comparison | |
| try: | |
| reference_set = set(tuple(row) for row in reference_df.values) | |
| candidate_set = set(tuple(row) for row in candidate_df.values) | |
| if not reference_set and not candidate_set: | |
| return 1.0 | |
| if not reference_set or not candidate_set: | |
| return 0.0 | |
| # Compute precision and recall | |
| intersection = reference_set.intersection(candidate_set) | |
| precision = len(intersection) / len(candidate_set) if candidate_set else 0.0 | |
| recall = len(intersection) / len(reference_set) if reference_set else 0.0 | |
| # Compute F1 | |
| if precision + recall == 0: | |
| return 0.0 | |
| f1 = 2 * (precision * recall) / (precision + recall) | |
| return f1 | |
| except Exception: | |
| return 0.0 | |
| def compute_metrics(self, reference_sql: str, candidate_sql: str, | |
| target_dialect: str, db_path: str) -> Metrics: | |
| """Compute all metrics for a candidate SQL query.""" | |
| # Connect to database | |
| self.executor.connect(db_path) | |
| try: | |
| # Execute reference SQL | |
| ref_success, ref_result, ref_error = self.executor.execute_sql(reference_sql) | |
| # Transpile candidate SQL to target dialect | |
| transpile_success, transpiled_sql, transpile_error = self.executor.transpile_sql( | |
| candidate_sql, target_dialect | |
| ) | |
| # Execute candidate SQL | |
| if transpile_success: | |
| cand_success, cand_result, cand_error = self.executor.execute_sql(transpiled_sql) | |
| else: | |
| cand_success, cand_result, cand_error = False, None, transpile_error | |
| # Compute metrics | |
| correctness_exact = 1.0 if (ref_success and cand_success and | |
| self._results_equal(ref_result, cand_result)) else 0.0 | |
| result_match_f1 = 0.0 | |
| if ref_success and cand_success: | |
| result_match_f1 = self.compute_result_match_f1(ref_result, cand_result) | |
| exec_success = 1.0 if cand_success else 0.0 | |
| dialect_ok = 1.0 if transpile_success else 0.0 | |
| # For now, use default readability (would need actual SQL for proper computation) | |
| readability = 0.8 | |
| # Latency is not measured here (would need timing in the calling code) | |
| latency_ms = 0.0 | |
| return Metrics( | |
| correctness_exact=correctness_exact, | |
| result_match_f1=result_match_f1, | |
| exec_success=exec_success, | |
| latency_ms=latency_ms, | |
| readability=readability, | |
| dialect_ok=dialect_ok | |
| ) | |
| finally: | |
| self.executor.disconnect() | |
| def _results_equal(self, df1: pd.DataFrame, df2: pd.DataFrame) -> bool: | |
| """Check if two DataFrames are equal.""" | |
| if df1 is None and df2 is None: | |
| return True | |
| if df1 is None or df2 is None: | |
| return False | |
| try: | |
| # Reset indices and compare | |
| df1_reset = df1.reset_index(drop=True) | |
| df2_reset = df2.reset_index(drop=True) | |
| # Compare shapes | |
| if df1_reset.shape != df2_reset.shape: | |
| return False | |
| # Compare values | |
| return df1_reset.equals(df2_reset) | |
| except Exception: | |
| return False | |
| class Evaluator: | |
| """Main evaluator class that orchestrates the evaluation process.""" | |
| def __init__(self): | |
| self.dataset_manager = DatasetManager() | |
| self.metrics_computer = MetricsComputer() | |
| def evaluate_model_on_case(self, model_name: str, dataset_name: str, | |
| case_id: str, dialect: str, prompt_template: str) -> Dict[str, Any]: | |
| """Evaluate a model on a specific case.""" | |
| # Get model configuration | |
| model_config = models_registry.get_model_by_name(model_name) | |
| if not model_config: | |
| raise ValueError(f"Model not found: {model_name}") | |
| # Get dataset and case | |
| cases = self.dataset_manager.load_cases(dataset_name) | |
| case = next((c for c in cases if c.id == case_id), None) | |
| if not case: | |
| raise ValueError(f"Case not found: {case_id}") | |
| # Get reference SQL for the dialect | |
| reference_sql = case.reference_sql.get(dialect) | |
| if not reference_sql: | |
| raise ValueError(f"Reference SQL not found for dialect: {dialect}") | |
| # Create database | |
| db_path = self.dataset_manager.create_database(dataset_name) | |
| # Load schema for prompt | |
| dataset = self.dataset_manager.get_dataset(dataset_name) | |
| with open(dataset.schema_path, 'r') as f: | |
| schema = f.read() | |
| # Create prompt | |
| prompt = prompt_template.format(schema=schema, question=case.question) | |
| # Generate SQL | |
| start_time = time.time() | |
| try: | |
| candidate_sql = model_interface.generate_sql(model_config, prompt) | |
| generation_time = (time.time() - start_time) * 1000 # Convert to ms | |
| except Exception as e: | |
| candidate_sql = "" | |
| generation_time = 0.0 | |
| print(f"Error generating SQL: {e}") | |
| # Compute metrics | |
| metrics = self.metrics_computer.compute_metrics( | |
| reference_sql, candidate_sql, dialect, db_path | |
| ) | |
| # Update latency | |
| metrics.latency_ms = generation_time | |
| # Compute composite score | |
| composite_score = scoring_engine.compute_composite_score(metrics) | |
| # Clean up database | |
| if os.path.exists(db_path): | |
| os.remove(db_path) | |
| return { | |
| 'model_name': model_name, | |
| 'provider': model_config.provider, | |
| 'dataset_name': dataset_name, | |
| 'case_id': case_id, | |
| 'dialect': dialect, | |
| 'question': case.question, | |
| 'reference_sql': reference_sql, | |
| 'candidate_sql': candidate_sql, | |
| 'correctness_exact': metrics.correctness_exact, | |
| 'result_match_f1': metrics.result_match_f1, | |
| 'exec_success': metrics.exec_success, | |
| 'latency_ms': metrics.latency_ms, | |
| 'readability': metrics.readability, | |
| 'dialect_ok': metrics.dialect_ok, | |
| 'composite_score': composite_score, | |
| 'timestamp': time.time() | |
| } | |
| # Global evaluator instance | |
| evaluator = Evaluator() | |