omniverse1 commited on
Commit
08f6049
·
verified ·
1 Parent(s): cd2ec3f

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +39 -26
utils.py CHANGED
@@ -9,7 +9,7 @@ import spaces
9
  import gc
10
  import time
11
  import random
12
- from chronos import ChronosPipeline # Menggunakan ChronosPipeline untuk Chronos-2
13
  from scipy.stats import skew, kurtosis
14
  from typing import Dict, Union, List
15
 
@@ -44,15 +44,14 @@ def load_pipeline():
44
  clear_gpu_memory()
45
  print(f"Loading Chronos model: {model_name}...")
46
 
47
- # PENTING: Optimasi untuk Chronos-2
48
  pipeline = ChronosPipeline.from_pretrained(
49
  model_name,
50
  device_map="cuda",
51
  torch_dtype=torch.float16,
52
- low_cpu_mem_usage=True,
53
- trust_remote_code=True,
54
- use_safetensors=True
55
  )
 
56
  pipeline.model = pipeline.model.eval()
57
  for param in pipeline.model.parameters():
58
  param.requires_grad = False
@@ -61,6 +60,7 @@ def load_pipeline():
61
  return pipeline
62
 
63
  except Exception as e:
 
64
  print(f"Error loading pipeline on CUDA, trying CPU: {str(e)}")
65
  try:
66
  # Fallback ke CPU
@@ -71,6 +71,8 @@ def load_pipeline():
71
  except Exception as cpu_e:
72
  raise RuntimeError(f"Failed to load model {model_name} on both CUDA and CPU: {str(cpu_e)}")
73
 
 
 
74
  def retry_yfinance_request(func, max_retries=3, initial_delay=1):
75
  """Mekanisme retry untuk permintaan yfinance dengan backoff eksponensial."""
76
  for attempt in range(max_retries):
@@ -176,6 +178,10 @@ def predict_technical_indicators_future(data: pd.DataFrame, price_prediction: np
176
  """Memprediksi MACD dan Bollinger Bands di masa depan berdasarkan prediksi harga."""
177
  predictions = {}
178
 
 
 
 
 
179
  full_price_series = np.concatenate([data['Close'].values, price_prediction])
180
  full_price_series = pd.Series(full_price_series)
181
 
@@ -208,8 +214,19 @@ def predict_technical_indicators_future(data: pd.DataFrame, price_prediction: np
208
  @spaces.GPU(duration=120)
209
  def predict_prices(data, prediction_days=30):
210
  """Fungsi prediksi utama menggunakan Chronos-2 dengan enhanced covariates."""
 
 
 
 
 
 
 
 
 
 
 
211
  try:
212
- # 1. Load Model
213
  pipeline = load_pipeline()
214
 
215
  data_original = data.copy()
@@ -265,21 +282,23 @@ def predict_prices(data, prediction_days=30):
265
  'change_pct': change_pct,
266
  'q01': q01_forecast,
267
  'q09': q09_forecast,
268
- 'future_macd': future_indicators.get('MACD_Future', []),
269
- 'future_macd_signal': future_indicators.get('MACD_Signal_Future', []),
270
- 'future_bb_upper': future_indicators.get('BB_Upper_Future', []),
271
- 'future_bb_lower': future_indicators.get('BB_Lower_Future', []),
272
  'summary': f"AI Model: Amazon Chronos-2 (Enhanced Covariates: {len(all_covariates)} features)\nExpected High: {predicted_high:.2f}\nExpected Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%"
273
  }
274
 
275
  except Exception as e:
276
  error_message = f'Model prediction failed: {e}'
277
  print(f"Error in prediction: {e}")
278
- return {'values': [], 'dates': [], 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0, 'summary': error_message, 'q01': [], 'q09': [], 'future_macd': [], 'future_macd_signal': [], 'future_bb_upper': [], 'future_bb_lower': []}
 
279
 
280
  # Memperbarui fungsi create_prediction_chart untuk menampilkan Quantile Bands (q01, q09) dan Future BB
281
  def create_prediction_chart(data, predictions):
282
- if not len(predictions['values']) or not len(predictions['q01']):
 
283
  return go.Figure().update_layout(title="Prediction Failed: No Data Available")
284
 
285
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.05,
@@ -299,7 +318,7 @@ def create_prediction_chart(data, predictions):
299
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['values'], name='Median Forecast (Q0.5)', line=dict(color='red', width=3, dash='solid')), row=1, col=1)
300
 
