cryptogold-backend / model_handler.py
omniverse1's picture
Update Gradio app with multiple files
f27fbb4 verified
raw
history blame
2.35 kB
import numpy as np
import torch
from chronos import BaseChronosPipeline
class ModelHandler:
def __init__(self):
self.model_name = "amazon/chronos-2"
self.pipeline = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model()
def load_model(self):
"""Load Chronos-2 model using the official BaseChronosPipeline"""
try:
print(f"Loading {self.model_name} on {self.device}...")
self.pipeline = BaseChronosPipeline.from_pretrained(
self.model_name,
device_map=self.device,
)
print("Chronos-2 pipeline loaded successfully.")
except Exception as e:
print(f"Error loading Chronos-2 model: {e}")
print("Using fallback prediction method")
self.pipeline = None
def predict(self, data, horizon=10):
"""Generate predictions using Chronos-2 or fallback."""
try:
if data is None or not isinstance(data, dict) or 'original' not in data or len(data['original']) < 20:
return np.array([0] * horizon)
if self.pipeline is None:
# Fallback Logic
values = data['original']
recent_trend = np.polyfit(range(len(values[-20:])), values[-20:], 1)[0]
predictions = []
last_value = values[-1]
for i in range(horizon):
next_value = last_value + recent_trend * (i + 1)
noise = np.random.normal(0, data.get('std', 1.0) * 0.1)
predictions.append(next_value + noise)
return np.array(predictions)
# Chronos-2 Inference
predictions_samples = self.pipeline.predict(
data['original'],
prediction_length=horizon,
n_samples=20
)
# Mengambil nilai rata-rata (mean) dari semua sampel
mean_predictions = np.mean(predictions_samples, axis=0)
return mean_predictions
except Exception as e:
print(f"Prediction error: {e}")
return np.array([0] * horizon)