File size: 11,273 Bytes
7114af0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
"""
Datawrapper Chart Generation Client

Integrates RAG pipeline with Datawrapper API for intelligent chart creation.
"""

import json
import os
from typing import Optional, Tuple
import pandas as pd

from .prompts import (
    CHART_SELECTION_SYSTEM_PROMPT,
    get_chart_selection_prompt,
    get_chart_styling_prompt
)
from .llm_client import create_llm_client
from .rag_pipeline import GraphicsDesignPipeline

# Import Datawrapper MCP handlers directly
from datawrapper_mcp.handlers.create import create_chart as mcp_create_chart
from datawrapper_mcp.handlers.publish import publish_chart as mcp_publish_chart
from datawrapper_mcp.handlers.retrieve import get_chart_info as mcp_get_chart_info


def get_data_summary(df: pd.DataFrame) -> str:
    """
    Generate a summary of the DataFrame structure and content.

    Args:
        df: Input DataFrame

    Returns:
        String summary of data characteristics
    """
    summary_parts = []

    # Basic info
    summary_parts.append(f"Rows: {len(df)}, Columns: {len(df.columns)}")
    summary_parts.append(f"Column names: {', '.join(df.columns.tolist())}")

    # Column types
    numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
    text_cols = df.select_dtypes(include=['object']).columns.tolist()
    date_cols = df.select_dtypes(include=['datetime']).columns.tolist()

    if numeric_cols:
        summary_parts.append(f"Numeric columns: {', '.join(numeric_cols)}")
    if text_cols:
        summary_parts.append(f"Text columns: {', '.join(text_cols)}")
    if date_cols:
        summary_parts.append(f"Date columns: {', '.join(date_cols)}")

    # Data preview (first 3 rows)
    summary_parts.append(f"\nData preview:\n{df.head(3).to_string()}")

    return "\n".join(summary_parts)


def analyze_csv_for_chart_type(
    df: pd.DataFrame,
    user_prompt: str,
    rag_pipeline: GraphicsDesignPipeline
) -> Tuple[str, str]:
    """
    Use RAG and LLM to determine the best chart type for the data.

    Args:
        df: Input DataFrame
        user_prompt: User's description of what they want to visualize
        rag_pipeline: RAG pipeline for retrieving best practices

    Returns:
        Tuple of (chart_type, reasoning)
    """
    # Get data summary
    data_summary = get_data_summary(df)

    # Query RAG for chart selection best practices
    rag_query = f"chart type selection for {user_prompt}"
    relevant_docs = rag_pipeline.retrieve_documents(rag_query, k=3)
    rag_context = rag_pipeline.vectorstore.format_documents_for_context(relevant_docs)

    # Generate chart type recommendation using LLM
    chart_prompt = get_chart_selection_prompt()
    full_prompt = chart_prompt.format(
        user_prompt=user_prompt,
        data_summary=data_summary,
        rag_context=rag_context
    )

    llm_client = create_llm_client(
        model=os.getenv("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct"),
        temperature=0.3,  # Lower temperature for more deterministic chart selection
        max_tokens=500
    )

    response = llm_client.generate(
        prompt=full_prompt,
        system_prompt=CHART_SELECTION_SYSTEM_PROMPT
    )

    # Parse JSON response
    try:
        # Extract JSON from response (handle markdown code blocks)
        response_clean = response.strip()
        if "```json" in response_clean:
            response_clean = response_clean.split("```json")[1].split("```")[0].strip()
        elif "```" in response_clean:
            response_clean = response_clean.split("```")[1].split("```")[0].strip()

        result = json.loads(response_clean)
        chart_type = result.get("chart_type", "line")
        reasoning = result.get("reasoning", "")

        # Validate chart type
        valid_types = ["bar", "line", "area", "scatter", "column", "stacked_bar", "arrow", "multiple_column"]
        if chart_type not in valid_types:
            chart_type = "line"  # Default fallback

        return chart_type, reasoning
    except Exception as e:
        print(f"Error parsing chart type response: {e}")
        print(f"Response was: {response}")
        # Default to line chart
        return "line", "Using default line chart due to parsing error"


def generate_chart_config(
    chart_type: str,
    df: pd.DataFrame,
    user_prompt: str,
    rag_pipeline: GraphicsDesignPipeline
) -> dict:
    """
    Generate Datawrapper chart configuration using RAG and LLM.

    Args:
        chart_type: Type of chart to create
        df: Input DataFrame
        user_prompt: User's visualization request
        rag_pipeline: RAG pipeline for retrieving design best practices

    Returns:
        Dictionary with chart configuration
    """
    # Get data summary
    data_summary = get_data_summary(df)

    # Query RAG for styling and design best practices
    rag_query = f"chart design best practices colors accessibility {chart_type}"
    relevant_docs = rag_pipeline.retrieve_documents(rag_query, k=3)
    rag_context = rag_pipeline.vectorstore.format_documents_for_context(relevant_docs)

    # Generate chart configuration using LLM
    styling_prompt = get_chart_styling_prompt()
    full_prompt = styling_prompt.format(
        chart_type=chart_type,
        user_prompt=user_prompt,
        data_summary=data_summary,
        rag_context=rag_context
    )

    llm_client = create_llm_client(
        model=os.getenv("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct"),
        temperature=0.5,
        max_tokens=800
    )

    response = llm_client.generate(
        prompt=full_prompt,
        system_prompt="You are a data visualization expert. Generate valid JSON configuration for Datawrapper charts."
    )

    # Parse JSON response
    try:
        # Extract JSON from response
        response_clean = response.strip()
        if "```json" in response_clean:
            response_clean = response_clean.split("```json")[1].split("```")[0].strip()
        elif "```" in response_clean:
            response_clean = response_clean.split("```")[1].split("```")[0].strip()

        config = json.loads(response_clean)

        # Ensure basic required fields
        if "title" not in config:
            config["title"] = user_prompt[:100]  # Use prompt as fallback title

        return config
    except Exception as e:
        print(f"Error parsing chart config: {e}")
        print(f"Response was: {response}")
        # Return minimal config
        return {
            "title": user_prompt[:100] if user_prompt else "Data Visualization",
            "source_name": "User Data"
        }


