omniverse1 commited on
Commit
9927daa
·
verified ·
1 Parent(s): a90bc0e

Update model_handler.py

Browse files
Files changed (1) hide show
  1. model_handler.py +12 -13
model_handler.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  import torch
3
  # Menggunakan ChronosPipeline untuk pemuatan dan inferensi yang efisien
4
- from chronos import ChronosPipeline
5
 
6
  class ModelHandler:
7
  def __init__(self):
@@ -13,11 +13,11 @@ class ModelHandler:
13
  self.load_model()
14
 
15
  def load_model(self):
16
- """Load Chronos-2 model optimized for CPU/GPU"""
17
  try:
18
  print(f"Loading {self.model_name} on {self.device}...")
19
 
20
- # ChronosPipeline menangani semua proses tokenisasi dan pemuatan arsitektur
21
  self.pipeline = ChronosPipeline.from_pretrained(
22
  self.model_name,
23
  device_map=self.device,
@@ -25,20 +25,20 @@ class ModelHandler:
25
  print("Chronos-2 pipeline loaded successfully.")
26
 
27
  except Exception as e:
 
28
  print(f"Error loading Chronos-2 model: {e}")
29
  print("Using fallback prediction method")
30
  self.pipeline = None
31
 
32
  def predict(self, data, horizon=10):
33
- """Generate predictions using Chronos-2 or fallback"""
34
  try:
35
- # Menggunakan data['original'] yang merupakan harga aktual riil
36
- if data is None or len(data['original']) < 20:
37
  return np.array([0] * horizon)
38
 
39
  if self.pipeline is None:
40
- # --- Fallback Logic ---
41
- # Logic ekstrapolasi tren lama tetap dipertahankan jika model Deep Learning gagal dimuat
42
  values = data['original']
43
  recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
44
 
@@ -46,27 +46,26 @@ class ModelHandler:
46
  last_value = values[-1]
47
 
48
  for i in range(horizon):
 
49
  next_value = last_value + recent_trend * (i + 1)
50
- noise = np.random.normal(0, data['std'] * 0.1)
 
51
  predictions.append(next_value + noise)
52
 
53
  return np.array(predictions)
54
 
55
  # --- Chronos-2 Inference ---
56
- # Input: numpy array dari harga Close historis yang riil
57
  predictions_samples = self.pipeline.predict(
58
  data['original'],
59
  prediction_length=horizon,
60
- # Mengambil 20 sampel prediksi untuk mendapatkan prediksi probablistik
61
  num_samples=20
62
  )
63
 
64
- # Untuk chart (garis tunggal), ambil nilai rata-rata (mean) dari semua sampel.
65
  mean_predictions = np.mean(predictions_samples, axis=0)
66
 
67
  return mean_predictions
68
 
69
  except Exception as e:
70
  print(f"Prediction error: {e}")
71
- # Mengembalikan array nol jika ada error saat inferensi Chronos
72
  return np.array([0] * horizon)
 
1
  import numpy as np
2
  import torch
3
  # Menggunakan ChronosPipeline untuk pemuatan dan inferensi yang efisien
4
+ from chronos import ChronosPipeline
5
 
6
  class ModelHandler:
7
  def __init__(self):
 
13
  self.load_model()
14
 
15
  def load_model(self):
16
+ """Load Chronos-2 model using the official ChronosPipeline"""
17
  try:
18
  print(f"Loading {self.model_name} on {self.device}...")
19
 
20
+ # ChronosPipeline menangani semua proses tokenisasi dan pemuatan arsitektur dengan benar
21
  self.pipeline = ChronosPipeline.from_pretrained(
22
  self.model_name,
23
  device_map=self.device,
 
25
  print("Chronos-2 pipeline loaded successfully.")
26
 
27
  except Exception as e:
28
+ # Jika gagal, pipeline akan tetap None, dan fallback akan digunakan
29
  print(f"Error loading Chronos-2 model: {e}")
30
  print("Using fallback prediction method")
31
  self.pipeline = None
32
 
33
  def predict(self, data, horizon=10):
34
+ """Generate predictions using Chronos-2 or fallback. 'data' must be the dict from data_processor.prepare_for_chronos."""
35
  try:
36
+ # Cek data: memastikan data yang masuk adalah dictionary yang valid
37
+ if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
38
  return np.array([0] * horizon)
39
 
40
  if self.pipeline is None:
41
+ # --- Fallback Logic (Menggunakan data['original']) ---
 
42
  values = data['original']
43
  recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
44
 
 
46
  last_value = values[-1]
47
 
48
  for i in range(horizon):
49
+ # Add trend with some noise
50
  next_value = last_value + recent_trend * (i + 1)
51
+ # Use .get('std', 1.0) for safety
52
+ noise = np.random.normal(0, data.get('std', 1.0) * 0.1)
53
  predictions.append(next_value + noise)
54
 
55
  return np.array(predictions)
56
 
57
  # --- Chronos-2 Inference ---
 
58
  predictions_samples = self.pipeline.predict(
59
  data['original'],
60
  prediction_length=horizon,
 
61
  num_samples=20
62
  )
63
 
64
+ # Mengambil nilai rata-rata (mean) dari semua sampel untuk plot garis tunggal
65
  mean_predictions = np.mean(predictions_samples, axis=0)
66
 
67
  return mean_predictions
68
 
69
  except Exception as e:
70
  print(f"Prediction error: {e}")
 
71
  return np.array([0] * horizon)