Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Demo script for the NL→SQL Leaderboard | |
| Shows how the system works without requiring API keys. | |
| """ | |
| import os | |
| import time | |
| from evaluator import evaluator, DatasetManager | |
| from models_registry import models_registry | |
| from scoring import scoring_engine | |
| def demo_dataset_loading(): | |
| """Demonstrate dataset loading.""" | |
| print("📊 Dataset Loading Demo") | |
| print("-" * 30) | |
| dataset_manager = DatasetManager() | |
| datasets = dataset_manager.get_datasets() | |
| print(f"Available datasets: {list(datasets.keys())}") | |
| # Load NYC Taxi dataset | |
| if "nyc_taxi_small" in datasets: | |
| print(f"\nLoading NYC Taxi dataset...") | |
| cases = dataset_manager.load_cases("nyc_taxi_small") | |
| print(f"Found {len(cases)} test cases:") | |
| for i, case in enumerate(cases[:3], 1): # Show first 3 cases | |
| print(f" {i}. {case.id}: {case.question}") | |
| print(f" Difficulty: {case.difficulty}") | |
| print(f" Reference SQL (Presto): {case.reference_sql.get('presto', 'N/A')}") | |
| print() | |
| def demo_models_loading(): | |
| """Demonstrate models loading.""" | |
| print("🤖 Models Loading Demo") | |
| print("-" * 30) | |
| models = models_registry.get_models() | |
| print(f"Available models: {len(models)}") | |
| for model in models: | |
| print(f" - {model.name} ({model.provider})") | |
| print(f" Model ID: {model.model_id}") | |
| print(f" Description: {model.description}") | |
| print() | |
| def demo_database_creation(): | |
| """Demonstrate database creation.""" | |
| print("🗄️ Database Creation Demo") | |
| print("-" * 30) | |
| dataset_manager = DatasetManager() | |
| print("Creating NYC Taxi database...") | |
| db_path = dataset_manager.create_database("nyc_taxi_small") | |
| if os.path.exists(db_path): | |
| print(f"✓ Database created: {db_path}") | |
| # Show some sample data | |
| import duckdb | |
| conn = duckdb.connect(db_path) | |
| # Show table info | |
| tables = conn.execute("SHOW TABLES").fetchall() | |
| print(f"Tables: {[table[0] for table in tables]}") | |
| # Show sample data | |
| trips_count = conn.execute("SELECT COUNT(*) FROM trips").fetchone()[0] | |
| zones_count = conn.execute("SELECT COUNT(*) FROM zones").fetchone()[0] | |
| print(f"Sample data: {trips_count} trips, {zones_count} zones") | |
| # Show a sample query result | |
| result = conn.execute("SELECT COUNT(*) as total_trips FROM trips").fetchdf() | |
| print(f"Sample query result: {result.iloc[0, 0]} total trips") | |
| conn.close() | |
| # Clean up | |
| os.remove(db_path) | |
| print("✓ Database cleaned up") | |
| else: | |
| print("✗ Database creation failed") | |
| def demo_sql_transpilation(): | |
| """Demonstrate SQL transpilation.""" | |
| print("🔄 SQL Transpilation Demo") | |
| print("-" * 30) | |
| import sqlglot | |
| # Sample SQL query | |
| sample_sql = """ | |
| SELECT | |
| passenger_count, | |
| COUNT(*) as trip_count, | |
| AVG(fare_amount) as avg_fare | |
| FROM trips | |
| WHERE total_amount > 20.0 | |
| GROUP BY passenger_count | |
| ORDER BY trip_count DESC | |
| """ | |
| print(f"Original SQL:\n{sample_sql.strip()}") | |
| # Parse and transpile to different dialects | |
| parsed = sqlglot.parse_one(sample_sql) | |
| dialects = ["presto", "bigquery", "snowflake"] | |
| for dialect in dialects: | |
| transpiled = parsed.sql(dialect=dialect) | |
| print(f"\n{dialect.upper()} SQL:") | |
| print(transpiled) | |
| def demo_scoring(): | |
| """Demonstrate scoring system.""" | |
| print("📈 Scoring System Demo") | |
| print("-" * 30) | |
| from scoring import Metrics | |
| # Simulate different evaluation results | |
| test_cases = [ | |
| { | |
| "name": "Perfect Result", | |
| "metrics": Metrics( | |
| correctness_exact=1.0, | |
| result_match_f1=1.0, | |
| exec_success=1.0, | |
| latency_ms=100.0, | |
| readability=0.9, | |
| dialect_ok=1.0 | |
| ) | |
| }, | |
| { | |
| "name": "Good Result", | |
| "metrics": Metrics( | |
| correctness_exact=0.0, | |
| result_match_f1=0.8, | |
| exec_success=1.0, | |
| latency_ms=500.0, | |
| readability=0.7, | |
| dialect_ok=1.0 | |
| ) | |
| }, | |
| { | |
| "name": "Poor Result", | |
| "metrics": Metrics( | |
| correctness_exact=0.0, | |
| result_match_f1=0.2, | |
| exec_success=0.0, | |
| latency_ms=2000.0, | |
| readability=0.3, | |
| dialect_ok=0.0 | |
| ) | |
| } | |
| ] | |
| for case in test_cases: | |
| score = scoring_engine.compute_composite_score(case["metrics"]) | |
| breakdown = scoring_engine.get_score_breakdown(case["metrics"]) | |
| print(f"\n{case['name']}:") | |
| print(f" Composite Score: {score:.4f}") | |
| print(f" Breakdown:") | |
| for metric, value in breakdown.items(): | |
| if metric != "composite_score": | |
| print(f" {metric}: {value:.4f}") | |
| def demo_prompt_templates(): | |
| """Demonstrate prompt templates.""" | |
| print("📝 Prompt Templates Demo") | |
| print("-" * 30) | |
| # Load a sample schema | |
| with open("tasks/nyc_taxi_small/schema.sql", "r") as f: | |
| schema = f.read() | |
| question = "How many total trips are there in the dataset?" | |
| # Show how templates work | |
| dialects = ["presto", "bigquery", "snowflake"] | |
| for dialect in dialects: | |
| template_path = f"prompts/template_{dialect}.txt" | |
| if os.path.exists(template_path): | |
| with open(template_path, "r") as f: | |
| template = f.read() | |
| prompt = template.format(schema=schema, question=question) | |
| print(f"\n{dialect.upper()} Prompt Template:") | |
| print("-" * 20) | |
| print(prompt[:200] + "..." if len(prompt) > 200 else prompt) | |
| def main(): | |
| """Run all demos.""" | |
| print("🎯 NL→SQL Leaderboard Demo") | |
| print("=" * 50) | |
| print("This demo shows how the system works without requiring API keys.") | |
| print("=" * 50) | |
| demos = [ | |
| demo_dataset_loading, | |
| demo_models_loading, | |
| demo_database_creation, | |
| demo_sql_transpilation, | |
| demo_scoring, | |
| demo_prompt_templates | |
| ] | |
| for demo in demos: | |
| try: | |
| demo() | |
| print("\n" + "=" * 50) | |
| except Exception as e: | |
| print(f"❌ Demo failed: {e}") | |
| print("=" * 50) | |
| print("\n🎉 Demo completed!") | |
| print("\nTo run the full application:") | |
| print(" python launch.py") | |
| print("\nTo test the system:") | |
| print(" python test_system.py") | |
| if __name__ == "__main__": | |
| main() | |