301
  # Future Bollinger Bands
302
- if len(predictions['future_bb_upper']) == len(predictions['dates']):
303
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['future_bb_upper'], name='BB Upper (Future)', line=dict(color='green', width=1, dash='dot')), row=1, col=1)
304
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['future_bb_lower'], name='BB Lower (Future)', line=dict(color='green', width=1, dash='dot')), row=1, col=1)
305
 
@@ -308,14 +327,16 @@ def create_prediction_chart(data, predictions):
308
  fig.add_trace(go.Scatter(x=[last_hist_date], y=[last_hist_price], mode='markers', marker=dict(size=10, color='blue', symbol='circle'), name='Last Known Price'), row=1, col=1)
309
 
310
  # 2. MACD Forecast (Row 2)
311
- if len(predictions['future_macd']) == len(predictions['dates']):
312
 
 
 
313
  macd_hist = data['Close'].ewm(span=12).mean() - data['Close'].ewm(span=26).mean()
314
  macd_signal_hist = macd_hist.ewm(span=9).mean()
315
 
316
- macd_full = np.concatenate([macd_hist.iloc[-60:].values, predictions['future_macd']])
317
- macd_signal_full = np.concatenate([macd_signal_hist.iloc[-60:].values, predictions['future_macd_signal']])
318
- macd_dates_full = pd.to_datetime(np.concatenate([data.index[-60:].values, predictions['dates']]))
319
 
320
  fig.add_trace(go.Scatter(x=macd_dates_full, y=macd_full, name='MACD Line', line=dict(color='blue', width=2)), row=2, col=1)
321
  fig.add_trace(go.Scatter(x=macd_dates_full, y=macd_signal_full, name='Signal Line', line=dict(color='red', width=1)), row=2, col=1)
@@ -334,9 +355,8 @@ def create_prediction_chart(data, predictions):
334
 
335
  return fig
336
 
337
- # --- Fungsi lama yang harus tetap ada ---
338
  def get_indonesian_stocks():
339
- # ... (kode yang sama)
340
  return {
341
  "BBCA.JK": "Bank Central Asia", "BBRI.JK": "Bank BRI", "BBNI.JK": "Bank BNI",
342
  "BMRI.JK": "Bank Mandiri", "TLKM.JK": "Telkom Indonesia", "UNVR.JK": "Unilever Indonesia",
@@ -348,7 +368,6 @@ def get_indonesian_stocks():
348
  }
349
 
350
  def calculate_technical_indicators(data):
351
- # Disesuaikan agar dapat menambahkan RSI, MACD, Signal ke DataFrame
352
  indicators = {}
353
 
354
  def calculate_rsi(prices, period=14):
@@ -391,7 +410,6 @@ def calculate_technical_indicators(data):
391
  indicators['volume'] = {'current': data['Volume'].iloc[-1], 'avg_20': data['Volume'].rolling(20).mean().iloc[-1], 'ratio': data['Volume'].iloc[-1] / data['Volume'].rolling(20).mean().iloc[-1]}
392
 
393
  # Tambahkan kolom indikator ke DataFrame input untuk digunakan nanti (di predict_technical_indicators_future)
394
- # Catatan: Perubahan ini memodifikasi 'data' in-place.
395
  data['RSI'] = rsi_series
396
  data['MACD'] = macd
397
  data['MACD_Signal'] = signal_line
@@ -399,7 +417,6 @@ def calculate_technical_indicators(data):
399
  return indicators
400
 
