DataEngEval / src /custom_evaluator.py
uparekh01151's picture
Initial commit for DataEngEval
acd8e16
"""
Custom SQL evaluation metrics without RAGAS dependency.
Provides comprehensive evaluation using only local models and basic metrics.
"""
import os
import time
import re
from dataclasses import dataclass
from typing import Dict, List, Any, Optional
import pandas as pd
import numpy as np
from transformers import pipeline, AutoTokenizer, AutoModel
import torch
from langchain_models import langchain_models_registry
@dataclass
class EvaluationResult:
"""Result of SQL evaluation."""
model_name: str
dataset: str
case_id: str
dialect: str
question: str
raw_sql: str # Raw SQL from model (before cleaning)
generated_sql: str # Cleaned SQL (after cleaning)
reference_sql: str
correctness_exact: float
result_match_f1: float
exec_success: float
latency_ms: float
readability: float
dialect_ok: float
# Custom metrics without RAGAS
sql_quality: float
semantic_similarity: float
structural_similarity: float
composite_score: float
timestamp: str
class CustomEvaluator:
"""Custom evaluator for SQL generation without RAGAS dependency."""
def __init__(self):
self.similarity_model = None
self._setup_similarity_model()
def _setup_similarity_model(self):
"""Setup a local model for semantic similarity."""
try:
print("📥 Setting up local similarity model...")
self.similarity_model = pipeline(
"feature-extraction",
model="sentence-transformers/all-MiniLM-L6-v2",
device=-1 # Use CPU
)
print("✅ Local similarity model configured")
except Exception as e:
print(f"⚠️ Could not setup similarity model: {e}")
self.similarity_model = None
def evaluate_sql(
self,
model_name: str,
dataset: str,
case_id: str,
dialect: str,
question: str,
raw_sql: str,
generated_sql: str,
reference_sql: str,
schema: str,
db_conn
) -> EvaluationResult:
"""Evaluate generated SQL against reference."""
start_time = time.time()
# Basic metrics
correctness_exact = self._calculate_exact_correctness(generated_sql, reference_sql)
result_match_f1 = self._calculate_result_match_f1(generated_sql, reference_sql, db_conn)
exec_success = self._calculate_execution_success(generated_sql, db_conn)
readability = self._calculate_readability(generated_sql)
dialect_ok = self._calculate_dialect_compliance(generated_sql, dialect)
# Custom metrics
sql_quality = self._calculate_sql_quality(generated_sql, question, schema)
semantic_similarity = self._calculate_semantic_similarity(generated_sql, reference_sql)
structural_similarity = self._calculate_structural_similarity(generated_sql, reference_sql)
latency_ms = (time.time() - start_time) * 1000
# Calculate composite score
composite_score = (
correctness_exact * 0.3 +
result_match_f1 * 0.3 +
exec_success * 0.2 +
sql_quality * 0.1 +
semantic_similarity * 0.1
)
return EvaluationResult(
model_name=model_name,
dataset=dataset,
case_id=case_id,
dialect=dialect,
question=question,
raw_sql=raw_sql,
generated_sql=generated_sql,
reference_sql=reference_sql,
correctness_exact=correctness_exact,
result_match_f1=result_match_f1,
exec_success=exec_success,
latency_ms=latency_ms,
readability=readability,
dialect_ok=dialect_ok,
sql_quality=sql_quality,
semantic_similarity=semantic_similarity,
structural_similarity=structural_similarity,
composite_score=composite_score,
timestamp=pd.Timestamp.now().isoformat()
)
def _calculate_exact_correctness(self, generated_sql: str, reference_sql: str) -> float:
"""Calculate exact string match correctness."""
# Normalize SQL for comparison
gen_norm = self._normalize_sql(generated_sql)
ref_norm = self._normalize_sql(reference_sql)
return 1.0 if gen_norm == ref_norm else 0.0
def _calculate_result_match_f1(self, generated_sql: str, reference_sql: str, db_conn) -> float:
"""Calculate F1 score based on query results."""
try:
# Clean the generated SQL before execution
clean_generated_sql = langchain_models_registry.clean_sql(generated_sql)
# Execute both queries
gen_result = db_conn.execute(clean_generated_sql).fetchall()
ref_result = db_conn.execute(reference_sql).fetchall()
# Convert to sets for comparison
gen_set = set(str(row) for row in gen_result)
ref_set = set(str(row) for row in ref_result)
if not ref_set:
return 1.0 if not gen_set else 0.0
# Calculate F1
intersection = gen_set & ref_set
precision = len(intersection) / len(gen_set) if gen_set else 0.0
recall = len(intersection) / len(ref_set)
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
return f1
except Exception as e:
print(f"⚠️ Error calculating result match F1: {e}")
return 0.0
def _calculate_execution_success(self, generated_sql: str, db_conn) -> float:
"""Calculate if SQL executes successfully."""
try:
# Clean the generated SQL before execution
clean_generated_sql = langchain_models_registry.clean_sql(generated_sql)
db_conn.execute(clean_generated_sql)
return 1.0
except Exception as e:
print(f"⚠️ SQL execution error: {e}")
return 0.0
def _calculate_readability(self, generated_sql: str) -> float:
"""Calculate SQL readability score."""
try:
# Basic readability metrics
lines = generated_sql.strip().split('\n')
avg_line_length = sum(len(line.strip()) for line in lines) / len(lines) if lines else 0
# Check for proper formatting
has_proper_indentation = any(line.startswith(' ') or line.startswith('\t') for line in lines[1:])
has_keywords_capitalized = any(keyword in generated_sql.upper() for keyword in ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY'])
# Score based on formatting
score = 0.0
if has_keywords_capitalized:
score += 0.4
if has_proper_indentation:
score += 0.3
if 20 <= avg_line_length <= 80: # Reasonable line length
score += 0.3
return min(score, 1.0)
except Exception:
return 0.0
def _calculate_dialect_compliance(self, generated_sql: str, dialect: str) -> float:
"""Calculate dialect compliance score."""
try:
sql_upper = generated_sql.upper()
score = 0.0
# Basic SQL compliance
if any(keyword in sql_upper for keyword in ['SELECT', 'FROM']):
score += 0.3
# Dialect-specific checks
if dialect.lower() == 'presto':
# Presto-specific features
if 'ARRAY' in sql_upper or 'MAP' in sql_upper:
score += 0.2
if 'APPROX_DISTINCT' in sql_upper:
score += 0.2
elif dialect.lower() == 'bigquery':
# BigQuery-specific features
if 'ARRAY_AGG' in sql_upper or 'STRUCT' in sql_upper:
score += 0.2
if 'QUALIFY' in sql_upper:
score += 0.2
elif dialect.lower() == 'snowflake':
# Snowflake-specific features
if 'QUALIFY' in sql_upper:
score += 0.2
if 'ARRAY_CONSTRUCT' in sql_upper:
score += 0.2
# General SQL quality
if 'WHERE' in sql_upper or 'GROUP BY' in sql_upper or 'ORDER BY' in sql_upper:
score += 0.3
return min(score, 1.0)
except Exception:
return 0.0
def _calculate_sql_quality(self, generated_sql: str, question: str, schema: str) -> float:
"""Calculate overall SQL quality score."""
try:
score = 0.0
# Check if SQL addresses the question
question_lower = question.lower()
sql_lower = generated_sql.lower()
# Question-SQL alignment
if 'count' in question_lower and 'count(' in sql_lower:
score += 0.2
if 'average' in question_lower and 'avg(' in sql_lower:
score += 0.2
if 'sum' in question_lower and 'sum(' in sql_lower:
score += 0.2
if 'group' in question_lower and 'group by' in sql_lower:
score += 0.2
# Schema usage
schema_tables = re.findall(r'CREATE TABLE (\w+)', schema, re.IGNORECASE)
used_tables = re.findall(r'FROM (\w+)', sql_lower)
if any(table.lower() in used_tables for table in schema_tables):
score += 0.2
return min(score, 1.0)
except Exception:
return 0.0
def _calculate_semantic_similarity(self, generated_sql: str, reference_sql: str) -> float:
"""Calculate semantic similarity between SQL queries."""
try:
if not self.similarity_model:
# Fallback to basic similarity
return self._basic_similarity(generated_sql, reference_sql)
# Use sentence transformer for semantic similarity
embeddings = self.similarity_model([generated_sql, reference_sql])
# Handle different embedding formats
if isinstance(embeddings, np.ndarray):
# Single array with both embeddings
if embeddings.shape[0] == 2:
gen_emb = embeddings[0]
ref_emb = embeddings[1]
else:
return self._basic_similarity(generated_sql, reference_sql)
elif isinstance(embeddings, list) and len(embeddings) == 2:
gen_emb = np.array(embeddings[0])
ref_emb = np.array(embeddings[1])
else:
return self._basic_similarity(generated_sql, reference_sql)
# Ensure both embeddings have the same shape
if gen_emb.shape != ref_emb.shape:
# Use basic similarity if shapes don't match
return self._basic_similarity(generated_sql, reference_sql)
# Calculate mean if multi-dimensional
if len(gen_emb.shape) > 1:
gen_emb = gen_emb.mean(axis=0)
ref_emb = ref_emb.mean(axis=0)
# Cosine similarity
similarity = np.dot(gen_emb, ref_emb) / (np.linalg.norm(gen_emb) * np.linalg.norm(ref_emb))
return float(similarity)
except Exception as e:
print(f"⚠️ Error calculating semantic similarity: {e}")
return self._basic_similarity(generated_sql, reference_sql)
def _calculate_structural_similarity(self, generated_sql: str, reference_sql: str) -> float:
"""Calculate structural similarity between SQL queries."""
try:
# Extract SQL structure
gen_structure = self._extract_sql_structure(generated_sql)
ref_structure = self._extract_sql_structure(reference_sql)
# Calculate Jaccard similarity
gen_set = set(gen_structure)
ref_set = set(ref_structure)
if not gen_set and not ref_set:
return 1.0
if not gen_set or not ref_set:
return 0.0
intersection = gen_set & ref_set
union = gen_set | ref_set
return len(intersection) / len(union)
except Exception:
return 0.0
def _basic_similarity(self, sql1: str, sql2: str) -> float:
"""Basic similarity calculation as fallback."""
try:
# Extract keywords
keywords1 = set(re.findall(r'\b(SELECT|FROM|WHERE|GROUP BY|ORDER BY|HAVING|JOIN|UNION)\b', sql1.upper()))
keywords2 = set(re.findall(r'\b(SELECT|FROM|WHERE|GROUP BY|ORDER BY|HAVING|JOIN|UNION)\b', sql2.upper()))
if not keywords1 and not keywords2:
return 1.0
if not keywords1 or not keywords2:
return 0.0
intersection = keywords1 & keywords2
union = keywords1 | keywords2
return len(intersection) / len(union)
except Exception:
return 0.0
def _extract_sql_structure(self, sql: str) -> List[str]:
"""Extract SQL structure elements."""
try:
structure = []
sql_upper = sql.upper()
# Extract main clauses
clauses = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'HAVING', 'LIMIT']
for clause in clauses:
if clause in sql_upper:
structure.append(clause)
# Extract functions
functions = re.findall(r'\b(COUNT|SUM|AVG|MIN|MAX|DISTINCT)\b', sql_upper)
structure.extend(functions)
# Extract operators
operators = re.findall(r'\b(AND|OR|IN|NOT IN|BETWEEN|LIKE)\b', sql_upper)
structure.extend(operators)
return structure
except Exception:
return []
def _normalize_sql(self, sql: str) -> str:
"""Normalize SQL for comparison."""
try:
# Remove extra whitespace
normalized = re.sub(r'\s+', ' ', sql.strip())
# Convert to uppercase
normalized = normalized.upper()
# Remove semicolons
normalized = normalized.rstrip(';')
return normalized
except Exception:
return sql
# Global instance
custom_evaluator = CustomEvaluator()