""" Implementación concreta del modelo Chronos-2. Este módulo implementa la interfaz IForecastModel usando Chronos2Pipeline, aplicando el principio DIP (Dependency Inversion Principle). """ from typing import List, Dict, Any import pandas as pd from chronos import Chronos2Pipeline from app.domain.interfaces.forecast_model import IForecastModel from app.utils.logger import setup_logger logger = setup_logger(__name__) class ChronosModel(IForecastModel): """ Implementación concreta de IForecastModel usando Chronos-2. Esta clase puede ser reemplazada por otra implementación (Prophet, ARIMA, etc.) sin modificar el resto del código, gracias al principio DIP. Attributes: model_id: ID del modelo en HuggingFace device_map: Dispositivo para inferencia (cpu/cuda) pipeline: Pipeline de Chronos2 """ def __init__(self, model_id: str = "amazon/chronos-2", device_map: str = "cpu"): """ Inicializa el modelo Chronos-2. Args: model_id: ID del modelo en HuggingFace device_map: Dispositivo para inferencia (cpu/cuda) """ self.model_id = model_id self.device_map = device_map logger.info(f"Loading Chronos model: {model_id} on {device_map}") try: self.pipeline = Chronos2Pipeline.from_pretrained( model_id, device_map=device_map ) logger.info("Chronos model loaded successfully") except Exception as e: logger.error(f"Failed to load Chronos model: {e}") raise def predict( self, context_df: pd.DataFrame, prediction_length: int, quantile_levels: List[float], **kwargs ) -> pd.DataFrame: """ Genera pronósticos probabilísticos usando Chronos-2. Args: context_df: DataFrame con columnas [id, timestamp, target] prediction_length: Horizonte de predicción quantile_levels: Cuantiles a calcular (ej: [0.1, 0.5, 0.9]) **kwargs: Argumentos adicionales para el pipeline Returns: DataFrame con pronósticos y cuantiles Raises: ValueError: Si el context_df no tiene el formato correcto RuntimeError: Si falla la inferencia """ logger.debug( f"Predicting {prediction_length} steps with " f"{len(quantile_levels)} quantiles" ) # Validar formato del DataFrame required_cols = {"id", "timestamp", "target"} if not required_cols.issubset(context_df.columns): raise ValueError( f"context_df debe tener columnas: {required_cols}. " f"Encontradas: {set(context_df.columns)}" ) try: # Realizar predicción pred_df = self.pipeline.predict_df( context_df, prediction_length=prediction_length, quantile_levels=quantile_levels, id_column="id", timestamp_column="timestamp", target="target", **kwargs ) # Ordenar resultado result = pred_df.sort_values(["id", "timestamp"]) logger.debug(f"Prediction completed: {len(result)} rows") return result except Exception as e: logger.error(f"Prediction failed: {e}") raise RuntimeError(f"Error en predicción: {e}") from e def get_model_info(self) -> Dict[str, Any]: """ Retorna información del modelo. Returns: Diccionario con información del modelo """ return { "type": "Chronos2", "model_id": self.model_id, "device": self.device_map, "provider": "Amazon", "version": "2.0" } def __repr__(self) -> str: return f"ChronosModel(model_id='{self.model_id}', device='{self.device_map}')"