""" Datawrapper Chart Generation Client Integrates RAG pipeline with Datawrapper API for intelligent chart creation. """ import json import os from typing import Optional, Tuple import pandas as pd from .prompts import ( CHART_SELECTION_SYSTEM_PROMPT, get_chart_selection_prompt, get_chart_styling_prompt ) from .llm_client import create_llm_client from .rag_pipeline import GraphicsDesignPipeline # Import Datawrapper MCP handlers directly from datawrapper_mcp.handlers.create import create_chart as mcp_create_chart from datawrapper_mcp.handlers.publish import publish_chart as mcp_publish_chart from datawrapper_mcp.handlers.retrieve import get_chart_info as mcp_get_chart_info def get_data_summary(df: pd.DataFrame) -> str: """ Generate a summary of the DataFrame structure and content. Args: df: Input DataFrame Returns: String summary of data characteristics """ summary_parts = [] # Basic info summary_parts.append(f"Rows: {len(df)}, Columns: {len(df.columns)}") summary_parts.append(f"Column names: {', '.join(df.columns.tolist())}") # Column types numeric_cols = df.select_dtypes(include=['number']).columns.tolist() text_cols = df.select_dtypes(include=['object']).columns.tolist() date_cols = df.select_dtypes(include=['datetime']).columns.tolist() if numeric_cols: summary_parts.append(f"Numeric columns: {', '.join(numeric_cols)}") if text_cols: summary_parts.append(f"Text columns: {', '.join(text_cols)}") if date_cols: summary_parts.append(f"Date columns: {', '.join(date_cols)}") # Data preview (first 3 rows) summary_parts.append(f"\nData preview:\n{df.head(3).to_string()}") return "\n".join(summary_parts) def analyze_csv_for_chart_type( df: pd.DataFrame, user_prompt: str, rag_pipeline: GraphicsDesignPipeline ) -> Tuple[str, str]: """ Use RAG and LLM to determine the best chart type for the data. Args: df: Input DataFrame user_prompt: User's description of what they want to visualize rag_pipeline: RAG pipeline for retrieving best practices Returns: Tuple of (chart_type, reasoning) """ # Get data summary data_summary = get_data_summary(df) # Query RAG for chart selection best practices rag_query = f"chart type selection for {user_prompt}" relevant_docs = rag_pipeline.retrieve_documents(rag_query, k=3) rag_context = rag_pipeline.vectorstore.format_documents_for_context(relevant_docs) # Generate chart type recommendation using LLM chart_prompt = get_chart_selection_prompt() full_prompt = chart_prompt.format( user_prompt=user_prompt, data_summary=data_summary, rag_context=rag_context ) llm_client = create_llm_client( model=os.getenv("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct"), temperature=0.3, # Lower temperature for more deterministic chart selection max_tokens=500 ) response = llm_client.generate( prompt=full_prompt, system_prompt=CHART_SELECTION_SYSTEM_PROMPT ) # Parse JSON response try: # Extract JSON from response (handle markdown code blocks) response_clean = response.strip() if "```json" in response_clean: response_clean = response_clean.split("```json")[1].split("```")[0].strip() elif "```" in response_clean: response_clean = response_clean.split("```")[1].split("```")[0].strip() result = json.loads(response_clean) chart_type = result.get("chart_type", "line") reasoning = result.get("reasoning", "") # Validate chart type valid_types = ["bar", "line", "area", "scatter", "column", "stacked_bar", "arrow", "multiple_column"] if chart_type not in valid_types: chart_type = "line" # Default fallback return chart_type, reasoning except Exception as e: print(f"Error parsing chart type response: {e}") print(f"Response was: {response}") # Default to line chart return "line", "Using default line chart due to parsing error" def generate_chart_config( chart_type: str, df: pd.DataFrame, user_prompt: str, rag_pipeline: GraphicsDesignPipeline ) -> dict: """ Generate Datawrapper chart configuration using RAG and LLM. Args: chart_type: Type of chart to create df: Input DataFrame user_prompt: User's visualization request rag_pipeline: RAG pipeline for retrieving design best practices Returns: Dictionary with chart configuration """ # Get data summary data_summary = get_data_summary(df) # Query RAG for styling and design best practices rag_query = f"chart design best practices colors accessibility {chart_type}" relevant_docs = rag_pipeline.retrieve_documents(rag_query, k=3) rag_context = rag_pipeline.vectorstore.format_documents_for_context(relevant_docs) # Generate chart configuration using LLM styling_prompt = get_chart_styling_prompt() full_prompt = styling_prompt.format( chart_type=chart_type, user_prompt=user_prompt, data_summary=data_summary, rag_context=rag_context ) llm_client = create_llm_client( model=os.getenv("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct"), temperature=0.5, max_tokens=800 ) response = llm_client.generate( prompt=full_prompt, system_prompt="You are a data visualization expert. Generate valid JSON configuration for Datawrapper charts." ) # Parse JSON response try: # Extract JSON from response response_clean = response.strip() if "```json" in response_clean: response_clean = response_clean.split("```json")[1].split("```")[0].strip() elif "```" in response_clean: response_clean = response_clean.split("```")[1].split("```")[0].strip() config = json.loads(response_clean) # Ensure basic required fields if "title" not in config: config["title"] = user_prompt[:100] # Use prompt as fallback title return config except Exception as e: print(f"Error parsing chart config: {e}") print(f"Response was: {response}") # Return minimal config return { "title": user_prompt[:100] if user_prompt else "Data Visualization", "source_name": "User Data" } async def create_and_publish_chart( df: pd.DataFrame, user_prompt: str, rag_pipeline: GraphicsDesignPipeline, api_token: Optional[str] = None ) -> dict: """ Complete workflow: analyze data, select chart type, create and publish chart. Args: df: Input DataFrame user_prompt: User's visualization request rag_pipeline: RAG pipeline instance api_token: Datawrapper API token (defaults to env var) Returns: Dictionary with chart info including iframe URL """ if api_token is None: api_token = os.getenv("DATAWRAPPER_ACCESS_TOKEN") if not api_token: raise ValueError("DATAWRAPPER_ACCESS_TOKEN not found in environment") try: # Step 1: Analyze data and select chart type chart_type, reasoning = analyze_csv_for_chart_type(df, user_prompt, rag_pipeline) # Step 2: Generate chart configuration chart_config = generate_chart_config(chart_type, df, user_prompt, rag_pipeline) # Step 3: Convert DataFrame to list of dicts for Datawrapper data_list = df.to_dict('records') # Step 4: Create chart using MCP handler create_args = { "data": data_list, "chart_type": chart_type, "chart_config": chart_config } create_result = await mcp_create_chart(create_args) if not create_result or len(create_result) == 0: raise ValueError("Empty response from chart creation") result_text = create_result[0].text if not result_text or result_text.strip() == "": raise ValueError("Empty text in chart creation response") result_data = json.loads(result_text) chart_id = result_data.get("chart_id") if not chart_id: raise ValueError(f"Failed to get chart_id from creation response. Response was: {result_data}") # Step 5: Try to publish chart using MCP handler publish_success = False publish_message = "" try: publish_args = {"chart_id": chart_id} publish_result = await mcp_publish_chart(publish_args) publish_text = publish_result[0].text publish_data = json.loads(publish_text) publish_success = True publish_message = publish_data.get("message", "Published successfully") except Exception as publish_error: publish_message = f"Publish failed: {str(publish_error)}" # Step 6: Get full chart info using MCP handler chart_info_args = {"chart_id": chart_id} chart_info_result = await mcp_get_chart_info(chart_info_args) chart_info_text = chart_info_result[0].text chart_info = json.loads(chart_info_text) # Return complete info return { "success": True, "chart_id": chart_id, "chart_type": chart_type, "reasoning": reasoning, "public_url": chart_info.get("public_url"), "edit_url": chart_info.get("edit_url"), "published": publish_success, "publish_message": publish_message, "title": chart_config.get("title", "Chart") } except json.JSONDecodeError as e: error_msg = f"JSON parsing error: {str(e)}" print(f"Error in chart creation: {error_msg}") print(f"Failed to parse: {result_text if 'result_text' in locals() else 'N/A'}") return { "success": False, "error": error_msg, "chart_type": chart_type if 'chart_type' in locals() else None, "public_url": None } except Exception as e: error_msg = f"{type(e).__name__}: {str(e)}" print(f"Error in chart creation: {error_msg}") import traceback traceback.print_exc() return { "success": False, "error": error_msg, "chart_type": chart_type if 'chart_type' in locals() else None, "public_url": None } def get_iframe_html(chart_url: str, height: int = 600) -> str: """ Generate iframe HTML for embedding a Datawrapper chart. Args: chart_url: Public URL of the chart height: Height of iframe in pixels Returns: HTML string with iframe """ if not chart_url: return "
No chart available
" return f"""
"""