File size: 30,207 Bytes
942ce50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dc8a59
 
942ce50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a910a1e
942ce50
 
 
 
fedc47d
 
 
942ce50
fedc47d
 
 
942ce50
 
fedc47d
 
 
942ce50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63f41a1
 
 
 
 
 
 
 
 
 
 
 
fedc47d
74a1fce
 
fedc47d
63f41a1
 
 
 
 
fedc47d
63f41a1
 
 
942ce50
 
 
b20c328
942ce50
b20c328
 
 
 
942ce50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a910a1e
942ce50
 
 
 
 
 
 
a910a1e
 
942ce50
 
 
 
 
 
 
 
 
 
 
 
 
 
a910a1e
 
942ce50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
"""
Screen 4: Trace Detail View
Shows detailed OpenTelemetry trace visualization
"""

import gradio as gr
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from datetime import datetime
import pandas as pd
from typing import Optional, Callable, Dict, Any, List
from components.thought_graph import create_thought_graph


def create_trace_detail_screen(
    trace_data: dict,
    on_back: Optional[Callable] = None,
    mcp_qa_enabled: bool = True
) -> gr.Blocks:
    """
    Create the trace detail screen UI

    Args:
        trace_data: OpenTelemetry trace data
        on_back: Callback for back button
        mcp_qa_enabled: Enable MCP Q&A tool

    Returns:
        Gradio Blocks for trace detail screen
    """

    with gr.Blocks() as trace_detail:
        with gr.Row():
            if on_back:
                back_btn = gr.Button("⬅️ Back to Run Detail", variant="secondary", size="sm")

        gr.Markdown(f"# 🔍 Trace Detail: {trace_data.get('trace_id', 'Unknown')}")

        # Safely extract spans
        spans = trace_data.get('spans', [])
        if hasattr(spans, 'tolist'):
            spans = spans.tolist()
        elif not isinstance(spans, list):
            spans = list(spans) if spans is not None else []

        # Trace metadata
        with gr.Row():
            gr.Markdown(f"""
            **Trace ID:** `{trace_data.get('trace_id', 'N/A')}`
            **Total Spans:** {len(spans)}
            """)

        # Tabs for different visualizations
        with gr.Tabs() as tabs:
            # Tab 1: Thought Graph (STAR FEATURE!)
            with gr.Tab("🧠 Thought Graph"):
                gr.Markdown("""
                ### Agent Reasoning Flow
                This graph visualizes how your agent thinks - showing the flow of reasoning steps,
                tool calls, and LLM interactions as a network.

                **Node Colors:**
                - 🟣 Purple: LLM reasoning steps
                - 🟠 Orange: Tool calls
                - 🔵 Blue: Chains/Agents
                - 🔴 Red: Errors
                """)

                # Create and display thought graph
                thought_graph_plot = gr.Plot(
                    value=create_thought_graph(spans, trace_data.get('trace_id', 'Unknown')),
                    label=""
                )

            # Tab 2: Execution Timeline (Waterfall)
            with gr.Tab("⏱️ Execution Timeline"):
                gr.Markdown("""
                ### Waterfall Chart
                Timeline view showing when each span executed and for how long.
                """)

                # Span visualization
                span_viz = gr.Plot(
                    value=create_span_visualization(spans, trace_data.get('trace_id', 'Unknown')),
                    label=""
                )

            # Tab 3: Span Details
            with gr.Tab("📋 Span Details"):
                gr.Markdown("""
                ### Detailed Span Information
                Raw span data with attributes, status, and metadata.
                """)

                # Span details table
                span_table = create_span_table(spans)

        # MCP Q&A Tool (below tabs)
        gr.Markdown("---")
        if mcp_qa_enabled:
            with gr.Accordion("🤖 Ask About This Trace", open=False):
                question_input = gr.Textbox(
                    label="Question",
                    placeholder="e.g., Why was the tool called twice? What tool did the agent use first?",
                    lines=2,
                    info="Ask questions about this trace execution, tool usage, or agent behavior"
                )
                ask_btn = gr.Button("Ask", variant="primary")
                answer_output = gr.Markdown("*Ask a question to get AI-powered insights*")

                # Wire up MCP Q&A (placeholder for now)
                ask_btn.click(
                    fn=lambda q: f"**Answer:** This is a placeholder. MCP integration coming soon.\n\n**Your question:** {q}",
                    inputs=[question_input],
                    outputs=[answer_output]
                )

        # Wire up events
        if on_back:
            back_btn.click(fn=on_back, inputs=[], outputs=[])

    return trace_detail


