File size: 6,870 Bytes
acd8e16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
#!/usr/bin/env python3
"""
Demo script for the NLโ†’SQL Leaderboard
Shows how the system works without requiring API keys.
"""

import os
import time
from evaluator import evaluator, DatasetManager
from models_registry import models_registry
from scoring import scoring_engine


def demo_dataset_loading():
    """Demonstrate dataset loading."""
    print("๐Ÿ“Š Dataset Loading Demo")
    print("-" * 30)
    
    dataset_manager = DatasetManager()
    datasets = dataset_manager.get_datasets()
    
    print(f"Available datasets: {list(datasets.keys())}")
    
    # Load NYC Taxi dataset
    if "nyc_taxi_small" in datasets:
        print(f"\nLoading NYC Taxi dataset...")
        cases = dataset_manager.load_cases("nyc_taxi_small")
        print(f"Found {len(cases)} test cases:")
        
        for i, case in enumerate(cases[:3], 1):  # Show first 3 cases
            print(f"  {i}. {case.id}: {case.question}")
            print(f"     Difficulty: {case.difficulty}")
            print(f"     Reference SQL (Presto): {case.reference_sql.get('presto', 'N/A')}")
            print()


def demo_models_loading():
    """Demonstrate models loading."""
    print("๐Ÿค– Models Loading Demo")
    print("-" * 30)
    
    models = models_registry.get_models()
    print(f"Available models: {len(models)}")
    
    for model in models:
        print(f"  - {model.name} ({model.provider})")
        print(f"    Model ID: {model.model_id}")
        print(f"    Description: {model.description}")
        print()


def demo_database_creation():
    """Demonstrate database creation."""
    print("๐Ÿ—„๏ธ Database Creation Demo")
    print("-" * 30)
    
    dataset_manager = DatasetManager()
    
    print("Creating NYC Taxi database...")
    db_path = dataset_manager.create_database("nyc_taxi_small")
    
    if os.path.exists(db_path):
        print(f"โœ“ Database created: {db_path}")
        
        # Show some sample data
        import duckdb
        conn = duckdb.connect(db_path)
        
        # Show table info
        tables = conn.execute("SHOW TABLES").fetchall()
        print(f"Tables: {[table[0] for table in tables]}")
        
        # Show sample data
        trips_count = conn.execute("SELECT COUNT(*) FROM trips").fetchone()[0]
        zones_count = conn.execute("SELECT COUNT(*) FROM zones").fetchone()[0]
        print(f"Sample data: {trips_count} trips, {zones_count} zones")
        
        # Show a sample query result
        result = conn.execute("SELECT COUNT(*) as total_trips FROM trips").fetchdf()
        print(f"Sample query result: {result.iloc[0, 0]} total trips")
        
        conn.close()
        
        # Clean up
        os.remove(db_path)
        print("โœ“ Database cleaned up")
    else:
        print("โœ— Database creation failed")


def demo_sql_transpilation():
    """Demonstrate SQL transpilation."""
    print("๐Ÿ”„ SQL Transpilation Demo")
    print("-" * 30)
    
    import sqlglot
    
    # Sample SQL query
    sample_sql = """
    SELECT 
        passenger_count,
        COUNT(*) as trip_count,
        AVG(fare_amount) as avg_fare
    FROM trips 
    WHERE total_amount > 20.0
    GROUP BY passenger_count
    ORDER BY trip_count DESC
    """
    
    print(f"Original SQL:\n{sample_sql.strip()}")
    
    # Parse and transpile to different dialects
    parsed = sqlglot.parse_one(sample_sql)
    
    dialects = ["presto", "bigquery", "snowflake"]
    for dialect in dialects:
        transpiled = parsed.sql(dialect=dialect)
        print(f"\n{dialect.upper()} SQL:")
        print(transpiled)


def demo_scoring():
    """Demonstrate scoring system."""
    print("๐Ÿ“ˆ Scoring System Demo")
    print("-" * 30)
    
    from scoring import Metrics
    
    # Simulate different evaluation results
    test_cases = [
        {
            "name": "Perfect Result",
            "metrics": Metrics(
                correctness_exact=1.0,
                result_match_f1=1.0,
                exec_success=1.0,
                latency_ms=100.0,
                readability=0.9,
                dialect_ok=1.0
            )
        },
        {
            "name": "Good Result",
            "metrics": Metrics(
                correctness_exact=0.0,
                result_match_f1=0.8,
                exec_success=1.0,
                latency_ms=500.0,
                readability=0.7,
                dialect_ok=1.0
            )
        },
        {
            "name": "Poor Result",
            "metrics": Metrics(
                correctness_exact=0.0,
                result_match_f1=0.2,
                exec_success=0.0,
                latency_ms=2000.0,
                readability=0.3,
                dialect_ok=0.0
            )
        }
    ]
    
    for case in test_cases:
        score = scoring_engine.compute_composite_score(case["metrics"])
        breakdown = scoring_engine.get_score_breakdown(case["metrics"])
        
        print(f"\n{case['name']}:")
        print(f"  Composite Score: {score:.4f}")
        print(f"  Breakdown:")
        for metric, value in breakdown.items():
            if metric != "composite_score":
                print(f"    {metric}: {value:.4f}")


def demo_prompt_templates():
    """Demonstrate prompt templates."""
    print("๐Ÿ“ Prompt Templates Demo")
    print("-" * 30)
    
    # Load a sample schema
    with open("tasks/nyc_taxi_small/schema.sql", "r") as f:
        schema = f.read()
    
    question = "How many total trips are there in the dataset?"
    
    # Show how templates work
    dialects = ["presto", "bigquery", "snowflake"]
    for dialect in dialects:
        template_path = f"prompts/template_{dialect}.txt"
        if os.path.exists(template_path):
            with open(template_path, "r") as f:
                template = f.read()
            
            prompt = template.format(schema=schema, question=question)
            print(f"\n{dialect.upper()} Prompt Template:")
            print("-" * 20)
            print(prompt[:200] + "..." if len(prompt) > 200 else prompt)


def main():
    """Run all demos."""
    print("๐ŸŽฏ NLโ†’SQL Leaderboard Demo")
    print("=" * 50)
    print("This demo shows how the system works without requiring API keys.")
    print("=" * 50)
    
    demos = [
        demo_dataset_loading,
        demo_models_loading,
        demo_database_creation,
        demo_sql_transpilation,
        demo_scoring,
        demo_prompt_templates
    ]
    
    for demo in demos:
        try:
            demo()
            print("\n" + "=" * 50)
        except Exception as e:
            print(f"โŒ Demo failed: {e}")
            print("=" * 50)
    
    print("\n๐ŸŽ‰ Demo completed!")
    print("\nTo run the full application:")
    print("  python launch.py")
    print("\nTo test the system:")
    print("  python test_system.py")


if __name__ == "__main__":
    main()