Spaces:
Sleeping
Sleeping
| """ | |
| Custom SQL evaluation metrics without RAGAS dependency. | |
| Provides comprehensive evaluation using only local models and basic metrics. | |
| """ | |
| import os | |
| import time | |
| import re | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Any, Optional | |
| import pandas as pd | |
| import numpy as np | |
| from transformers import pipeline, AutoTokenizer, AutoModel | |
| import torch | |
| from langchain_models import langchain_models_registry | |
| class EvaluationResult: | |
| """Result of SQL evaluation.""" | |
| model_name: str | |
| dataset: str | |
| case_id: str | |
| dialect: str | |
| question: str | |
| raw_sql: str # Raw SQL from model (before cleaning) | |
| generated_sql: str # Cleaned SQL (after cleaning) | |
| reference_sql: str | |
| correctness_exact: float | |
| result_match_f1: float | |
| exec_success: float | |
| latency_ms: float | |
| readability: float | |
| dialect_ok: float | |
| # Custom metrics without RAGAS | |
| sql_quality: float | |
| semantic_similarity: float | |
| structural_similarity: float | |
| composite_score: float | |
| timestamp: str | |
| class CustomEvaluator: | |
| """Custom evaluator for SQL generation without RAGAS dependency.""" | |
| def __init__(self): | |
| self.similarity_model = None | |
| self._setup_similarity_model() | |
| def _setup_similarity_model(self): | |
| """Setup a local model for semantic similarity.""" | |
| try: | |
| print("📥 Setting up local similarity model...") | |
| self.similarity_model = pipeline( | |
| "feature-extraction", | |
| model="sentence-transformers/all-MiniLM-L6-v2", | |
| device=-1 # Use CPU | |
| ) | |
| print("✅ Local similarity model configured") | |
| except Exception as e: | |
| print(f"⚠️ Could not setup similarity model: {e}") | |
| self.similarity_model = None | |
| def evaluate_sql( | |
| self, | |
| model_name: str, | |
| dataset: str, | |
| case_id: str, | |
| dialect: str, | |
| question: str, | |
| raw_sql: str, | |
| generated_sql: str, | |
| reference_sql: str, | |
| schema: str, | |
| db_conn | |
| ) -> EvaluationResult: | |
| """Evaluate generated SQL against reference.""" | |
| start_time = time.time() | |
| # Basic metrics | |
| correctness_exact = self._calculate_exact_correctness(generated_sql, reference_sql) | |
| result_match_f1 = self._calculate_result_match_f1(generated_sql, reference_sql, db_conn) | |
| exec_success = self._calculate_execution_success(generated_sql, db_conn) | |
| readability = self._calculate_readability(generated_sql) | |
| dialect_ok = self._calculate_dialect_compliance(generated_sql, dialect) | |
| # Custom metrics | |
| sql_quality = self._calculate_sql_quality(generated_sql, question, schema) | |
| semantic_similarity = self._calculate_semantic_similarity(generated_sql, reference_sql) | |
| structural_similarity = self._calculate_structural_similarity(generated_sql, reference_sql) | |
| latency_ms = (time.time() - start_time) * 1000 | |
| # Calculate composite score | |
| composite_score = ( | |
| correctness_exact * 0.3 + | |
| result_match_f1 * 0.3 + | |
| exec_success * 0.2 + | |
| sql_quality * 0.1 + | |
| semantic_similarity * 0.1 | |
| ) | |
| return EvaluationResult( | |
| model_name=model_name, | |
| dataset=dataset, | |
| case_id=case_id, | |
| dialect=dialect, | |
| question=question, | |
| raw_sql=raw_sql, | |
| generated_sql=generated_sql, | |
| reference_sql=reference_sql, | |
| correctness_exact=correctness_exact, | |
| result_match_f1=result_match_f1, | |
| exec_success=exec_success, | |
| latency_ms=latency_ms, | |
| readability=readability, | |
| dialect_ok=dialect_ok, | |
| sql_quality=sql_quality, | |
| semantic_similarity=semantic_similarity, | |
| structural_similarity=structural_similarity, | |
| composite_score=composite_score, | |
| timestamp=pd.Timestamp.now().isoformat() | |
| ) | |
| def _calculate_exact_correctness(self, generated_sql: str, reference_sql: str) -> float: | |
| """Calculate exact string match correctness.""" | |
| # Normalize SQL for comparison | |
| gen_norm = self._normalize_sql(generated_sql) | |
| ref_norm = self._normalize_sql(reference_sql) | |
| return 1.0 if gen_norm == ref_norm else 0.0 | |
| def _calculate_result_match_f1(self, generated_sql: str, reference_sql: str, db_conn) -> float: | |
| """Calculate F1 score based on query results.""" | |
| try: | |
| # Clean the generated SQL before execution | |
| clean_generated_sql = langchain_models_registry.clean_sql(generated_sql) | |
| # Execute both queries | |
| gen_result = db_conn.execute(clean_generated_sql).fetchall() | |
| ref_result = db_conn.execute(reference_sql).fetchall() | |
| # Convert to sets for comparison | |
| gen_set = set(str(row) for row in gen_result) | |
| ref_set = set(str(row) for row in ref_result) | |
| if not ref_set: | |
| return 1.0 if not gen_set else 0.0 | |
| # Calculate F1 | |
| intersection = gen_set & ref_set | |
| precision = len(intersection) / len(gen_set) if gen_set else 0.0 | |
| recall = len(intersection) / len(ref_set) | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| return f1 | |
| except Exception as e: | |
| print(f"⚠️ Error calculating result match F1: {e}") | |
| return 0.0 | |
| def _calculate_execution_success(self, generated_sql: str, db_conn) -> float: | |
| """Calculate if SQL executes successfully.""" | |
| try: | |
| # Clean the generated SQL before execution | |
| clean_generated_sql = langchain_models_registry.clean_sql(generated_sql) | |
| db_conn.execute(clean_generated_sql) | |
| return 1.0 | |
| except Exception as e: | |
| print(f"⚠️ SQL execution error: {e}") | |
| return 0.0 | |
| def _calculate_readability(self, generated_sql: str) -> float: | |
| """Calculate SQL readability score.""" | |
| try: | |
| # Basic readability metrics | |
| lines = generated_sql.strip().split('\n') | |
| avg_line_length = sum(len(line.strip()) for line in lines) / len(lines) if lines else 0 | |
| # Check for proper formatting | |
| has_proper_indentation = any(line.startswith(' ') or line.startswith('\t') for line in lines[1:]) | |
| has_keywords_capitalized = any(keyword in generated_sql.upper() for keyword in ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY']) | |
| # Score based on formatting | |
| score = 0.0 | |
| if has_keywords_capitalized: | |
| score += 0.4 | |
| if has_proper_indentation: | |
| score += 0.3 | |
| if 20 <= avg_line_length <= 80: # Reasonable line length | |
| score += 0.3 | |
| return min(score, 1.0) | |
| except Exception: | |
| return 0.0 | |
| def _calculate_dialect_compliance(self, generated_sql: str, dialect: str) -> float: | |
| """Calculate dialect compliance score.""" | |
| try: | |
| sql_upper = generated_sql.upper() | |
| score = 0.0 | |
| # Basic SQL compliance | |
| if any(keyword in sql_upper for keyword in ['SELECT', 'FROM']): | |
| score += 0.3 | |
| # Dialect-specific checks | |
| if dialect.lower() == 'presto': | |
| # Presto-specific features | |
| if 'ARRAY' in sql_upper or 'MAP' in sql_upper: | |
| score += 0.2 | |
| if 'APPROX_DISTINCT' in sql_upper: | |
| score += 0.2 | |
| elif dialect.lower() == 'bigquery': | |
| # BigQuery-specific features | |
| if 'ARRAY_AGG' in sql_upper or 'STRUCT' in sql_upper: | |
| score += 0.2 | |
| if 'QUALIFY' in sql_upper: | |
| score += 0.2 | |
| elif dialect.lower() == 'snowflake': | |
| # Snowflake-specific features | |
| if 'QUALIFY' in sql_upper: | |
| score += 0.2 | |
| if 'ARRAY_CONSTRUCT' in sql_upper: | |
| score += 0.2 | |
| # General SQL quality | |
| if 'WHERE' in sql_upper or 'GROUP BY' in sql_upper or 'ORDER BY' in sql_upper: | |
| score += 0.3 | |
| return min(score, 1.0) | |
| except Exception: | |
| return 0.0 | |
| def _calculate_sql_quality(self, generated_sql: str, question: str, schema: str) -> float: | |
| """Calculate overall SQL quality score.""" | |
| try: | |
| score = 0.0 | |
| # Check if SQL addresses the question | |
| question_lower = question.lower() | |
| sql_lower = generated_sql.lower() | |
| # Question-SQL alignment | |
| if 'count' in question_lower and 'count(' in sql_lower: | |
| score += 0.2 | |
| if 'average' in question_lower and 'avg(' in sql_lower: | |
| score += 0.2 | |
| if 'sum' in question_lower and 'sum(' in sql_lower: | |
| score += 0.2 | |
| if 'group' in question_lower and 'group by' in sql_lower: | |
| score += 0.2 | |
| # Schema usage | |
| schema_tables = re.findall(r'CREATE TABLE (\w+)', schema, re.IGNORECASE) | |
| used_tables = re.findall(r'FROM (\w+)', sql_lower) | |
| if any(table.lower() in used_tables for table in schema_tables): | |
| score += 0.2 | |
| return min(score, 1.0) | |
| except Exception: | |
| return 0.0 | |
| def _calculate_semantic_similarity(self, generated_sql: str, reference_sql: str) -> float: | |
| """Calculate semantic similarity between SQL queries.""" | |
| try: | |
| if not self.similarity_model: | |
| # Fallback to basic similarity | |
| return self._basic_similarity(generated_sql, reference_sql) | |
| # Use sentence transformer for semantic similarity | |
| embeddings = self.similarity_model([generated_sql, reference_sql]) | |
| # Handle different embedding formats | |
| if isinstance(embeddings, np.ndarray): | |
| # Single array with both embeddings | |
| if embeddings.shape[0] == 2: | |
| gen_emb = embeddings[0] | |
| ref_emb = embeddings[1] | |
| else: | |
| return self._basic_similarity(generated_sql, reference_sql) | |
| elif isinstance(embeddings, list) and len(embeddings) == 2: | |
| gen_emb = np.array(embeddings[0]) | |
| ref_emb = np.array(embeddings[1]) | |
| else: | |
| return self._basic_similarity(generated_sql, reference_sql) | |
| # Ensure both embeddings have the same shape | |
| if gen_emb.shape != ref_emb.shape: | |
| # Use basic similarity if shapes don't match | |
| return self._basic_similarity(generated_sql, reference_sql) | |
| # Calculate mean if multi-dimensional | |
| if len(gen_emb.shape) > 1: | |
| gen_emb = gen_emb.mean(axis=0) | |
| ref_emb = ref_emb.mean(axis=0) | |
| # Cosine similarity | |
| similarity = np.dot(gen_emb, ref_emb) / (np.linalg.norm(gen_emb) * np.linalg.norm(ref_emb)) | |
| return float(similarity) | |
| except Exception as e: | |
| print(f"⚠️ Error calculating semantic similarity: {e}") | |
| return self._basic_similarity(generated_sql, reference_sql) | |
| def _calculate_structural_similarity(self, generated_sql: str, reference_sql: str) -> float: | |
| """Calculate structural similarity between SQL queries.""" | |
| try: | |
| # Extract SQL structure | |
| gen_structure = self._extract_sql_structure(generated_sql) | |
| ref_structure = self._extract_sql_structure(reference_sql) | |
| # Calculate Jaccard similarity | |
| gen_set = set(gen_structure) | |
| ref_set = set(ref_structure) | |
| if not gen_set and not ref_set: | |
| return 1.0 | |
| if not gen_set or not ref_set: | |
| return 0.0 | |
| intersection = gen_set & ref_set | |
| union = gen_set | ref_set | |
| return len(intersection) / len(union) | |
| except Exception: | |
| return 0.0 | |
| def _basic_similarity(self, sql1: str, sql2: str) -> float: | |
| """Basic similarity calculation as fallback.""" | |
| try: | |
| # Extract keywords | |
| keywords1 = set(re.findall(r'\b(SELECT|FROM|WHERE|GROUP BY|ORDER BY|HAVING|JOIN|UNION)\b', sql1.upper())) | |
| keywords2 = set(re.findall(r'\b(SELECT|FROM|WHERE|GROUP BY|ORDER BY|HAVING|JOIN|UNION)\b', sql2.upper())) | |
| if not keywords1 and not keywords2: | |
| return 1.0 | |
| if not keywords1 or not keywords2: | |
| return 0.0 | |
| intersection = keywords1 & keywords2 | |
| union = keywords1 | keywords2 | |
| return len(intersection) / len(union) | |
| except Exception: | |
| return 0.0 | |
| def _extract_sql_structure(self, sql: str) -> List[str]: | |
| """Extract SQL structure elements.""" | |
| try: | |
| structure = [] | |
| sql_upper = sql.upper() | |
| # Extract main clauses | |
| clauses = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'HAVING', 'LIMIT'] | |
| for clause in clauses: | |
| if clause in sql_upper: | |
| structure.append(clause) | |
| # Extract functions | |
| functions = re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX|DISTINCT)\b', sql_upper) | |
| structure.extend(functions) | |
| # Extract operators | |
| operators = re.findall(r'\b(AND|OR|IN|NOT IN|BETWEEN|LIKE)\b', sql_upper) | |
| structure.extend(operators) | |
| return structure | |
| except Exception: | |
| return [] | |
| def _normalize_sql(self, sql: str) -> str: | |
| """Normalize SQL for comparison.""" | |
| try: | |
| # Remove extra whitespace | |
| normalized = re.sub(r'\s+', ' ', sql.strip()) | |
| # Convert to uppercase | |
| normalized = normalized.upper() | |
| # Remove semicolons | |
| normalized = normalized.rstrip(';') | |
| return normalized | |
| except Exception: | |
| return sql | |
| # Global instance | |
| custom_evaluator = CustomEvaluator() | |