""" Custom SQL evaluation metrics without RAGAS dependency. Provides comprehensive evaluation using only local models and basic metrics. """ import os import time import re from dataclasses import dataclass from typing import Dict, List, Any, Optional import pandas as pd import numpy as np from transformers import pipeline, AutoTokenizer, AutoModel import torch from langchain_models import langchain_models_registry @dataclass class EvaluationResult: """Result of SQL evaluation.""" model_name: str dataset: str case_id: str dialect: str question: str raw_sql: str # Raw SQL from model (before cleaning) generated_sql: str # Cleaned SQL (after cleaning) reference_sql: str correctness_exact: float result_match_f1: float exec_success: float latency_ms: float readability: float dialect_ok: float # Custom metrics without RAGAS sql_quality: float semantic_similarity: float structural_similarity: float composite_score: float timestamp: str class CustomEvaluator: """Custom evaluator for SQL generation without RAGAS dependency.""" def __init__(self): self.similarity_model = None self._setup_similarity_model() def _setup_similarity_model(self): """Setup a local model for semantic similarity.""" try: print("📥 Setting up local similarity model...") self.similarity_model = pipeline( "feature-extraction", model="sentence-transformers/all-MiniLM-L6-v2", device=-1 # Use CPU ) print("✅ Local similarity model configured") except Exception as e: print(f"⚠️ Could not setup similarity model: {e}") self.similarity_model = None def evaluate_sql( self, model_name: str, dataset: str, case_id: str, dialect: str, question: str, raw_sql: str, generated_sql: str, reference_sql: str, schema: str, db_conn ) -> EvaluationResult: """Evaluate generated SQL against reference.""" start_time = time.time() # Basic metrics correctness_exact = self._calculate_exact_correctness(generated_sql, reference_sql) result_match_f1 = self._calculate_result_match_f1(generated_sql, reference_sql, db_conn) exec_success = self._calculate_execution_success(generated_sql, db_conn) readability = self._calculate_readability(generated_sql) dialect_ok = self._calculate_dialect_compliance(generated_sql, dialect) # Custom metrics sql_quality = self._calculate_sql_quality(generated_sql, question, schema) semantic_similarity = self._calculate_semantic_similarity(generated_sql, reference_sql) structural_similarity = self._calculate_structural_similarity(generated_sql, reference_sql) latency_ms = (time.time() - start_time) * 1000 # Calculate composite score composite_score = ( correctness_exact * 0.3 + result_match_f1 * 0.3 + exec_success * 0.2 + sql_quality * 0.1 + semantic_similarity * 0.1 ) return EvaluationResult( model_name=model_name, dataset=dataset, case_id=case_id, dialect=dialect, question=question, raw_sql=raw_sql, generated_sql=generated_sql, reference_sql=reference_sql, correctness_exact=correctness_exact, result_match_f1=result_match_f1, exec_success=exec_success, latency_ms=latency_ms, readability=readability, dialect_ok=dialect_ok, sql_quality=sql_quality, semantic_similarity=semantic_similarity, structural_similarity=structural_similarity, composite_score=composite_score, timestamp=pd.Timestamp.now().isoformat() ) def _calculate_exact_correctness(self, generated_sql: str, reference_sql: str) -> float: """Calculate exact string match correctness.""" # Normalize SQL for comparison gen_norm = self._normalize_sql(generated_sql) ref_norm = self._normalize_sql(reference_sql) return 1.0 if gen_norm == ref_norm else 0.0 def _calculate_result_match_f1(self, generated_sql: str, reference_sql: str, db_conn) -> float: """Calculate F1 score based on query results.""" try: # Clean the generated SQL before execution clean_generated_sql = langchain_models_registry.clean_sql(generated_sql) # Execute both queries gen_result = db_conn.execute(clean_generated_sql).fetchall() ref_result = db_conn.execute(reference_sql).fetchall() # Convert to sets for comparison gen_set = set(str(row) for row in gen_result) ref_set = set(str(row) for row in ref_result) if not ref_set: return 1.0 if not gen_set else 0.0 # Calculate F1 intersection = gen_set & ref_set precision = len(intersection) / len(gen_set) if gen_set else 0.0 recall = len(intersection) / len(ref_set) f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 return f1 except Exception as e: print(f"⚠️ Error calculating result match F1: {e}") return 0.0 def _calculate_execution_success(self, generated_sql: str, db_conn) -> float: """Calculate if SQL executes successfully.""" try: # Clean the generated SQL before execution clean_generated_sql = langchain_models_registry.clean_sql(generated_sql) db_conn.execute(clean_generated_sql) return 1.0 except Exception as e: print(f"⚠️ SQL execution error: {e}") return 0.0 def _calculate_readability(self, generated_sql: str) -> float: """Calculate SQL readability score.""" try: # Basic readability metrics lines = generated_sql.strip().split('\n') avg_line_length = sum(len(line.strip()) for line in lines) / len(lines) if lines else 0 # Check for proper formatting has_proper_indentation = any(line.startswith(' ') or line.startswith('\t') for line in lines[1:]) has_keywords_capitalized = any(keyword in generated_sql.upper() for keyword in ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY']) # Score based on formatting score = 0.0 if has_keywords_capitalized: score += 0.4 if has_proper_indentation: score += 0.3 if 20 <= avg_line_length <= 80: # Reasonable line length score += 0.3 return min(score, 1.0) except Exception: return 0.0 def _calculate_dialect_compliance(self, generated_sql: str, dialect: str) -> float: """Calculate dialect compliance score.""" try: sql_upper = generated_sql.upper() score = 0.0 # Basic SQL compliance if any(keyword in sql_upper for keyword in ['SELECT', 'FROM']): score += 0.3 # Dialect-specific checks if dialect.lower() == 'presto': # Presto-specific features if 'ARRAY' in sql_upper or 'MAP' in sql_upper: score += 0.2 if 'APPROX_DISTINCT' in sql_upper: score += 0.2 elif dialect.lower() == 'bigquery': # BigQuery-specific features if 'ARRAY_AGG' in sql_upper or 'STRUCT' in sql_upper: score += 0.2 if 'QUALIFY' in sql_upper: score += 0.2 elif dialect.lower() == 'snowflake': # Snowflake-specific features if 'QUALIFY' in sql_upper: score += 0.2 if 'ARRAY_CONSTRUCT' in sql_upper: score += 0.2 # General SQL quality if 'WHERE' in sql_upper or 'GROUP BY' in sql_upper or 'ORDER BY' in sql_upper: score += 0.3 return min(score, 1.0) except Exception: return 0.0 def _calculate_sql_quality(self, generated_sql: str, question: str, schema: str) -> float: """Calculate overall SQL quality score.""" try: score = 0.0 # Check if SQL addresses the question question_lower = question.lower() sql_lower = generated_sql.lower() # Question-SQL alignment if 'count' in question_lower and 'count(' in sql_lower: score += 0.2 if 'average' in question_lower and 'avg(' in sql_lower: score += 0.2 if 'sum' in question_lower and 'sum(' in sql_lower: score += 0.2 if 'group' in question_lower and 'group by' in sql_lower: score += 0.2 # Schema usage schema_tables = re.findall(r'CREATE TABLE (\w+)', schema, re.IGNORECASE) used_tables = re.findall(r'FROM (\w+)', sql_lower) if any(table.lower() in used_tables for table in schema_tables): score += 0.2 return min(score, 1.0) except Exception: return 0.0 def _calculate_semantic_similarity(self, generated_sql: str, reference_sql: str) -> float: """Calculate semantic similarity between SQL queries.""" try: if not self.similarity_model: # Fallback to basic similarity return self._basic_similarity(generated_sql, reference_sql) # Use sentence transformer for semantic similarity embeddings = self.similarity_model([generated_sql, reference_sql]) # Handle different embedding formats if isinstance(embeddings, np.ndarray): # Single array with both embeddings if embeddings.shape[0] == 2: gen_emb = embeddings[0] ref_emb = embeddings[1] else: return self._basic_similarity(generated_sql, reference_sql) elif isinstance(embeddings, list) and len(embeddings) == 2: gen_emb = np.array(embeddings[0]) ref_emb = np.array(embeddings[1]) else: return self._basic_similarity(generated_sql, reference_sql) # Ensure both embeddings have the same shape if gen_emb.shape != ref_emb.shape: # Use basic similarity if shapes don't match return self._basic_similarity(generated_sql, reference_sql) # Calculate mean if multi-dimensional if len(gen_emb.shape) > 1: gen_emb = gen_emb.mean(axis=0) ref_emb = ref_emb.mean(axis=0) # Cosine similarity similarity = np.dot(gen_emb, ref_emb) / (np.linalg.norm(gen_emb) * np.linalg.norm(ref_emb)) return float(similarity) except Exception as e: print(f"⚠️ Error calculating semantic similarity: {e}") return self._basic_similarity(generated_sql, reference_sql) def _calculate_structural_similarity(self, generated_sql: str, reference_sql: str) -> float: """Calculate structural similarity between SQL queries.""" try: # Extract SQL structure gen_structure = self._extract_sql_structure(generated_sql) ref_structure = self._extract_sql_structure(reference_sql) # Calculate Jaccard similarity gen_set = set(gen_structure) ref_set = set(ref_structure) if not gen_set and not ref_set: return 1.0 if not gen_set or not ref_set: return 0.0 intersection = gen_set & ref_set union = gen_set | ref_set return len(intersection) / len(union) except Exception: return 0.0 def _basic_similarity(self, sql1: str, sql2: str) -> float: """Basic similarity calculation as fallback.""" try: # Extract keywords keywords1 = set(re.findall(r'\b(SELECT|FROM|WHERE|GROUP BY|ORDER BY|HAVING|JOIN|UNION)\b', sql1.upper())) keywords2 = set(re.findall(r'\b(SELECT|FROM|WHERE|GROUP BY|ORDER BY|HAVING|JOIN|UNION)\b', sql2.upper())) if not keywords1 and not keywords2: return 1.0 if not keywords1 or not keywords2: return 0.0 intersection = keywords1 & keywords2 union = keywords1 | keywords2 return len(intersection) / len(union) except Exception: return 0.0 def _extract_sql_structure(self, sql: str) -> List[str]: """Extract SQL structure elements.""" try: structure = [] sql_upper = sql.upper() # Extract main clauses clauses = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'HAVING', 'LIMIT'] for clause in clauses: if clause in sql_upper: structure.append(clause) # Extract functions functions = re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX|DISTINCT)\b', sql_upper) structure.extend(functions) # Extract operators operators = re.findall(r'\b(AND|OR|IN|NOT IN|BETWEEN|LIKE)\b', sql_upper) structure.extend(operators) return structure except Exception: return [] def _normalize_sql(self, sql: str) -> str: """Normalize SQL for comparison.""" try: # Remove extra whitespace normalized = re.sub(r'\s+', ' ', sql.strip()) # Convert to uppercase normalized = normalized.upper() # Remove semicolons normalized = normalized.rstrip(';') return normalized except Exception: return sql # Global instance custom_evaluator = CustomEvaluator()