TempoPFN / src /plotting /gift_eval_utils.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
import logging
import numpy as np
import pandas as pd
from gluonts.model.forecast import QuantileForecast
from src.data.frequency import parse_frequency
from src.plotting.plot_timeseries import (
plot_multivariate_timeseries,
)
logger = logging.getLogger(__name__)
def _prepare_data_for_plotting(input_data: dict, label_data: dict, max_context_length: int):
history_values = np.asarray(input_data["target"], dtype=np.float32)
future_values = np.asarray(label_data["target"], dtype=np.float32)
start_period = input_data["start"]
def ensure_time_first(arr: np.ndarray) -> np.ndarray:
if arr.ndim == 1:
return arr.reshape(-1, 1)
elif arr.ndim == 2:
if arr.shape[0] < arr.shape[1]:
return arr.T
return arr
else:
return arr.reshape(arr.shape[-1], -1).T
history_values = ensure_time_first(history_values)
future_values = ensure_time_first(future_values)
if max_context_length is not None and history_values.shape[0] > max_context_length:
history_values = history_values[-max_context_length:]
# Convert Period to Timestamp if needed
start_timestamp = (
start_period.to_timestamp() if hasattr(start_period, "to_timestamp") else pd.Timestamp(start_period)
)
return history_values, future_values, start_timestamp
def _extract_quantile_predictions(
forecast,
) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]:
def ensure_2d_time_first(arr):
if arr is None:
return None
arr = np.asarray(arr)
if arr.ndim == 1:
return arr.reshape(-1, 1)
elif arr.ndim == 2:
return arr
else:
return arr.reshape(arr.shape[0], -1)
if isinstance(forecast, QuantileForecast):
try:
median_pred = forecast.quantile(0.5)
try:
lower_bound = forecast.quantile(0.1)
upper_bound = forecast.quantile(0.9)
except (KeyError, ValueError):
lower_bound = None
upper_bound = None
median_pred = ensure_2d_time_first(median_pred)
lower_bound = ensure_2d_time_first(lower_bound)
upper_bound = ensure_2d_time_first(upper_bound)
return median_pred, lower_bound, upper_bound
except Exception:
try:
median_pred = forecast.quantile(0.5)
median_pred = ensure_2d_time_first(median_pred)
return median_pred, None, None
except Exception:
return None, None, None
else:
try:
samples = forecast.samples
if samples.ndim == 1:
median_pred = samples
elif samples.ndim == 2:
if samples.shape[0] == 1:
median_pred = samples[0]
else:
median_pred = np.median(samples, axis=0)
elif samples.ndim == 3:
median_pred = np.median(samples, axis=0)
else:
median_pred = samples[0] if len(samples) > 0 else samples
median_pred = ensure_2d_time_first(median_pred)
return median_pred, None, None
except Exception:
return None, None, None
def _create_plot(
input_data: dict,
label_data: dict,
forecast,
dataset_full_name: str,
dataset_freq: str,
max_context_length: int,
title: str | None = None,
):
try:
history_values, future_values, start_timestamp = _prepare_data_for_plotting(
input_data, label_data, max_context_length
)
median_pred, lower_bound, upper_bound = _extract_quantile_predictions(forecast)
if median_pred is None:
logger.warning(f"Could not extract predictions for {dataset_full_name}")
return None
def ensure_compatible_shape(pred_arr, target_arr):
if pred_arr is None:
return None
pred_arr = np.asarray(pred_arr)
target_arr = np.asarray(target_arr)
if pred_arr.ndim == 1:
pred_arr = pred_arr.reshape(-1, 1)
if target_arr.ndim == 1:
target_arr = target_arr.reshape(-1, 1)
if pred_arr.shape != target_arr.shape:
if pred_arr.shape[0] == target_arr.shape[0]:
if pred_arr.shape[1] == 1 and target_arr.shape[1] > 1:
pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
elif pred_arr.shape[1] > 1 and target_arr.shape[1] == 1:
pred_arr = pred_arr[:, :1]
elif pred_arr.shape[1] == target_arr.shape[1]:
min_time = min(pred_arr.shape[0], target_arr.shape[0])
pred_arr = pred_arr[:min_time]
else:
if pred_arr.T.shape == target_arr.shape:
pred_arr = pred_arr.T
else:
if pred_arr.size >= target_arr.shape[0]:
pred_arr = pred_arr.flatten()[: target_arr.shape[0]].reshape(-1, 1)
if target_arr.shape[1] > 1:
pred_arr = np.broadcast_to(pred_arr, target_arr.shape)
return pred_arr
median_pred = ensure_compatible_shape(median_pred, future_values)
lower_bound = ensure_compatible_shape(lower_bound, future_values)
upper_bound = ensure_compatible_shape(upper_bound, future_values)
title = title or f"GIFT-Eval: {dataset_full_name}"
frequency = parse_frequency(dataset_freq)
fig = plot_multivariate_timeseries(
history_values=history_values,
future_values=future_values,
predicted_values=median_pred,
lower_bound=lower_bound,
upper_bound=upper_bound,
start=start_timestamp,
frequency=frequency,
title=title,
show=False,
)
return fig
except Exception as e:
logger.warning(f"Failed to create plot for {dataset_full_name}: {e}")
return None
def create_plots_for_dataset(
forecasts: list,
test_data,
dataset_metadata,
max_plots: int,
max_context_length: int,
) -> list[tuple[object, str]]:
input_data_list = list(test_data.input)
label_data_list = list(test_data.label)
num_plots = min(len(forecasts), max_plots)
logger.info(f"Creating {num_plots} plots for {getattr(dataset_metadata, 'full_name', str(dataset_metadata))}")
figures_with_names: list[tuple[object, str]] = []
for i in range(num_plots):
try:
forecast = forecasts[i]
input_data = input_data_list[i]
label_data = label_data_list[i]
title = (
f"GIFT-Eval: {dataset_metadata.full_name} - Window {i + 1}/{num_plots}"
if hasattr(dataset_metadata, "full_name")
else f"Window {i + 1}/{num_plots}"
)
fig = _create_plot(
input_data=input_data,
label_data=label_data,
forecast=forecast,
dataset_full_name=getattr(dataset_metadata, "full_name", "dataset"),
dataset_freq=getattr(dataset_metadata, "freq", "D"),
max_context_length=max_context_length,
title=title,
)
if fig is not None:
filename = f"{getattr(dataset_metadata, 'freq', 'D')}_window_{i + 1:03d}.png"
figures_with_names.append((fig, filename))
except Exception as e:
logger.warning(f"Error creating plot for window {i + 1}: {e}")
continue
return figures_with_names