ttzzs commited on
Commit
19c8775
verified
1 Parent(s): 9fb828f

Fix: Use local model loading (HF Inference API doesn't support Chronos)

Browse files
Files changed (2) hide show
  1. app/main.py +529 -281
  2. requirements.txt +2 -2
app/main.py CHANGED
@@ -1,395 +1,643 @@
1
  import os
2
  from typing import List, Dict, Optional
3
- import json
4
 
5
  import numpy as np
6
  import pandas as pd
7
  from fastapi import FastAPI, HTTPException
8
  from fastapi.middleware.cors import CORSMiddleware
9
  from pydantic import BaseModel, Field
10
- from huggingface_hub import InferenceClient
 
11
 
12
 
13
  # =========================
14
- # Configuraci贸n
15
  # =========================
16
 
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
- MODEL_ID = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-t5-large")
19
 
20
  app = FastAPI(
21
- title="Chronos-2 Forecasting API (HF Inference)",
22
  description=(
23
- "API de pron贸sticos usando Chronos-2 via Hugging Face Inference API. "
24
- "Compatible con Excel Add-in."
25
  ),
26
  version="1.0.0",
27
  )
28
 
29
- # Configurar CORS
30
  app.add_middleware(
31
  CORSMiddleware,
32
- allow_origins=["*"], # En producci贸n, especificar dominios permitidos
33
  allow_credentials=True,
34
  allow_methods=["*"],
35
  allow_headers=["*"],
36
  )
37
 
38
- # Cliente de HF Inference
39
- if not HF_TOKEN:
40
- print("鈿狅笍 WARNING: HF_TOKEN no configurado. La API puede no funcionar correctamente.")
41
- print(" Configura HF_TOKEN en las variables de entorno del Space.")
42
- client = None
43
- else:
44
- client = InferenceClient(token=HF_TOKEN)
45
 
46
 
47
  # =========================
48
- # Modelos Pydantic
49
  # =========================
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  class UnivariateSeries(BaseModel):
52
  values: List[float]
53
 
54
 
55
- class ForecastUnivariateRequest(BaseModel):
56
- series: UnivariateSeries
57
- prediction_length: int = Field(7, description="N煤mero de pasos a predecir")
58
- quantile_levels: Optional[List[float]] = Field(
59
- default=[0.1, 0.5, 0.9],
60
- description="Cuantiles para intervalos de confianza"
 
 
 
 
 
 
 
 
 
61
  )
62
- freq: str = Field("D", description="Frecuencia temporal (D, W, M, etc.)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
 
65
  class ForecastUnivariateResponse(BaseModel):
66
  timestamps: List[str]
67
  median: List[float]
68
- quantiles: Dict[str, List[float]]
69
 
70
 
71
- class AnomalyDetectionRequest(BaseModel):
72
- context: UnivariateSeries
73
- recent_observed: List[float]
74
- prediction_length: int = 7
75
- quantile_low: float = 0.05
76
- quantile_high: float = 0.95
 
 
 
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- class AnomalyPoint(BaseModel):
80
- index: int
81
- value: float
82
- predicted_median: float
83
- lower: float
84
- upper: float
85
- is_anomaly: bool
 
86
 
 
 
 
87
 
88
- class AnomalyDetectionResponse(BaseModel):
89
- anomalies: List[AnomalyPoint]
 
 
 
90
 
 
 
 
 
 
91
 
92
- class BacktestRequest(BaseModel):
93
- series: UnivariateSeries
94
- prediction_length: int = 7
95
- test_length: int = 28
96
 
 
 
 
97
 
98
- class BacktestMetrics(BaseModel):
99
- mae: float
100
- mape: float
101
- rmse: float
102
 
103
 
104
- class BacktestResponse(BaseModel):
105
- metrics: BacktestMetrics
106
- forecast_median: List[float]
107
- forecast_timestamps: List[str]
108
- actuals: List[float]
109
 
110
 
111
- # =========================
112
- # Funci贸n auxiliar para llamar a HF Inference
113
- # =========================
114
 
115
- def call_chronos_inference(series: List[float], prediction_length: int) -> Dict:
 
116
  """
117
- Llama a la API de Hugging Face Inference para Chronos.
118
- Retorna un diccionario con las predicciones.
119
  """
120
- if client is None:
121
- raise HTTPException(
122
- status_code=503,
123
- detail="HF_TOKEN no configurado. Contacta al administrador del servicio."
124
- )
125
-
126
- try:
127
- # Intentar usando el endpoint espec铆fico de time series
128
- import requests
129
-
130
- url = f"https://router.huggingface.co/hf-inference/models/{MODEL_ID}"
131
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
132
-
133
- payload = {
134
- "inputs": series,
135
- "parameters": {
136
- "prediction_length": prediction_length,
137
- "num_samples": 100 # Para obtener cuantiles
138
- }
139
- }
140
-
141
- response = requests.post(url, headers=headers, json=payload, timeout=60)
142
-
143
- if response.status_code == 503:
144
- raise HTTPException(
145
- status_code=503,
146
- detail="El modelo est谩 cargando. Por favor, intenta de nuevo en 30-60 segundos."
147
  )
148
- elif response.status_code != 200:
149
- raise HTTPException(
150
- status_code=response.status_code,
151
- detail=f"Error de la API de HuggingFace: {response.text}"
 
 
 
 
 
 
152
  )
153
-
154
- result = response.json()
155
- return result
156
-
157
- except requests.exceptions.Timeout:
158
- raise HTTPException(
159
- status_code=504,
160
- detail="Timeout al comunicarse con HuggingFace API. El modelo puede estar cargando."
161
  )
162
- except Exception as e:
163
- raise HTTPException(
164
- status_code=500,
165
- detail=f"Error inesperado: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  )
167
 
 
168
 
169
- def process_chronos_output(raw_output: Dict, prediction_length: int) -> Dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  """
171
- Procesa la salida de Chronos para extraer mediana y cuantiles.
 
172
  """
173
- # La API de Chronos puede devolver diferentes formatos
174
- # Intentamos adaptarnos a ellos
175
-
176
- if isinstance(raw_output, list):
177
- # Si es una lista de valores, asumimos que es la predicci贸n media
178
- median = raw_output[:prediction_length]
179
- return {
180
- "median": median,
181
- "quantiles": {
182
- "0.1": median, # Sin cuantiles, usar median
183
- "0.5": median,
184
- "0.9": median
185
- }
186
- }
187
-
188
- # Si tiene estructura m谩s compleja, intentar extraer
189
- if "forecast" in raw_output:
190
- forecast = raw_output["forecast"]
191
- if "median" in forecast:
192
- median = forecast["median"][:prediction_length]
193
- else:
194
- median = forecast.get("mean", [0] * prediction_length)[:prediction_length]
195
-
196
- quantiles = forecast.get("quantiles", {})
197
- return {
198
- "median": median,
199
- "quantiles": quantiles
200
  }
201
-
202
- # Formato por defecto
203
- return {
204
- "median": [0] * prediction_length,
205
- "quantiles": {
206
- "0.1": [0] * prediction_length,
207
- "0.5": [0] * prediction_length,
208
- "0.9": [0] * prediction_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  }
210
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
  # =========================
214
- # Endpoints
215
  # =========================
216
 
217
- @app.get("/")
218
- def root():
219
- """Informaci贸n b谩sica de la API"""
220
- return {
221
- "name": "Chronos-2 Forecasting API",
222
- "version": "1.0.0",
223
- "model": MODEL_ID,
224
- "status": "running",
225
- "docs": "/docs",
226
- "health": "/health"
227
- }
228
 
229
 
230
- @app.get("/health")
231
- def health():
232
- """Health check del servicio"""
233
- return {
234
- "status": "ok" if HF_TOKEN else "warning",
235
- "model_id": MODEL_ID,
236
- "hf_token_configured": HF_TOKEN is not None,
237
- "message": "Ready" if HF_TOKEN else "HF_TOKEN not configured"
238
- }
239
 
240
 
241
- @app.post("/forecast_univariate", response_model=ForecastUnivariateResponse)
242
- def forecast_univariate(req: ForecastUnivariateRequest):
 
 
 
 
 
 
 
 
 
243
  """
244
- Pron贸stico para una serie temporal univariada.
245
-
246
- Compatible con el Excel Add-in.
247
  """
248
- values = req.series.values
249
- n = len(values)
250
-
251
- if n == 0:
252
- raise HTTPException(status_code=400, detail="La serie no puede estar vac铆a.")
253
-
254
- if n < 3:
255
- raise HTTPException(
256
- status_code=400,
257
- detail="La serie debe tener al menos 3 puntos hist贸ricos."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
  )
259
-
260
- # Llamar a la API de HuggingFace
261
- raw_output = call_chronos_inference(values, req.prediction_length)
262
-
263
- # Procesar la salida
264
- processed = process_chronos_output(raw_output, req.prediction_length)
265
-
266
- # Generar timestamps
267
- timestamps = [f"t+{i+1}" for i in range(req.prediction_length)]
268
-
269
- return ForecastUnivariateResponse(
270
- timestamps=timestamps,
271
- median=processed["median"],
272
- quantiles=processed["quantiles"]
273
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
 
276
  @app.post("/detect_anomalies", response_model=AnomalyDetectionResponse)
277
  def detect_anomalies(req: AnomalyDetectionRequest):
278
  """
279
- Detecta anomal铆as comparando valores observados con predicciones.
 
280
  """
281
  n_hist = len(req.context.values)
282
-
283
  if n_hist == 0:
284
- raise HTTPException(status_code=400, detail="El contexto no puede estar vac铆o.")
285
-
286
  if len(req.recent_observed) != req.prediction_length:
287
  raise HTTPException(
288
  status_code=400,
289
- detail="recent_observed debe tener la misma longitud que prediction_length."
290
  )
291
-
292
- # Hacer predicci贸n
293
- raw_output = call_chronos_inference(req.context.values, req.prediction_length)
294
- processed = process_chronos_output(raw_output, req.prediction_length)
295
-
296
- # Comparar con valores observados
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  anomalies: List[AnomalyPoint] = []
298
-
299
- median = processed["median"]
300
- # Intentar obtener cuantiles o usar aproximaciones
301
- q_low = processed["quantiles"].get(str(req.quantile_low), median)
302
- q_high = processed["quantiles"].get(str(req.quantile_high), median)
303
-
304
- for i, obs in enumerate(req.recent_observed):
305
- if i < len(median):
306
- lower = q_low[i] if i < len(q_low) else median[i] * 0.8
307
- upper = q_high[i] if i < len(q_high) else median[i] * 1.2
308
- predicted = median[i]
309
- is_anom = (obs < lower) or (obs > upper)
310
-
311
- anomalies.append(
312
- AnomalyPoint(
313
- index=i,
314
- value=obs,
315
- predicted_median=predicted,
316
- lower=lower,
317
- upper=upper,
318
- is_anomaly=is_anom,
319
- )
320
  )
321
-
 
322
  return AnomalyDetectionResponse(anomalies=anomalies)
323
 
324
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
325
  @app.post("/backtest_simple", response_model=BacktestResponse)
326
  def backtest_simple(req: BacktestRequest):
327
  """
328
- Backtesting simple: divide la serie en train/test y eval煤a m茅tricas.
 
329
  """
330
  values = np.array(req.series.values, dtype=float)
331
  n = len(values)
332
-
333
  if n <= req.test_length:
334
  raise HTTPException(
335
  status_code=400,
336
- detail="La serie debe ser m谩s larga que test_length."
337
  )
338
-
339
- # Dividir en train/test
340
- train = values[: n - req.test_length].tolist()
341
- test = values[n - req.test_length :].tolist()
342
-
343
- # Hacer predicci贸n
344
- raw_output = call_chronos_inference(train, req.test_length)
345
- processed = process_chronos_output(raw_output, req.test_length)
346
-
347
- forecast = np.array(processed["median"], dtype=float)
348
- test_arr = np.array(test, dtype=float)
349
-
350
- # Calcular m茅tricas
351
- mae = float(np.mean(np.abs(test_arr - forecast)))
352
- rmse = float(np.sqrt(np.mean((test_arr - forecast) ** 2)))
353
-
354
- eps = 1e-8
355
- mape = float(np.mean(np.abs((test_arr - forecast) / (test_arr + eps)))) * 100.0
356
-
357
- timestamps = [f"test_t{i+1}" for i in range(req.test_length)]
358
-
359
- metrics = BacktestMetrics(mae=mae, mape=mape, rmse=rmse)
360
-
361
- return BacktestResponse(
362
- metrics=metrics,
363
- forecast_median=forecast.tolist(),
364
- forecast_timestamps=timestamps,
365
- actuals=test,
366
  )
367
 
 
 
 
 
 
 
 
 
368
 
369
- # =========================
370
- # Endpoints simplificados para testing
371
- # =========================
372
 
373
- @app.post("/simple_forecast")
374
- def simple_forecast(series: List[float], prediction_length: int = 7):
375
- """
376
- Endpoint simplificado para testing r谩pido.
377
- """
378
- if not series:
379
- raise HTTPException(status_code=400, detail="Serie vac铆a")
380
-
381
- raw_output = call_chronos_inference(series, prediction_length)
382
- processed = process_chronos_output(raw_output, prediction_length)
383
-
384
- return {
385
- "input_series": series,
386
- "prediction_length": prediction_length,
387
- "forecast": processed["median"],
388
- "model": MODEL_ID
389
- }
390
 
 
391
 
392
- if __name__ == "__main__":
393
- import uvicorn
394
- port = int(os.getenv("PORT", 7860))
395
- uvicorn.run(app, host="0.0.0.0", port=port)
 
 
 
1
  import os
2
  from typing import List, Dict, Optional
 
3
 
4
  import numpy as np
5
  import pandas as pd
6
  from fastapi import FastAPI, HTTPException
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel, Field
9
+
10
+ from chronos import Chronos2Pipeline
11
 
12
 
13
  # =========================
14
+ # Configuraci贸n del modelo
15
  # =========================
16
 
17
+ MODEL_ID = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
18
+ DEVICE_MAP = os.getenv("DEVICE_MAP", "cpu") # "cpu" o "cuda"
19
 
20
  app = FastAPI(
21
+ title="Chronos-2 Universal Forecasting API",
22
  description=(
23
+ "Servidor local (Docker) para pron贸sticos con Chronos-2: univariante, "
24
+ "multivariante, covariables, escenarios, anomal铆as y backtesting."
25
  ),
26
  version="1.0.0",
27
  )
28
 
29
+ # Configurar CORS para Excel Add-in
30
  app.add_middleware(
31
  CORSMiddleware,
32
+ allow_origins=["https://localhost:3001", "https://localhost:3000"],
33
  allow_credentials=True,
34
  allow_methods=["*"],
35
  allow_headers=["*"],
36
  )
37
 
38
+ # Carga 煤nica del modelo al iniciar el proceso
39
+ pipeline = Chronos2Pipeline.from_pretrained(MODEL_ID, device_map=DEVICE_MAP)
 
 
 
 
 
40
 
41
 
42
  # =========================
43
+ # Modelos Pydantic comunes
44
  # =========================
45
 
46
+ class BaseForecastConfig(BaseModel):
47
+ prediction_length: int = Field(
48
+ 7, description="Horizonte de predicci贸n (n煤mero de pasos futuros)"
49
+ )
50
+ quantile_levels: List[float] = Field(
51
+ default_factory=lambda: [0.1, 0.5, 0.9],
52
+ description="Cuantiles para el pron贸stico probabil铆stico",
53
+ )
54
+ start_timestamp: Optional[str] = Field(
55
+ default=None,
56
+ description=(
57
+ "Fecha/hora inicial del hist贸rico (formato ISO). "
58
+ "Si no se especifica, se usan 铆ndices enteros."
59
+ ),
60
+ )
61
+ freq: str = Field(
62
+ "D",
63
+ description="Frecuencia temporal (p.ej. 'D' diario, 'H' horario, 'W' semanal...).",
64
+ )
65
+
66
+
67
  class UnivariateSeries(BaseModel):
68
  values: List[float]
69
 
70
 
71
+ class MultiSeriesItem(BaseModel):
72
+ series_id: str
73
+ values: List[float]
74
+
75
+
76
+ class CovariatePoint(BaseModel):
77
+ """
78
+ Punto temporal usado tanto para contexto (hist贸rico) como para covariables futuras.
79
+ """
80
+ timestamp: Optional[str] = None # opcional si se usan 铆ndices enteros
81
+ id: Optional[str] = None # id de serie, por defecto 'series_0'
82
+ target: Optional[float] = None # valor de la variable objetivo (hist贸rico)
83
+ covariates: Dict[str, float] = Field(
84
+ default_factory=dict,
85
+ description="Nombre -> valor de cada covariable din谩mica.",
86
  )
87
+
88
+
89
+ # =========================
90
+ # 1) Healthcheck
91
+ # =========================
92
+
93
+ @app.get("/health")
94
+ def health():
95
+ """
96
+ Devuelve informaci贸n b谩sica del estado del servidor y el modelo cargado.
97
+ """
98
+ return {
99
+ "status": "ok",
100
+ "model_id": MODEL_ID,
101
+ "device_map": DEVICE_MAP,
102
+ }
103
+
104
+
105
+ # =========================
106
+ # 2) Pron贸stico univariante
107
+ # =========================
108
+
109
+ class ForecastUnivariateRequest(BaseForecastConfig):
110
+ series: UnivariateSeries
111
 
112
 
113
  class ForecastUnivariateResponse(BaseModel):
114
  timestamps: List[str]
115
  median: List[float]
116
+ quantiles: Dict[str, List[float]] # "0.1" -> [..], "0.9" -> [..]
117
 
118
 
119
+ @app.post("/forecast_univariate", response_model=ForecastUnivariateResponse)
120
+ def forecast_univariate(req: ForecastUnivariateRequest):
121
+ """
122
+ Pron贸stico para una sola serie temporal (univariante, sin covariables).
123
+ Pensado para uso directo desde Excel u otras herramientas sencillas.
124
+ """
125
+ values = req.series.values
126
+ n = len(values)
127
+ if n == 0:
128
+ raise HTTPException(status_code=400, detail="La serie no puede estar vac铆a.")
129
 
130
+ # Construimos contexto como DataFrame largo (id, timestamp, target)
131
+ if req.start_timestamp:
132
+ timestamps = pd.date_range(
133
+ start=pd.to_datetime(req.start_timestamp),
134
+ periods=n,
135
+ freq=req.freq,
136
+ )
137
+ else:
138
+ timestamps = pd.RangeIndex(start=0, stop=n, step=1)
139
+
140
+ context_df = pd.DataFrame(
141
+ {
142
+ "id": ["series_0"] * n,
143
+ "timestamp": timestamps,
144
+ "target": values,
145
+ }
146
+ )
147
 
148
+ pred_df = pipeline.predict_df(
149
+ context_df,
150
+ prediction_length=req.prediction_length,
151
+ quantile_levels=req.quantile_levels,
152
+ id_column="id",
153
+ timestamp_column="timestamp",
154
+ target="target",
155
+ )
156
 
157
+ pred_df = pred_df.sort_values("timestamp")
158
+ timestamps_out = pred_df["timestamp"].astype(str).tolist()
159
+ median = pred_df["predictions"].astype(float).tolist()
160
 
161
+ quantiles_dict: Dict[str, List[float]] = {}
162
+ for q in req.quantile_levels:
163
+ key = f"{q:.3g}"
164
+ if key in pred_df.columns:
165
+ quantiles_dict[key] = pred_df[key].astype(float).tolist()
166
 
167
+ return ForecastUnivariateResponse(
168
+ timestamps=timestamps_out,
169
+ median=median,
170
+ quantiles=quantiles_dict,
171
+ )
172
 
 
 
 
 
173
 
174
+ # =========================
175
+ # 3) Multi-serie (multi-id)
176
+ # =========================
177
 
178
+ class ForecastMultiSeriesRequest(BaseForecastConfig):
179
+ series_list: List[MultiSeriesItem]
 
 
180
 
181
 
182
+ class SeriesForecast(BaseModel):
183
+ series_id: str
184
+ timestamps: List[str]
185
+ median: List[float]
186
+ quantiles: Dict[str, List[float]]
187
 
188
 
189
+ class ForecastMultiSeriesResponse(BaseModel):
190
+ forecasts: List[SeriesForecast]
191
+
192
 
193
+ @app.post("/forecast_multi_id", response_model=ForecastMultiSeriesResponse)
194
+ def forecast_multi_id(req: ForecastMultiSeriesRequest):
195
  """
196
+ Pron贸stico para m煤ltiples series (por ejemplo, varios SKU o tiendas).
 
197
  """
198
+ if not req.series_list:
199
+ raise HTTPException(status_code=400, detail="Debes enviar al menos una serie.")
200
+
201
+ frames = []
202
+ for item in req.series_list:
203
+ n = len(item.values)
204
+ if n == 0:
205
+ continue
206
+ if req.start_timestamp:
207
+ timestamps = pd.date_range(
208
+ start=pd.to_datetime(req.start_timestamp),
209
+ periods=n,
210
+ freq=req.freq,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  )
212
+ else:
213
+ timestamps = pd.RangeIndex(start=0, stop=n, step=1)
214
+
215
+ frames.append(
216
+ pd.DataFrame(
217
+ {
218
+ "id": [item.series_id] * n,
219
+ "timestamp": timestamps,
220
+ "target": item.values,
221
+ }
222
  )
 
 
 
 
 
 
 
 
223
  )
224
+
225
+ if not frames:
226
+ raise HTTPException(status_code=400, detail="Todas las series est谩n vac铆as.")
227
+
228
+ context_df = pd.concat(frames, ignore_index=True)
229
+
230
+ pred_df = pipeline.predict_df(
231
+ context_df,
232
+ prediction_length=req.prediction_length,
233
+ quantile_levels=req.quantile_levels,
234
+ id_column="id",
235
+ timestamp_column="timestamp",
236
+ target="target",
237
+ )
238
+
239
+ forecasts: List[SeriesForecast] = []
240
+ for series_id, group in pred_df.groupby("id"):
241
+ group = group.sort_values("timestamp")
242
+ timestamps_out = group["timestamp"].astype(str).tolist()
243
+ median = group["predictions"].astype(float).tolist()
244
+ quantiles_dict: Dict[str, List[float]] = {}
245
+ for q in req.quantile_levels:
246
+ key = f"{q:.3g}"
247
+ if key in group.columns:
248
+ quantiles_dict[key] = group[key].astype(float).tolist()
249
+
250
+ forecasts.append(
251
+ SeriesForecast(
252
+ series_id=series_id,
253
+ timestamps=timestamps_out,
254
+ median=median,
255
+ quantiles=quantiles_dict,
256
+ )
257
  )
258
 
259
+ return ForecastMultiSeriesResponse(forecasts=forecasts)
260
 
261
+
262
+ # =========================
263
+ # 4) Pron贸stico con covariables
264
+ # =========================
265
+
266
+ class ForecastWithCovariatesRequest(BaseForecastConfig):
267
+ context: List[CovariatePoint]
268
+ future: Optional[List[CovariatePoint]] = None
269
+
270
+
271
+ class ForecastWithCovariatesResponse(BaseModel):
272
+ # filas con todas las columnas de pred_df serializadas como string
273
+ pred_df: List[Dict[str, str]]
274
+
275
+
276
+ @app.post("/forecast_with_covariates", response_model=ForecastWithCovariatesResponse)
277
+ def forecast_with_covariates(req: ForecastWithCovariatesRequest):
278
  """
279
+ Pron贸stico con informaci贸n de covariables (promos, precio, clima...) tanto
280
+ en el hist贸rico (context) como en futuros posibles (future).
281
  """
282
+ if not req.context:
283
+ raise HTTPException(status_code=400, detail="El contexto no puede estar vac铆o.")
284
+
285
+ ctx_rows = []
286
+ for p in req.context:
287
+ if p.target is None:
288
+ continue
289
+ row = {
290
+ "id": p.id or "series_0",
291
+ "timestamp": p.timestamp,
292
+ "target": p.target,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  }
294
+ for k, v in p.covariates.items():
295
+ row[k] = v
296
+ ctx_rows.append(row)
297
+
298
+ context_df = pd.DataFrame(ctx_rows)
299
+ if "timestamp" not in context_df or context_df["timestamp"].isna().any():
300
+ context_df["timestamp"] = pd.RangeIndex(start=0, stop=len(context_df), step=1)
301
+
302
+ future_df = None
303
+ if req.future:
304
+ fut_rows = []
305
+ for p in req.future:
306
+ row = {
307
+ "id": p.id or "series_0",
308
+ "timestamp": p.timestamp,
309
+ }
310
+ for k, v in p.covariates.items():
311
+ row[k] = v
312
+ fut_rows.append(row)
313
+ future_df = pd.DataFrame(fut_rows)
314
+ if "timestamp" not in future_df or future_df["timestamp"].isna().any():
315
+ future_df["timestamp"] = pd.RangeIndex(
316
+ start=context_df["timestamp"].max() + 1,
317
+ stop=context_df["timestamp"].max() + 1 + len(future_df),
318
+ step=1,
319
+ )
320
+
321
+ pred_df = pipeline.predict_df(
322
+ context_df,
323
+ future_df=future_df,
324
+ prediction_length=req.prediction_length,
325
+ quantile_levels=req.quantile_levels,
326
+ id_column="id",
327
+ timestamp_column="timestamp",
328
+ target="target",
329
+ )
330
+
331
+ pred_df = pred_df.sort_values(["id", "timestamp"])
332
+ out_records: List[Dict[str, str]] = []
333
+ for _, row in pred_df.iterrows():
334
+ record = {k: str(v) for k, v in row.items()}
335
+ out_records.append(record)
336
+
337
+ return ForecastWithCovariatesResponse(pred_df=out_records)
338
+
339
+
340
+ # =========================
341
+ # 5) Multivariante (varios targets)
342
+ # =========================
343
+
344
+ class MultivariateContextPoint(BaseModel):
345
+ timestamp: Optional[str] = None
346
+ id: Optional[str] = None
347
+ targets: Dict[str, float] # p.ej. {"demand": 100, "returns": 5}
348
+ covariates: Dict[str, float] = Field(default_factory=dict)
349
+
350
+
351
+ class ForecastMultivariateRequest(BaseForecastConfig):
352
+ context: List[MultivariateContextPoint]
353
+ target_columns: List[str] # nombres de columnas objetivo
354
+
355
+
356
+ class ForecastMultivariateResponse(BaseModel):
357
+ pred_df: List[Dict[str, str]]
358
+
359
+
360
+ @app.post("/forecast_multivariate", response_model=ForecastMultivariateResponse)
361
+ def forecast_multivariate(req: ForecastMultivariateRequest):
362
+ """
363
+ Pron贸stico multivariante: m煤ltiples columnas objetivo (p.ej. demanda y devoluciones).
364
+ """
365
+ if not req.context:
366
+ raise HTTPException(status_code=400, detail="El contexto no puede estar vac铆o.")
367
+ if not req.target_columns:
368
+ raise HTTPException(status_code=400, detail="Debes indicar columnas objetivo.")
369
+
370
+ rows = []
371
+ for p in req.context:
372
+ base = {
373
+ "id": p.id or "series_0",
374
+ "timestamp": p.timestamp,
375
  }
376
+ for t_name, t_val in p.targets.items():
377
+ base[t_name] = t_val
378
+ for k, v in p.covariates.items():
379
+ base[k] = v
380
+ rows.append(base)
381
+
382
+ context_df = pd.DataFrame(rows)
383
+ if "timestamp" not in context_df or context_df["timestamp"].isna().any():
384
+ context_df["timestamp"] = pd.RangeIndex(start=0, stop=len(context_df), step=1)
385
+
386
+ pred_df = pipeline.predict_df(
387
+ context_df,
388
+ prediction_length=req.prediction_length,
389
+ quantile_levels=req.quantile_levels,
390
+ id_column="id",
391
+ timestamp_column="timestamp",
392
+ target=req.target_columns,
393
+ )
394
+
395
+ pred_df = pred_df.sort_values(["id", "timestamp"])
396
+ out_records = [{k: str(v) for k, v in row.items()} for _, row in pred_df.iterrows()]
397
+ return ForecastMultivariateResponse(pred_df=out_records)
398
 
399
 
400
  # =========================
401
+ # 6) Escenarios (what-if)
402
  # =========================
403
 
404
+ class ScenarioDefinition(BaseModel):
405
+ name: str
406
+ future_covariates: List[CovariatePoint]
 
 
 
 
 
 
 
 
407
 
408
 
409
+ class ScenarioForecast(BaseModel):
410
+ name: str
411
+ pred_df: List[Dict[str, str]]
 
 
 
 
 
 
412
 
413
 
414
+ class ForecastScenariosRequest(BaseForecastConfig):
415
+ context: List[CovariatePoint]
416
+ scenarios: List[ScenarioDefinition]
417
+
418
+
419
+ class ForecastScenariosResponse(BaseModel):
420
+ scenarios: List[ScenarioForecast]
421
+
422
+
423
+ @app.post("/forecast_scenarios", response_model=ForecastScenariosResponse)
424
+ def forecast_scenarios(req: ForecastScenariosRequest):
425
  """
426
+ Evaluaci贸n de m煤ltiples escenarios (what-if) cambiando las covariables futuras
427
+ (por ejemplo, promo ON/OFF, diferentes precios, etc.).
 
428
  """
429
+ if not req.context:
430
+ raise HTTPException(status_code=400, detail="El contexto no puede estar vac铆o.")
431
+ if not req.scenarios:
432
+ raise HTTPException(status_code=400, detail="Debes definir al menos un escenario.")
433
+
434
+ ctx_rows = []
435
+ for p in req.context:
436
+ if p.target is None:
437
+ continue
438
+ row = {
439
+ "id": p.id or "series_0",
440
+ "timestamp": p.timestamp,
441
+ "target": p.target,
442
+ }
443
+ for k, v in p.covariates.items():
444
+ row[k] = v
445
+ ctx_rows.append(row)
446
+
447
+ context_df = pd.DataFrame(ctx_rows)
448
+ if "timestamp" not in context_df or context_df["timestamp"].isna().any():
449
+ context_df["timestamp"] = pd.RangeIndex(start=0, stop=len(context_df), step=1)
450
+
451
+ results: List[ScenarioForecast] = []
452
+
453
+ for scen in req.scenarios:
454
+ fut_rows = []
455
+ for p in scen.future_covariates:
456
+ row = {
457
+ "id": p.id or "series_0",
458
+ "timestamp": p.timestamp,
459
+ }
460
+ for k, v in p.covariates.items():
461
+ row[k] = v
462
+ fut_rows.append(row)
463
+ future_df = pd.DataFrame(fut_rows)
464
+ if "timestamp" not in future_df or future_df["timestamp"].isna().any():
465
+ future_df["timestamp"] = pd.RangeIndex(
466
+ start=context_df["timestamp"].max() + 1,
467
+ stop=context_df["timestamp"].max() + 1 + len(future_df),
468
+ step=1,
469
+ )
470
+
471
+ pred_df = pipeline.predict_df(
472
+ context_df,
473
+ future_df=future_df,
474
+ prediction_length=req.prediction_length,
475
+ quantile_levels=req.quantile_levels,
476
+ id_column="id",
477
+ timestamp_column="timestamp",
478
+ target="target",
479
  )
480
+ pred_df = pred_df.sort_values(["id", "timestamp"])
481
+ out_records = [{k: str(v) for k, v in row.items()} for _, row in pred_df.iterrows()]
482
+
483
+ results.append(ScenarioForecast(name=scen.name, pred_df=out_records))
484
+
485
+ return ForecastScenariosResponse(scenarios=results)
486
+
487
+
488
+ # =========================
489
+ # 7) Detecci贸n de anomal铆as
490
+ # =========================
491
+
492
+ class AnomalyDetectionRequest(BaseModel):
493
+ context: UnivariateSeries
494
+ recent_observed: List[float]
495
+ prediction_length: int = 7
496
+ quantile_low: float = 0.05
497
+ quantile_high: float = 0.95
498
+
499
+
500
+ class AnomalyPoint(BaseModel):
501
+ index: int
502
+ value: float
503
+ predicted_median: float
504
+ lower: float
505
+ upper: float
506
+ is_anomaly: bool
507
+
508
+
509
+ class AnomalyDetectionResponse(BaseModel):
510
+ anomalies: List[AnomalyPoint]
511
 
512
 
513
  @app.post("/detect_anomalies", response_model=AnomalyDetectionResponse)
514
  def detect_anomalies(req: AnomalyDetectionRequest):
515
  """
516
+ Marca como anomal铆as los puntos observados recientes que caen fuera del
517
+ intervalo [quantile_low, quantile_high] del pron贸stico.
518
  """
519
  n_hist = len(req.context.values)
 
520
  if n_hist == 0:
521
+ raise HTTPException(status_code=400, detail="La serie hist贸rica no puede estar vac铆a.")
 
522
  if len(req.recent_observed) != req.prediction_length:
523
  raise HTTPException(
524
  status_code=400,
525
+ detail="recent_observed debe tener la misma longitud que prediction_length.",
526
  )
527
+
528
+ context_df = pd.DataFrame(
529
+ {
530
+ "id": ["series_0"] * n_hist,
531
+ "timestamp": pd.RangeIndex(start=0, stop=n_hist, step=1),
532
+ "target": req.context.values,
533
+ }
534
+ )
535
+
536
+ quantiles = sorted({req.quantile_low, 0.5, req.quantile_high})
537
+ pred_df = pipeline.predict_df(
538
+ context_df,
539
+ prediction_length=req.prediction_length,
540
+ quantile_levels=quantiles,
541
+ id_column="id",
542
+ timestamp_column="timestamp",
543
+ target="target",
544
+ ).sort_values("timestamp")
545
+
546
+ q_low_col = f"{req.quantile_low:.3g}"
547
+ q_high_col = f"{req.quantile_high:.3g}"
548
+
549
  anomalies: List[AnomalyPoint] = []
550
+ for i, (obs, (_, row)) in enumerate(zip(req.recent_observed, pred_df.iterrows())):
551
+ lower = float(row[q_low_col])
552
+ upper = float(row[q_high_col])
553
+ median = float(row["predictions"])
554
+ is_anom = (obs < lower) or (obs > upper)
555
+ anomalies.append(
556
+ AnomalyPoint(
557
+ index=i,
558
+ value=obs,
559
+ predicted_median=median,
560
+ lower=lower,
561
+ upper=upper,
562
+ is_anomaly=is_anom,
 
 
 
 
 
 
 
 
 
563
  )
564
+ )
565
+
566
  return AnomalyDetectionResponse(anomalies=anomalies)
567
 
568
 
569
+ # =========================
570
+ # 8) Backtest simple
571
+ # =========================
572
+
573
+ class BacktestRequest(BaseModel):
574
+ series: UnivariateSeries
575
+ prediction_length: int = 7
576
+ test_length: int = 28
577
+
578
+
579
+ class BacktestMetrics(BaseModel):
580
+ mae: float
581
+ mape: float
582
+ wql: float # Weighted Quantile Loss aproximada para el cuantil 0.5
583
+
584
+
585
+ class BacktestResponse(BaseModel):
586
+ metrics: BacktestMetrics
587
+ forecast_median: List[float]
588
+ forecast_timestamps: List[str]
589
+ actuals: List[float]
590
+
591
+
592
  @app.post("/backtest_simple", response_model=BacktestResponse)
593
  def backtest_simple(req: BacktestRequest):
594
  """
595
+ Backtest sencillo: separamos un tramo final de la serie como test, pronosticamos
596
+ ese tramo y calculamos m茅tricas MAE / MAPE / WQL.
597
  """
598
  values = np.array(req.series.values, dtype=float)
599
  n = len(values)
 
600
  if n <= req.test_length:
601
  raise HTTPException(
602
  status_code=400,
603
+ detail="La serie debe ser m谩s larga que test_length.",
604
  )
605
+
606
+ train = values[: n - req.test_length]
607
+ test = values[n - req.test_length :]
608
+
609
+ context_df = pd.DataFrame(
610
+ {
611
+ "id": ["series_0"] * len(train),
612
+ "timestamp": pd.RangeIndex(start=0, stop=len(train), step=1),
613
+ "target": train.tolist(),
614
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
  )
616
 
617
+ pred_df = pipeline.predict_df(
618
+ context_df,
619
+ prediction_length=req.test_length,
620
+ quantile_levels=[0.5],
621
+ id_column="id",
622
+ timestamp_column="timestamp",
623
+ target="target",
624
+ ).sort_values("timestamp")
625
 
626
+ forecast = pred_df["predictions"].to_numpy(dtype=float)
627
+ timestamps = pred_df["timestamp"].astype(str).tolist()
 
628
 
629
+ mae = float(np.mean(np.abs(test - forecast)))
630
+ eps = 1e-8
631
+ mape = float(np.mean(np.abs((test - forecast) / (test + eps)))) * 100.0
632
+ tau = 0.5
633
+ diff = test - forecast
634
+ wql = float(np.mean(np.maximum(tau * diff, (tau - 1) * diff)))
 
 
 
 
 
 
 
 
 
 
 
635
 
636
+ metrics = BacktestMetrics(mae=mae, mape=mape, wql=wql)
637
 
638
+ return BacktestResponse(
639
+ metrics=metrics,
640
+ forecast_median=forecast.tolist(),
641
+ forecast_timestamps=timestamps,
642
+ actuals=test.tolist(),
643
+ )
requirements.txt CHANGED
@@ -2,7 +2,7 @@ fastapi>=0.104.0
2
  uvicorn[standard]>=0.24.0
3
  pandas>=2.0.0
4
  numpy>=1.24.0
5
- huggingface_hub>=0.20.0
6
  pydantic>=2.0.0
7
  python-dotenv>=1.0.0
8
- requests>=2.31.0
 
 
2
  uvicorn[standard]>=0.24.0
3
  pandas>=2.0.0
4
  numpy>=1.24.0
 
5
  pydantic>=2.0.0
6
  python-dotenv>=1.0.0
7
+ chronos-forecasting>=1.0.0
8
+ torch>=2.0.0