.gitignore CHANGED
@@ -48,4 +48,3 @@ flagged/
48
 
49
  # Logs
50
  *.log
51
- memory/chroma.sqlite3
 
48
 
49
  # Logs
50
  *.log
 
513935c4d2db2d2d/query_results_661f24f3.csv DELETED
@@ -1,5 +0,0 @@
1
- id,title,source_url,author,published_date,image_url,type
2
- 1242,These preteen go-kart drivers are spending millions for a shot at F1 racing,https://www.washingtonpost.com/world/interactive/2024/formula-1-karting-children-parents-racing-costs/,The Washington Post,2025-07-17,,spotlight
3
- 1912,A Formula 1 pistop: 2 seconds of adrenaline and pressure,https://www.washingtonpost.com/sports/interactive/2023/formula-one-pitstop-haas-red-bull/,The Washington Post,2023-07-17,,spotlight
4
- 7047,Racing Against History,http://www.nytimes.com/interactive/2012/08/01/sports/olympics/racing-against-history.html?gwh=2D12538F1CD4F05B39F50285EFA1313E,The New York Times,2012-07-17,,spotlight
5
- 442,75 years of innovation: How F1 has evolved since 1950 and where it's headed,https://www.espn.com/espn/feature/story/_/id/43832710/how-f1-evolved-1950-where-headed-2026,ESPN,,,spotlight
 
 
 
 
 
 
513935c4d2db2d2d/query_results_8b61c5d0.csv DELETED
@@ -1,2 +0,0 @@
1
- id,title,source_url,author,published_date,image_url,type
2
- 391,Our World | Justdiggit,https://ourworld.justdiggit.org/en/,Just Digg It,2024-01-19,https://towumekminbldlabbyss.supabase.co/storage/v1/object/public/images/posts/share-ourworld-justdiggit.jpg,spotlight
 
 
 
513935c4d2db2d2d/query_results_c6e0aed3.csv DELETED
@@ -1,9 +0,0 @@
1
- id,title,source_url,author,published_date,image_url,type
2
- 1242,These preteen go-kart drivers are spending millions for a shot at F1 racing,https://www.washingtonpost.com/world/interactive/2024/formula-1-karting-children-parents-racing-costs/,The Washington Post,2025-07-17,,spotlight
3
- 925,Weed drinks are everywhere in Minnesota. Other states are now embracing them.,https://www.politico.com/news/2024/07/10/minnesota-weed-drinks-00165375,POLITICO,2025-07-17,,spotlight
4
- 1912,A Formula 1 pistop: 2 seconds of adrenaline and pressure,https://www.washingtonpost.com/sports/interactive/2023/formula-one-pitstop-haas-red-bull/,The Washington Post,2023-07-17,,spotlight
5
- 3122,Rising Reality: A look at the difficulties facing communities bracing for climate change all along San Francisco Bay,https://www.sfchronicle.com/projects/2021/san-francisco-bay-area-sea-level-rise-2021/mission-creek,San Francisco Chronicle,2021-07-17,,spotlight
6
- 7047,Racing Against History,http://www.nytimes.com/interactive/2012/08/01/sports/olympics/racing-against-history.html?gwh=2D12538F1CD4F05B39F50285EFA1313E,The New York Times,2012-07-17,,spotlight
7
- 3754,For embracing responsive design,http://www.bostonglobe.com/arts/specials/gardner,Boston Globe,2011-07-17,,spotlight
8
- 46,Privacy Preserving Proximity Tracing,https://tracing.ft0.ch/#/,Privacy Preserving Proximity Tracing,,,spotlight
9
- 442,75 years of innovation: How F1 has evolved since 1950 and where it's headed,https://www.espn.com/espn/feature/story/_/id/43832710/how-f1-evolved-1950-where-headed-2026,ESPN,,,spotlight
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -10,7 +10,6 @@ Now with Datawrapper integration for chart generation!
10
  import os
11
  import io
12
  import asyncio
13
- import time
14
  import pandas as pd
15
  import gradio as gr
16
  from dotenv import load_dotenv
@@ -19,7 +18,6 @@ from src.datawrapper_client import create_and_publish_chart, get_iframe_html
19
  from datetime import datetime, timedelta
20
  from collections import defaultdict
21
  from src.vanna import VannaComponent
22
- from src.query_intent_classifier import classify_query, IntentClassifier
23
 
24
  # Load environment variables
25
  load_dotenv()
@@ -56,32 +54,6 @@ except Exception as e:
56
  print(f"✗ Error initializing Vanna: {e}")
57
  raise
58
 
59
- # CSV cleanup function
60
- def cleanup_old_csv_files():
61
- """Delete CSV files older than 24 hours to prevent accumulation"""
62
- folder = "513935c4d2db2d2d"
63
- if not os.path.exists(folder):
64
- return
65
-
66
- cleaned = 0
67
- for file in os.listdir(folder):
68
- if file.endswith(".csv"):
69
- file_path = os.path.join(folder, file)
70
- try:
71
- # Check if file is older than 24 hours
72
- if os.path.getmtime(file_path) < time.time() - 86400:
73
- os.remove(file_path)
74
- cleaned += 1
75
- except Exception as e:
76
- print(f"Warning: Could not delete {file_path}: {e}")
77
-
78
- if cleaned > 0:
79
- print(f"✓ Cleaned up {cleaned} old CSV files")
80
-
81
- # Run cleanup on startup
82
- print("Cleaning up old CSV files...")
83
- cleanup_old_csv_files()
84
-
85
  def check_rate_limit(request: gr.Request) -> tuple[bool, int]:
86
  """Check if user has exceeded rate limit"""
87
  if request is None:
@@ -138,41 +110,23 @@ def recommend_stream(message: str, history: list, request: gr.Request):
138
  yield f"Error generating response: {str(e)}\n\nPlease check your environment variables (HF_TOKEN, SUPABASE_URL, SUPABASE_KEY) and try again."
139
 
140
 
141
- def generate_chart_from_csv(csv_file, user_prompt, api_key):
142
  """
143
- Generate a Datawrapper chart from uploaded CSV and user prompt using user's API key.
144
 
145
  Args:
146
  csv_file: Uploaded CSV file
147
  user_prompt: User's description of the chart
148
- api_key: User's Datawrapper API key
149
 
150
  Returns:
151
  HTML string with iframe or error message
152
  """
153
- # Validate API key first
154
- if not api_key or api_key.strip() == "":
155
- return """
156
- <div style='padding: 50px; text-align: center; color: #d9534f;'>
157
- <h3>❌ No API Key Provided</h3>
158
- <p>Please enter your Datawrapper API key above to generate charts.</p>
159
- <p style='margin-top: 15px;'>
160
- <a href='https://app.datawrapper.de/account/api-tokens' target='_blank'
161
- style='color: #1976d2; text-decoration: underline;'>Get your API key →</a>
162
- </p>
163
- </div>
164
- """
165
-
166
  if not csv_file:
167
  return "<div style='padding: 50px; text-align: center;'>Please upload a CSV file to generate a chart.</div>"
168
 
169
  if not user_prompt or user_prompt.strip() == "":
170
  return "<div style='padding: 50px; text-align: center;'>Please describe what chart you want to create.</div>"
171
 
172
- # Temporarily set the API key in environment for this request
173
- original_key = os.environ.get("DATAWRAPPER_ACCESS_TOKEN")
174
- os.environ["DATAWRAPPER_ACCESS_TOKEN"] = api_key
175
-
176
  try:
177
  # Show loading message
178
  loading_html = """
@@ -238,15 +192,9 @@ def generate_chart_from_csv(csv_file, user_prompt, api_key):
238
  <div style='padding: 50px; text-align: center; color: red;'>
239
  <h3>❌ Error</h3>
240
  <p>{str(e)}</p>
241
- <p style='font-size: 0.9em; color: #666;'>Please ensure your CSV is properly formatted and your API key is correct.</p>
242
  </div>
243
  """
244
- finally:
245
- # Restore original API key or remove it
246
- if original_key:
247
- os.environ["DATAWRAPPER_ACCESS_TOKEN"] = original_key
248
- elif "DATAWRAPPER_ACCESS_TOKEN" in os.environ:
249
- del os.environ["DATAWRAPPER_ACCESS_TOKEN"]
250
 
251
  def csv_to_cards_html(csv_text: str) -> str:
252
  """
@@ -263,7 +211,11 @@ def csv_to_cards_html(csv_text: str) -> str:
263
  source_url = row.get("source_url", "#")
264
  author = row.get("author", "Inconnu")
265
  published_date = row.get("published_date", "")
266
- image_url = row.get("image_url", "https://fpoimg.com/800x600?text=Image+not+found")
 
 
 
 
267
 
268
  cards_html += f"""
269
  <div style="background: white; border-radius: 10px; box-shadow: 0 2px 8px rgba(0,0,0,0.1);
