Spaces:
Sleeping
Sleeping
Update model_handler.py
Browse files- model_handler.py +5 -4
model_handler.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
|
|
|
|
|
|
| 3 |
from chronos import ChronosPipeline
|
| 4 |
|
| 5 |
class ModelHandler:
|
| 6 |
def __init__(self):
|
|
|
|
| 7 |
self.model_name = "amazon/chronos-2"
|
| 8 |
self.pipeline = None
|
| 9 |
# Penentuan device: "cuda" jika ada GPU, jika tidak "cpu"
|
|
@@ -15,7 +18,7 @@ class ModelHandler:
|
|
| 15 |
try:
|
| 16 |
print(f"Loading {self.model_name} on {self.device}...")
|
| 17 |
|
| 18 |
-
# ChronosPipeline menangani semua
|
| 19 |
self.pipeline = ChronosPipeline.from_pretrained(
|
| 20 |
self.model_name,
|
| 21 |
device_map=self.device,
|
|
@@ -23,7 +26,6 @@ class ModelHandler:
|
|
| 23 |
print("Chronos-2 pipeline loaded successfully.")
|
| 24 |
|
| 25 |
except Exception as e:
|
| 26 |
-
# Jika gagal, pipeline akan tetap None, dan fallback akan digunakan
|
| 27 |
print(f"Error loading Chronos-2 model: {e}")
|
| 28 |
print("Using fallback prediction method")
|
| 29 |
self.pipeline = None
|
|
@@ -31,8 +33,7 @@ class ModelHandler:
|
|
| 31 |
def predict(self, data, horizon=10):
|
| 32 |
"""Generate predictions using Chronos-2 or fallback. 'data' must be the dict from data_processor.prepare_for_chronos."""
|
| 33 |
try:
|
| 34 |
-
#
|
| 35 |
-
# Fix error 'original' in fallback logic by ensuring proper data check
|
| 36 |
if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
|
| 37 |
return np.array([0] * horizon)
|
| 38 |
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
+
# Menggunakan ChronosPipeline untuk pemuatan dan inferensi yang efisien
|
| 4 |
+
# PENTING: Class ini hanya tersedia jika 'chronos-forecasting' terinstal
|
| 5 |
from chronos import ChronosPipeline
|
| 6 |
|
| 7 |
class ModelHandler:
|
| 8 |
def __init__(self):
|
| 9 |
+
# Mengganti model lama dengan Chronos-2
|
| 10 |
self.model_name = "amazon/chronos-2"
|
| 11 |
self.pipeline = None
|
| 12 |
# Penentuan device: "cuda" jika ada GPU, jika tidak "cpu"
|
|
|
|
| 18 |
try:
|
| 19 |
print(f"Loading {self.model_name} on {self.device}...")
|
| 20 |
|
| 21 |
+
# Ini adalah fix-nya: ChronosPipeline.from_pretrained menangani semua konfigurasi
|
| 22 |
self.pipeline = ChronosPipeline.from_pretrained(
|
| 23 |
self.model_name,
|
| 24 |
device_map=self.device,
|
|
|
|
| 26 |
print("Chronos-2 pipeline loaded successfully.")
|
| 27 |
|
| 28 |
except Exception as e:
|
|
|
|
| 29 |
print(f"Error loading Chronos-2 model: {e}")
|
| 30 |
print("Using fallback prediction method")
|
| 31 |
self.pipeline = None
|
|
|
|
| 33 |
def predict(self, data, horizon=10):
|
| 34 |
"""Generate predictions using Chronos-2 or fallback. 'data' must be the dict from data_processor.prepare_for_chronos."""
|
| 35 |
try:
|
| 36 |
+
# Memastikan data valid
|
|
|
|
| 37 |
if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
|
| 38 |
return np.array([0] * horizon)
|
| 39 |
|