ttzzs's picture
Deploy Chronos2 Forecasting API v3.0.0 with new SOLID architecture
c40c447 verified
"""
Implementaci贸n concreta del modelo Chronos-2.
Este m贸dulo implementa la interfaz IForecastModel usando Chronos2Pipeline,
aplicando el principio DIP (Dependency Inversion Principle).
"""
from typing import List, Dict, Any
import pandas as pd
from chronos import Chronos2Pipeline
from app.domain.interfaces.forecast_model import IForecastModel
from app.utils.logger import setup_logger
logger = setup_logger(__name__)
class ChronosModel(IForecastModel):
"""
Implementaci贸n concreta de IForecastModel usando Chronos-2.
Esta clase puede ser reemplazada por otra implementaci贸n
(Prophet, ARIMA, etc.) sin modificar el resto del c贸digo,
gracias al principio DIP.
Attributes:
model_id: ID del modelo en HuggingFace
device_map: Dispositivo para inferencia (cpu/cuda)
pipeline: Pipeline de Chronos2
"""
def __init__(self, model_id: str = "amazon/chronos-2", device_map: str = "cpu"):
"""
Inicializa el modelo Chronos-2.
Args:
model_id: ID del modelo en HuggingFace
device_map: Dispositivo para inferencia (cpu/cuda)
"""
self.model_id = model_id
self.device_map = device_map
logger.info(f"Loading Chronos model: {model_id} on {device_map}")
try:
self.pipeline = Chronos2Pipeline.from_pretrained(
model_id,
device_map=device_map
)
logger.info("Chronos model loaded successfully")
except Exception as e:
logger.error(f"Failed to load Chronos model: {e}")
raise
def predict(
self,
context_df: pd.DataFrame,
prediction_length: int,
quantile_levels: List[float],
**kwargs
) -> pd.DataFrame:
"""
Genera pron贸sticos probabil铆sticos usando Chronos-2.
Args:
context_df: DataFrame con columnas [id, timestamp, target]
prediction_length: Horizonte de predicci贸n
quantile_levels: Cuantiles a calcular (ej: [0.1, 0.5, 0.9])
**kwargs: Argumentos adicionales para el pipeline
Returns:
DataFrame con pron贸sticos y cuantiles
Raises:
ValueError: Si el context_df no tiene el formato correcto
RuntimeError: Si falla la inferencia
"""
logger.debug(
f"Predicting {prediction_length} steps with "
f"{len(quantile_levels)} quantiles"
)
# Validar formato del DataFrame
required_cols = {"id", "timestamp", "target"}
if not required_cols.issubset(context_df.columns):
raise ValueError(
f"context_df debe tener columnas: {required_cols}. "
f"Encontradas: {set(context_df.columns)}"
)
try:
# Realizar predicci贸n
pred_df = self.pipeline.predict_df(
context_df,
prediction_length=prediction_length,
quantile_levels=quantile_levels,
id_column="id",
timestamp_column="timestamp",
target="target",
**kwargs
)
# Ordenar resultado
result = pred_df.sort_values(["id", "timestamp"])
logger.debug(f"Prediction completed: {len(result)} rows")
return result
except Exception as e:
logger.error(f"Prediction failed: {e}")
raise RuntimeError(f"Error en predicci贸n: {e}") from e
def get_model_info(self) -> Dict[str, Any]:
"""
Retorna informaci贸n del modelo.
Returns:
Diccionario con informaci贸n del modelo
"""
return {
"type": "Chronos2",
"model_id": self.model_id,
"device": self.device_map,
"provider": "Amazon",
"version": "2.0"
}
def __repr__(self) -> str:
return f"ChronosModel(model_id='{self.model_id}', device='{self.device_map}')"