|
|
import logging |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from gluonts.model.forecast import QuantileForecast |
|
|
|
|
|
from src.data.frequency import parse_frequency |
|
|
from src.plotting.plot_timeseries import ( |
|
|
plot_multivariate_timeseries, |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _prepare_data_for_plotting(input_data: dict, label_data: dict, max_context_length: int): |
|
|
history_values = np.asarray(input_data["target"], dtype=np.float32) |
|
|
future_values = np.asarray(label_data["target"], dtype=np.float32) |
|
|
start_period = input_data["start"] |
|
|
|
|
|
def ensure_time_first(arr: np.ndarray) -> np.ndarray: |
|
|
if arr.ndim == 1: |
|
|
return arr.reshape(-1, 1) |
|
|
elif arr.ndim == 2: |
|
|
if arr.shape[0] < arr.shape[1]: |
|
|
return arr.T |
|
|
return arr |
|
|
else: |
|
|
return arr.reshape(arr.shape[-1], -1).T |
|
|
|
|
|
history_values = ensure_time_first(history_values) |
|
|
future_values = ensure_time_first(future_values) |
|
|
|
|
|
if max_context_length is not None and history_values.shape[0] > max_context_length: |
|
|
history_values = history_values[-max_context_length:] |
|
|
|
|
|
|
|
|
start_timestamp = ( |
|
|
start_period.to_timestamp() if hasattr(start_period, "to_timestamp") else pd.Timestamp(start_period) |
|
|
) |
|
|
return history_values, future_values, start_timestamp |
|
|
|
|
|
|
|
|
def _extract_quantile_predictions( |
|
|
forecast, |
|
|
) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: |
|
|
def ensure_2d_time_first(arr): |
|
|
if arr is None: |
|
|
return None |
|
|
arr = np.asarray(arr) |
|
|
if arr.ndim == 1: |
|
|
return arr.reshape(-1, 1) |
|
|
elif arr.ndim == 2: |
|
|
return arr |
|
|
else: |
|
|
return arr.reshape(arr.shape[0], -1) |
|
|
|
|
|
if isinstance(forecast, QuantileForecast): |
|
|
try: |
|
|
median_pred = forecast.quantile(0.5) |
|
|
try: |
|
|
lower_bound = forecast.quantile(0.1) |
|
|
upper_bound = forecast.quantile(0.9) |
|
|
except (KeyError, ValueError): |
|
|
lower_bound = None |
|
|
upper_bound = None |
|
|
median_pred = ensure_2d_time_first(median_pred) |
|
|
lower_bound = ensure_2d_time_first(lower_bound) |
|
|
upper_bound = ensure_2d_time_first(upper_bound) |
|
|
return median_pred, lower_bound, upper_bound |
|
|
except Exception: |
|
|
try: |
|
|
median_pred = forecast.quantile(0.5) |
|
|
median_pred = ensure_2d_time_first(median_pred) |
|
|
return median_pred, None, None |
|
|
except Exception: |
|
|
return None, None, None |
|
|
else: |
|
|
try: |
|
|
samples = forecast.samples |
|
|
if samples.ndim == 1: |
|
|
median_pred = samples |
|
|
elif samples.ndim == 2: |
|
|
if samples.shape[0] == 1: |
|
|
median_pred = samples[0] |
|
|
else: |
|
|
median_pred = np.median(samples, axis=0) |
|
|
elif samples.ndim == 3: |
|
|
median_pred = np.median(samples, axis=0) |
|
|
else: |
|
|
median_pred = samples[0] if len(samples) > 0 else samples |
|
|
median_pred = ensure_2d_time_first(median_pred) |
|
|
return median_pred, None, None |
|
|
except Exception: |
|
|
return None, None, None |
|
|
|
|
|
|
|
|
def _create_plot( |
|
|
input_data: dict, |
|
|
label_data: dict, |
|
|
forecast, |
|
|
dataset_full_name: str, |
|
|
dataset_freq: str, |
|
|
max_context_length: int, |
|
|
title: str | None = None, |
|
|
): |
|
|
try: |
|
|
history_values, future_values, start_timestamp = _prepare_data_for_plotting( |
|
|
input_data, label_data, max_context_length |
|
|
) |
|
|
median_pred, lower_bound, upper_bound = _extract_quantile_predictions(forecast) |
|
|
if median_pred is None: |
|
|
logger.warning(f"Could not extract predictions for {dataset_full_name}") |
|
|
return None |
|
|
|
|
|
def ensure_compatible_shape(pred_arr, target_arr): |
|
|
if pred_arr is None: |
|
|
return None |
|
|
pred_arr = np.asarray(pred_arr) |
|
|
target_arr = np.asarray(target_arr) |
|
|
if pred_arr.ndim == 1: |
|
|
pred_arr = pred_arr.reshape(-1, 1) |
|
|
if target_arr.ndim == 1: |
|
|
target_arr = target_arr.reshape(-1, 1) |
|
|
if pred_arr.shape != target_arr.shape: |
|
|
if pred_arr.shape[0] == target_arr.shape[0]: |
|
|
if pred_arr.shape[1] == 1 and target_arr.shape[1] > 1: |
|
|
pred_arr = np.broadcast_to(pred_arr, target_arr.shape) |
|
|
elif pred_arr.shape[1] > 1 and target_arr.shape[1] == 1: |
|
|
pred_arr = pred_arr[:, :1] |
|
|
elif pred_arr.shape[1] == target_arr.shape[1]: |
|
|
min_time = min(pred_arr.shape[0], target_arr.shape[0]) |
|
|
pred_arr = pred_arr[:min_time] |
|
|
else: |
|
|
if pred_arr.T.shape == target_arr.shape: |
|
|
pred_arr = pred_arr.T |
|
|
else: |
|
|
if pred_arr.size >= target_arr.shape[0]: |
|
|
pred_arr = pred_arr.flatten()[: target_arr.shape[0]].reshape(-1, 1) |
|
|
if target_arr.shape[1] > 1: |
|
|
pred_arr = np.broadcast_to(pred_arr, target_arr.shape) |
|
|
return pred_arr |
|
|
|
|
|
median_pred = ensure_compatible_shape(median_pred, future_values) |
|
|
lower_bound = ensure_compatible_shape(lower_bound, future_values) |
|
|
upper_bound = ensure_compatible_shape(upper_bound, future_values) |
|
|
|
|
|
title = title or f"GIFT-Eval: {dataset_full_name}" |
|
|
frequency = parse_frequency(dataset_freq) |
|
|
fig = plot_multivariate_timeseries( |
|
|
history_values=history_values, |
|
|
future_values=future_values, |
|
|
predicted_values=median_pred, |
|
|
lower_bound=lower_bound, |
|
|
upper_bound=upper_bound, |
|
|
start=start_timestamp, |
|
|
frequency=frequency, |
|
|
title=title, |
|
|
show=False, |
|
|
) |
|
|
return fig |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to create plot for {dataset_full_name}: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def create_plots_for_dataset( |
|
|
forecasts: list, |
|
|
test_data, |
|
|
dataset_metadata, |
|
|
max_plots: int, |
|
|
max_context_length: int, |
|
|
) -> list[tuple[object, str]]: |
|
|
input_data_list = list(test_data.input) |
|
|
label_data_list = list(test_data.label) |
|
|
num_plots = min(len(forecasts), max_plots) |
|
|
logger.info(f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}") |
|
|
|
|
|
figures_with_names: list[tuple[object, str]] = [] |
|
|
for i in range(num_plots): |
|
|
try: |
|
|
forecast = forecasts[i] |
|
|
input_data = input_data_list[i] |
|
|
label_data = label_data_list[i] |
|
|
title = ( |
|
|
f"GIFT-Eval: {dataset_metadata.full_name} - Window {i + 1}/{num_plots}" |
|
|
if hasattr(dataset_metadata, "full_name") |
|
|
else f"Window {i + 1}/{num_plots}" |
|
|
) |
|
|
fig = _create_plot( |
|
|
input_data=input_data, |
|
|
label_data=label_data, |
|
|
forecast=forecast, |
|
|
dataset_full_name=getattr(dataset_metadata, "full_name", "dataset"), |
|
|
dataset_freq=getattr(dataset_metadata, "freq", "D"), |
|
|
max_context_length=max_context_length, |
|
|
title=title, |
|
|
) |
|
|
if fig is not None: |
|
|
filename = f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png" |
|
|
figures_with_names.append((fig, filename)) |
|
|
except Exception as e: |
|
|
logger.warning(f"Error creating plot for window {i + 1}: {e}") |
|
|
continue |
|
|
return figures_with_names |
|
|
|