omniverse1 commited on
Commit
6bf1eb6
·
verified ·
1 Parent(s): c5edd45

Update Gradio app with multiple files

Browse files
Files changed (2) hide show
  1. model_handler.py +48 -20
  2. requirements.txt +3 -2
model_handler.py CHANGED
@@ -1,6 +1,15 @@
1
  import numpy as np
2
  import torch
3
- from chronos import BaseChronosPipeline
 
 
 
 
 
 
 
 
 
4
 
5
  class ModelHandler:
6
  def __init__(self):
@@ -11,6 +20,10 @@ class ModelHandler:
11
 
12
  def load_model(self):
13
  """Load Chronos-2 model using the official BaseChronosPipeline"""
 
 
 
 
14
  try:
15
  print(f"Loading {self.model_name} on {self.device}...")
16
 
@@ -27,30 +40,18 @@ class ModelHandler:
27
 
28
  def predict(self, data, horizon=10):
29
  """Generate predictions using Chronos-2 or fallback."""
 
 
 
 
30
  try:
31
  if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
32
- return np.array([0] * horizon)
33
-
34
- if self.pipeline is None:
35
- # --- Fallback Logic ---
36
- values = data['original']
37
- recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
38
-
39
- predictions = []
40
- last_value = values[-1]
41
-
42
- for i in range(horizon):
43
- next_value = last_value + recent_trend * (i + 1)
44
- noise = np.random.normal(0, data.get('std', 1.0) * 0.1)
45
- predictions.append(next_value + noise)
46
-
47
- return np.array(predictions)
48
 
49
  # --- Chronos-2 Inference ---
50
  predictions_samples = self.pipeline.predict(
51
  data['original'],
52
  prediction_length=horizon,
53
- # FIX UTAMA: Menghapus 'n_samples'
54
  )
55
 
56
  # Mengambil nilai rata-rata (mean) dari semua sampel atau single trajectory
@@ -64,5 +65,32 @@ class ModelHandler:
64
  return mean_predictions
65
 
66
  except Exception as e:
67
- print(f"Prediction error: {e}")
68
- return np.array([0] * horizon)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import torch
3
+ import warnings
4
+
5
+ # Make chronos import optional
6
+ try:
7
+ from chronos import BaseChronosPipeline
8
+ CHRONOS_AVAILABLE = True
9
+ except ImportError:
10
+ warnings.warn("Chronos-forecasting not available. Using fallback predictions.")
11
+ CHRONOS_AVAILABLE = False
12
+ BaseChronosPipeline = None
13
 
14
  class ModelHandler:
15
  def __init__(self):
 
20
 
21
  def load_model(self):
22
  """Load Chronos-2 model using the official BaseChronosPipeline"""
23
+ if not CHRONOS_AVAILABLE:
24
+ print("Chronos-forecasting not installed. Using fallback prediction method.")
25
+ return
26
+
27
  try:
28
  print(f"Loading {self.model_name} on {self.device}...")
29
 
 
40
 
41
  def predict(self, data, horizon=10):
42
  """Generate predictions using Chronos-2 or fallback."""
43
+ if not CHRONOS_AVAILABLE or self.pipeline is None:
44
+ # Fallback to simple trend-based prediction
45
+ return self._fallback_predict(data, horizon)
46
+
47
  try:
48
  if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
49
+ return self._fallback_predict(data, horizon)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  # --- Chronos-2 Inference ---
52
  predictions_samples = self.pipeline.predict(
53
  data['original'],
54
  prediction_length=horizon,
 
55
  )
56
 
57
  # Mengambil nilai rata-rata (mean) dari semua sampel atau single trajectory
 
65
  return mean_predictions
66
 
67
  except Exception as e:
68
+ print(f"Prediction error with Chronos: {e}. Using fallback.")
69
+ return self._fallback_predict(data, horizon)
70
+
71
+ def _fallback_predict(self, data, horizon=10):
72
+ """Fallback prediction method when Chronos is unavailable"""
73
+ try:
74
+ if data is None or not isinstance(data, dict) or 'original' not in data:
75
+ # Return zero predictions if no data
76
+ return np.zeros(horizon)
77
+
78
+ values = data['original']
79
+ if len(values) < 5:
80
+ return np.zeros(horizon)
81
+
82
+ # Simple trend extrapolation
83
+ recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
84
+ predictions = []
85
+ last_value = values[-1]
86
+
87
+ for i in range(horizon):
88
+ next_value = last_value + recent_trend * (i + 1)
89
+ noise = np.random.normal(0, data.get('std', 1.0) * 0.1)
90
+ predictions.append(next_value + noise)
91
+
92
+ return np.array(predictions)
93
+
94
+ except Exception as e:
95
+ print(f"Fallback prediction error: {e}")
96
+ return np.zeros(horizon)
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- === requirements.txt ===
2
  pandas
3
  plotly
4
  numpy
@@ -18,4 +17,6 @@ tokenizers
18
  yfinance
19
  scipy
20
  joblib
21
- chronos-forecasting
 
 
 
 
1
  pandas
2
  plotly
3
  numpy
 
17
  yfinance
18
  scipy
19
  joblib
20
+ chronos-forecasting
21
+ safetensors
22
+ huggingface-hub