cryptogold-prime / model_handler.py
omniverse1's picture
Update Gradio app with multiple files
6bf1eb6 verified
import numpy as np
import torch
import warnings
# Make chronos import optional
try:
from chronos import BaseChronosPipeline
CHRONOS_AVAILABLE = True
except ImportError:
warnings.warn("Chronos-forecasting not available. Using fallback predictions.")
CHRONOS_AVAILABLE = False
BaseChronosPipeline = None
class ModelHandler:
def __init__(self):
self.model_name = "amazon/chronos-2"
self.pipeline = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model()
def load_model(self):
"""Load Chronos-2 model using the official BaseChronosPipeline"""
if not CHRONOS_AVAILABLE:
print("Chronos-forecasting not installed. Using fallback prediction method.")
return
try:
print(f"Loading {self.model_name} on {self.device}...")
self.pipeline = BaseChronosPipeline.from_pretrained(
self.model_name,
device_map=self.device,
)
print("Chronos-2 pipeline loaded successfully.")
except Exception as e:
print(f"Error loading Chronos-2 model: {e}")
print("Using fallback prediction method")
self.pipeline = None
def predict(self, data, horizon=10):
"""Generate predictions using Chronos-2 or fallback."""
if not CHRONOS_AVAILABLE or self.pipeline is None:
# Fallback to simple trend-based prediction
return self._fallback_predict(data, horizon)
try:
if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
return self._fallback_predict(data, horizon)
# --- Chronos-2 Inference ---
predictions_samples = self.pipeline.predict(
data['original'],
prediction_length=horizon,
)
# Mengambil nilai rata-rata (mean) dari semua sampel atau single trajectory
if predictions_samples.ndim > 1 and predictions_samples.shape[0] > 1:
mean_predictions = np.mean(predictions_samples, axis=0)
elif predictions_samples.ndim > 1 and predictions_samples.shape[0] == 1:
mean_predictions = predictions_samples[0]
else:
mean_predictions = predictions_samples
return mean_predictions
except Exception as e:
print(f"Prediction error with Chronos: {e}. Using fallback.")
return self._fallback_predict(data, horizon)
def _fallback_predict(self, data, horizon=10):
"""Fallback prediction method when Chronos is unavailable"""
try:
if data is None or not isinstance(data, dict) or 'original' not in data:
# Return zero predictions if no data
return np.zeros(horizon)
values = data['original']
if len(values) < 5:
return np.zeros(horizon)
# Simple trend extrapolation
recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
predictions = []
last_value = values[-1]
for i in range(horizon):
next_value = last_value + recent_trend * (i + 1)
noise = np.random.normal(0, data.get('std', 1.0) * 0.1)
predictions.append(next_value + noise)
return np.array(predictions)
except Exception as e:
print(f"Fallback prediction error: {e}")
return np.zeros(horizon)