File size: 7,828 Bytes
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
0a58567
c4b87d2
 
 
0a58567
c4b87d2
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import logging

import numpy as np
import pandas as pd
from gluonts.model.forecast import QuantileForecast

from src.data.frequency import parse_frequency
from src.plotting.plot_timeseries import (
    plot_multivariate_timeseries,
)

logger = logging.getLogger(__name__)


def _prepare_data_for_plotting(input_data: dict, label_data: dict, max_context_length: int):
    history_values = np.asarray(input_data["target"], dtype=np.float32)
    future_values = np.asarray(label_data["target"], dtype=np.float32)
    start_period = input_data["start"]

    def ensure_time_first(arr: np.ndarray) -> np.ndarray:
        if arr.ndim == 1:
            return arr.reshape(-1, 1)
        elif arr.ndim == 2:
            if arr.shape[0] < arr.shape[1]:
                return arr.T
            return arr
        else:
            return arr.reshape(arr.shape[-1], -1).T

    history_values = ensure_time_first(history_values)
    future_values = ensure_time_first(future_values)

    if max_context_length is not None and history_values.shape[0] > max_context_length:
        history_values = history_values[-max_context_length:]

    # Convert Period to Timestamp if needed
    start_timestamp = (
        start_period.to_timestamp() if hasattr(start_period, "to_timestamp") else pd.Timestamp(start_period)
    )
    return history_values, future_values, start_timestamp


def _extract_quantile_predictions(
    forecast,
) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
    def ensure_2d_time_first(arr):
        if arr is None:
            return None
        arr = np.asarray(arr)
        if arr.ndim == 1:
            return arr.reshape(-1, 1)
        elif arr.ndim == 2:
            return arr
        else:
            return arr.reshape(arr.shape[0], -1)

    if isinstance(forecast, QuantileForecast):
        try:
            median_pred = forecast.quantile(0.5)
            try:
                lower_bound = forecast.quantile(0.1)
                upper_bound = forecast.quantile(0.9)
            except (KeyError, ValueError):
                lower_bound = None
                upper_bound = None
            median_pred = ensure_2d_time_first(median_pred)
            lower_bound = ensure_2d_time_first(lower_bound)
            upper_bound = ensure_2d_time_first(upper_bound)
            return median_pred, lower_bound, upper_bound
        except Exception:
            try:
                median_pred = forecast.quantile(0.5)
                median_pred = ensure_2d_time_first(median_pred)
                return median_pred, None, None
            except Exception:
                return None, None, None
    else:
        try:
            samples = forecast.samples
            if samples.ndim == 1:
                median_pred = samples
            elif samples.ndim == 2:
                if samples.shape[0] == 1:
                    median_pred = samples[0]
                else:
                    median_pred = np.median(samples, axis=0)
            elif samples.ndim == 3:
                median_pred = np.median(samples, axis=0)
            else:
                median_pred = samples[0] if len(samples) > 0 else samples
            median_pred = ensure_2d_time_first(median_pred)
            return median_pred, None, None
        except Exception:
            return None, None, None


def _create_plot(
    input_data: dict,
    label_data: dict,
    forecast,
    dataset_full_name: str,
    dataset_freq: str,
    max_context_length: int,
    title: str | None = None,
):
    try:
        history_values, future_values, start_timestamp = _prepare_data_for_plotting(
            input_data, label_data, max_context_length
        )
        median_pred, lower_bound, upper_bound = _extract_quantile_predictions(forecast)
        if median_pred is None:
            logger.warning(f"Could not extract predictions for {dataset_full_name}")
            return None

        def ensure_compatible_shape(pred_arr, target_arr):
            if pred_arr is None:
                return None
            pred_arr = np.asarray(pred_arr)
            target_arr = np.asarray(target_arr)
            if pred_arr.ndim == 1:
                pred_arr = pred_arr.reshape(-1, 1)
            if target_arr.ndim == 1:
                target_arr = target_arr.reshape(-1, 1)
            if pred_arr.shape != target_arr.shape:
                if pred_arr.shape[0] == target_arr.shape[0]:
                    if pred_arr.shape[1] == 1 and target_arr.shape[1] > 1:
                        pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
                    elif pred_arr.shape[1] > 1 and target_arr.shape[1] == 1:
                        pred_arr = pred_arr[:, :1]
                elif pred_arr.shape[1] == target_arr.shape[1]:
                    min_time = min(pred_arr.shape[0], target_arr.shape[0])
                    pred_arr = pred_arr[:min_time]
                else:
                    if pred_arr.T.shape == target_arr.shape:
                        pred_arr = pred_arr.T
                    else:
                        if pred_arr.size >= target_arr.shape[0]:
                            pred_arr = pred_arr.flatten()[: target_arr.shape[0]].reshape(-1, 1)
                            if target_arr.shape[1] > 1:
                                pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
            return pred_arr

        median_pred = ensure_compatible_shape(median_pred, future_values)
        lower_bound = ensure_compatible_shape(lower_bound, future_values)
        upper_bound = ensure_compatible_shape(upper_bound, future_values)

        title = title or f"GIFT-Eval: {dataset_full_name}"
        frequency = parse_frequency(dataset_freq)
        fig = plot_multivariate_timeseries(
            history_values=history_values,
            future_values=future_values,
            predicted_values=median_pred,
            lower_bound=lower_bound,
            upper_bound=upper_bound,
            start=start_timestamp,
            frequency=frequency,
            title=title,
            show=False,
        )
        return fig
    except Exception as e:
        logger.warning(f"Failed to create plot for {dataset_full_name}: {e}")
        return None


def create_plots_for_dataset(
    forecasts: list,
    test_data,
    dataset_metadata,
    max_plots: int,
    max_context_length: int,
) -> list[tuple[object, str]]:
    input_data_list = list(test_data.input)
    label_data_list = list(test_data.label)
    num_plots = min(len(forecasts), max_plots)
    logger.info(f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}")

    figures_with_names: list[tuple[object, str]] = []
    for i in range(num_plots):
        try:
            forecast = forecasts[i]
            input_data = input_data_list[i]
            label_data = label_data_list[i]
            title = (
                f"GIFT-Eval: {dataset_metadata.full_name} - Window {i + 1}/{num_plots}"
                if hasattr(dataset_metadata, "full_name")
                else f"Window {i + 1}/{num_plots}"
            )
            fig = _create_plot(
                input_data=input_data,
                label_data=label_data,
                forecast=forecast,
                dataset_full_name=getattr(dataset_metadata, "full_name", "dataset"),
                dataset_freq=getattr(dataset_metadata, "freq", "D"),
                max_context_length=max_context_length,
                title=title,
            )
            if fig is not None:
                filename = f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png"
                figures_with_names.append((fig, filename))
        except Exception as e:
            logger.warning(f"Error creating plot for window {i + 1}: {e}")
            continue
    return figures_with_names