|
|
import logging |
|
|
import os |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import yaml |
|
|
from src.data.containers import BatchTimeSeriesContainer |
|
|
from src.models.model import TimeSeriesModel |
|
|
from src.plotting.plot_timeseries import plot_from_container |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def load_model(config_path: str, model_path: str, device: torch.device) -> TimeSeriesModel: |
|
|
"""Load the TimeSeriesModel from config and checkpoint.""" |
|
|
with open(config_path) as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device) |
|
|
checkpoint = torch.load(model_path, map_location=device) |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
model.eval() |
|
|
logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}") |
|
|
return model |
|
|
|
|
|
|
|
|
def plot_with_library( |
|
|
container: BatchTimeSeriesContainer, |
|
|
predictions_np: np.ndarray, |
|
|
model_quantiles: list[float] | None, |
|
|
output_dir: str = "outputs", |
|
|
show_plots: bool = True, |
|
|
save_plots: bool = True, |
|
|
): |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
batch_size = container.batch_size |
|
|
for i in range(batch_size): |
|
|
output_file = os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") if save_plots else None |
|
|
plot_from_container( |
|
|
batch=container, |
|
|
sample_idx=i, |
|
|
predicted_values=predictions_np, |
|
|
model_quantiles=model_quantiles, |
|
|
title=f"Sine Wave Time Series Prediction - Sample {i + 1}", |
|
|
output_file=output_file, |
|
|
show=show_plots, |
|
|
) |
|
|
|
|
|
|
|
|
def run_inference_and_plot( |
|
|
model: TimeSeriesModel, |
|
|
container: BatchTimeSeriesContainer, |
|
|
output_dir: str = "outputs", |
|
|
use_bfloat16: bool = True, |
|
|
) -> None: |
|
|
"""Run model inference with optional bfloat16 and plot using shared utilities.""" |
|
|
device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu" |
|
|
autocast_enabled = use_bfloat16 and device_type == "cuda" |
|
|
with ( |
|
|
torch.no_grad(), |
|
|
torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled), |
|
|
): |
|
|
model_output = model(container) |
|
|
|
|
|
preds_full = model_output["result"].to(torch.float32) |
|
|
if hasattr(model, "scaler") and "scale_statistics" in model_output: |
|
|
preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"]) |
|
|
|
|
|
preds_np = preds_full.detach().cpu().numpy() |
|
|
model_quantiles = model.quantiles if getattr(model, "loss_type", None) == "quantile" else None |
|
|
plot_with_library( |
|
|
container=container, |
|
|
predictions_np=preds_np, |
|
|
model_quantiles=model_quantiles, |
|
|
output_dir=output_dir, |
|
|
show_plots=True, |
|
|
save_plots=True, |
|
|
) |
|
|
|