401
  def generate_trading_signals(data, indicators):
402
- # ... (kode yang sama)
403
  signals = {}
404
  current_price = data['Close'].iloc[-1]
405
  buy_signals = 0
@@ -458,7 +475,6 @@ def generate_trading_signals(data, indicators):
458
  return signals
459
 
460
  def get_fundamental_data(stock):
461
- # ... (kode yang sama)
462
  try:
463
  info = stock.info
464
  history = stock.history(period="1d")
@@ -468,7 +484,6 @@ def get_fundamental_data(stock):
468
  return {'name': 'N/A', 'current_price': 0, 'market_cap': 0, 'pe_ratio': 0, 'dividend_yield': 0, 'volume': 0, 'info': 'Unable to fetch fundamental data'}
469
 
470
  def format_large_number(num):
471
- # ... (kode yang sama)
472
  if num >= 1e12:
473
  return f"{num/1e12:.2f}T"
474
  elif num >= 1e9:
@@ -481,7 +496,6 @@ def format_large_number(num):
481
  return f"{num:.2f}"
482
 
483
  def create_price_chart(data, indicators):
484
- # ... (kode yang sama)
485
  fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.05)
486
  fig.add_trace(go.Candlestick(x=data.index, open=data['Open'], high=data['High'], low=data['Low'], close=data['Close'], name='Price'), row=1, col=1)
487
  fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_20_values'], name='SMA 20', line=dict(color='orange')), row=1, col=1)
@@ -493,7 +507,6 @@ def create_price_chart(data, indicators):
493
  return fig
494
 
495
  def create_technical_chart(data, indicators):
496
- # ... (kode yang sama)
497
  fig = make_subplots(rows=2, cols=2, subplot_titles=('Bollinger Bands', 'Volume', 'Price vs MA', 'RSI Analysis'))
498
  fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='black')), row=1, col=1)
499
  fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['upper_values'], name='Upper Band', line=dict(color='red')), row=1, col=1)
 
9
  import gc
10
  import time
11
  import random
12
+ from chronos import ChronosPipeline
13
  from scipy.stats import skew, kurtosis
14
  from typing import Dict, Union, List
15
 
 
44
  clear_gpu_memory()
45
  print(f"Loading Chronos model: {model_name}...")
46
 
47
+ # FIX 1: Menyederhanakan argumen untuk menghindari error 'input_patch_size'
48
  pipeline = ChronosPipeline.from_pretrained(
49
  model_name,
50
  device_map="cuda",
51
  torch_dtype=torch.float16,
52
+ # Menghapus argumen yang mungkin memicu error konfigurasi
 
 
53
  )
54
+
55
  pipeline.model = pipeline.model.eval()
56
  for param in pipeline.model.parameters():
57
  param.requires_grad = False
 
60
  return pipeline
61
 
62
  except Exception as e:
63
+ # Menampilkan error yang lebih spesifik
64
  print(f"Error loading pipeline on CUDA, trying CPU: {str(e)}")
65
  try:
66
  # Fallback ke CPU
 
71
  except Exception as cpu_e:
72
  raise RuntimeError(f"Failed to load model {model_name} on both CUDA and CPU: {str(cpu_e)}")
73
 
74
+ # ... (Fungsi-fungsi lain: retry_yfinance_request, fetch_enhanced_covariates, calculate_advanced_risk_metrics)
75
+
76
  def retry_yfinance_request(func, max_retries=3, initial_delay=1):
77
  """Mekanisme retry untuk permintaan yfinance dengan backoff eksponensial."""
78
  for attempt in range(max_retries):
 
178
  """Memprediksi MACD dan Bollinger Bands di masa depan berdasarkan prediksi harga."""
179
  predictions = {}
180
 
181
+ # Pastikan price_prediction tidak kosong sebelum diolah
182
+ if price_prediction.size == 0:
183
+ return {"MACD_Future": np.array([]), "MACD_Signal_Future": np.array([]), "BB_Upper_Future": np.array([]), "BB_Lower_Future": np.array([])}
184
+
185
  full_price_series = np.concatenate([data['Close'].values, price_prediction])
