kuldeep0204 commited on
Commit
4d919ad
·
verified ·
1 Parent(s): 7f645b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -6
app.py CHANGED
@@ -1,9 +1,98 @@
 
 
1
  import torch
 
 
 
2
 
3
- MODEL_PATH = "model/model.pt"
4
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- print(f"Loading model from: {MODEL_PATH}")
7
- model = torch.load(MODEL_PATH, map_location=device, weights_only=False)
8
- model.to(device)
9
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
  import torch
4
+ import pandas as pd
5
+ from chronos import ChronosPipeline
6
+ from io import StringIO
7
 
8
+ # --- Model Loading ---
9
+ # This part is outside the function so it only runs once when the app starts
10
+ try:
11
+ model_name = "amazon/chronos-t5-small"
12
+ pipeline = ChronosPipeline.from_pretrained(
13
+ model_name,
14
+ device_map="cpu", # Force CPU usage for free tier
15
+ torch_dtype=torch.float32,
16
+ )
17
+ print(f"Loaded model: {model_name}")
18
+ except Exception as e:
19
+ # A fallback in case the model fails to load
20
+ print(f"Error loading model: {e}")
21
+ pipeline = None
22
 
23
+ # --- Prediction Function ---
24
+ def forecast_time_series(csv_file, prediction_length):
25
+ """
26
+ Takes a CSV file, extracts the last column (time series), and forecasts.
27
+ """
28
+ if pipeline is None:
29
+ return "Model failed to load. Please check logs/dependencies."
30
+
31
+ try:
32
+ # Read the CSV file content from the Gradio InputFile
33
+ content = csv_file.read().decode('utf-8')
34
+ df = pd.read_csv(StringIO(content))
35
+
36
+ # Assume the time series data is in the last column
37
+ # and has no missing values
38
+ historical_data = df.iloc[:, -1].values
39
+
40
+ if len(historical_data) < 50:
41
+ return "Please upload a time series with at least 50 historical points for a good forecast."
42
+
43
+ # Convert historical data to the required format
44
+ historical_series = torch.tensor(historical_data, dtype=torch.float32)
45
+
46
+ # Generate the forecast
47
+ forecast_samples = pipeline.predict(
48
+ historical_series,
49
+ prediction_length=int(prediction_length),
50
+ num_samples=20, # Number of probabilistic paths to generate
51
+ )
52
+
53
+ # Calculate the median for the central prediction line
54
+ median_forecast = np.quantile(forecast_samples.numpy(), 0.5, axis=0)
55
+
56
+ # Prepare the output data for plotting
57
+ historical_index = np.arange(len(historical_data))
58
+ forecast_index = np.arange(len(historical_data), len(historical_data) + int(prediction_length))
59
+
60
+ # Create a single plot with both historical and forecast data
61
+ plot_data = {
62
+ "Historical": list(historical_data),
63
+ "Forecast": list(median_forecast),
64
+ }
65
+
66
+ return {
67
+ "Historical": (historical_index, historical_data),
68
+ "Forecast": (forecast_index, median_forecast)
69
+ }
70
+
71
+ except Exception as e:
72
+ return f"An error occurred: {e}"
73
+
74
+ # --- Gradio Interface Setup ---
75
+ # Define the example input file structure (for user convenience)
76
+ example_data = [
77
+ [
78
+ 'date,value\n2025-01-01,10.0\n2025-01-02,11.5\n...\n2025-03-20,15.2',
79
+ 7
80
+ ] # A sample input isn't a file, so it can't be added directly here.
81
+ # Users will need to upload a CSV file manually.
82
+ ]
83
+
84
+
85
+ gr_plot = gr.Plot(label="Time Series Forecast (Historical + Predicted Median)")
86
+
87
+ gr.Interface(
88
+ fn=forecast_time_series,
89
+ inputs=[
90
+ gr.File(label="Upload a CSV file (Time series must be in the last column)"),
91
+ gr.Slider(minimum=7, maximum=30, step=1, value=14, label="Number of Future Steps (Days) to Predict"),
92
+ ],
93
+ outputs=gr_plot,
94
+ title="Chronos Time Series Forecasting Demo on Hugging Face",
95
+ description="Upload a CSV file containing a single historical time series. This demo uses the Chronos-T5-Small Foundation Model to generate a 14-day (default) forecast.",
96
+ examples=None,
97
+ live=False,
98
+ ).launch()