File size: 7,828 Bytes
c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 0a58567 c4b87d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import logging
import numpy as np
import pandas as pd
from gluonts.model.forecast import QuantileForecast
from src.data.frequency import parse_frequency
from src.plotting.plot_timeseries import (
plot_multivariate_timeseries,
)
logger = logging.getLogger(__name__)
def _prepare_data_for_plotting(input_data: dict, label_data: dict, max_context_length: int):
history_values = np.asarray(input_data["target"], dtype=np.float32)
future_values = np.asarray(label_data["target"], dtype=np.float32)
start_period = input_data["start"]
def ensure_time_first(arr: np.ndarray) -> np.ndarray:
if arr.ndim == 1:
return arr.reshape(-1, 1)
elif arr.ndim == 2:
if arr.shape[0] < arr.shape[1]:
return arr.T
return arr
else:
return arr.reshape(arr.shape[-1], -1).T
history_values = ensure_time_first(history_values)
future_values = ensure_time_first(future_values)
if max_context_length is not None and history_values.shape[0] > max_context_length:
history_values = history_values[-max_context_length:]
# Convert Period to Timestamp if needed
start_timestamp = (
start_period.to_timestamp() if hasattr(start_period, "to_timestamp") else pd.Timestamp(start_period)
)
return history_values, future_values, start_timestamp
def _extract_quantile_predictions(
forecast,
) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
def ensure_2d_time_first(arr):
if arr is None:
return None
arr = np.asarray(arr)
if arr.ndim == 1:
return arr.reshape(-1, 1)
elif arr.ndim == 2:
return arr
else:
return arr.reshape(arr.shape[0], -1)
if isinstance(forecast, QuantileForecast):
try:
median_pred = forecast.quantile(0.5)
try:
lower_bound = forecast.quantile(0.1)
upper_bound = forecast.quantile(0.9)
except (KeyError, ValueError):
lower_bound = None
upper_bound = None
median_pred = ensure_2d_time_first(median_pred)
lower_bound = ensure_2d_time_first(lower_bound)
upper_bound = ensure_2d_time_first(upper_bound)
return median_pred, lower_bound, upper_bound
except Exception:
try:
median_pred = forecast.quantile(0.5)
median_pred = ensure_2d_time_first(median_pred)
return median_pred, None, None
except Exception:
return None, None, None
else:
try:
samples = forecast.samples
if samples.ndim == 1:
median_pred = samples
elif samples.ndim == 2:
if samples.shape[0] == 1:
median_pred = samples[0]
else:
median_pred = np.median(samples, axis=0)
elif samples.ndim == 3:
median_pred = np.median(samples, axis=0)
else:
median_pred = samples[0] if len(samples) > 0 else samples
median_pred = ensure_2d_time_first(median_pred)
return median_pred, None, None
except Exception:
return None, None, None
def _create_plot(
input_data: dict,
label_data: dict,
forecast,
dataset_full_name: str,
dataset_freq: str,
max_context_length: int,
title: str | None = None,
):
try:
history_values, future_values, start_timestamp = _prepare_data_for_plotting(
input_data, label_data, max_context_length
)
median_pred, lower_bound, upper_bound = _extract_quantile_predictions(forecast)
if median_pred is None:
logger.warning(f"Could not extract predictions for {dataset_full_name}")
return None
def ensure_compatible_shape(pred_arr, target_arr):
if pred_arr is None:
return None
pred_arr = np.asarray(pred_arr)
target_arr = np.asarray(target_arr)
if pred_arr.ndim == 1:
pred_arr = pred_arr.reshape(-1, 1)
if target_arr.ndim == 1:
target_arr = target_arr.reshape(-1, 1)
if pred_arr.shape != target_arr.shape:
if pred_arr.shape[0] == target_arr.shape[0]:
if pred_arr.shape[1] == 1 and target_arr.shape[1] > 1:
pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
elif pred_arr.shape[1] > 1 and target_arr.shape[1] == 1:
pred_arr = pred_arr[:, :1]
elif pred_arr.shape[1] == target_arr.shape[1]:
min_time = min(pred_arr.shape[0], target_arr.shape[0])
pred_arr = pred_arr[:min_time]
else:
if pred_arr.T.shape == target_arr.shape:
pred_arr = pred_arr.T
else:
if pred_arr.size >= target_arr.shape[0]:
pred_arr = pred_arr.flatten()[: target_arr.shape[0]].reshape(-1, 1)
if target_arr.shape[1] > 1:
pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
return pred_arr
median_pred = ensure_compatible_shape(median_pred, future_values)
lower_bound = ensure_compatible_shape(lower_bound, future_values)
upper_bound = ensure_compatible_shape(upper_bound, future_values)
title = title or f"GIFT-Eval: {dataset_full_name}"
frequency = parse_frequency(dataset_freq)
fig = plot_multivariate_timeseries(
history_values=history_values,
future_values=future_values,
predicted_values=median_pred,
lower_bound=lower_bound,
upper_bound=upper_bound,
start=start_timestamp,
frequency=frequency,
title=title,
show=False,
)
return fig
except Exception as e:
logger.warning(f"Failed to create plot for {dataset_full_name}: {e}")
return None
def create_plots_for_dataset(
forecasts: list,
test_data,
dataset_metadata,
max_plots: int,
max_context_length: int,
) -> list[tuple[object, str]]:
input_data_list = list(test_data.input)
label_data_list = list(test_data.label)
num_plots = min(len(forecasts), max_plots)
logger.info(f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}")
figures_with_names: list[tuple[object, str]] = []
for i in range(num_plots):
try:
forecast = forecasts[i]
input_data = input_data_list[i]
label_data = label_data_list[i]
title = (
f"GIFT-Eval: {dataset_metadata.full_name} - Window {i + 1}/{num_plots}"
if hasattr(dataset_metadata, "full_name")
else f"Window {i + 1}/{num_plots}"
)
fig = _create_plot(
input_data=input_data,
label_data=label_data,
forecast=forecast,
dataset_full_name=getattr(dataset_metadata, "full_name", "dataset"),
dataset_freq=getattr(dataset_metadata, "freq", "D"),
max_context_length=max_context_length,
title=title,
)
if fig is not None:
filename = f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png"
figures_with_names.append((fig, filename))
except Exception as e:
logger.warning(f"Error creating plot for window {i + 1}: {e}")
continue
return figures_with_names
|