Spaces:
Running
Running
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| from chronos import Chronos2Pipeline | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| # Load the Chronos Pipeline model | |
| def load_pipeline(): | |
| pipeline = Chronos2Pipeline.from_pretrained( | |
| "amazon/chronos-2", | |
| device_map="cpu", # Change to CPU | |
| dtype=torch.float32, # Use float32 for CPU | |
| ) | |
| return pipeline | |
| pipeline = load_pipeline() | |
| # Streamlit app interface | |
| st.title("Time Series Forecasting Demo with Chronos-2") | |
| st.write("This demo uses **Chronos-2**, Amazon's state-of-the-art pretrained model for zero-shot time series forecasting.") | |
| # Default time series data (comma-separated) | |
| default_data = """ | |
| 112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158, | |
| 133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218, | |
| 230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235, | |
| 227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278, | |
| 284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404, | |
| 347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472, | |
| 548, 559, 463, 407, 362, 405, 417, 391, 419, 461, 472, 535, 622, 606, 508, 461, 390, 432 | |
| """ | |
| # Input field for user-provided data | |
| user_input = st.text_area( | |
| "Enter time series data (comma-separated values):", | |
| default_data.strip() | |
| ) | |
| # Convert user input into a list of numbers | |
| def process_input(input_str): | |
| return [float(x.strip()) for x in input_str.split(",")] | |
| try: | |
| time_series_data = process_input(user_input) | |
| except ValueError: | |
| st.error("Please make sure all values are numbers, separated by commas.") | |
| time_series_data = [] # Set empty data on error to prevent further processing | |
| # Select the number of months for forecasting | |
| prediction_length = st.slider("Select Forecast Horizon (Months)", min_value=1, max_value=64, value=12) | |
| # If data is valid, perform the forecast | |
| if time_series_data: | |
| # Create a DataFrame for Chronos-2 | |
| context_df = pd.DataFrame({ | |
| 'timestamp': pd.date_range(start='2020-01-01', periods=len(time_series_data), freq='ME'), | |
| 'target': time_series_data, | |
| 'id': 'series_1' | |
| }) | |
| # Make the forecast using Chronos-2 API | |
| pred_df = pipeline.predict_df( | |
| context_df, | |
| prediction_length=prediction_length, | |
| quantile_levels=[0.1, 0.5, 0.9], | |
| id_column="id", | |
| timestamp_column="timestamp", | |
| target="target", | |
| ) | |
| # Prepare forecast data for plotting | |
| forecast_index = range(len(time_series_data), len(time_series_data) + prediction_length) | |
| median = pred_df["predictions"].values | |
| low = pred_df["0.1"].values | |
| high = pred_df["0.9"].values | |
| # Plot the historical and forecasted data | |
| plt.figure(figsize=(8, 4)) | |
| plt.plot(time_series_data, color="royalblue", label="Historical data") | |
| plt.plot(forecast_index, median, color="tomato", label="Median forecast") | |
| plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval") | |
| plt.legend() | |
| plt.grid() | |
| # Show the plot in the Streamlit app | |
| st.pyplot(plt) | |
| # Note for comments, feedback, or questions | |
| st.write("### Notes") | |
| st.write("For comments, feedback, or any questions, please reach out to me on [LinkedIn](https://www.linkedin.com/in/javadbayazi/).") |