Spaces:
Running
Running
| """ | |
| Viz LLM - Gradio App | |
| A RAG-powered assistant for data visualization guidance, powered by Jina-CLIP-v2 | |
| embeddings and research from the field of information graphics. | |
| Now with Datawrapper integration for chart generation! | |
| """ | |
| import os | |
| import io | |
| import asyncio | |
| import pandas as pd | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| from src.rag_pipeline import create_pipeline | |
| from src.datawrapper_client import create_and_publish_chart, get_iframe_html | |
| from datetime import datetime, timedelta | |
| from collections import defaultdict | |
| from src.vanna import VannaComponent | |
| # Load environment variables | |
| load_dotenv() | |
| # Rate limiting: Track requests per user (IP-based) | |
| # Format: {ip: [(timestamp1, timestamp2, ...)]} | |
| rate_limit_tracker = defaultdict(list) | |
| DAILY_LIMIT = 20 | |
| # Initialize the RAG pipeline | |
| print("Initializing Graphics Design Pipeline...") | |
| try: | |
| pipeline = create_pipeline( | |
| retrieval_k=5, | |
| model=os.getenv("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct"), | |
| temperature=float(os.getenv("LLM_TEMPERATURE", "0.2")) | |
| ) | |
| print("โ Pipeline initialized successfully") | |
| except Exception as e: | |
| print(f"โ Error initializing pipeline: {e}") | |
| raise | |
| # Initialize Vanna | |
| print("Initializing Vanna...") | |
| try: | |
| vanna = VannaComponent( | |
| hf_model="Qwen/Qwen3-VL-30B-A3B-Instruct", | |
| hf_token=os.getenv("HF_TOKEN_VANNA"), | |
| hf_provider="novita", | |
| connection_string=os.getenv("SUPABASE_CONNECTION") | |
| ) | |
| print("โ Vanna initialized successfully") | |
| except Exception as e: | |
| print(f"โ Error initializing Vanna: {e}") | |
| raise | |
| def check_rate_limit(request: gr.Request) -> tuple[bool, int]: | |
| """Check if user has exceeded rate limit""" | |
| if request is None: | |
| return True, DAILY_LIMIT # Allow if no request object | |
| user_id = request.client.host | |
| now = datetime.now() | |
| cutoff = now - timedelta(days=1) | |
| # Remove old requests (older than 24 hours) | |
| rate_limit_tracker[user_id] = [ | |
| ts for ts in rate_limit_tracker[user_id] if ts > cutoff | |
| ] | |
| remaining = DAILY_LIMIT - len(rate_limit_tracker[user_id]) | |
| if remaining <= 0: | |
| return False, 0 | |
| # Add current request | |
| rate_limit_tracker[user_id].append(now) | |
| return True, remaining - 1 | |
| def recommend_stream(message: str, history: list, request: gr.Request): | |
| """ | |
| Streaming version of design recommendation function | |
| Args: | |
| message: User's design query | |
| history: Chat history | |
| request: Gradio request object for rate limiting | |
| Yields: | |
| Response chunks | |
| """ | |
| # Check rate limit | |
| allowed, remaining = check_rate_limit(request) | |
| if not allowed: | |
| yield "โ ๏ธ **Rate limit exceeded.** You've reached the maximum of 20 queries per day. Please try again in 24 hours." | |
| return | |
| try: | |
| response_stream = pipeline.generate_recommendations(message, stream=True) | |
| full_response = "" | |
| for chunk in response_stream: | |
| full_response += chunk | |
| yield full_response | |
| # Add rate limit info at the end | |
| if remaining <= 5: | |
| yield full_response + f"\n\n---\n*You have {remaining} queries remaining today.*" | |
| except Exception as e: | |
| yield f"Error generating response: {str(e)}\n\nPlease check your environment variables (HF_TOKEN, SUPABASE_URL, SUPABASE_KEY) and try again." | |
| def generate_chart_from_csv(csv_file, user_prompt): | |
| """ | |
| Generate a Datawrapper chart from uploaded CSV and user prompt. | |
| Args: | |
| csv_file: Uploaded CSV file | |
| user_prompt: User's description of the chart | |
| Returns: | |
| HTML string with iframe or error message | |
| """ | |
| if not csv_file: | |
| return "<div style='padding: 50px; text-align: center;'>Please upload a CSV file to generate a chart.</div>" | |
| if not user_prompt or user_prompt.strip() == "": | |
| return "<div style='padding: 50px; text-align: center;'>Please describe what chart you want to create.</div>" | |
| try: | |
| # Show loading message | |
| loading_html = """ | |
| <div style='padding: 100px; text-align: center;'> | |
| <h3>๐จ Creating your chart...</h3> | |
| <p>Analyzing your data and selecting the best visualization...</p> | |
| </div> | |
| """ | |
| # Read CSV file | |
| df = pd.read_csv(csv_file) | |
| # Create and publish chart (async function, need to run in event loop) | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| result = loop.run_until_complete( | |
| create_and_publish_chart(df, user_prompt, pipeline) | |
| ) | |
| loop.close() | |
| if result.get("success"): | |
| # Get the iframe HTML | |
| iframe_html = get_iframe_html(result.get('public_url'), height=500) | |
| # Create HTML with iframe, reasoning, and edit button | |
| chart_html = f""" | |
| <div style='padding: 20px;'> | |
| <!-- Chart iframe --> | |
| <div style='margin-bottom: 20px;'> | |
| {iframe_html} | |
| </div> | |
| <!-- Why this chart? --> | |
| <div style='background: #f9f9f9; padding: 15px; border-radius: 5px; margin-bottom: 15px;'> | |
| <strong>Why this chart?</strong><br> | |
| <p style='margin: 10px 0 0 0;'>{result['reasoning']}</p> | |
| </div> | |
| <!-- Edit button --> | |
| <div> | |
| <a href="{result['edit_url']}" target="_blank" | |
| style="display: inline-block; padding: 12px 24px; background: #1976d2; color: white; | |
| text-decoration: none; border-radius: 5px; font-weight: bold;"> | |
| โ๏ธ Open in Datawrapper | |
| </a> | |
| </div> | |
| </div> | |
| """ | |
| return chart_html | |
| else: | |
| error_msg = result.get("error", "Unknown error") | |
| return f""" | |
| <div style='padding: 50px; text-align: center; color: red;'> | |
| <h3>โ Chart Generation Failed</h3> | |
| <p>{error_msg}</p> | |
| <p style='font-size: 0.9em; color: #666;'>Please check your CSV format and try again.</p> | |
| </div> | |
| """ | |
| except Exception as e: | |
| return f""" | |
| <div style='padding: 50px; text-align: center; color: red;'> | |
| <h3>โ Error</h3> | |
| <p>{str(e)}</p> | |
| <p style='font-size: 0.9em; color: #666;'>Please ensure your CSV is properly formatted and try again.</p> | |
| </div> | |
| """ | |
| def csv_to_cards_html(csv_text: str) -> str: | |
| """ | |
| Transforme le CSV brut retournรฉ par Vanna en cartes HTML. | |
| """ | |
| try: | |
| df = pd.read_csv(io.StringIO(csv_text.strip())) | |
| if df.empty: | |
| return "<div style='padding: 50px; text-align: center;'>Aucune donnรฉe trouvรฉe.</div>" | |
| cards_html = "" | |
| for _, row in df.iterrows(): | |
| title = row.get("title", "Sans titre") | |
| source_url = row.get("source_url", "#") | |
| author = row.get("author", "Inconnu") | |
| published_date = row.get("published_date", "") | |
| if not published_date == "nan": | |
| published_date = "" | |
| image_url = row.get("image_url", "") | |
| if not image_url == "nan": | |
| image_url = "https://fpoimg.com/800x600?text=Image+not+found" | |
| cards_html += f""" | |
| <div style="background: white; border-radius: 10px; box-shadow: 0 2px 8px rgba(0,0,0,0.1); | |
| overflow: hidden; margin: 10px; width: 320px; flex: 0 0 auto;"> | |
| <img src="{image_url}" alt="{title}" style="width:100%; height:180px; object-fit:cover;"> | |
| <div style="padding: 12px 16px;"> | |
| <h3 style="margin:0; font-size:1.1em; color:#222;">{title}</h3> | |
| <p style="margin:6px 0; color:#555; font-size:0.9em;">{author}</p> | |
| <p style="margin:0; color:#999; font-size:0.8em;">{published_date}</p> | |
| <a href="{source_url}" target="_blank" | |
| style="display:inline-block; margin-top:8px; font-size:0.9em; color:#1976d2; text-decoration:none;"> | |
| ๐ Voir la source | |
| </a> | |
| </div> | |
| </div> | |
| """ | |
| html = f""" | |
| <div style="display:flex; flex-wrap:wrap; justify-content:center; padding:20px;"> | |
| {cards_html} | |
| </div> | |
| """ | |
| return html | |
| except Exception as e: | |
| return f"<div style='padding: 50px; text-align: center; color:red;'>Erreur lors du parsing du CSV : {e}</div>" | |
| async def search_inspiration_from_database(user_prompt): | |
| """ | |
| Search inspiration posts from user prompt in database. | |
| Args: | |
| user_prompt: User's description of the inspiration query | |
| Returns: | |
| HTML string displaying cards or an error message | |
| """ | |
| if not user_prompt or user_prompt.strip() == "": | |
| return """ | |
| <div style='padding: 50px; text-align: center;'> | |
| Please describe what kind of inspiration you want to search for. | |
| </div> | |
| """ | |
| try: | |
| response = await vanna.ask(user_prompt) | |
| print("response :", repr(response)) | |
| clean_response = response.strip() | |
| if clean_response.startswith("โ ๏ธ") or "Aucun CSV dรฉtectรฉ" in clean_response: | |
| return f""" | |
| <div style='padding: 50px; text-align: center; color: #d9534f;'> | |
| <h3>โ No valid data found</h3> | |
| <p>The AI couldn't generate any data for this request. Try being more specific โ for example: | |
| <em>"Show me spotlights from 2020 about design"</em>.</p> | |
| </div> | |
| """ | |
| csv_text = ( | |
| clean_response | |
| .strip("```") | |
| .replace("csv", "") | |
| .replace("CSV", "") | |
| ) | |
| if "," not in csv_text: | |
| return f""" | |
| <div style='padding: 50px; text-align: center; color: #d9534f;'> | |
| <h3>โ No valid CSV detected</h3> | |
| <p>The model didn't return any structured data. Try rephrasing your query to be more precise.</p> | |
| </div> | |
| """ | |
| cards_html = csv_to_cards_html(csv_text) | |
| return cards_html | |
| except Exception as e: | |
| return f""" | |
| <div style='padding: 50px; text-align: center; color: red;'> | |
| <h3>โ Error</h3> | |
| <p>{str(e)}</p> | |
| <p style='font-size: 0.9em; color: #666;'>Please try again.</p> | |
| </div> | |
| """ | |
| # Minimal CSS to fix UI artifacts and style the mode selector | |
| custom_css = """ | |
| /* Hide retry/undo buttons that appear as artifacts */ | |
| .chatbot button[aria-label="Retry"], | |
| .chatbot button[aria-label="Undo"] { | |
| display: none !important; | |
| } | |
| /* Remove overflow-y scroll from textarea */ | |
| textarea[data-testid="textbox"] { | |
| overflow-y: hidden !important; | |
| } | |
| /* Mode selector buttons */ | |
| .mode-button { | |
| font-size: 1.1em; | |
| padding: 12px 24px; | |
| margin: 5px; | |
| } | |
| """ | |
| # Create Gradio interface with dual-mode layout | |
| with gr.Blocks( | |
| title="Viz LLM", | |
| css=custom_css | |
| ) as demo: | |
| gr.Markdown(""" | |
| # ๐ Viz LLM | |
| Get design recommendations or generate charts with AI-powered data visualization assistance. | |
| """) | |
| # Mode selector buttons | |
| with gr.Row(): | |
| ideation_btn = gr.Button("๐ก Ideation Mode", variant="primary", elem_classes="mode-button") | |
| chart_gen_btn = gr.Button("๐ Chart Generation Mode", variant="secondary", elem_classes="mode-button") | |
| inspiration_btn = gr.Button("โจ Inspiration Mode", variant="secondary", elem_classes="mode-button") | |
| # Ideation Mode: Chat interface (shown by default, wrapped in Column) | |
| with gr.Column(visible=True) as ideation_container: | |
| ideation_interface = gr.ChatInterface( | |
| fn=recommend_stream, | |
| type="messages", | |
| examples=[ | |
| "What's the best chart type for showing trends over time?", | |
| "How do I create an effective infographic for complex data?", | |
| "What are best practices for data visualization accessibility?", | |
| "How should I design a dashboard for storytelling?", | |
| "What visualization works best for comparing categories?" | |
| ], | |
| cache_examples=False, | |
| api_name="recommend" | |
| ) | |
| # Chart Generation Mode: Chart controls and output (hidden by default) | |
| with gr.Column(visible=False) as chart_gen_container: | |
| csv_upload = gr.File( | |
| label="๐ Upload CSV File", | |
| file_types=[".csv"], | |
| type="filepath" | |
| ) | |
| chart_prompt_input = gr.Textbox( | |
| label="Describe your chart", | |
| placeholder="E.g., 'Show sales trends over time' or 'Compare revenue by category'", | |
| lines=2 | |
| ) | |
| generate_chart_btn = gr.Button("Generate Chart", variant="primary", size="lg") | |
| chart_output = gr.HTML( | |
| value="<div style='text-align:center; padding:100px; color: #666;'>Upload a CSV file and describe your visualization above, then click Generate Chart.</div>", | |
| label="Generated Chart" | |
| ) | |
| # Inspiration Mode: | |
| with gr.Column(visible=False) as inspiration_container: | |
| with gr.Row(): | |
| inspiration_prompt_input = gr.Textbox( | |
| placeholder="Ask for an inspiration...", | |
| show_label=False, | |
| scale=4, | |
| container=False | |
| ) | |
| inspiration_search_btn = gr.Button("๐ Search", variant="primary", scale=1) | |
| inspiration_cards_html = gr.HTML("") | |
| # Mode switching functions | |
| def switch_to_ideation(): | |
| return [ | |
| gr.update(variant="primary"), # ideation_btn | |
| gr.update(variant="secondary"), # chart_gen_btn | |
| gr.update(variant="secondary"), # inspiration_btn | |
| gr.update(visible=True), # ideation_container | |
| gr.update(visible=False), # chart_gen_container | |
| gr.update(visible=False), # inspiration_container | |
| ] | |
| def switch_to_chart_gen(): | |
| return [ | |
| gr.update(variant="secondary"), # ideation_btn | |
| gr.update(variant="primary"), # chart_gen_btn | |
| gr.update(variant="secondary"), # inspiration_btn | |
| gr.update(visible=False), # ideation_container | |
| gr.update(visible=True), # chart_gen_container | |
| gr.update(visible=False), # inspiration_container | |
| ] | |
| def switch_to_inspiration(): | |
| return [ | |
| gr.update(variant="secondary"), # ideation_btn | |
| gr.update(variant="secondary"), # chart_gen_btn | |
| gr.update(variant="primary"), # inspiration_btn | |
| gr.update(visible=False), # ideation_container | |
| gr.update(visible=False), # chart_gen_container | |
| gr.update(visible=True), # inspiration_container | |
| ] | |
| # Wire up mode switching | |
| ideation_btn.click( | |
| fn=switch_to_ideation, | |
| inputs=[], | |
| outputs=[ideation_btn, chart_gen_btn, inspiration_btn, ideation_container, chart_gen_container, inspiration_container] | |
| ) | |
| chart_gen_btn.click( | |
| fn=switch_to_chart_gen, | |
| inputs=[], | |
| outputs=[ideation_btn, chart_gen_btn, inspiration_btn, ideation_container, chart_gen_container, inspiration_container] | |
| ) | |
| inspiration_btn.click( | |
| fn=switch_to_inspiration, | |
| inputs=[], | |
| outputs=[ideation_btn, chart_gen_btn, inspiration_btn, ideation_container, chart_gen_container, inspiration_container] | |
| ) | |
| # Generate chart when button is clicked | |
| generate_chart_btn.click( | |
| fn=generate_chart_from_csv, | |
| inputs=[csv_upload, chart_prompt_input], | |
| outputs=[chart_output] | |
| ) | |
| # Search inspiration when button is clicked | |
| inspiration_search_btn.click( | |
| fn=search_inspiration_from_database, | |
| inputs=[inspiration_prompt_input], | |
| outputs=[inspiration_cards_html] | |
| ) | |
| # Knowledge base section (below both interfaces) | |
| gr.Markdown(""" | |
| ### About Viz LLM | |
| **Ideation Mode:** Get design recommendations based on research papers, design principles, and examples from the field of information graphics and data visualization. | |
| **Chart Generation Mode:** Upload your CSV data and describe your visualization goal. The AI will analyze your data, select the optimal chart type, and generate a publication-ready chart using Datawrapper. | |
| **Inspiration Mode:** Coming soon. | |
| **Credits:** Special thanks to the researchers whose work informed this model: Robert Kosara, Edward Segel, Jeffrey Heer, Matthew Conlen, John Maeda, Kennedy Elliott, Scott McCloud, and many others. | |
| --- | |
| **Usage Limits:** This service is limited to 20 queries per day per user to manage costs. Responses are optimized for English. | |
| <div style="text-align: center; margin-top: 20px; opacity: 0.6; font-size: 0.9em;"> | |
| Embeddings: Jina-CLIP-v2 | Charts: Datawrapper API | |
| </div> | |
| """) | |
| # Launch configuration | |
| if __name__ == "__main__": | |
| # Check for required environment variables | |
| required_vars = ["SUPABASE_URL", "SUPABASE_KEY", "HF_TOKEN", "DATAWRAPPER_ACCESS_TOKEN"] | |
| missing_vars = [var for var in required_vars if not os.getenv(var)] | |
| if missing_vars: | |
| print(f"โ ๏ธ Warning: Missing environment variables: {', '.join(missing_vars)}") | |
| print("Please set these in your .env file or as environment variables") | |
| if "DATAWRAPPER_ACCESS_TOKEN" in missing_vars: | |
| print("Note: DATAWRAPPER_ACCESS_TOKEN is required for chart generation mode") | |
| # Launch the app | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_api=True | |
| ) | |