import logging import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torchmetrics from matplotlib.figure import Figure from src.data.containers import BatchTimeSeriesContainer from src.data.frequency import Frequency logger = logging.getLogger(__name__) def calculate_smape(y_true: np.ndarray, y_pred: np.ndarray) -> float: """Calculate Symmetric Mean Absolute Percentage Error (SMAPE).""" pred_tensor = torch.from_numpy(y_pred).float() true_tensor = torch.from_numpy(y_true).float() return torchmetrics.SymmetricMeanAbsolutePercentageError()(pred_tensor, true_tensor).item() def _create_date_ranges( start: np.datetime64 | pd.Timestamp | None, frequency: Frequency | str | None, history_length: int, prediction_length: int, ) -> tuple[pd.DatetimeIndex, pd.DatetimeIndex]: """Create date ranges for history and future periods.""" if start is not None and frequency is not None: start_timestamp = pd.Timestamp(start) pandas_freq = frequency.to_pandas_freq(for_date_range=True) history_dates = pd.date_range(start=start_timestamp, periods=history_length, freq=pandas_freq) if prediction_length > 0: next_timestamp = history_dates[-1] + pd.tseries.frequencies.to_offset(pandas_freq) future_dates = pd.date_range(start=next_timestamp, periods=prediction_length, freq=pandas_freq) else: future_dates = pd.DatetimeIndex([]) else: # Fallback to default daily frequency history_dates = pd.date_range(end=pd.Timestamp.now(), periods=history_length, freq="D") if prediction_length > 0: future_dates = pd.date_range( start=history_dates[-1] + pd.Timedelta(days=1), periods=prediction_length, freq="D", ) else: future_dates = pd.DatetimeIndex([]) return history_dates, future_dates def _plot_single_channel( ax: plt.Axes, channel_idx: int, history_dates: pd.DatetimeIndex, future_dates: pd.DatetimeIndex, history_values: np.ndarray, future_values: np.ndarray | None = None, predicted_values: np.ndarray | None = None, lower_bound: np.ndarray | None = None, upper_bound: np.ndarray | None = None, ) -> None: """Plot a single channel's time series data.""" # Plot history ax.plot(history_dates, history_values[:, channel_idx], color="black", label="History") # Plot ground truth future if future_values is not None: ax.plot( future_dates, future_values[:, channel_idx], color="blue", label="Ground Truth", ) # Plot predictions if predicted_values is not None: ax.plot( future_dates, predicted_values[:, channel_idx], color="orange", linestyle="--", label="Prediction (Median)", ) # Plot uncertainty band if lower_bound is not None and upper_bound is not None: ax.fill_between( future_dates, lower_bound[:, channel_idx], upper_bound[:, channel_idx], color="orange", alpha=0.2, label="Uncertainty Band", ) ax.set_title(f"Channel {channel_idx + 1}") ax.grid(True, which="both", linestyle="--", linewidth=0.5) def _setup_figure(num_channels: int) -> tuple[Figure, list[plt.Axes]]: """Create and configure the matplotlib figure and axes.""" fig, axes = plt.subplots(num_channels, 1, figsize=(15, 3 * num_channels), sharex=True) if num_channels == 1: axes = [axes] return fig, axes def _finalize_plot( fig: Figure, axes: list[plt.Axes], title: str | None = None, smape_value: float | None = None, output_file: str | None = None, show: bool = True, ) -> None: """Add legend, title, and save/show the plot.""" # Create legend from first axis handles, labels = axes[0].get_legend_handles_labels() fig.legend(handles, labels, loc="upper right") # Set title with optional SMAPE if title: if smape_value is not None: title = f"{title} | SMAPE: {smape_value:.4f}" fig.suptitle(title, fontsize=16) # Adjust layout plt.tight_layout(rect=[0, 0.03, 1, 0.95] if title else None) # Save and/or show if output_file: plt.savefig(output_file, dpi=300) if show: plt.show() else: plt.close(fig) def plot_multivariate_timeseries( history_values: np.ndarray, future_values: np.ndarray | None = None, predicted_values: np.ndarray | None = None, start: np.datetime64 | pd.Timestamp | None = None, frequency: Frequency | str | None = None, title: str | None = None, output_file: str | None = None, show: bool = True, lower_bound: np.ndarray | None = None, upper_bound: np.ndarray | None = None, ) -> Figure: """Plot a multivariate time series with history, future, predictions, and uncertainty bands.""" # Calculate SMAPE if both predicted and true values are available smape_value = None if predicted_values is not None and future_values is not None: try: smape_value = calculate_smape(future_values, predicted_values) except Exception as e: logger.warning(f"Failed to calculate SMAPE: {str(e)}") # Extract dimensions num_channels = history_values.shape[1] history_length = history_values.shape[0] prediction_length = ( predicted_values.shape[0] if predicted_values is not None else (future_values.shape[0] if future_values is not None else 0) ) # Create date ranges history_dates, future_dates = _create_date_ranges(start, frequency, history_length, prediction_length) # Setup figure fig, axes = _setup_figure(num_channels) # Plot each channel for i in range(num_channels): _plot_single_channel( ax=axes[i], channel_idx=i, history_dates=history_dates, future_dates=future_dates, history_values=history_values, future_values=future_values, predicted_values=predicted_values, lower_bound=lower_bound, upper_bound=upper_bound, ) # Finalize plot _finalize_plot(fig, axes, title, smape_value, output_file, show) return fig def _extract_quantile_predictions( predicted_values: np.ndarray, model_quantiles: list[float], ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: """Extract median, lower, and upper bound predictions from quantile output.""" try: median_idx = model_quantiles.index(0.5) lower_idx = model_quantiles.index(0.1) upper_idx = model_quantiles.index(0.9) median_preds = predicted_values[..., median_idx] lower_bound = predicted_values[..., lower_idx] upper_bound = predicted_values[..., upper_idx] return median_preds, lower_bound, upper_bound except (ValueError, IndexError): logger.warning("Could not find 0.1, 0.5, 0.9 quantiles for plotting. Using median of available quantiles.") median_preds = predicted_values[..., predicted_values.shape[-1] // 2] return median_preds, None, None def plot_from_container( batch: BatchTimeSeriesContainer, sample_idx: int, predicted_values: np.ndarray | None = None, model_quantiles: list[float] | None = None, title: str | None = None, output_file: str | None = None, show: bool = True, ) -> Figure: """Plot a single sample from a BatchTimeSeriesContainer with proper quantile handling.""" # Extract data for the specific sample history_values = batch.history_values[sample_idx].cpu().numpy() future_values = batch.future_values[sample_idx].cpu().numpy() # Process predictions if predicted_values is not None: # Handle batch vs single sample predictions if predicted_values.ndim >= 3 or ( predicted_values.ndim == 2 and predicted_values.shape[0] > future_values.shape[0] ): sample_preds = predicted_values[sample_idx] else: sample_preds = predicted_values # Extract quantile information if available if model_quantiles: median_preds, lower_bound, upper_bound = _extract_quantile_predictions(sample_preds, model_quantiles) else: median_preds = sample_preds lower_bound = None upper_bound = None else: median_preds = None lower_bound = None upper_bound = None # Create the plot return plot_multivariate_timeseries( history_values=history_values, future_values=future_values, predicted_values=median_preds, start=batch.start[sample_idx], frequency=batch.frequency[sample_idx], title=title, output_file=output_file, show=show, lower_bound=lower_bound, upper_bound=upper_bound, )