File size: 5,250 Bytes
c40c447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
Factory para crear modelos de forecasting.

Este m贸dulo implementa el patr贸n Factory aplicando OCP
(Open/Closed Principle) - abierto para extensi贸n, cerrado para modificaci贸n.
"""

from typing import Dict, Type, List

from app.domain.interfaces.forecast_model import IForecastModel
from app.infrastructure.ml.chronos_model import ChronosModel
from app.utils.logger import setup_logger

logger = setup_logger(__name__)


class ModelFactory:
    """
    Factory para crear modelos de forecasting.
    
    Permite agregar nuevos modelos sin modificar c贸digo existente,
    aplicando el principio OCP (Open/Closed Principle).
    
    Ejemplo de uso:
        >>> model = ModelFactory.create("chronos2", model_id="amazon/chronos-2")
        >>> # Futuro: model = ModelFactory.create("prophet", ...)
    """
    
    # Registro de modelos disponibles
    _models: Dict[str, Type[IForecastModel]] = {
        "chronos2": ChronosModel,
        # Futuro: Agregar sin modificar c贸digo existente
        # "prophet": ProphetModel,
        # "arima": ARIMAModel,
        # "custom": CustomModel,
    }
    
    @classmethod
    def create(
        cls,
        model_type: str,
        **kwargs
    ) -> IForecastModel:
        """
        Crea una instancia de modelo de forecasting.
        
        Args:
            model_type: Tipo de modelo ("chronos2", "prophet", etc.)
            **kwargs: Par谩metros espec铆ficos del modelo
            
        Returns:
            Instancia de IForecastModel
            
        Raises:
            ValueError: Si el tipo de modelo no existe
            
        Example:
            >>> model = ModelFactory.create(
            ...     "chronos2",
            ...     model_id="amazon/chronos-2",
            ...     device_map="cpu"
            ... )
        """
        if model_type not in cls._models:
            available = ", ".join(cls._models.keys())
            raise ValueError(
                f"Unknown model type: '{model_type}'. "
                f"Available: {available}"
            )
        
        model_class = cls._models[model_type]
        logger.info(f"Creating model: {model_type}")
        
        try:
            instance = model_class(**kwargs)
            logger.info(f"Model created: {instance}")
            return instance
        except Exception as e:
            logger.error(f"Failed to create model '{model_type}': {e}")
            raise
    
    @classmethod
    def register_model(
        cls,
        name: str,
        model_class: Type[IForecastModel]
    ) -> None:
        """
        Registra un nuevo tipo de modelo (OCP - extensi贸n).
        
        Permite agregar nuevos modelos din谩micamente sin modificar
        el c贸digo de la factory.
        
        Args:
            name: Nombre del modelo
            model_class: Clase que implementa IForecastModel
            
        Raises:
            TypeError: Si model_class no implementa IForecastModel
            ValueError: Si el nombre ya est谩 registrado
            
        Example:
            >>> class MyCustomModel(IForecastModel):
            ...     pass
            >>> ModelFactory.register_model("custom", MyCustomModel)
        """
        # Validar que implementa la interfaz
        if not issubclass(model_class, IForecastModel):
            raise TypeError(
                f"{model_class.__name__} debe implementar IForecastModel"
            )
        
        # Validar que no est茅 duplicado
        if name in cls._models:
            raise ValueError(
                f"Model '{name}' ya est谩 registrado. "
                f"Use un nombre diferente o llame a unregister_model primero."
            )
        
        cls._models[name] = model_class
        logger.info(f"Registered new model: {name} -> {model_class.__name__}")
    
    @classmethod
    def unregister_model(cls, name: str) -> None:
        """
        Elimina un modelo del registro.
        
        Args:
            name: Nombre del modelo a eliminar
            
        Raises:
            ValueError: Si el modelo no existe
        """
        if name not in cls._models:
            raise ValueError(f"Model '{name}' no est谩 registrado")
        
        del cls._models[name]
        logger.info(f"Unregistered model: {name}")
    
    @classmethod
    def list_available_models(cls) -> List[str]:
        """
        Lista todos los modelos disponibles.
        
        Returns:
            Lista de nombres de modelos
        """
        return list(cls._models.keys())
    
    @classmethod
    def get_model_info(cls, model_type: str) -> Dict[str, str]:
        """
        Obtiene informaci贸n sobre un tipo de modelo.
        
        Args:
            model_type: Nombre del tipo de modelo
            
        Returns:
            Diccionario con informaci贸n del modelo
            
        Raises:
            ValueError: Si el modelo no existe
        """
        if model_type not in cls._models:
            raise ValueError(f"Model '{model_type}' no est谩 registrado")
        
        model_class = cls._models[model_type]
        return {
            "name": model_type,
            "class": model_class.__name__,
            "module": model_class.__module__
        }