Spaces:
Running
Running
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -157,6 +157,7 @@ def format_large_number(num):
|
|
| 157 |
@spaces.GPU(duration=120)
|
| 158 |
def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
| 159 |
try:
|
|
|
|
| 160 |
pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-2", device_map="auto")
|
| 161 |
|
| 162 |
# Chronos-2 with Covariate: Menggunakan Close (target) dan Volume (covariate)
|
|
@@ -164,16 +165,21 @@ def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
|
| 164 |
context_df.columns = ['timestamp', 'target', 'volume']
|
| 165 |
context_df['id'] = 'stock_price'
|
| 166 |
|
| 167 |
-
# Fix Error: Could not infer frequency
|
| 168 |
context_df['timestamp'] = pd.to_datetime(context_df['timestamp'])
|
| 169 |
-
context_df = context_df.set_index('timestamp').asfreq('D')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# Pastikan kolom sesuai urutan Chronos-2: timestamp, target, covariate(s), id
|
| 172 |
context_df['id'] = 'stock_price'
|
| 173 |
context_df = context_df[['timestamp', 'target', 'volume', 'id']]
|
| 174 |
|
| 175 |
with torch.no_grad():
|
| 176 |
-
# Menggunakan multiple quantile levels untuk Probabilistic Forecast (90% CI)
|
| 177 |
pred_df = pipeline.predict_df(
|
| 178 |
context_df,
|
| 179 |
prediction_length=prediction_days,
|
|
@@ -183,12 +189,13 @@ def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
|
| 183 |
quantile_levels=[0.1, 0.5, 0.9]
|
| 184 |
)
|
| 185 |
|
| 186 |
-
# --- FIX UTAMA: Pengecekan kolom hasil prediksi ---
|
| 187 |
required_cols = ['target_0.1', 'target_0.5', 'target_0.9']
|
| 188 |
if pred_df.empty or not all(col in pred_df.columns for col in required_cols):
|
| 189 |
-
#
|
| 190 |
-
|
| 191 |
-
|
|
|
|
| 192 |
|
| 193 |
# Ekstraksi hasil prediksi kuantil
|
| 194 |
q05_forecast = pred_df['target_0.5'].values.astype(np.float32)
|
|
@@ -198,13 +205,11 @@ def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
|
| 198 |
|
| 199 |
last_price = data['Close'].iloc[-1]
|
| 200 |
|
| 201 |
-
# Statistik prediksi dihitung dari median (Q0.5)
|
| 202 |
predicted_high = float(np.max(q05_forecast))
|
| 203 |
predicted_low = float(np.min(q05_forecast))
|
| 204 |
predicted_mean = float(np.mean(q05_forecast))
|
| 205 |
change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
|
| 206 |
|
| 207 |
-
# Mengembalikan semua kuantil untuk chart
|
| 208 |
return {
|
| 209 |
'values': q05_forecast,
|
| 210 |
'dates': predicted_dates,
|
|
|
|
| 157 |
@spaces.GPU(duration=120)
|
| 158 |
def predict_prices(data, model=None, tokenizer=None, prediction_days=30):
|
| 159 |
try:
|
| 160 |
+
# Panggil pipeline di sini untuk memastikan instance baru tiap run (mencegah error memori/state)
|
| 161 |
pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-2", device_map="auto")
|
| 162 |
|
| 163 |
# Chronos-2 with Covariate: Menggunakan Close (target) dan Volume (covariate)
|
|
|
|
| 165 |
context_df.columns = ['timestamp', 'target', 'volume']
|
| 166 |
context_df['id'] = 'stock_price'
|
| 167 |
|
| 168 |
+
# Fix Error: Could not infer frequency & FIX VOLUME COVARIATE IMPUTATION
|
| 169 |
context_df['timestamp'] = pd.to_datetime(context_df['timestamp'])
|
| 170 |
+
context_df = context_df.set_index('timestamp').asfreq('D')
|
| 171 |
+
|
| 172 |
+
# IMPUTATION FIX: Target ffill, Covariate (Volume) fillna(0)
|
| 173 |
+
context_df['target'] = context_df['target'].fillna(method='ffill')
|
| 174 |
+
context_df['volume'] = context_df['volume'].fillna(0)
|
| 175 |
+
|
| 176 |
+
context_df = context_df.reset_index()
|
| 177 |
|
| 178 |
# Pastikan kolom sesuai urutan Chronos-2: timestamp, target, covariate(s), id
|
| 179 |
context_df['id'] = 'stock_price'
|
| 180 |
context_df = context_df[['timestamp', 'target', 'volume', 'id']]
|
| 181 |
|
| 182 |
with torch.no_grad():
|
|
|
|
| 183 |
pred_df = pipeline.predict_df(
|
| 184 |
context_df,
|
| 185 |
prediction_length=prediction_days,
|
|
|
|
| 189 |
quantile_levels=[0.1, 0.5, 0.9]
|
| 190 |
)
|
| 191 |
|
| 192 |
+
# --- FIX UTAMA: Pengecekan kolom hasil prediksi yang lebih ketat ---
|
| 193 |
required_cols = ['target_0.1', 'target_0.5', 'target_0.9']
|
| 194 |
if pred_df.empty or not all(col in pred_df.columns for col in required_cols):
|
| 195 |
+
# Jika gagal, pastikan kita tahu errornya dan melempar Runtime yang akan ditangkap di luar
|
| 196 |
+
missing = [col for col in required_cols if col not in pred_df.columns]
|
| 197 |
+
raise RuntimeError(f"Prediction failed. Result DataFrame is empty or incomplete. Missing: {missing}")
|
| 198 |
+
# ------------------------------------------------------------------
|
| 199 |
|
| 200 |
# Ekstraksi hasil prediksi kuantil
|
| 201 |
q05_forecast = pred_df['target_0.5'].values.astype(np.float32)
|
|
|
|
| 205 |
|
| 206 |
last_price = data['Close'].iloc[-1]
|
| 207 |
|
|
|
|
| 208 |
predicted_high = float(np.max(q05_forecast))
|
| 209 |
predicted_low = float(np.min(q05_forecast))
|
| 210 |
predicted_mean = float(np.mean(q05_forecast))
|
| 211 |
change_pct = ((predicted_mean - last_price) / last_price) * 100 if last_price != 0 else 0
|
| 212 |
|
|
|
|
| 213 |
return {
|
| 214 |
'values': q05_forecast,
|
| 215 |
'dates': predicted_dates,
|