@@ -275,7 +227,7 @@ def csv_to_cards_html(csv_text: str) -> str:
275
  <p style="margin:0; color:#999; font-size:0.8em;">{published_date}</p>
276
  <a href="{source_url}" target="_blank"
277
  style="display:inline-block; margin-top:8px; font-size:0.9em; color:#1976d2; text-decoration:none;">
278
- 🔗 Source
279
  </a>
280
  </div>
281
  </div>
@@ -310,60 +262,20 @@ async def search_inspiration_from_database(user_prompt):
310
  """
311
 
312
  try:
313
- # Classify user intent
314
- print(f"\n{'='*60}")
315
- print(f"[SEARCH] User prompt: {user_prompt}")
316
-
317
- classifier = IntentClassifier()
318
- classification = classifier.classify(user_prompt)
319
-
320
- print(f"[INTENT] Type: {classification['intent'].value}")
321
- print(f"[INTENT] Keywords: {classification['keywords']}")
322
- print(f"[INTENT] Inferred tags: {classification['tags']}")
323
- print(f"[INTENT] Short query: {classification['is_short_query']}")
324
-
325
- # Enhance prompt with intent guidance
326
- enhanced_prompt = classifier.format_for_vanna(classification)
327
- full_prompt = f"{user_prompt}\n\n{enhanced_prompt}"
328
-
329
- print(f"[VANNA] Sending enhanced prompt to Vanna...")
330
- response = await vanna.ask(full_prompt)
331
- print(f"[VANNA] Response received: {repr(response)[:200]}...")
332
- print(f"{'='*60}\n")
333
 
334
  clean_response = response.strip()
335
 
336
- # Check for empty query results (0 rows returned)
337
- if "No rows returned" in clean_response or "0 rows" in clean_response.lower():
338
- return f"""
339
- <div style='padding: 50px; text-align: center; color: #f0ad4e;'>
340
- <h3>🔍 No Results Found</h3>
341
- <p>Your query was executed successfully, but no posts matched your criteria.</p>
342
- <p style='margin-top: 15px; font-weight: 600;'>Suggestions:</p>
343
- <ul style='list-style: none; padding: 0; text-align: left; display: inline-block;'>
344
- <li>• Try broader keywords (e.g., "visualization" instead of "F1 dataviz")</li>
345
- <li>• Search by author names (e.g., "New York Times")</li>
346
- <li>• Use simple terms (e.g., "interactive", "maps")</li>
347
- </ul>
348
- <p style='margin-top: 15px; font-style: italic; color: #666; font-size: 0.9em;'>
349
- <strong>Note:</strong> Most posts are currently being enriched with tags.<br/>
350
- Keyword search works for all {classification.get('total_posts', '7,000+')} posts in the database.
351
- </p>
352
- </div>
353
- """
354
-
355
- # Check for errors or warnings
356
- if clean_response.startswith("⚠️") or clean_response.startswith("❌") or "Aucun CSV détecté" in clean_response:
357
  return f"""
358
  <div style='padding: 50px; text-align: center; color: #d9534f;'>
359
- <h3>❌ Query Error</h3>
360
- <p>The AI encountered an issue processing your request.</p>
361
- <p style='margin-top: 10px; font-size: 0.9em; color: #666;'>{clean_response[:200]}</p>
362
- <p style='margin-top: 15px;'>Try rephrasing your query or being more specific.</p>
363
  </div>
364
  """
365
 
366
- # Process CSV response
367
  csv_text = (
368
  clean_response
369
  .strip("```")
@@ -371,15 +283,11 @@ async def search_inspiration_from_database(user_prompt):
371
  .replace("CSV", "")
372
  )
373
 
374
- # Check if response contains CSV data
375
- if "," not in csv_text or "id,title" not in csv_text.lower():
376
  return f"""
377
  <div style='padding: 50px; text-align: center; color: #d9534f;'>
378
- <h3>❌ Invalid Response Format</h3>
379
- <p>The database query didn't return structured data.</p>
380
- <p style='margin-top: 10px; font-size: 0.9em; color: #666;'>
381
- This might be a temporary issue. Please try again.
382
- </p>
383
  </div>
384
  """
385
 
@@ -387,17 +295,11 @@ async def search_inspiration_from_database(user_prompt):
387
  return cards_html
388
 
389
  except Exception as e:
390
- print(f"❌ Exception in search_inspiration_from_database: {str(e)}")
391
- import traceback
392
- traceback.print_exc()
393
  return f"""
394
  <div style='padding: 50px; text-align: center; color: red;'>
395
- <h3>❌ System Error</h3>
396
- <p style='margin-bottom: 10px;'>An unexpected error occurred:</p>
397
- <p style='font-family: monospace; font-size: 0.85em; color: #666;'>{str(e)}</p>
398
- <p style='margin-top: 15px; font-size: 0.9em; color: #666;'>
399
- Please check the console logs for more details.
400
- </p>
401
  </div>
402
  """
403
 
@@ -430,63 +332,18 @@ with gr.Blocks(
430
  gr.Markdown("""
431
  # 📊 Viz LLM
432
 
