File size: 2,780 Bytes
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a58567
c4b87d2
 
 
 
 
0a58567
c4b87d2
 
0a58567
c4b87d2
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import logging
import os

import numpy as np
import torch
import yaml
from src.data.containers import BatchTimeSeriesContainer
from src.models.model import TimeSeriesModel
from src.plotting.plot_timeseries import plot_from_container

logger = logging.getLogger(__name__)


def load_model(config_path: str, model_path: str, device: torch.device) -> TimeSeriesModel:
    """Load the TimeSeriesModel from config and checkpoint."""
    with open(config_path) as f:
        config = yaml.safe_load(f)

    model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()
    logger.info(f"Successfully loaded TimeSeriesModel from {model_path} on {device}")
    return model


def plot_with_library(
    container: BatchTimeSeriesContainer,
    predictions_np: np.ndarray,  # [B, P, N, Q]
    model_quantiles: list[float] | None,
    output_dir: str = "outputs",
    show_plots: bool = True,
    save_plots: bool = True,
):
    os.makedirs(output_dir, exist_ok=True)
    batch_size = container.batch_size
    for i in range(batch_size):
        output_file = os.path.join(output_dir, f"sine_wave_prediction_sample_{i + 1}.png") if save_plots else None
        plot_from_container(
            batch=container,
            sample_idx=i,
            predicted_values=predictions_np,
            model_quantiles=model_quantiles,
            title=f"Sine Wave Time Series Prediction - Sample {i + 1}",
            output_file=output_file,
            show=show_plots,
        )


def run_inference_and_plot(
    model: TimeSeriesModel,
    container: BatchTimeSeriesContainer,
    output_dir: str = "outputs",
    use_bfloat16: bool = True,
) -> None:
    """Run model inference with optional bfloat16 and plot using shared utilities."""
    device_type = "cuda" if (container.history_values.device.type == "cuda") else "cpu"
    autocast_enabled = use_bfloat16 and device_type == "cuda"
    with (
        torch.no_grad(),
        torch.autocast(device_type=device_type, dtype=torch.bfloat16, enabled=autocast_enabled),
    ):
        model_output = model(container)

    preds_full = model_output["result"].to(torch.float32)
    if hasattr(model, "scaler") and "scale_statistics" in model_output:
        preds_full = model.scaler.inverse_scale(preds_full, model_output["scale_statistics"])

    preds_np = preds_full.detach().cpu().numpy()
    model_quantiles = model.quantiles if getattr(model, "loss_type", None) == "quantile" else None
    plot_with_library(
        container=container,
        predictions_np=preds_np,
        model_quantiles=model_quantiles,
        output_dir=output_dir,
        show_plots=True,
        save_plots=True,
    )