TempoPFN / examples /utils.py
Vladyslav Moroshan
Apply ruff formatting
0a58567
import logging
import os
import numpy as np
import torch
import yaml
from src.data.containers import BatchTimeSeriesContainer
from src.models.model import TimeSeriesModel
from src.plotting.plot_timeseries import plot_from_container
logger = logging.getLogger(__name__)
def load_model(config_path: str, model_path: str, device: torch.device) -> TimeSeriesModel:
"""Load the TimeSeriesModel from config and checkpoint."""
with open(config_path) as f:
config = yaml.safe_load(f)
model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}")
return model
def plot_with_library(
container: BatchTimeSeriesContainer,
predictions_np: np.ndarray, # [B, P, N, Q]
model_quantiles: list[float] | None,
output_dir: str = "outputs",
show_plots: bool = True,
save_plots: bool = True,
):
os.makedirs(output_dir, exist_ok=True)
batch_size = container.batch_size
for i in range(batch_size):
output_file = os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") if save_plots else None
plot_from_container(
batch=container,
sample_idx=i,
predicted_values=predictions_np,
model_quantiles=model_quantiles,
title=f"Sine Wave Time Series Prediction - Sample {i + 1}",
output_file=output_file,
show=show_plots,
)
def run_inference_and_plot(
model: TimeSeriesModel,
container: BatchTimeSeriesContainer,
output_dir: str = "outputs",
use_bfloat16: bool = True,
) -> None:
"""Run model inference with optional bfloat16 and plot using shared utilities."""
device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu"
autocast_enabled = use_bfloat16 and device_type == "cuda"
with (
torch.no_grad(),
torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled),
):
model_output = model(container)
preds_full = model_output["result"].to(torch.float32)
if hasattr(model, "scaler") and "scale_statistics" in model_output:
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])
preds_np = preds_full.detach().cpu().numpy()
model_quantiles = model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
plot_with_library(
container=container,
predictions_np=preds_np,
model_quantiles=model_quantiles,
output_dir=output_dir,
show_plots=True,
save_plots=True,
)