omniverse1 commited on
Commit
29091b3
·
verified ·
1 Parent(s): 88fcfd5

Update model_handler.py

Browse files
Files changed (1) hide show
  1. model_handler.py +5 -7
model_handler.py CHANGED
@@ -1,22 +1,21 @@
1
  import numpy as np
2
  import torch
3
- # PENTING: Mengganti ChronosPipeline dengan BaseChronosPipeline sesuai referensi terbaru
4
  from chronos import BaseChronosPipeline
5
 
6
  class ModelHandler:
7
  def __init__(self):
8
  self.model_name = "amazon/chronos-2"
9
  self.pipeline = None
10
- # Penentuan device
11
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
  self.load_model()
13
 
14
  def load_model(self):
15
- """Load Chronos-2 model using the BaseChronosPipeline"""
16
  try:
17
  print(f"Loading {self.model_name} on {self.device}...")
18
 
19
- # Perhatikan: Menggunakan BaseChronosPipeline.from_pretrained
20
  self.pipeline = BaseChronosPipeline.from_pretrained(
21
  self.model_name,
22
  device_map=self.device,
@@ -31,7 +30,6 @@ class ModelHandler:
31
  def predict(self, data, horizon=10):
32
  """Generate predictions using Chronos-2 or fallback."""
33
  try:
34
- # Memastikan data valid
35
  if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
36
  return np.array([0] * horizon)
37
 
@@ -51,11 +49,11 @@ class ModelHandler:
51
  return np.array(predictions)
52
 
53
  # --- Chronos-2 Inference ---
54
- # NOTE: BaseChronosPipeline.predict mengembalikan array of arrays (sampel)
55
  predictions_samples = self.pipeline.predict(
56
  data['original'],
57
  prediction_length=horizon,
58
- num_samples=20
 
59
  )
60
 
61
  # Mengambil nilai rata-rata (mean) dari semua sampel
 
1
  import numpy as np
2
  import torch
3
+ # PENTING: Class ini adalah satu-satunya cara yang benar untuk memuat Chronos-2
4
  from chronos import BaseChronosPipeline
5
 
6
  class ModelHandler:
7
  def __init__(self):
8
  self.model_name = "amazon/chronos-2"
9
  self.pipeline = None
 
10
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
  self.load_model()
12
 
13
  def load_model(self):
14
+ """Load Chronos-2 model using the official BaseChronosPipeline"""
15
  try:
16
  print(f"Loading {self.model_name} on {self.device}...")
17
 
18
+ # Pemuatan otomatis oleh pipeline (sudah terbukti berhasil di langkah sebelumnya)
19
  self.pipeline = BaseChronosPipeline.from_pretrained(
20
  self.model_name,
21
  device_map=self.device,
 
30
  def predict(self, data, horizon=10):
31
  """Generate predictions using Chronos-2 or fallback."""
32
  try:
 
33
  if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
34
  return np.array([0] * horizon)
35
 
 
49
  return np.array(predictions)
50
 
51
  # --- Chronos-2 Inference ---
 
52
  predictions_samples = self.pipeline.predict(
53
  data['original'],
54
  prediction_length=horizon,
55
+ # KOREKSI: Mengganti 'num_samples' menjadi 'n_samples'
56
+ n_samples=20
57
  )
58
 
59
  # Mengambil nilai rata-rata (mean) dari semua sampel