186
  full_price_series = pd.Series(full_price_series)
187
 
 
214
  @spaces.GPU(duration=120)
215
  def predict_prices(data, prediction_days=30):
216
  """Fungsi prediksi utama menggunakan Chronos-2 dengan enhanced covariates."""
217
+
218
+ # Default return structure for errors (FIX 2: Menggunakan np.array([]) yang aman)
219
+ empty_result = {
220
+ 'values': np.array([]), 'dates': pd.Series([], dtype='datetime64[ns]'),
221
+ 'high_30d': 0, 'low_30d': 0, 'mean_30d': 0, 'change_pct': 0,
222
+ 'q01': np.array([]), 'q09': np.array([]),
223
+ 'future_macd': np.array([]), 'future_macd_signal': np.array([]),
224
+ 'future_bb_upper': np.array([]), 'future_bb_lower': np.array([]),
225
+ 'summary': 'Prediction failed due to model or data error.'
226
+ }
227
+
228
  try:
229
+ # 1. Load Model (Akan memanggil load_pipeline yang sudah diperbaiki)
230
  pipeline = load_pipeline()
231
 
232
  data_original = data.copy()
 
282
  'change_pct': change_pct,
283
  'q01': q01_forecast,
284
  'q09': q09_forecast,
285
+ 'future_macd': future_indicators.get('MACD_Future', np.array([])),
286
+ 'future_macd_signal': future_indicators.get('MACD_Signal_Future', np.array([])),
287
+ 'future_bb_upper': future_indicators.get('BB_Upper_Future', np.array([])),
288
+ 'future_bb_lower': future_indicators.get('BB_Lower_Future', np.array([])),
289
  'summary': f"AI Model: Amazon Chronos-2 (Enhanced Covariates: {len(all_covariates)} features)\nExpected High: {predicted_high:.2f}\nExpected Low: {predicted_low:.2f}\nExpected Change: {change_pct:.2f}%"
290
  }
291
 
292
  except Exception as e:
293
  error_message = f'Model prediction failed: {e}'
294
  print(f"Error in prediction: {e}")
295
+ empty_result['summary'] = error_message
296
+ return empty_result
297
 
298
  # Memperbarui fungsi create_prediction_chart untuk menampilkan Quantile Bands (q01, q09) dan Future BB
299
  def create_prediction_chart(data, predictions):
300
+ # Cek yang lebih aman untuk array kosong
301
+ if not predictions['values'].size or not predictions['q01'].size:
302
  return go.Figure().update_layout(title="Prediction Failed: No Data Available")
303
 
304
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.05,
 
318
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['values'], name='Median Forecast (Q0.5)', line=dict(color='red', width=3, dash='solid')), row=1, col=1)
319
 
320
  # Future Bollinger Bands
321
+ if predictions['future_bb_upper'].size == predictions['dates'].size:
322
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['future_bb_upper'], name='BB Upper (Future)', line=dict(color='green', width=1, dash='dot')), row=1, col=1)
323
  fig.add_trace(go.Scatter(x=predictions['dates'], y=predictions['future_bb_lower'], name='BB Lower (Future)', line=dict(color='green', width=1, dash='dot')), row=1, col=1)
324
 
 
327
  fig.add_trace(go.Scatter(x=[last_hist_date], y=[last_hist_price], mode='markers', marker=dict(size=10, color='blue', symbol='circle'), name='Last Known Price'), row=1, col=1)
328
 
329
  # 2. MACD Forecast (Row 2)
330
+ if predictions['future_macd'].size == predictions['dates'].size:
331
 
332
+ # Perluas data historis MACD untuk charting yang lebih baik
333
+ lookback_period = 60
334
  macd_hist = data['Close'].ewm(span=12).mean() - data['Close'].ewm(span=26).mean()