async def create_and_publish_chart(
    df: pd.DataFrame,
    user_prompt: str,
    rag_pipeline: GraphicsDesignPipeline,
    api_token: Optional[str] = None
) -> dict:
    """
    Complete workflow: analyze data, select chart type, create and publish chart.

    Args:
        df: Input DataFrame
        user_prompt: User's visualization request
        rag_pipeline: RAG pipeline instance
        api_token: Datawrapper API token (defaults to env var)

    Returns:
        Dictionary with chart info including iframe URL
    """
    if api_token is None:
        api_token = os.getenv("DATAWRAPPER_ACCESS_TOKEN")
        if not api_token:
            raise ValueError("DATAWRAPPER_ACCESS_TOKEN not found in environment")

    try:
        # Step 1: Analyze data and select chart type
        chart_type, reasoning = analyze_csv_for_chart_type(df, user_prompt, rag_pipeline)

        # Step 2: Generate chart configuration
        chart_config = generate_chart_config(chart_type, df, user_prompt, rag_pipeline)

        # Step 3: Convert DataFrame to list of dicts for Datawrapper
        data_list = df.to_dict('records')

        # Step 4: Create chart using MCP handler
        create_args = {
            "data": data_list,
            "chart_type": chart_type,
            "chart_config": chart_config
        }

        create_result = await mcp_create_chart(create_args)

        if not create_result or len(create_result) == 0:
            raise ValueError("Empty response from chart creation")

        result_text = create_result[0].text

        if not result_text or result_text.strip() == "":
            raise ValueError("Empty text in chart creation response")

        result_data = json.loads(result_text)

        chart_id = result_data.get("chart_id")
        if not chart_id:
            raise ValueError(f"Failed to get chart_id from creation response. Response was: {result_data}")

        # Step 5: Try to publish chart using MCP handler
        publish_success = False
        publish_message = ""
        try:
            publish_args = {"chart_id": chart_id}
            publish_result = await mcp_publish_chart(publish_args)
            publish_text = publish_result[0].text
            publish_data = json.loads(publish_text)
            publish_success = True
            publish_message = publish_data.get("message", "Published successfully")
        except Exception as publish_error:
            publish_message = f"Publish failed: {str(publish_error)}"

        # Step 6: Get full chart info using MCP handler
        chart_info_args = {"chart_id": chart_id}
        chart_info_result = await mcp_get_chart_info(chart_info_args)
        chart_info_text = chart_info_result[0].text
        chart_info = json.loads(chart_info_text)

        # Return complete info
        return {
            "success": True,
            "chart_id": chart_id,
            "chart_type": chart_type,
            "reasoning": reasoning,
            "public_url": chart_info.get("public_url"),
            "edit_url": chart_info.get("edit_url"),
            "published": publish_success,
            "publish_message": publish_message,
            "title": chart_config.get("title", "Chart")
        }

    except json.JSONDecodeError as e:
        error_msg = f"JSON parsing error: {str(e)}"
        print(f"Error in chart creation: {error_msg}")
        print(f"Failed to parse: {result_text if 'result_text' in locals() else 'N/A'}")
        return {
            "success": False,
            "error": error_msg,
            "chart_type": chart_type if 'chart_type' in locals() else None,
            "public_url": None
        }
    except Exception as e:
        error_msg = f"{type(e).__name__}: {str(e)}"
        print(f"Error in chart creation: {error_msg}")
        import traceback
        traceback.print_exc()
        return {
            "success": False,
            "error": error_msg,
            "chart_type": chart_type if 'chart_type' in locals() else None,
            "public_url": None
        }


def get_iframe_html(chart_url: str, height: int = 600) -> str:
    """
    Generate iframe HTML for embedding a Datawrapper chart.

    Args:
        chart_url: Public URL of the chart
        height: Height of iframe in pixels

    Returns:
        HTML string with iframe
    """
    if not chart_url:
        return "<div style='padding: 50px; text-align: center;'>No chart available</div>"

    return f"""
    <div style="width: 100%; height: {height}px;">
        <iframe
            src="{chart_url}"
            style="width: 100%; height: 100%; border: none;"
            frameborder="0"
            scrolling="no"
            aria-label="Chart">
        </iframe>
    </div>
    """