omniverse1 commited on
Commit
1a10128
·
verified ·
1 Parent(s): 61b4b05

Update model_handler.py

Browse files
Files changed (1) hide show
  1. model_handler.py +5 -8
model_handler.py CHANGED
@@ -1,15 +1,13 @@
1
  import numpy as np
2
  import torch
3
- # Menggunakan ChronosPipeline untuk pemuatan dan inferensi yang efisien
4
- # PENTING: Class ini hanya tersedia jika 'chronos-forecasting' terinstal
5
- from chronos import ChronosPipeline
6
 
7
  class ModelHandler:
8
  def __init__(self):
9
- # Mengganti model lama dengan Chronos-2
10
  self.model_name = "amazon/chronos-2"
11
  self.pipeline = None
12
- # Penentuan device: "cuda" jika ada GPU, jika tidak "cpu"
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
  self.load_model()
15
 
@@ -18,7 +16,7 @@ class ModelHandler:
18
  try:
19
  print(f"Loading {self.model_name} on {self.device}...")
20
 
21
- # Ini adalah fix-nya: ChronosPipeline.from_pretrained menangani semua konfigurasi
22
  self.pipeline = ChronosPipeline.from_pretrained(
23
  self.model_name,
24
  device_map=self.device,
@@ -31,7 +29,7 @@ class ModelHandler:
31
  self.pipeline = None
32
 
33
  def predict(self, data, horizon=10):
34
- """Generate predictions using Chronos-2 or fallback. 'data' must be the dict from data_processor.prepare_for_chronos."""
35
  try:
36
  # Memastikan data valid
37
  if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
@@ -59,7 +57,6 @@ class ModelHandler:
59
  num_samples=20
60
  )
61
 
62
- # Mengambil nilai rata-rata (mean) dari semua sampel untuk plot garis tunggal
63
  mean_predictions = np.mean(predictions_samples, axis=0)
64
 
65
  return mean_predictions
 
1
  import numpy as np
2
  import torch
3
+ # PENTING: Class ini adalah satu-satunya cara yang benar untuk memuat Chronos-2
4
+ # Memerlukan instalasi: git+https://github.com/amazon-science/chronos-forecasting.git
5
+ from chronos import Chronos2Pipeline
6
 
7
  class ModelHandler:
8
  def __init__(self):
 
9
  self.model_name = "amazon/chronos-2"
10
  self.pipeline = None
 
11
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
  self.load_model()
13
 
 
16
  try:
17
  print(f"Loading {self.model_name} on {self.device}...")
18
 
19
+ # FIX UTAMA: Pemuatan otomatis oleh pipeline
20
  self.pipeline = ChronosPipeline.from_pretrained(
21
  self.model_name,
22
  device_map=self.device,
 
29
  self.pipeline = None
30
 
31
  def predict(self, data, horizon=10):
32
+ """Generate predictions using Chronos-2 or fallback."""
33
  try:
34
  # Memastikan data valid
35
  if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
 
57
  num_samples=20
58
  )
59
 
 
60
  mean_predictions = np.mean(predictions_samples, axis=0)
61
 
62
  return mean_predictions