Spaces:
Sleeping
Sleeping
Update model_handler.py
Browse files- model_handler.py +2 -5
model_handler.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
-
# Menggunakan ChronosPipeline untuk pemuatan dan inferensi yang efisien
|
| 4 |
from chronos import ChronosPipeline
|
| 5 |
|
| 6 |
class ModelHandler:
|
| 7 |
def __init__(self):
|
| 8 |
-
# Mengganti model lama dengan Chronos-2 yang lebih canggih
|
| 9 |
self.model_name = "amazon/chronos-2"
|
| 10 |
self.pipeline = None
|
| 11 |
# Penentuan device: "cuda" jika ada GPU, jika tidak "cpu"
|
|
@@ -34,11 +32,12 @@ class ModelHandler:
|
|
| 34 |
"""Generate predictions using Chronos-2 or fallback. 'data' must be the dict from data_processor.prepare_for_chronos."""
|
| 35 |
try:
|
| 36 |
# Cek data: memastikan data yang masuk adalah dictionary yang valid
|
|
|
|
| 37 |
if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
|
| 38 |
return np.array([0] * horizon)
|
| 39 |
|
| 40 |
if self.pipeline is None:
|
| 41 |
-
# --- Fallback Logic
|
| 42 |
values = data['original']
|
| 43 |
recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
|
| 44 |
|
|
@@ -46,9 +45,7 @@ class ModelHandler:
|
|
| 46 |
last_value = values[-1]
|
| 47 |
|
| 48 |
for i in range(horizon):
|
| 49 |
-
# Add trend with some noise
|
| 50 |
next_value = last_value + recent_trend * (i + 1)
|
| 51 |
-
# Use .get('std', 1.0) for safety
|
| 52 |
noise = np.random.normal(0, data.get('std', 1.0) * 0.1)
|
| 53 |
predictions.append(next_value + noise)
|
| 54 |
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
|
|
|
| 3 |
from chronos import ChronosPipeline
|
| 4 |
|
| 5 |
class ModelHandler:
|
| 6 |
def __init__(self):
|
|
|
|
| 7 |
self.model_name = "amazon/chronos-2"
|
| 8 |
self.pipeline = None
|
| 9 |
# Penentuan device: "cuda" jika ada GPU, jika tidak "cpu"
|
|
|
|
| 32 |
"""Generate predictions using Chronos-2 or fallback. 'data' must be the dict from data_processor.prepare_for_chronos."""
|
| 33 |
try:
|
| 34 |
# Cek data: memastikan data yang masuk adalah dictionary yang valid
|
| 35 |
+
# Fix error 'original' in fallback logic by ensuring proper data check
|
| 36 |
if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
|
| 37 |
return np.array([0] * horizon)
|
| 38 |
|
| 39 |
if self.pipeline is None:
|
| 40 |
+
# --- Fallback Logic ---
|
| 41 |
values = data['original']
|
| 42 |
recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
|
| 43 |
|
|
|
|
| 45 |
last_value = values[-1]
|
| 46 |
|
| 47 |
for i in range(horizon):
|
|
|
|
| 48 |
next_value = last_value + recent_trend * (i + 1)
|
|
|
|
| 49 |
noise = np.random.normal(0, data.get('std', 1.0) * 0.1)
|
| 50 |
predictions.append(next_value + noise)
|
| 51 |
|