335
  macd_signal_hist = macd_hist.ewm(span=9).mean()
336
 
337
+ macd_full = np.concatenate([macd_hist.iloc[-lookback_period:].values, predictions['future_macd']])
338
+ macd_signal_full = np.concatenate([macd_signal_hist.iloc[-lookback_period:].values, predictions['future_macd_signal']])
339
+ macd_dates_full = pd.to_datetime(np.concatenate([data.index[-lookback_period:].values, predictions['dates']]))
340
 
341
  fig.add_trace(go.Scatter(x=macd_dates_full, y=macd_full, name='MACD Line', line=dict(color='blue', width=2)), row=2, col=1)
342
  fig.add_trace(go.Scatter(x=macd_dates_full, y=macd_signal_full, name='Signal Line', line=dict(color='red', width=1)), row=2, col=1)
 
355
 
356
  return fig
357
 
358
+ # ... (Fungsi-fungsi lama lainnya seperti get_indonesian_stocks, calculate_technical_indicators, dll. tetap sama)
359
  def get_indonesian_stocks():
 
360
  return {
361
  "BBCA.JK": "Bank Central Asia", "BBRI.JK": "Bank BRI", "BBNI.JK": "Bank BNI",
362
  "BMRI.JK": "Bank Mandiri", "TLKM.JK": "Telkom Indonesia", "UNVR.JK": "Unilever Indonesia",
 
368
  }
369
 
370
  def calculate_technical_indicators(data):
 
371
  indicators = {}
372
 
373
  def calculate_rsi(prices, period=14):
 
410
  indicators['volume'] = {'current': data['Volume'].iloc[-1], 'avg_20': data['Volume'].rolling(20).mean().iloc[-1], 'ratio': data['Volume'].iloc[-1] / data['Volume'].rolling(20).mean().iloc[-1]}
411
 
412
  # Tambahkan kolom indikator ke DataFrame input untuk digunakan nanti (di predict_technical_indicators_future)
 
413
  data['RSI'] = rsi_series
414
  data['MACD'] = macd
415
  data['MACD_Signal'] = signal_line
 
417
  return indicators
418
 
419
  def generate_trading_signals(data, indicators):
 
420
  signals = {}
421
  current_price = data['Close'].iloc[-1]
422
  buy_signals = 0
 
475
  return signals
476
 
477
  def get_fundamental_data(stock):
 
478
  try:
479
  info = stock.info
480
  history = stock.history(period="1d")
 
484
  return {'name': 'N/A', 'current_price': 0, 'market_cap': 0, 'pe_ratio': 0, 'dividend_yield': 0, 'volume': 0, 'info': 'Unable to fetch fundamental data'}
485
 
486
  def format_large_number(num):
 
487
  if num >= 1e12:
488
  return f"{num/1e12:.2f}T"
489
  elif num >= 1e9:
 
496
  return f"{num:.2f}"
497
 
498
  def create_price_chart(data, indicators):
 
499
  fig = make_subplots(rows=3, cols=1, shared_xaxes=True, vertical_spacing=0.05)
500
  fig.add_trace(go.Candlestick(x=data.index, open=data['Open'], high=data['High'], low=data['Low'], close=data['Close'], name='Price'), row=1, col=1)
501
  fig.add_trace(go.Scatter(x=data.index, y=indicators['moving_averages']['sma_20_values'], name='SMA 20', line=dict(color='orange')), row=1, col=1)
 
507
  return fig
508
 
509
  def create_technical_chart(data, indicators):
 
510
  fig = make_subplots(rows=2, cols=2, subplot_titles=('Bollinger Bands', 'Volume', 'Price vs MA', 'RSI Analysis'))
511
  fig.add_trace(go.Scatter(x=data.index, y=data['Close'], name='Price', line=dict(color='black')), row=1, col=1)
512
  fig.add_trace(go.Scatter(x=data.index, y=indicators['bollinger']['upper_values'], name='Upper Band', line=dict(color='red')), row=1, col=1)