def process_trace_data(spans: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """Process trace spans for waterfall visualization"""
    # Ensure spans is a list
    if hasattr(spans, 'tolist'):
        spans = spans.tolist()
    elif not isinstance(spans, list):
        spans = list(spans) if spans is not None else []

    if not spans:
        return []

    # Helper function to get timestamp from span (handles different field names)
    def get_timestamp(span, field_name):
        """Get timestamp handling different OpenTelemetry field name variations"""
        # Try different variations of field names
        variations = [
            field_name,  # e.g., 'startTime'
            field_name.lower(),  # e.g., 'starttime'
            field_name.replace('Time', 'TimeUnixNano'),  # e.g., 'startTimeUnixNano'
            field_name[0].lower() + field_name[1:],  # e.g., 'startTime'
            # Add snake_case variations (start_time, end_time)
            field_name.replace('Time', '_time').lower(),  # e.g., 'start_time'
            field_name.replace('Time', '_time_unix_nano').lower(),  # e.g., 'start_time_unix_nano'
        ]

        for var in variations:
            if var in span:
                value = span[var]
                # Handle both string and numeric timestamps
                if isinstance(value, str):
                    return int(value)
                return value

        # If not found, return 0
        return 0

    # Calculate relative times
    start_times = [get_timestamp(span, 'startTime') for span in spans]
    min_start = min(start_times) if start_times else 0
    max_start = max(start_times) if start_times else 0

    # Check if we have any actual timing data
    has_timing_data = min_start > 0 or max_start > 0

    # Debug: Print first span's raw timestamps
    if spans:
        first_span = spans[0]
        print(f"[DEBUG] First span raw data sample:")
        print(f"  startTime field: {first_span.get('startTime', 'NOT FOUND')}")
        print(f"  endTime field: {first_span.get('endTime', 'NOT FOUND')}")
        print(f"  startTimeUnixNano field: {first_span.get('startTimeUnixNano', 'NOT FOUND')}")
        print(f"  endTimeUnixNano field: {first_span.get('endTimeUnixNano', 'NOT FOUND')}")
        print(f"  HAS_TIMING_DATA: {has_timing_data}")
        if 'attributes' in first_span:
            attrs = first_span['attributes']
            print(f"  Sample attributes: {list(attrs.keys())[:5] if isinstance(attrs, dict) else 'N/A'}")
            if isinstance(attrs, dict):
                # Check for cost fields
                cost_fields = [k for k in attrs.keys() if 'cost' in k.lower() or 'price' in k.lower()]
                if cost_fields:
                    print(f"  Cost-related fields found: {cost_fields}")

    # Auto-detect timestamp unit based on magnitude
    time_divisor = 1000000  # Default: assume nanoseconds, convert to milliseconds
    if start_times and min_start > 0:
        # If timestamp is > 1e15, it's likely nanoseconds
        # If timestamp is > 1e12, it's likely microseconds
        # If timestamp is > 1e9, it's likely milliseconds
        # If timestamp is < 1e9, it's likely seconds
        if min_start > 1e15:
            time_divisor = 1000000  # nanoseconds to milliseconds
            time_unit = "nanoseconds"
        elif min_start > 1e12:
            time_divisor = 1000  # microseconds to milliseconds
            time_unit = "microseconds"
        elif min_start > 1e9:
            time_divisor = 1  # already in milliseconds
            time_unit = "milliseconds"
        else:
            time_divisor = 0.001  # seconds to milliseconds
            time_unit = "seconds"
        print(f"[DEBUG] Auto-detected timestamp unit: {time_unit} (min_start={min_start}, divisor={time_divisor})")

    processed_spans = []
    for idx, span in enumerate(spans):
        start_time = get_timestamp(span, 'startTime')
        end_time = get_timestamp(span, 'endTime')

        # Calculate relative start
        relative_start = (start_time - min_start) / time_divisor if has_timing_data else 0

        # Calculate duration - prefer duration_ms if available
        if 'duration_ms' in span and span['duration_ms'] is not None:
            actual_duration = float(span['duration_ms'])
        else:
            actual_duration = (end_time - start_time) / time_divisor

        # Debug: Print first few durations
        if idx < 3:
            duration_source = 'duration_ms' if 'duration_ms' in span else 'calculated'
            print(f"[DEBUG] Span {idx}: start={start_time}, end={end_time}, duration={actual_duration:.3f}ms ({duration_source})")

        # Handle span ID variations
        span_id = span.get('spanId') or span.get('span_id') or span.get('spanID') or f'span_{idx}'
        parent_id = span.get('parentSpanId') or span.get('parent_span_id') or span.get('parentSpanID')

        # Get span kind - check both top-level and OpenInference attributes
        span_kind = span.get('kind', 'INTERNAL')
        attributes = span.get('attributes', {})

        # Check for OpenInference span kind in attributes
        if isinstance(attributes, dict) and 'openinference.span.kind' in attributes:
            openinference_kind = attributes.get('openinference.span.kind')
            # Map OpenInference kinds to OpenTelemetry kinds for consistency
            # OpenInference kinds: CHAIN, TOOL, LLM, RETRIEVER, EMBEDDING, AGENT, etc.
            if openinference_kind:
                span_kind = openinference_kind.upper()

        # Extract token and cost information from attributes
        token_info = {}
        cost_info = {}
        if isinstance(attributes, dict):
            # Helper to safely extract numeric values
            def safe_numeric(value):
                """Safely convert to numeric, return None if invalid"""
                if value is None:
                    return None
                try:
                    if isinstance(value, (int, float)):
                        return value
                    return float(value)
                except (ValueError, TypeError):
                    return None

            # Check for token usage (various formats)
            prompt_tokens = None
            completion_tokens = None

            if 'gen_ai.usage.prompt_tokens' in attributes:
                prompt_tokens = safe_numeric(attributes['gen_ai.usage.prompt_tokens'])
            if 'gen_ai.usage.completion_tokens' in attributes:
                completion_tokens = safe_numeric(attributes['gen_ai.usage.completion_tokens'])
            if 'llm.token_count.prompt' in attributes and prompt_tokens is None:
                prompt_tokens = safe_numeric(attributes['llm.token_count.prompt'])
            if 'llm.token_count.completion' in attributes and completion_tokens is None:
                completion_tokens = safe_numeric(attributes['llm.token_count.completion'])

            # Store valid token counts
            if prompt_tokens is not None:
                token_info['prompt_tokens'] = int(prompt_tokens)
            if completion_tokens is not None:
                token_info['completion_tokens'] = int(completion_tokens)

            # Calculate total tokens
            if 'prompt_tokens' in token_info and 'completion_tokens' in token_info:
                token_info['total_tokens'] = token_info['prompt_tokens'] + token_info['completion_tokens']
            elif 'llm.usage.total_tokens' in attributes:
                total = safe_numeric(attributes['llm.usage.total_tokens'])
                if total is not None:
                    token_info['total_tokens'] = int(total)

            # Check for cost information (various formats)
            if 'gen_ai.usage.cost.total' in attributes:
                cost = safe_numeric(attributes['gen_ai.usage.cost.total'])
                if cost is not None:
                    cost_info['total_cost'] = cost
            elif 'llm.usage.cost' in attributes:
                cost = safe_numeric(attributes['llm.usage.cost'])
                if cost is not None:
                    cost_info['total_cost'] = cost

            # Debug: Print cost info for LLM spans
            if idx < 2 and span_kind == 'LLM':
                print(f"[DEBUG] LLM Span {idx} cost extraction:")
                print(f"  gen_ai.usage.cost.total: {attributes.get('gen_ai.usage.cost.total', 'NOT FOUND')}")
                print(f"  llm.usage.cost: {attributes.get('llm.usage.cost', 'NOT FOUND')}")
                print(f"  cost_info: {cost_info}")

        # Store actual duration for tooltip, use minimum for visualization
        display_duration = max(actual_duration, 0.1)  # Minimum width for visibility

        processed_spans.append({
            'span_id': span_id,
            'parent_id': parent_id,
            'name': span.get('name', 'Unknown'),
            'kind': span_kind,
            'start_time': relative_start,
            'duration': display_duration,  # For bar width
            'actual_duration': actual_duration,  # For tooltip
            'end_time': relative_start + actual_duration,  # Use actual for end time
            'attributes': attributes,
            'status': span.get('status', {}).get('code', 'UNKNOWN'),
            'tokens': token_info,
            'cost': cost_info
        })

    print(f"[DEBUG] Total spans in input: {len(spans)}")
    print(f"[DEBUG] Processed spans: {len(processed_spans)}")

    # Debug: Show span kinds and statuses detected
    span_kinds = {}
    span_statuses = {}
    durations = []
    spans_with_tokens = 0
    spans_with_cost = 0
    for span in processed_spans:
        kind = span['kind']
        status = span['status']
        span_kinds[kind] = span_kinds.get(kind, 0) + 1
        span_statuses[status] = span_statuses.get(status, 0) + 1
        durations.append(span['actual_duration'])
        if span['tokens']:
            spans_with_tokens += 1
        if span['cost']:
            spans_with_cost += 1

    print(f"[DEBUG] Span kinds detected: {span_kinds}")
    print(f"[DEBUG] Span statuses detected: {span_statuses}")
    if durations:
        print(f"[DEBUG] Duration range: {min(durations):.3f}ms - {max(durations):.3f}ms")
    print(f"[DEBUG] Spans with token info: {spans_with_tokens}/{len(processed_spans)}")
    print(f"[DEBUG] Spans with cost info: {spans_with_cost}/{len(processed_spans)}")

    return processed_spans


def create_span_visualization(spans: List[Dict[str, Any]], trace_id: str = "Unknown") -> go.Figure:
    """Create an interactive Plotly waterfall visualization of spans"""
    processed_spans = process_trace_data(spans)

    print(f"[DEBUG] create_span_visualization - Received {len(spans)} spans")
    print(f"[DEBUG] create_span_visualization - Processed {len(processed_spans)} spans")

    if not processed_spans:
        # Return empty figure with message
        fig = go.Figure()
        fig.add_annotation(
            text="No spans to display",
            xref="paper", yref="paper",
            x=0.5, y=0.5, xanchor='center', yanchor='middle',
            showarrow=False,
            font=dict(size=20)
        )
        return fig

    # Sort spans by start time for better visualization
    processed_spans.sort(key=lambda x: x['start_time'])

    # Create unique labels for each span (include index to ensure uniqueness)
    for idx, span in enumerate(processed_spans):
        # Add span index to make labels unique
        span['display_name'] = f"{span['name']} [{idx}]"

    # Create colors based on span status and kind
    colors = []
    color_map = {}  # Track which colors are assigned to which kinds
    for span in processed_spans:
        status = span['status']
        kind = span['kind']

        # Only show red for actual errors (ERROR status)
        if status == 'ERROR':
            color = '#DC143C'  # Crimson for errors
        else:
            # Color by span kind (supports both OpenTelemetry and OpenInference)
            if kind == 'SERVER':
                color = '#2E8B57'  # Sea Green
            elif kind == 'CLIENT':
                color = '#4169E1'  # Royal Blue
            elif kind == 'LLM':
                color = '#9B59B6'  # Purple for LLM calls
            elif kind == 'TOOL':
                color = '#E67E22'  # Orange for Tool calls
            elif kind == 'CHAIN':
                color = '#3498DB'  # Light Blue for Chains
            elif kind == 'AGENT':
                color = '#1ABC9C'  # Turquoise for Agents
            elif kind == 'RETRIEVER':
                color = '#F39C12'  # Yellow-Orange for Retrievers
            elif kind == 'EMBEDDING':
                color = '#8E44AD'  # Dark Purple for Embeddings
            else:
                color = '#4682B4'  # Steel Blue for INTERNAL/unknown

        colors.append(color)
        if kind not in color_map:
            color_map[kind] = color

    print(f"[DEBUG] Color assignments: {color_map}")

    # Create the waterfall chart
    fig = go.Figure()

    # Prepare custom data for hover tooltips
    customdata = []
    for span in processed_spans:
        # Build token info string
        token_str = ""
        if span['tokens']:
            tokens = span['tokens']
            if 'total_tokens' in tokens:
                token_str = f"<br>Tokens: {tokens['total_tokens']}"
                if 'prompt_tokens' in tokens and 'completion_tokens' in tokens:
                    token_str += f" (prompt: {tokens['prompt_tokens']}, completion: {tokens['completion_tokens']})"
            elif 'prompt_tokens' in tokens or 'completion_tokens' in tokens:
                parts = []
                if 'prompt_tokens' in tokens:
                    parts.append(f"prompt: {tokens['prompt_tokens']}")
                if 'completion_tokens' in tokens:
                    parts.append(f"completion: {tokens['completion_tokens']}")
                token_str = f"<br>Tokens: {', '.join(parts)}"

        # Build cost info string
        cost_str = ""
        if span['cost'] and 'total_cost' in span['cost']:
            cost_str = f"<br>Cost: ${span['cost']['total_cost']:.6f}"

        customdata.append([
            span['name'],
            span['kind'],
            span['span_id'],
            span['end_time'],
            span['actual_duration'],  # Show actual duration, not display duration
            token_str,
            cost_str
        ])

    # Add bars for each span (use display_name for unique y-axis labels)
    fig.add_trace(go.Bar(
        y=[span['display_name'] for span in processed_spans],
        x=[span['duration'] for span in processed_spans],  # Display duration (min 0.1ms)
        base=[span['start_time'] for span in processed_spans],
        orientation='h',
        marker_color=colors,
        hovertemplate=(
            "<b>%{customdata[0]}</b><br>" +
            "Type: %{customdata[1]}<br>" +
            "Span ID: %{customdata[2]}<br>" +
            "Duration: %{customdata[4]:.3f} ms<br>" +  # Actual duration with 3 decimal places
            "Start: %{base:.2f} ms<br>" +
            "End: %{customdata[3]:.2f} ms" +
            "%{customdata[5]}" +  # Token info (already formatted)
            "%{customdata[6]}" +  # Cost info (already formatted)
            "<extra></extra>"
        ),
        customdata=customdata,
        name="Spans"
    ))

    # Update layout for better visualization
    fig.update_layout(
        title={
            'text': f"OpenTelemetry Trace: {trace_id}",
            'x': 0.5,
            'xanchor': 'center'
        },
        xaxis_title="Time (milliseconds)",
        yaxis_title="Spans",
        showlegend=False,
        height=400 + len(processed_spans) * 30,  # Dynamic height based on span count
        bargap=0.2,
        hovermode='closest'
    )

    return fig


def create_span_table(spans: List[Dict[str, Any]]) -> gr.JSON:
    """Create detailed span information display"""

    # Ensure spans is a list
    if hasattr(spans, 'tolist'):
        spans = spans.tolist()
    elif not isinstance(spans, list):
        spans = list(spans) if spans is not None else []

    # Helper function to get timestamp (same as in process_trace_data)
    def get_timestamp(span, field_name):
        variations = [
            field_name,
            field_name.lower(),
            field_name.replace('Time', 'TimeUnixNano'),
            field_name[0].lower() + field_name[1:],
        ]
        for var in variations:
            if var in span:
                value = span[var]
                if isinstance(value, str):
                    return int(value)
                return value
        return 0

    # Simplify span data for display
    simplified_spans = []
    for span in spans:
        start_time = get_timestamp(span, 'startTime')
        end_time = get_timestamp(span, 'endTime')
        duration_ms = (end_time - start_time) / 1000000 if (end_time and start_time) else 0

        # Handle span ID variations
        span_id = span.get('spanId') or span.get('span_id') or span.get('spanID') or 'N/A'
        parent_id = span.get('parentSpanId') or span.get('parent_span_id') or span.get('parentSpanID') or 'root'

        simplified_spans.append({
            "Span ID": span_id,
            "Parent": parent_id,
            "Name": span.get('name', 'N/A'),
            "Kind": span.get('kind', 'N/A'),
            "Duration (ms)": round(duration_ms, 2),
            "Attributes": span.get('attributes', {}),
            "Status": span.get('status', {}).get('code', 'UNKNOWN')
        })

    return gr.JSON(value=simplified_spans, label="Span Details")


# GPU Metrics Visualization Functions

def extract_metrics_data(metrics_df):
    """
    Extract and prepare GPU metrics data for visualization

    Args:
        metrics_df: DataFrame with flat metrics structure (from HuggingFace dataset)
                   Expected columns: timestamp, gpu_utilization_percent, gpu_memory_used_mib,
                                   gpu_temperature_celsius, gpu_power_watts, co2_emissions_gco2e

    Returns:
        DataFrame ready for visualization
    """
    if metrics_df is None or metrics_df.empty:
        return pd.DataFrame()

    # Make a copy to avoid modifying original
    df = metrics_df.copy()

    # Ensure timestamp is datetime
    if 'timestamp' in df.columns:
        if not pd.api.types.is_datetime64_any_dtype(df['timestamp']):
            df['timestamp'] = pd.to_datetime(df['timestamp'])

    # Sort by timestamp
    df = df.sort_values('timestamp').reset_index(drop=True)

    return df


def create_gpu_summary_cards(df):
    """
    Create summary cards for GPU metrics

    Args:
        df: DataFrame with flat metrics structure (columns: gpu_utilization_percent, etc.)

    Returns:
        HTML string with summary cards
    """
    if df is None or df.empty:
        return "<div style='padding: 20px; text-align: center;'>⚠️ No GPU metrics available (expected for API models)</div>"

    # Debug: Print DataFrame info
    print(f"[DEBUG create_gpu_summary_cards] DataFrame shape: {df.shape}")
    print(f"[DEBUG create_gpu_summary_cards] DataFrame columns: {list(df.columns)}")
    if not df.empty:
        print(f"[DEBUG create_gpu_summary_cards] First row sample: {df.iloc[0].to_dict()}")
        print(f"[DEBUG create_gpu_summary_cards] Last row sample: {df.iloc[-1].to_dict()}")

    # Use aggregate statistics (average/max) instead of just last row
    # This is more representative of overall GPU performance
    utilization = df['gpu_utilization_percent'].mean() if 'gpu_utilization_percent' in df.columns else 0
    memory_used = df['gpu_memory_used_mib'].max() if 'gpu_memory_used_mib' in df.columns else 0
    temperature = df['gpu_temperature_celsius'].max() if 'gpu_temperature_celsius' in df.columns else 0

    # CO2 emissions - use max value (cumulative total)
    co2_emissions = df['co2_emissions_gco2e'].max() if 'co2_emissions_gco2e' in df.columns else 0

    power = df['gpu_power_watts'].mean() if 'gpu_power_watts' in df.columns else 0

    # Get GPU name from first row (it's constant across all rows)
    gpu_name = df['gpu_name'].iloc[0] if 'gpu_name' in df.columns and not df.empty else 'Unknown GPU'

    print(f"[DEBUG create_gpu_summary_cards] Aggregated values - util: {utilization:.2f}, mem: {memory_used:.2f}, temp: {temperature:.2f}, co2: {co2_emissions:.4f}, gpu_name: {gpu_name}")

    # Get memory total from max value if available
    memory_total = df['gpu_memory_total_mib'].max() if 'gpu_memory_total_mib' in df.columns else 0
    memory_percent = (memory_used / memory_total * 100) if memory_total > 0 else 0

    cards_html = f"""
    <div style="display: grid; grid-template-columns: repeat(5, 1fr); gap: 15px; margin: 20px 0;">
        <div style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; border-radius: 10px; color: white; text-align: center;">
            <h3 style="margin: 0 0 10px 0; font-size: 1em;">GPU Name</h3>
            <h2 style="margin: 0; font-size: 1.2em;">{gpu_name}</h2>
        </div>
        <div style="background: linear-gradient(135deg, #fa709a 0%, #fee140 100%); padding: 20px; border-radius: 10px; color: white; text-align: center;">
            <h3 style="margin: 0 0 10px 0; font-size: 1em;">GPU Utilization</h3>
            <h2 style="margin: 0; font-size: 2em;">{utilization:.1f}%</h2>
        </div>
        <div style="background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); padding: 20px; border-radius: 10px; color: white; text-align: center;">
            <h3 style="margin: 0 0 10px 0; font-size: 1em;">GPU Memory</h3>
            <h2 style="margin: 0; font-size: 2em;">{memory_used:.0f} MiB</h2>
            <p style="margin: 5px 0 0 0; font-size: 0.8em; opacity: 0.9;">{memory_percent:.1f}% of {memory_total:.0f} MiB</p>
        </div>
        <div style="background: linear-gradient(135deg, #4facfe 0%, #00f2fe 100%); padding: 20px; border-radius: 10px; color: white; text-align: center;">
            <h3 style="margin: 0 0 10px 0; font-size: 1em;">GPU Temperature</h3>
            <h2 style="margin: 0; font-size: 2em;">{temperature:.0f}°C</h2>
        </div>
        <div style="background: linear-gradient(135deg, #43e97b 0%, #38f9d7 100%); padding: 20px; border-radius: 10px; color: white; text-align: center;">
            <h3 style="margin: 0 0 10px 0; font-size: 1em;">CO2 Emissions</h3>
            <h2 style="margin: 0; font-size: 2em;">{co2_emissions:.4f} g</h2>
            <p style="margin: 5px 0 0 0; font-size: 0.8em; opacity: 0.9;">Power: {power:.1f} W</p>
        </div>
    </div>
    """

    return cards_html


def create_gpu_metrics_dashboard(metrics_df):
    """
    Create a combined dashboard with GPU metric charts

    Args:
        metrics_df: DataFrame with flat metrics structure (from HuggingFace dataset)

    Returns:
        Plotly figure with GPU metrics time series
    """
    if metrics_df is None or metrics_df.empty:
        # Return empty figure with message
        fig = go.Figure()
        fig.add_annotation(
            text="No GPU metrics available (expected for API models)",
            xref="paper", yref="paper",
            x=0.5, y=0.5, xanchor='center', yanchor='middle',
            showarrow=False,
            font=dict(size=16)
        )
        return fig

    # Prepare data
    df = extract_metrics_data(metrics_df)

    if df.empty:
        return None

    # Create subplots for GPU metrics
    # We'll show: Utilization, Memory, Temperature, Power, CO2, Power Cost
    fig = make_subplots(
        rows=3, cols=2,
        subplot_titles=[
            'GPU Utilization (%)',
            'GPU Memory (MiB)',
            'GPU Temperature (°C)',
            'GPU Power (W)',
            'CO2 Emissions (g)',
            'Power Cost (USD)'
        ],
        vertical_spacing=0.10,
        horizontal_spacing=0.12,
        specs=[[{}, {}], [{}, {}], [{}, {}]]
    )

    colors = ['#667eea', '#f093fb', '#4facfe', '#FFE66D', '#43e97b', '#FF6B6B']

    # Define metrics to plot
    metrics_config = [
        ('gpu_utilization_percent', 'GPU Utilization (%)', 1, 1, colors[0]),
        ('gpu_memory_used_mib', 'GPU Memory (MiB)', 1, 2, colors[1]),
        ('gpu_temperature_celsius', 'GPU Temperature (°C)', 2, 1, colors[2]),
        ('gpu_power_watts', 'GPU Power (W)', 2, 2, colors[3]),
        ('co2_emissions_gco2e', 'CO2 Emissions (g)', 3, 1, colors[4]),
        ('power_cost_usd', 'Power Cost (USD)', 3, 2, colors[5]),
    ]

    for col_name, title, row, col, color in metrics_config:
        if col_name in df.columns:
            fig.add_trace(
                go.Scatter(
                    x=df['timestamp'],
                    y=df[col_name],
                    mode='lines+markers',
                    name=title,
                    line=dict(color=color, width=3),
                    marker=dict(size=6, color=color),
                    hovertemplate=(
                        f"<b>{title}</b><br>" +
                        "Time: %{x}<br>" +
                        "Value: %{y:.2f}<br>" +
                        "<extra></extra>"
                    )
                ),
                row=row, col=col
            )

    # Add memory total as a dashed line if available
    if 'gpu_memory_total_mib' in df.columns:
        total_memory = df['gpu_memory_total_mib'].iloc[0]
        fig.add_hline(
            y=total_memory,
            line_dash="dash",
            line_color="gray",
            annotation_text=f"Total: {total_memory:.0f} MiB",
            annotation_position="right",
            row=1, col=2
        )

    fig.update_layout(
        title_text="GPU Metrics Over Time",
        height=900,
        template="plotly_white",
        showlegend=False,
        hovermode='x unified'
    )

    # Update x-axes to show time format
    fig.update_xaxes(tickformat='%H:%M:%S')

    return fig