433
- Discover inspiring visualizations, refine your design ideas, or generate charts using Datawrapper.
434
  """)
435
 
436
- # JavaScript for localStorage persistence
437
- gr.HTML("""
438
- <script>
439
- // Save API key to localStorage when it changes
440
- function saveApiKeyToStorage(key) {
441
- if (key && key.trim() !== '') {
442
- localStorage.setItem('datawrapper_api_key', key);
443
- }
444
- }
445
-
446
- // Load API key from localStorage on page load
447
- function loadApiKeyFromStorage() {
448
- return localStorage.getItem('datawrapper_api_key') || '';
449
- }
450
-
451
- // Auto-load API key when the page loads
452
- window.addEventListener('DOMContentLoaded', function() {
453
- setTimeout(function() {
454
- const savedKey = loadApiKeyFromStorage();
455
- if (savedKey) {
456
- const apiKeyInput = document.querySelector('input[type="password"]');
457
- if (apiKeyInput) {
458
- apiKeyInput.value = savedKey;
459
- // Trigger change event to update Gradio state
460
- apiKeyInput.dispatchEvent(new Event('input', { bubbles: true }));
461
- }
462
- }
463
- }, 1000);
464
- });
465
- </script>
466
- """)
467
-
468
- # Mode selector buttons (reordered: Inspiration, Refinement, Chart)
469
  with gr.Row():
470
- inspiration_btn = gr.Button(" Inspiration", variant="primary", elem_classes="mode-button")
471
- ideation_btn = gr.Button("💡 Refinement", variant="secondary", elem_classes="mode-button")
472
- chart_gen_btn = gr.Button("📊 Chart", variant="secondary", elem_classes="mode-button")
473
 
474
 
475
- # Inspiration Mode: Search interface (shown by default)
476
- with gr.Column(visible=True) as inspiration_container:
477
- with gr.Row():
478
- inspiration_prompt_input = gr.Textbox(
479
- placeholder="Search for inspiration (e.g., 'F1', 'interactive maps')...",
480
- show_label=False,
481
- scale=4,
482
- container=False
483
- )
484
- inspiration_search_btn = gr.Button("🔍 Search", variant="primary", scale=1)
485
-
486
- inspiration_cards_html = gr.HTML("")
487
-
488
- # Refinement Mode: Chat interface (hidden by default, wrapped in Column)
489
- with gr.Column(visible=False) as ideation_container:
490
  ideation_interface = gr.ChatInterface(
491
  fn=recommend_stream,
492
  type="messages",
@@ -503,32 +360,6 @@ with gr.Blocks(
503
 
504
  # Chart Generation Mode: Chart controls and output (hidden by default)
505
  with gr.Column(visible=False) as chart_gen_container:
506
- gr.Markdown("### Chart Generator")
507
-
508
- # API Key Input (collapsible)
509
- with gr.Accordion("🔑 Datawrapper API Key", open=False):
510
- gr.Markdown("""
511
- Enter your Datawrapper API key to generate charts. Your key is stored in your browser and persists across sessions.
512
-
513
- **Get your key**: [Datawrapper Account Settings](https://app.datawrapper.de/account/api-tokens)
514
- """)
515
-
516
- # Warning about permissions
517
- gr.HTML("""
518
- <div style="background: #fff3cd; border: 1px solid #ffc107; border-radius: 5px; padding: 12px; margin: 10px 0;">
519
- <strong>⚠️ Important:</strong> When creating your API key, toggle <strong>ALL permissions</strong> (Read & Write for Charts, Tables, Folders, etc.) otherwise chart generation will fail.
520
- </div>
521
- """)
522
-
523
- api_key_input = gr.Textbox(
524
- label="API Key",
525
- placeholder="Paste your Datawrapper API key here...",
526
- type="password",
527
- value=""
528
- )
529
-
530
- api_key_status = gr.Markdown("⚠️ Status: No API key provided")
531
-
532
  csv_upload = gr.File(
533
  label="📁 Upload CSV File",
534
  file_types=[".csv"],
@@ -548,111 +379,79 @@ with gr.Blocks(
548
  label="Generated Chart"
549
  )
550
 
551
- # API key state management
552
- api_key_state = gr.State(value="")
553
-
554
- def validate_api_key(api_key: str) -> tuple[str, str]:
555
- """Validate and store API key"""
556
- if not api_key or api_key.strip() == "":
557
- return "", "⚠️ Status: No API key provided"
558
-
559
- # Basic validation (check format)
560
- if len(api_key) < 20:
561
- return "", "❌ Status: Invalid API key format (too short)"
562
-
563
- # Key looks valid - it will be saved to localStorage via JavaScript
564
- masked_key = f"...{api_key[-6:]}" if len(api_key) > 6 else "***"
565
- return api_key, f"✅ Status: API key saved to browser storage (ends with {masked_key})"
566
 
567
- # Mode switching functions (updated for new order: Inspiration, Refinement, Chart)
568
- def switch_to_inspiration():
569
- return [
570
- gr.update(variant="primary"), # inspiration_btn
571
- gr.update(variant="secondary"), # ideation_btn
572
- gr.update(variant="secondary"), # chart_gen_btn
573
- gr.update(visible=True), # inspiration_container
574
- gr.update(visible=False), # ideation_container
575
- gr.update(visible=False), # chart_gen_container
576
- ]
577
 
 
578
  def switch_to_ideation():
579
  return [
580
- gr.update(variant="secondary"), # inspiration_btn
581
  gr.update(variant="primary"), # ideation_btn
582
  gr.update(variant="secondary"), # chart_gen_btn
583
- gr.update(visible=False), # inspiration_container
584
  gr.update(visible=True), # ideation_container
585
  gr.update(visible=False), # chart_gen_container
 
586
  ]
587
 
588
  def switch_to_chart_gen():
589
  return [
590
- gr.update(variant="secondary"), # inspiration_btn
591
  gr.update(variant="secondary"), # ideation_btn
592
  gr.update(variant="primary"), # chart_gen_btn
593
- gr.update(visible=False), # inspiration_container
594
  gr.update(visible=False), # ideation_container
595
  gr.update(visible=True), # chart_gen_container
 
596
  ]
597
 
598
- # Wire up mode switching (updated order: inspiration, ideation, chart)
599
- inspiration_btn.click(
600
- fn=switch_to_inspiration,
601
- inputs=[],
602
- outputs=[inspiration_btn, ideation_btn, chart_gen_btn, inspiration_container, ideation_container, chart_gen_container]
603
- )
 
 
 
604
 
 
605
  ideation_btn.click(
606
  fn=switch_to_ideation,
607
  inputs=[],
608
- outputs=[inspiration_btn, ideation_btn, chart_gen_btn, inspiration_container, ideation_container, chart_gen_container]
609
  )
610
 
611
  chart_gen_btn.click(
612
  fn=switch_to_chart_gen,
613
  inputs=[],
614
- outputs=[inspiration_btn, ideation_btn, chart_gen_btn, inspiration_container, ideation_container, chart_gen_container]
615
  )
616
 
617
- # Connect API key validation and localStorage save
618
- api_key_input.change(
619
- fn=validate_api_key,
620
- inputs=[api_key_input],
621
- outputs=[api_key_state, api_key_status],
622
- js="(key) => { saveApiKeyToStorage(key); return key; }"
623
  )
624
 
625
- # Generate chart when button is clicked (now with API key)
626
  generate_chart_btn.click(
627
  fn=generate_chart_from_csv,
628
- inputs=[csv_upload, chart_prompt_input, api_key_state],
629
  outputs=[chart_output]
630
  )
631
 
632
- # Search inspiration with loading state
633
- def search_with_loading(prompt):
634
- """Wrapper to show loading state"""
635
- if not prompt or not prompt.strip():
636
- return """
637
- <div style='padding: 50px; text-align: center;'>
638
- Please enter a search query.
639
- </div>
640
- """
641
- # Show loading immediately (Gradio will display this first)
642
- yield """
643
- <div style='padding: 50px; text-align: center;'>
644
- <div style='font-size: 2em; margin-bottom: 20px;'>🔍</div>
645
- <h3>Searching database...</h3>
646
- <p style='color: #666;'>Analyzing your query and generating SQL...</p>
647
- </div>
648
- """
649
- # Run the actual search
650
- import asyncio
651
- result = asyncio.run(search_inspiration_from_database(prompt))
652
- yield result
653
-
654
  inspiration_search_btn.click(
655
- fn=search_with_loading,
656
  inputs=[inspiration_prompt_input],
657
  outputs=[inspiration_cards_html]
658
  )
@@ -661,6 +460,12 @@ with gr.Blocks(
661
  gr.Markdown("""
662
  ### About Viz LLM
663
 
 
 
 
 
 
 
664
  **Credits:** Special thanks to the researchers whose work informed this model: Robert Kosara, Edward Segel, Jeffrey Heer, Matthew Conlen, John Maeda, Kennedy Elliott, Scott McCloud, and many others.
665
 
666
  ---
@@ -668,19 +473,21 @@ with gr.Blocks(
668
  **Usage Limits:** This service is limited to 20 queries per day per user to manage costs. Responses are optimized for English.
669
 
670
  <div style="text-align: center; margin-top: 20px; opacity: 0.6; font-size: 0.9em;">
671
- Embeddings: Jina-CLIP-v2 | Charts: Datawrapper API | Database: Nuanced
672
  </div>
673
  """)
674
 
675
  # Launch configuration
676
  if __name__ == "__main__":
677
- # Check for required environment variables (Datawrapper key now user-provided)
678
- required_vars = ["SUPABASE_URL", "SUPABASE_KEY", "HF_TOKEN"]
679
  missing_vars = [var for var in required_vars if not os.getenv(var)]
680
 
681
  if missing_vars:
682
  print(f"⚠️ Warning: Missing environment variables: {', '.join(missing_vars)}")
683
  print("Please set these in your .env file or as environment variables")
 
 
684
 
685
  # Launch the app
686
  demo.launch(
 
10
  import os
11
  import io
12
  import asyncio
 
13
  import pandas as pd
14
  import gradio as gr
15
  from dotenv import load_dotenv
 
18
  from datetime import datetime, timedelta
19
  from collections import defaultdict
20
  from src.vanna import VannaComponent
 
21
 
22
  # Load environment variables
23
  load_dotenv()
 
54
  print(f"✗ Error initializing Vanna: {e}")
55
  raise
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def check_rate_limit(request: gr.Request) -> tuple[bool, int]:
58
  """Check if user has exceeded rate limit"""
59
  if request is None:
 
110
  yield f"Error generating response: {str(e)}\n\nPlease check your environment variables (HF_TOKEN, SUPABASE_URL, SUPABASE_KEY) and try again."
111
 
112
 
113
+ def generate_chart_from_csv(csv_file, user_prompt):
114
  """
115
+ Generate a Datawrapper chart from uploaded CSV and user prompt.
116
 
117
  Args:
118
  csv_file: Uploaded CSV file
119
  user_prompt: User's description of the chart
 
120
 
121
  Returns:
122
  HTML string with iframe or error message
123
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if not csv_file:
125
  return "<div style='padding: 50px; text-align: center;'>Please upload a CSV file to generate a chart.</div>"
126
 
127
  if not user_prompt or user_prompt.strip() == "":
128
  return "<div style='padding: 50px; text-align: center;'>Please describe what chart you want to create.</div>"
129
 
 
 
 
 
130
  try:
131
  # Show loading message
132
  loading_html = """
 
192
  <div style='padding: 50px; text-align: center; color: red;'>
193
  <h3>❌ Error</h3>
194
  <p>{str(e)}</p>
195
+ <p style='font-size: 0.9em; color: #666;'>Please ensure your CSV is properly formatted and try again.</p>
196
  </div>
197
  """
 
 
 
 
 
 
198
 
199
  def csv_to_cards_html(csv_text: str) -> str:
200
  """
 
211
  source_url = row.get("source_url", "#")
212
  author = row.get("author", "Inconnu")
213
  published_date = row.get("published_date", "")
214
+ if not published_date == "nan":
215
+ published_date = ""
216
+ image_url = row.get("image_url", "")
217
+ if not image_url == "nan":
218
+ image_url = "https://fpoimg.com/800x600?text=Image+not+found"
219
 
220
  cards_html += f"""
221
  <div style="background: white; border-radius: 10px; box-shadow: 0 2px 8px rgba(0,0,0,0.1);
 
227
  <p style="margin:0; color:#999; font-size:0.8em;">{published_date}</p>
228
  <a href="{source_url}" target="_blank"
229
  style="display:inline-block; margin-top:8px; font-size:0.9em; color:#1976d2; text-decoration:none;">
230
+ 🔗 Voir la source
231
  </a>
232
  </div>
233
  </div>
 
262
  """
263
 
264
  try:
265
+ response = await vanna.ask(user_prompt)
266
+ print("response :", repr(response))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
  clean_response = response.strip()
269
 
270
+ if clean_response.startswith("⚠️") or "Aucun CSV détecté" in clean_response:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  return f"""
272
  <div style='padding: 50px; text-align: center; color: #d9534f;'>
273
+ <h3>❌ No valid data found</h3>
274
+ <p>The AI couldn't generate any data for this request. Try being more specific — for example:
275
+ <em>"Show me spotlights from 2020 about design"</em>.</p>
 
276
  </div>
277
  """
278
 
 
279
  csv_text = (
280
  clean_response
281
  .strip("```")
 
283
  .replace("CSV", "")
284
  )
285
 
286
+ if "," not in csv_text:
 
287
  return f"""
288
  <div style='padding: 50px; text-align: center; color: #d9534f;'>
289
+ <h3>❌ No valid CSV detected</h3>
290
+ <p>The model didn't return any structured data. Try rephrasing your query to be more precise.</p>
 
 
 
291
  </div>
292
  """
293
 
 
295
  return cards_html
296
 
297
  except Exception as e:
 
 
 
298
  return f"""
299
  <div style='padding: 50px; text-align: center; color: red;'>
300
+ <h3>❌ Error</h3>
301
+ <p>{str(e)}</p>
302
+ <p style='font-size: 0.9em; color: #666;'>Please try again.</p>
 
 
 
303
  </div>
304
  """
305
 
 
332
  gr.Markdown("""
333
  # 📊 Viz LLM
334
 
335
+ Get design recommendations or generate charts with AI-powered data visualization assistance.
336
  """)
337
 
338
+ # Mode selector buttons
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
  with gr.Row():
340
+ ideation_btn = gr.Button("💡 Ideation Mode", variant="primary", elem_classes="mode-button")
341
+ chart_gen_btn = gr.Button("📊 Chart Generation Mode", variant="secondary", elem_classes="mode-button")
342
+ inspiration_btn = gr.Button(" Inspiration Mode", variant="secondary", elem_classes="mode-button")
343
 
344
 
345
+ # Ideation Mode: Chat interface (shown by default, wrapped in Column)
346
+ with gr.Column(visible=True) as ideation_container:
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  ideation_interface = gr.ChatInterface(
348
  fn=recommend_stream,
349
  type="messages",
 
360
 
361
  # Chart Generation Mode: Chart controls and output (hidden by default)
362
  with gr.Column(visible=False) as chart_gen_container:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  csv_upload = gr.File(
364
  label="📁 Upload CSV File",
365
  file_types=[".csv"],
 
379
  label="Generated Chart"
380
  )
381
 
382
+ # Inspiration Mode:
383
+ with gr.Column(visible=False) as inspiration_container:
384
+ with gr.Row():
385
+ inspiration_prompt_input = gr.Textbox(
386
+ placeholder="Ask for an inspiration...",
387
+ show_label=False,
388
+ scale=4,
389
+ container=False
390
+ )
391
+ inspiration_search_btn = gr.Button("🔍 Search", variant="primary", scale=1)
 
 
 
 
 
392
 
393
+ inspiration_cards_html = gr.HTML("")
 
 
 
 
 
 
 
 
 
394
 
395
+ # Mode switching functions
396
  def switch_to_ideation():
397
  return [
 
398
  gr.update(variant="primary"), # ideation_btn
399
  gr.update(variant="secondary"), # chart_gen_btn
400
+ gr.update(variant="secondary"), # inspiration_btn
401
  gr.update(visible=True), # ideation_container
402
  gr.update(visible=False), # chart_gen_container
403
+ gr.update(visible=False), # inspiration_container
404
  ]
405
 
406
  def switch_to_chart_gen():
407
  return [
 
408
  gr.update(variant="secondary"), # ideation_btn
409
  gr.update(variant="primary"), # chart_gen_btn
410
+ gr.update(variant="secondary"), # inspiration_btn
411
  gr.update(visible=False), # ideation_container
412
  gr.update(visible=True), # chart_gen_container
413
+ gr.update(visible=False), # inspiration_container
414
  ]
415
 
416
+ def switch_to_inspiration():
417
+ return [
418
+ gr.update(variant="secondary"), # ideation_btn
419
+ gr.update(variant="secondary"), # chart_gen_btn
420
+ gr.update(variant="primary"), # inspiration_btn
421
+ gr.update(visible=False), # ideation_container
422
+ gr.update(visible=False), # chart_gen_container
423
+ gr.update(visible=True), # inspiration_container
424
+ ]
425
 
426
+ # Wire up mode switching
427
  ideation_btn.click(
428
  fn=switch_to_ideation,
429
  inputs=[],
430
+ outputs=[ideation_btn, chart_gen_btn, inspiration_btn, ideation_container, chart_gen_container, inspiration_container]
431
  )
432
 
433
  chart_gen_btn.click(
434
  fn=switch_to_chart_gen,
435
  inputs=[],
436
+ outputs=[ideation_btn, chart_gen_btn, inspiration_btn, ideation_container, chart_gen_container, inspiration_container]
437
  )
438
 
439
+ inspiration_btn.click(
440
+ fn=switch_to_inspiration,
441
+ inputs=[],
442
+ outputs=[ideation_btn, chart_gen_btn, inspiration_btn, ideation_container, chart_gen_container, inspiration_container]
 
 
443
  )
444
 
445
+ # Generate chart when button is clicked
446
  generate_chart_btn.click(
447
  fn=generate_chart_from_csv,
448
+ inputs=[csv_upload, chart_prompt_input],
449
  outputs=[chart_output]
450
  )
451
 
452
+ # Search inspiration when button is clicked
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  inspiration_search_btn.click(
454
+ fn=search_inspiration_from_database,
455
  inputs=[inspiration_prompt_input],
456
  outputs=[inspiration_cards_html]
457
  )
 
460
  gr.Markdown("""
461
  ### About Viz LLM
462
 
463
+ **Ideation Mode:** Get design recommendations based on research papers, design principles, and examples from the field of information graphics and data visualization.
464
+
465
+ **Chart Generation Mode:** Upload your CSV data and describe your visualization goal. The AI will analyze your data, select the optimal chart type, and generate a publication-ready chart using Datawrapper.
466
+
467
+ **Inspiration Mode:** Coming soon.
468
+
469
  **Credits:** Special thanks to the researchers whose work informed this model: Robert Kosara, Edward Segel, Jeffrey Heer, Matthew Conlen, John Maeda, Kennedy Elliott, Scott McCloud, and many others.
470
 
471
  ---
 
473
  **Usage Limits:** This service is limited to 20 queries per day per user to manage costs. Responses are optimized for English.
474
 
475
  <div style="text-align: center; margin-top: 20px; opacity: 0.6; font-size: 0.9em;">
476
+ Embeddings: Jina-CLIP-v2 | Charts: Datawrapper API
477
  </div>
478
  """)
479
 
480
  # Launch configuration
481
  if __name__ == "__main__":
482
+ # Check for required environment variables
483
+ required_vars = ["SUPABASE_URL", "SUPABASE_KEY", "HF_TOKEN", "DATAWRAPPER_ACCESS_TOKEN"]
484
  missing_vars = [var for var in required_vars if not os.getenv(var)]
485
 
486
  if missing_vars:
487
  print(f"⚠️ Warning: Missing environment variables: {', '.join(missing_vars)}")
488
  print("Please set these in your .env file or as environment variables")
489
+ if "DATAWRAPPER_ACCESS_TOKEN" in missing_vars:
490
+ print("Note: DATAWRAPPER_ACCESS_TOKEN is required for chart generation mode")
491
 
492
  # Launch the app
493
  demo.launch(
src/query_intent_classifier.py DELETED
@@ -1,240 +0,0 @@
1
- """
2
- Query Intent Classifier for Hybrid Search
3
-
4
- Analyzes user queries to determine the best search strategy:
5
- - keyword: Full-text search on title/author/provider (works for all posts)
6
- - tag: Tag-based search (works only for tagged posts)
7
- - hybrid: Try both approaches
8
- """
9
-
10
- import re
11
- from typing import Dict, List
12
- from enum import Enum
13
-
14
-
15
- class QueryIntent(Enum):
16
- KEYWORD = "keyword"
17
- TAG = "tag"
18
- HYBRID = "hybrid"
19
-
20
-
21
- class IntentClassifier:
22
- """
23
- Classifies user queries and extracts relevant search parameters.
24
- """
25
-
26
- # Keywords that suggest tag search
27
- TAG_INDICATORS = ["tagged", "category", "topic", "theme", "type", "about"]
28
-
29
- # Common keywords to expand for better matching
30
- KEYWORD_EXPANSIONS = {
31
- "f1": ["f1", "formula 1", "formula one", "racing"],
32
- "dataviz": ["dataviz", "data visualization", "visualization", "chart", "graph"],
33
- "interactive": ["interactive", "interaction", "explore"],
34
- "map": ["map", "maps", "mapping", "geographic", "geo"],
35
- "nyt": ["new york times", "nyt", "ny times"],
36
- }
37
-
38
- def __init__(self):
39
- pass
40
-
41
- def classify(self, user_prompt: str) -> Dict:
42
- """
43
- Classify user intent and extract search parameters.
44
-
45
- Args:
46
- user_prompt: The user's search query
47
-
48
- Returns:
49
- Dict with:
50
- - intent: QueryIntent enum
51
- - keywords: List of keywords to search
52
- - tags: List of potential tags to search
53
- - original_query: Original user prompt
54
- """
55
- prompt_lower = user_prompt.lower().strip()
56
-
57
- # Detect intent
58
- intent = self._detect_intent(prompt_lower)
59
-
60
- # Extract keywords
61
- keywords = self._extract_keywords(prompt_lower)
62
-
63
- # Infer potential tags
64
- tags = self._infer_tags(prompt_lower, keywords)
65
-
66
- return {
67
- "intent": intent,
68
- "keywords": keywords,
69
- "tags": tags,
70
- "original_query": user_prompt,
71
- "is_short_query": len(prompt_lower.split()) <= 3
72
- }
73
-
74
- def _detect_intent(self, prompt: str) -> QueryIntent:
75
- """
76
- Determine if user wants tag search, keyword search, or hybrid.
77
- """
78
- # Check for tag indicators
79
- has_tag_indicator = any(indicator in prompt for indicator in self.TAG_INDICATORS)
80
-
81
- # Short queries (1-3 words) should try hybrid approach
82
- word_count = len(prompt.split())
83
-
84
- if has_tag_indicator:
85
- return QueryIntent.TAG
86
- elif word_count <= 3:
87
- # Short queries: try both tag and keyword search
88
- return QueryIntent.HYBRID
89
- else:
90
- # Longer natural language queries: keyword search first
91
- return QueryIntent.KEYWORD
92
-
93
- def _extract_keywords(self, prompt: str) -> List[str]:
94
- """
95
- Extract meaningful keywords from the prompt.
96
- """
97
- # Remove common stop words
98
- stop_words = {
99
- "show", "me", "find", "get", "search", "for", "the", "a", "an",
100
- "with", "about", "of", "in", "on", "at", "to", "from", "by",
101
- "what", "where", "when", "who", "how", "is", "are", "was", "were"
102
- }
103
-
104
- # Split and clean
105
- words = re.findall(r'\b\w+\b', prompt.lower())
106
- # Allow 2-character words like "F1", "AI", "3D"
107
- keywords = [w for w in words if w not in stop_words and len(w) >= 2]
108
-
109
- # Expand known keywords
110
- expanded_keywords = []
111
- for keyword in keywords:
112
- if keyword in self.KEYWORD_EXPANSIONS:
113
- expanded_keywords.extend(self.KEYWORD_EXPANSIONS[keyword])
114
- else:
115
- expanded_keywords.append(keyword)
116
-
117
- # Remove duplicates while preserving order
118
- return list(dict.fromkeys(expanded_keywords))
119
-
120
- def _infer_tags(self, prompt: str, keywords: List[str]) -> List[str]:
121
- """
122
- Infer potential tag names from keywords.
123
-
124
- Since we have limited tags in the database, we map common terms
125
- to likely tag names.
126
- """
127
- # Common tag mappings based on the database
128
- tag_mappings = {
129
- "f1": ["f1", "racing", "motorsport", "sports"],
130
- "formula": ["f1", "racing", "motorsport"],
131
- "racing": ["racing", "motorsport", "f1"],
132
- "dataviz": ["dataviz", "visualization"],
133
- "visualization": ["dataviz", "visualization"],
134
- "interactive": ["interactive"],
135
- "map": ["maps", "geographic"],
136
- "maps": ["maps", "geographic"],
137
- "math": ["mathematics", "statistics"],
138
- "statistics": ["statistics", "mathematics"],
139
- "africa": ["africa", "kenya", "tanzania"],
140
- "sustainability": ["sustainability", "regreening"],
141
- "documentary": ["documentary", "cinematic"],
142
- "education": ["students", "researchers"],
143
- }
144
-
145
- inferred_tags = []
146
- for keyword in keywords:
147
- if keyword in tag_mappings:
148
- inferred_tags.extend(tag_mappings[keyword])
149
-
150
- # If no specific mapping, use the keyword as-is
151
- if not inferred_tags:
152
- inferred_tags = keywords[:3] # Limit to top 3 keywords
153
-
154
- # Remove duplicates
155
- return list(dict.fromkeys(inferred_tags))
156
-
157
- def format_for_vanna(self, classification: Dict) -> str:
158
- """
159
- Format the classification result for Vanna's prompt.
160
-
161
- Returns a string that guides Vanna's SQL generation.
162
- """
163
- intent = classification["intent"]
164
- keywords = classification["keywords"]
165
- tags = classification["tags"]
166
-
167
- if intent == QueryIntent.KEYWORD:
168
- keyword_example = keywords[0] if keywords else "keyword"
169
- return f"""
170
- Search using KEYWORD approach:
171
- - Search terms: {', '.join(keywords)}
172
- - Search in: posts.title, posts.author, providers.name
173
- - Use LOWER(column) LIKE '%keyword%' for flexible matching
174
- - Example: LOWER(p.title) LIKE '%{keyword_example}%'
175
- - This matches word variants: '{keyword_example}', '{keyword_example}n', '{keyword_example}\\'s', etc.
176
- """
177
-
178
- elif intent == QueryIntent.TAG:
179
- return f"""
180
- Search using TAG approach:
181
- - Tag names: {', '.join(tags)}
182
- - 88% of posts (3,362) have tags - tag search is highly effective!
183
- - Use LOWER(t.name) LIKE '%tagname%' for flexible matching
184
- - Join with post_tags and tags tables
185
- """
186
-
187
- else: # HYBRID
188
- return f"""
189
- Search using HYBRID approach (RECOMMENDED):
190
- - Tags to search: {', '.join(tags)}
191
- - Keywords to search: {', '.join(keywords)}
192
- - Use OR logic: tag matches OR keyword matches in title/author
193
- - 88% of posts have tags, so tag search is primary
194
-
195
- Recommended SQL pattern:
196
- SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type
197
- FROM posts p
198
- LEFT JOIN post_tags pt ON p.id = pt.post_id
199
- LEFT JOIN tags t ON pt.tag_id = t.id
200
- LEFT JOIN providers pr ON p.provider_id = pr.id
201
- WHERE (
202
- {' OR '.join(f"LOWER(t.name) LIKE '%{tag}%'" for tag in tags)}
203
- OR {' OR '.join(f"LOWER(p.title) LIKE '%{kw}%'" for kw in keywords)}
204
- OR {' OR '.join(f"LOWER(p.author) LIKE '%{kw}%'" for kw in keywords)}
205
- )
206
- ORDER BY p.published_date DESC NULLS LAST
207
- LIMIT 50
208
- """
209
-
210
-
211
- # Convenience function
212
- def classify_query(user_prompt: str) -> Dict:
213
- """
214
- Classify a user query and return search parameters.
215
- """
216
- classifier = IntentClassifier()
217
- return classifier.classify(user_prompt)
218
-
219
-
220
- # Example usage
221
- if __name__ == "__main__":
222
- # Test cases
223
- test_queries = [
224
- "F1",
225
- "Show me F1 content",
226
- "interactive visualizations",
227
- "New York Times articles",
228
- "content tagged with dataviz",
229
- "recent sustainability projects in Africa",
230
- ]
231
-
232
- classifier = IntentClassifier()
233
-
234
- for query in test_queries:
235
- result = classifier.classify(query)
236
- print(f"\nQuery: '{query}'")
237
- print(f"Intent: {result['intent'].value}")
238
- print(f"Keywords: {result['keywords']}")
239
- print(f"Tags: {result['tags']}")
240
- print(f"Short query: {result['is_short_query']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/vanna.py CHANGED
@@ -55,6 +55,9 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
55
  "- Never use SELECT *\n"
56
  "- Prefer window functions over subqueries when possible\n"
57
  "- Always include a LIMIT for exploratory queries\n"
 
 
 
58
  "- Format dates and numbers for readability\n"
59
  )
60
 
@@ -68,7 +71,7 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
68
  prompt += (
69
  "\n## Database Schema\n"
70
  "Tables:\n"
71
- "- posts (id, title, source_url, author, published_date, image_url, type, provider_id, created_at, updated_at, content_markdown, fts)\n"
72
  "- providers (id, name)\n"
73
  "- provider_attributes (id, provider_id, type, name)\n"
74
  "- post_provider_attributes (post_id, attribute_id)\n"
@@ -96,7 +99,6 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
96
  "- `providers.name`: name of the publishing organization (e.g., 'Nuanced', 'SND').\n"
97
  "- `tags.name`: thematic keyword or topic (e.g., '3D', 'AI', 'Design').\n"
98
  "- `post_tags.weight`: relevance score between a post and a tag.\n"
99
- "- `posts.fts`: tsvector column for full-text search (auto-generated from title and author).\n"
100
  )
101
 
102
  # ======================
@@ -104,38 +106,15 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
104
  # ======================
105
  prompt += (
106
  "\n## Business Logic\n"
 
107
  "- A query mentioning an organization (e.g., 'New York Times') should search both `posts.author` and `providers.name`.\n"
108
- "- Return all post types (spotlight, resource, insight) unless the user specifies otherwise.\n"
 
109
  "- Tags link posts to specific themes or disciplines.\n"
110
  "- A single post may have multiple tags, awards, or categories.\n"
111
  "- If the user mentions a year (e.g., 'in 2021'), filter with `EXTRACT(YEAR FROM published_date) = 2021`.\n"
112
  "- If the user says 'recently', filter posts from the last 90 days.\n"
113
- "- Default limit is 50 rows for search results. Use OFFSET for pagination if needed.\n"
114
- "\n"
115
- "## Search Strategy\n"
116
- "**TAG COVERAGE**: 3,362 posts (88%) have tags. Tag-based search is highly effective!\n"
117
- "- 9,105 tags available including countries (russia, china, usa), topics (climate change, politics), and formats (interactive, dataviz)\n"
118
- "- Use tag matching as PRIMARY search for topic-based queries\n"
119
- "\n"
120
- "**Hybrid Search Approach (RECOMMENDED)**:\n"
121
- "- Combine tag search AND keyword search with OR logic for maximum coverage\n"
122
- "- Use LEFT JOINs for tags to also include the 12% of untagged posts\n"
123
- "\n"
124
- "**Keyword Matching - Use ILIKE for Flexible Matching**:\n"
125
- "- Use LOWER(column) LIKE '%keyword%' for case-insensitive substring matching\n"
126
- "- Example: LOWER(p.title) LIKE '%russia%' matches 'Russia', 'Russian', 'Russia\\'s', etc.\n"
127
- "- This ensures word variants are captured (much better than exact word boundary matching)\n"
128
- "- For multi-word searches: LOWER(p.title) LIKE '%new york%'\n"
129
- "\n"
130
- "**Full-Text Search (for relevance ranking)**:\n"
131
- "- The posts table has an 'fts' column (tsvector) for full-text search\n"
132
- "- Use: p.fts @@ plainto_tsquery('english', 'search terms')\n"
133
- "- For relevance-ranked results: ORDER BY ts_rank(p.fts, plainto_tsquery('english', 'search terms')) DESC\n"
134
- "- FTS handles stemming automatically: 'visualization' matches 'visualizations'\n"
135
- "- Combine FTS with ILIKE fallback: WHERE p.fts @@ query OR LOWER(p.title) LIKE '%keyword%'\n"
136
- "\n"
137
- "**When to use tag-only search**: Only if user explicitly mentions 'tagged with' or 'tag:'.\n"
138
- "**When to use keyword-only search**: For author/organization names.\n"
139
  )
140
 
141
  # ======================
@@ -166,38 +145,21 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
166
  # ======================
167
  prompt += (
168
  "\n## Example Interactions\n"
169
- "User: 'F1' or 'Show me F1 content'\n"
170
- "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
171
  "FROM posts p "
172
- "LEFT JOIN post_tags pt ON p.id = pt.post_id "
173
- "LEFT JOIN tags t ON pt.tag_id = t.id "
174
- "LEFT JOIN providers pr ON p.provider_id = pr.id "
175
- "WHERE (LOWER(t.name) LIKE '%f1%' OR LOWER(t.name) LIKE '%formula%' "
176
- "OR LOWER(p.title) LIKE '%f1%' OR LOWER(p.title) LIKE '%formula%' "
177
- "OR LOWER(p.author) LIKE '%f1%') "
178
- "ORDER BY p.published_date DESC NULLS LAST LIMIT 50;\"]\n"
179
  "\nUser: 'Show me posts from The New York Times'\n"
180
- "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
181
  "FROM posts p "
182
- "LEFT JOIN providers pr ON p.provider_id = pr.id "
183
- "WHERE (LOWER(p.author) LIKE '%new york times%' OR LOWER(pr.name) LIKE '%new york times%') "
184
- "ORDER BY p.published_date DESC NULLS LAST LIMIT 50;\"]\n"
185
- "\nUser: 'Russia' or 'Show me Russia content'\n"
186
- "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
187
- "FROM posts p "
188
- "LEFT JOIN post_tags pt ON p.id = pt.post_id "
189
- "LEFT JOIN tags t ON pt.tag_id = t.id "
190
- "WHERE (LOWER(t.name) LIKE '%russia%' "
191
- "OR LOWER(p.title) LIKE '%russia%' OR LOWER(p.author) LIKE '%russia%') "
192
- "ORDER BY p.published_date DESC NULLS LAST LIMIT 50;\"]\n"
193
- "\nUser: 'interactive visualizations'\n"
194
- "Assistant: [call run_sql with \"SELECT DISTINCT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
195
- "FROM posts p "
196
- "LEFT JOIN post_tags pt ON p.id = pt.post_id "
197
- "LEFT JOIN tags t ON pt.tag_id = t.id "
198
- "WHERE (LOWER(t.name) LIKE '%interactive%' OR LOWER(p.title) LIKE '%interactive%' "
199
- "OR LOWER(p.title) LIKE '%visualization%' OR LOWER(t.name) LIKE '%dataviz%') "
200
- "ORDER BY p.published_date DESC NULLS LAST LIMIT 50;\"]\n"
201
  )
202
 
203
  # ======================
@@ -205,6 +167,8 @@ class CustomSQLSystemPromptBuilder(SystemPromptBuilder):
205
  # ======================
206
  prompt += (
207
  "\nIMPORTANT:\n"
 
 
208
  "- Always return **only the raw CSV result** — no explanations, no JSON, no commentary.\n"
209
  "- Stop tool execution once the query result is obtained.\n"
210
  )
@@ -233,8 +197,8 @@ class VannaComponent:
233
  db_tool = RunSqlTool(sql_runner=self.sql_runner)
234
 
235
  agent_memory = DemoAgentMemory(max_items=1000)
236
- save_memory_tool = SaveQuestionToolArgsTool()
237
- search_memory_tool = SearchSavedCorrectToolUsesTool()
238
 
239
  self.user_resolver = SimpleUserResolver()
240
 
@@ -247,46 +211,32 @@ class VannaComponent:
247
  llm_service=llm,
248
  tool_registry=tools,
249
  user_resolver=self.user_resolver,
250
- agent_memory=agent_memory,
251
  system_prompt_builder=CustomSQLSystemPromptBuilder("CoJournalist", self.sql_runner),
252
- config=AgentConfig(stream_responses=False, max_tool_iterations=3)
253
  )
254
 
255
  async def ask(self, prompt_for_llm: str):
256
  ctx = RequestContext()
257
- print(f"\n{'='*80}")
258
- print(f"🙋 User Query: {prompt_for_llm}")
259
- print(f"{'='*80}\n")
260
 
261
  final_text = ""
262
  seen_texts = set()
263
- query_executed = False
264
- result_row_count = 0
265
 
266
  async for component in self.agent.send_message(request_context=ctx, message=prompt_for_llm):
267
  simple = getattr(component, "simple_component", None)
268
  text = getattr(simple, "text", "") if simple else ""
269
  if text and text not in seen_texts:
270
- print(f"💬 LLM Response: {text[:300]}...")
271
  final_text += text + "\n"
272
  seen_texts.add(text)
273
 
274
  sql_query = getattr(component, "sql", None)
275
  if sql_query:
276
- query_executed = True
277
- print(f"\n🧾 SQL Query Generated:")
278
- print(f"{'-'*80}")
279
- print(f"{sql_query}")
280
- print(f"{'-'*80}\n")
281
 
282
  metadata = getattr(component, "metadata", None)
283
  if metadata:
284
- print(f"📋 Query Metadata: {metadata}")
285
- result_row_count = metadata.get("row_count", 0)
286
- if result_row_count == 0:
287
- print(f"⚠️ Query returned 0 rows - no data matched the criteria")
288
- else:
289
- print(f"✅ Query returned {result_row_count} rows")
290
 
291
  component_type = getattr(component, "type", None)
292
  if component_type:
@@ -295,36 +245,16 @@ class VannaComponent:
295
  match = re.search(r"query_results_[\w-]+\.csv", final_text)
296
  if match:
297
  filename = match.group(0)
298
- # Calculate the user-specific folder based on the default user ID
299
- import hashlib
300
- user_hash = hashlib.sha256("guest@example.com".encode()).hexdigest()[:16]
301
- folder = user_hash
302
  full_path = os.path.join(folder, filename)
303
 
304
- print(f"\n📁 Looking for CSV file: {full_path}")
305
-
306
- # Create folder if it doesn't exist
307
- if not os.path.exists(folder):
308
- print(f"📂 Creating user directory: {folder}")
309
- os.makedirs(folder, exist_ok=True)
310
-
311
  if os.path.exists(full_path):
312
- print(f" Found CSV file, reading contents...")
313
  with open(full_path, "r", encoding="utf-8") as f:
314
  csv_data = f.read().strip()
315
- print(f"📊 CSV Data Preview: {csv_data[:200]}...")
316
- print(f"{'='*80}\n")
317
  return csv_data
318
  else:
319
- print(f" CSV file not found at: {full_path}")
320
- # List files in the directory to help debug
321
- if os.path.exists(folder):
322
- files = os.listdir(folder)
323
- print(f"📂 Files in {folder}: {files}")
324
-
325
- print(f"\n{'='*80}")
326
- if not query_executed:
327
- print(f"⚠️ No SQL query was executed by the LLM")
328
- print(f"📤 Returning final response to user")
329
- print(f"{'='*80}\n")
330
  return final_text
 
55
  "- Never use SELECT *\n"
56
  "- Prefer window functions over subqueries when possible\n"
57
  "- Always include a LIMIT for exploratory queries\n"
58
+ "- Exclude posts where provider = 'SND'\n"
59
+ "- Exclude posts where type = 'resource'\n"
60
+ "- Exclude posts where type = 'insight'\n"
61
  "- Format dates and numbers for readability\n"
62
  )
63
 
 
71
  prompt += (
72
  "\n## Database Schema\n"
73
  "Tables:\n"
74
+ "- posts (id, title, source_url, author, published_date, image_url, type, provider_id, created_at, updated_at)\n"
75
  "- providers (id, name)\n"
76
  "- provider_attributes (id, provider_id, type, name)\n"
77
  "- post_provider_attributes (post_id, attribute_id)\n"
 
99
  "- `providers.name`: name of the publishing organization (e.g., 'Nuanced', 'SND').\n"
100
  "- `tags.name`: thematic keyword or topic (e.g., '3D', 'AI', 'Design').\n"
101
  "- `post_tags.weight`: relevance score between a post and a tag.\n"
 
102
  )
103
 
104
  # ======================
 
106
  # ======================
107
  prompt += (
108
  "\n## Business Logic\n"
109
+ "- Providers named 'SND' must always be excluded.\n"
110
  "- A query mentioning an organization (e.g., 'New York Times') should search both `posts.author` and `providers.name`.\n"
111
+ "- By default, only posts with `type = 'spotlight'` are returned.\n"
112
+ "- Posts of type `resource` or `insight` are excluded unless explicitly requested.\n"
113
  "- Tags link posts to specific themes or disciplines.\n"
114
  "- A single post may have multiple tags, awards, or categories.\n"
115
  "- If the user mentions a year (e.g., 'in 2021'), filter with `EXTRACT(YEAR FROM published_date) = 2021`.\n"
116
  "- If the user says 'recently', filter posts from the last 90 days.\n"
117
+ "- Always limit exploratory results to 9 rows.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  )
119
 
120
  # ======================
 
145
  # ======================
146
  prompt += (
147
  "\n## Example Interactions\n"
148
+ "User: 'Show me posts related to 3D'\n"
149
+ "Assistant: [call run_sql with \"SELECT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
150
  "FROM posts p "
151
+ "JOIN post_tags pt ON p.id = pt.post_id "
152
+ "JOIN tags t ON pt.tag_id = t.id "
153
+ "JOIN providers pr ON p.provider_id = pr.id "
154
+ "WHERE t.name ILIKE '%3D%' AND pr.name != 'SND' AND p.type = 'spotlight' "
155
+ "LIMIT 9;\"]\n"
 
 
156
  "\nUser: 'Show me posts from The New York Times'\n"
157
+ "Assistant: [call run_sql with \"SELECT p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type "
158
  "FROM posts p "
159
+ "LEFT JOIN providers pr ON pr.id = p.provider_id "
160
+ "WHERE LOWER(p.author) LIKE '%new york times%' OR LOWER(pr.name) LIKE '%new york times%' "
161
+ "AND pr.name != 'SND' AND p.type = 'spotlight' "
162
+ "LIMIT 9;\"]\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  )
164
 
165
  # ======================
 
167
  # ======================
168
  prompt += (
169
  "\nIMPORTANT:\n"
170
+ "- Always exclude posts with provider = 'SND'.\n"
171
+ "- Always exclude posts with type = 'resource' or 'insight'.\n"
172
  "- Always return **only the raw CSV result** — no explanations, no JSON, no commentary.\n"
173
  "- Stop tool execution once the query result is obtained.\n"
174
  )
 
197
  db_tool = RunSqlTool(sql_runner=self.sql_runner)
198
 
199
  agent_memory = DemoAgentMemory(max_items=1000)
200
+ save_memory_tool = SaveQuestionToolArgsTool(agent_memory)
201
+ search_memory_tool = SearchSavedCorrectToolUsesTool(agent_memory)
202
 
203
  self.user_resolver = SimpleUserResolver()
204
 
 
211
  llm_service=llm,
212
  tool_registry=tools,
213
  user_resolver=self.user_resolver,
 
214
  system_prompt_builder=CustomSQLSystemPromptBuilder("CoJournalist", self.sql_runner),
215
+ config=AgentConfig(stream_responses=False, max_tool_iterations=1)
216
  )
217
 
218
  async def ask(self, prompt_for_llm: str):
219
  ctx = RequestContext()
220
+ print(f"🙋 Prompt sent to LLM: {prompt_for_llm}")
 
 
221
 
222
  final_text = ""
223
  seen_texts = set()
 
 
224
 
225
  async for component in self.agent.send_message(request_context=ctx, message=prompt_for_llm):
226
  simple = getattr(component, "simple_component", None)
227
  text = getattr(simple, "text", "") if simple else ""
228
  if text and text not in seen_texts:
229
+ print(f"💬 LLM says (part): {text[:200]}...")
230
  final_text += text + "\n"
231
  seen_texts.add(text)
232
 
233
  sql_query = getattr(component, "sql", None)
234
  if sql_query:
235
+ print(f"🧾 SQL Query Generated: {sql_query}")
 
 
 
 
236
 
237
  metadata = getattr(component, "metadata", None)
238
  if metadata:
239
+ print(f"📋 Metadata: {metadata}")
 
 
 
 
 
240
 
241
  component_type = getattr(component, "type", None)
242
  if component_type:
 
245
  match = re.search(r"query_results_[\w-]+\.csv", final_text)
246
  if match:
247
  filename = match.group(0)
248
+ folder = "513935c4d2db2d2d"
 
 
 
249
  full_path = os.path.join(folder, filename)
250
 
 
 
 
 
 
 
 
251
  if os.path.exists(full_path):
252
+ print(f"📂 Reading result file: {full_path}")
253
  with open(full_path, "r", encoding="utf-8") as f:
254
  csv_data = f.read().strip()
255
+ print("🤖 Response sent to user (from file):", csv_data[:300])
 
256
  return csv_data
257
  else:
258
+ print(f"⚠️ File not found: {full_path}")
259
+
 
 
 
 
 
 
 
 
 
260
  return final_text
src/vanna_query_functions.py DELETED
@@ -1,300 +0,0 @@
1
- """
2
- Vanna Query Function Templates
3
-
4
- Defines SQL templates for different search strategies.
5
- These are used by Vanna to generate accurate, performant SQL queries.
6
- """
7
-
8
- from typing import Dict, List
9
-
10
-
11
- class QueryFunctions:
12
- """
13
- Collection of SQL query templates for different search strategies.
14
- """
15
-
16
- @staticmethod
17
- def keyword_search(keywords: List[str], limit: int = 9) -> str:
18
- """
19
- Full-text keyword search across title, author, and provider.
20
-
21
- Works for all posts in the database (7,248 posts).
22
-
23
- Args:
24
- keywords: List of keywords to search for
25
- limit: Maximum number of results
26
-
27
- Returns:
28
- SQL query string
29
- """
30
- # Build regex conditions for each keyword with word boundaries
31
- # Use PostgreSQL ~* operator for case-insensitive regex matching
32
- # \m and \M are word boundary markers (start/end of word)
33
- keyword_conditions = []
34
- for keyword in keywords:
35
- keyword_lower = keyword.lower()
36
- # Escape special regex characters
37
- keyword_escaped = keyword_lower.replace('\\', '\\\\').replace('.', '\\.').replace('+', '\\+')
38
- keyword_conditions.append(f"""
39
- (p.title ~* '\\m{keyword_escaped}\\M'
40
- OR p.author ~* '\\m{keyword_escaped}\\M'
41
- OR pr.name ~* '\\m{keyword_escaped}\\M')
42
- """)
43
-
44
- where_clause = " OR ".join(keyword_conditions)
45
-
46
- return f"""
47
- SELECT DISTINCT
48
- p.id,
49
- p.title,
50
- p.source_url,
51
- p.author,
52
- p.published_date,
53
- p.image_url,
54
- p.type,
55
- pr.name as provider_name
56
- FROM posts p
57
- LEFT JOIN providers pr ON p.provider_id = pr.id
58
- WHERE {where_clause}
59
- ORDER BY p.published_date DESC NULLS LAST
60
- LIMIT {limit};
61
- """
62
-
63
- @staticmethod
64
- def tag_search(tags: List[str], limit: int = 9) -> str:
65
- """
66
- Tag-based search.
67
-
68
- Currently works for only 3 posts with tags.
69
- As more posts are tagged, this will return more results.
70
-
71
- Args:
72
- tags: List of tag names to search for
73
- limit: Maximum number of results
74
-
75
- Returns:
76
- SQL query string
77
- """
78
- # Format tag array for SQL
79
- tags_lower = [f"'{tag.lower()}'" for tag in tags]
80
- tags_array = f"ARRAY[{', '.join(tags_lower)}]"
81
-
82
- return f"""
83
- SELECT DISTINCT
84
- p.id,
85
- p.title,
86
- p.source_url,
87
- p.author,
88
- p.published_date,
89
- p.image_url,
90
- p.type,
91
- pr.name as provider_name,
92
- string_agg(DISTINCT t.name, ', ') as tags
93
- FROM posts p
94
- JOIN post_tags pt ON p.id = pt.post_id
95
- JOIN tags t ON pt.tag_id = t.id
96
- LEFT JOIN providers pr ON p.provider_id = pr.id
97
- WHERE LOWER(t.name) = ANY({tags_array})
98
- GROUP BY p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type, pr.name
99
- ORDER BY p.published_date DESC NULLS LAST
100
- LIMIT {limit};
101
- """
102
-
103
- @staticmethod
104
- def hybrid_search(keywords: List[str], tags: List[str], limit: int = 9) -> str:
105
- """
106
- Hybrid search combining tags AND keywords.
107
-
108
- Best of both worlds:
109
- - Finds tagged posts (currently 3)
110
- - Falls back to keyword search for untagged posts (7,245)
111
-
112
- Args:
113
- keywords: List of keywords to search for
114
- tags: List of tag names to search for
115
- limit: Maximum number of results
116
-
117
- Returns:
118
- SQL query string
119
- """
120
- # Build tag conditions
121
- tags_lower = [f"'{tag.lower()}'" for tag in tags]
122
- tags_array = f"ARRAY[{', '.join(tags_lower)}]"
123
-
124
- # Build regex keyword conditions with word boundaries
125
- keyword_conditions = []
126
- for keyword in keywords:
127
- keyword_lower = keyword.lower()
128
- # Escape special regex characters
129
- keyword_escaped = keyword_lower.replace('\\', '\\\\').replace('.', '\\.').replace('+', '\\+')
130
- keyword_conditions.append(f"""
131
- (p.title ~* '\\m{keyword_escaped}\\M'
132
- OR p.author ~* '\\m{keyword_escaped}\\M'
133
- OR pr.name ~* '\\m{keyword_escaped}\\M')
134
- """)
135
-
136
- keyword_where = " OR ".join(keyword_conditions)
137
-
138
- return f"""
139
- SELECT DISTINCT
140
- p.id,
141
- p.title,
142
- p.source_url,
143
- p.author,
144
- p.published_date,
145
- p.image_url,
146
- p.type,
147
- pr.name as provider_name,
148
- string_agg(DISTINCT t.name, ', ') as tags
149
- FROM posts p
150
- LEFT JOIN post_tags pt ON p.id = pt.post_id
151
- LEFT JOIN tags t ON pt.tag_id = t.id
152
- LEFT JOIN providers pr ON p.provider_id = pr.id
153
- WHERE
154
- LOWER(t.name) = ANY({tags_array})
155
- OR ({keyword_where})
156
- GROUP BY p.id, p.title, p.source_url, p.author, p.published_date, p.image_url, p.type, pr.name
157
- ORDER BY p.published_date DESC NULLS LAST
158
- LIMIT {limit};
159
- """
160
-
161
- @staticmethod
162
- def search_by_author(author: str, limit: int = 9) -> str:
163
- """
164
- Search posts by specific author or organization.
165
-
166
- Args:
167
- author: Author name to search for
168
- limit: Maximum number of results
169
-
170
- Returns:
171
- SQL query string
172
- """
173
- # Escape special regex characters
174
- author_escaped = author.lower().replace('\\', '\\\\').replace('.', '\\.').replace('+', '\\+')
175
-
176
- return f"""
177
- SELECT DISTINCT
178
- p.id,
179
- p.title,
180
- p.source_url,
181
- p.author,
182
- p.published_date,
183
- p.image_url,
184
- p.type,
185
- pr.name as provider_name
186
- FROM posts p
187
- LEFT JOIN providers pr ON p.provider_id = pr.id
188
- WHERE
189
- p.author ~* '\\m{author_escaped}\\M'
190
- OR pr.name ~* '\\m{author_escaped}\\M'
191
- ORDER BY p.published_date DESC NULLS LAST
192
- LIMIT {limit};
193
- """
194
-
195
- @staticmethod
196
- def search_recent(days: int = 90, limit: int = 9) -> str:
197
- """
198
- Search for recent posts within the last N days.
199
-
200
- Args:
201
- days: Number of days to look back
202
- limit: Maximum number of results
203
-
204
- Returns:
205
- SQL query string
206
- """
207
- return f"""
208
- SELECT DISTINCT
209
- p.id,
210
- p.title,
211
- p.source_url,
212
- p.author,
213
- p.published_date,
214
- p.image_url,
215
- p.type,
216
- pr.name as provider_name
217
- FROM posts p
218
- LEFT JOIN providers pr ON p.provider_id = pr.id
219
- WHERE
220
- p.published_date >= CURRENT_DATE - INTERVAL '{days} days'
221
- ORDER BY p.published_date DESC
222
- LIMIT {limit};
223
- """
224
-
225
- @staticmethod
226
- def search_by_type(post_type: str, limit: int = 9) -> str:
227
- """
228
- Search by post type (spotlight, insight, resource).
229
-
230
- Args:
231
- post_type: Type of post (spotlight, insight, resource)
232
- limit: Maximum number of results
233
-
234
- Returns:
235
- SQL query string
236
- """
237
- return f"""
238
- SELECT DISTINCT
239
- p.id,
240
- p.title,
241
- p.source_url,
242
- p.author,
243
- p.published_date,
244
- p.image_url,
245
- p.type,
246
- pr.name as provider_name
247
- FROM posts p
248
- LEFT JOIN providers pr ON p.provider_id = pr.id
249
- WHERE p.type = '{post_type}'
250
- ORDER BY p.published_date DESC NULLS LAST
251
- LIMIT {limit};
252
- """
253
-
254
-
255
- def generate_query(search_type: str, **kwargs) -> str:
256
- """
257
- Generate SQL query based on search type.
258
-
259
- Args:
260
- search_type: Type of search (keyword, tag, hybrid, author, recent, type)
261
- **kwargs: Parameters for the specific search type
262
-
263
- Returns:
264
- SQL query string
265
- """
266
- functions = {
267
- "keyword": QueryFunctions.keyword_search,
268
- "tag": QueryFunctions.tag_search,
269
- "hybrid": QueryFunctions.hybrid_search,
270
- "author": QueryFunctions.search_by_author,
271
- "recent": QueryFunctions.search_recent,
272
- "type": QueryFunctions.search_by_type,
273
- }
274
-
275
- if search_type not in functions:
276
- raise ValueError(f"Unknown search type: {search_type}")
277
-
278
- return functions[search_type](**kwargs)
279
-
280
-
281
- # Example usage
282
- if __name__ == "__main__":
283
- # Test keyword search
284
- print("=== KEYWORD SEARCH ===")
285
- print(QueryFunctions.keyword_search(["F1", "racing"]))
286
-
287
- print("\n=== TAG SEARCH ===")
288
- print(QueryFunctions.tag_search(["dataviz", "interactive"]))
289
-
290
- print("\n=== HYBRID SEARCH ===")
291
- print(QueryFunctions.hybrid_search(
292
- keywords=["visualization"],
293
- tags=["dataviz", "interactive"]
294
- ))
295
-
296
- print("\n=== AUTHOR SEARCH ===")
297
- print(QueryFunctions.search_by_author("New York Times"))
298
-
299
- print("\n=== RECENT POSTS ===")
300
- print(QueryFunctions.search_recent(days=30))