DataEngEval / src /demo.py
uparekh01151's picture
Initial commit for DataEngEval
acd8e16
#!/usr/bin/env python3
"""
Demo script for the NL→SQL Leaderboard
Shows how the system works without requiring API keys.
"""
import os
import time
from evaluator import evaluator, DatasetManager
from models_registry import models_registry
from scoring import scoring_engine
def demo_dataset_loading():
"""Demonstrate dataset loading."""
print("📊 Dataset Loading Demo")
print("-" * 30)
dataset_manager = DatasetManager()
datasets = dataset_manager.get_datasets()
print(f"Available datasets: {list(datasets.keys())}")
# Load NYC Taxi dataset
if "nyc_taxi_small" in datasets:
print(f"\nLoading NYC Taxi dataset...")
cases = dataset_manager.load_cases("nyc_taxi_small")
print(f"Found {len(cases)} test cases:")
for i, case in enumerate(cases[:3], 1): # Show first 3 cases
print(f" {i}. {case.id}: {case.question}")
print(f" Difficulty: {case.difficulty}")
print(f" Reference SQL (Presto): {case.reference_sql.get('presto', 'N/A')}")
print()
def demo_models_loading():
"""Demonstrate models loading."""
print("🤖 Models Loading Demo")
print("-" * 30)
models = models_registry.get_models()
print(f"Available models: {len(models)}")
for model in models:
print(f" - {model.name} ({model.provider})")
print(f" Model ID: {model.model_id}")
print(f" Description: {model.description}")
print()
def demo_database_creation():
"""Demonstrate database creation."""
print("🗄️ Database Creation Demo")
print("-" * 30)
dataset_manager = DatasetManager()
print("Creating NYC Taxi database...")
db_path = dataset_manager.create_database("nyc_taxi_small")
if os.path.exists(db_path):
print(f"✓ Database created: {db_path}")
# Show some sample data
import duckdb
conn = duckdb.connect(db_path)
# Show table info
tables = conn.execute("SHOW TABLES").fetchall()
print(f"Tables: {[table[0] for table in tables]}")
# Show sample data
trips_count = conn.execute("SELECT COUNT(*) FROM trips").fetchone()[0]
zones_count = conn.execute("SELECT COUNT(*) FROM zones").fetchone()[0]
print(f"Sample data: {trips_count} trips, {zones_count} zones")
# Show a sample query result
result = conn.execute("SELECT COUNT(*) as total_trips FROM trips").fetchdf()
print(f"Sample query result: {result.iloc[0, 0]} total trips")
conn.close()
# Clean up
os.remove(db_path)
print("✓ Database cleaned up")
else:
print("✗ Database creation failed")
def demo_sql_transpilation():
"""Demonstrate SQL transpilation."""
print("🔄 SQL Transpilation Demo")
print("-" * 30)
import sqlglot
# Sample SQL query
sample_sql = """
SELECT
passenger_count,
COUNT(*) as trip_count,
AVG(fare_amount) as avg_fare
FROM trips
WHERE total_amount > 20.0
GROUP BY passenger_count
ORDER BY trip_count DESC
"""
print(f"Original SQL:\n{sample_sql.strip()}")
# Parse and transpile to different dialects
parsed = sqlglot.parse_one(sample_sql)
dialects = ["presto", "bigquery", "snowflake"]
for dialect in dialects:
transpiled = parsed.sql(dialect=dialect)
print(f"\n{dialect.upper()} SQL:")
print(transpiled)
def demo_scoring():
"""Demonstrate scoring system."""
print("📈 Scoring System Demo")
print("-" * 30)
from scoring import Metrics
# Simulate different evaluation results
test_cases = [
{
"name": "Perfect Result",
"metrics": Metrics(
correctness_exact=1.0,
result_match_f1=1.0,
exec_success=1.0,
latency_ms=100.0,
readability=0.9,
dialect_ok=1.0
)
},
{
"name": "Good Result",
"metrics": Metrics(
correctness_exact=0.0,
result_match_f1=0.8,
exec_success=1.0,
latency_ms=500.0,
readability=0.7,
dialect_ok=1.0
)
},
{
"name": "Poor Result",
"metrics": Metrics(
correctness_exact=0.0,
result_match_f1=0.2,
exec_success=0.0,
latency_ms=2000.0,
readability=0.3,
dialect_ok=0.0
)
}
]
for case in test_cases:
score = scoring_engine.compute_composite_score(case["metrics"])
breakdown = scoring_engine.get_score_breakdown(case["metrics"])
print(f"\n{case['name']}:")
print(f" Composite Score: {score:.4f}")
print(f" Breakdown:")
for metric, value in breakdown.items():
if metric != "composite_score":
print(f" {metric}: {value:.4f}")
def demo_prompt_templates():
"""Demonstrate prompt templates."""
print("📝 Prompt Templates Demo")
print("-" * 30)
# Load a sample schema
with open("tasks/nyc_taxi_small/schema.sql", "r") as f:
schema = f.read()
question = "How many total trips are there in the dataset?"
# Show how templates work
dialects = ["presto", "bigquery", "snowflake"]
for dialect in dialects:
template_path = f"prompts/template_{dialect}.txt"
if os.path.exists(template_path):
with open(template_path, "r") as f:
template = f.read()
prompt = template.format(schema=schema, question=question)
print(f"\n{dialect.upper()} Prompt Template:")
print("-" * 20)
print(prompt[:200] + "..." if len(prompt) > 200 else prompt)
def main():
"""Run all demos."""
print("🎯 NL→SQL Leaderboard Demo")
print("=" * 50)
print("This demo shows how the system works without requiring API keys.")
print("=" * 50)
demos = [
demo_dataset_loading,
demo_models_loading,
demo_database_creation,
demo_sql_transpilation,
demo_scoring,
demo_prompt_templates
]
for demo in demos:
try:
demo()
print("\n" + "=" * 50)
except Exception as e:
print(f"❌ Demo failed: {e}")
print("=" * 50)
print("\n🎉 Demo completed!")
print("\nTo run the full application:")
print(" python launch.py")
print("\nTo test the system:")
print(" python test_system.py")
if __name__ == "__main__":
main()