BFM / app.py
Mohammad Javad Darvishi
'add contact link'
b62199f
raw
history blame contribute delete
No virus
3.27 kB
import streamlit as st
import pandas as pd
import torch
from chronos import ChronosPipeline
import matplotlib.pyplot as plt
import numpy as np
# Load the Chronos Pipeline model
@st.cache_resource
def load_pipeline():
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="cpu", # Change to CPU
torch_dtype=torch.float32, # Use float32 for CPU
)
return pipeline
pipeline = load_pipeline()
# Streamlit app interface
st.title("Time Series Forecasting Demo with Deep Learning models")
st.write("This demo uses the ChronosPipeline model for time series forecasting.")
# Default time series data (comma-separated)
default_data = """
112, 118, 132, 129, 121, 135, 148, 148, 136, 119, 104, 118, 115, 126, 141, 135, 125, 149, 170, 170, 158,
133, 114, 140, 145, 150, 178, 163, 172, 178, 199, 199, 184, 162, 146, 166, 171, 180, 193, 181, 183, 218,
230, 242, 209, 191, 172, 194, 196, 196, 236, 235, 229, 243, 264, 272, 237, 211, 180, 201, 204, 188, 235,
227, 234, 264, 302, 293, 259, 229, 203, 229, 242, 233, 267, 269, 270, 315, 364, 347, 312, 274, 237, 278,
284, 277, 317, 313, 318, 374, 413, 405, 355, 306, 271, 306, 315, 301, 356, 348, 355, 422, 465, 467, 404,
347, 305, 336, 340, 318, 362, 348, 363, 435, 491, 505, 404, 359, 310, 337, 360, 342, 406, 396, 420, 472,
548, 559, 463, 407, 362, 405, 417, 391, 419, 461, 472, 535, 622, 606, 508, 461, 390, 432
"""
# Input field for user-provided data
user_input = st.text_area(
"Enter time series data (comma-separated values):",
default_data.strip()
)
# Convert user input into a list of numbers
def process_input(input_str):
return [float(x.strip()) for x in input_str.split(",")]
try:
time_series_data = process_input(user_input)
except ValueError:
st.error("Please make sure all values are numbers, separated by commas.")
time_series_data = [] # Set empty data on error to prevent further processing
# Select the number of months for forecasting
prediction_length = st.slider("Select Forecast Horizon (Months)", min_value=1, max_value=64, value=12)
# If data is valid, perform the forecast
if time_series_data:
# Convert the data to a tensor
context = torch.tensor(time_series_data, dtype=torch.float32)
# Make the forecast
forecast = pipeline.predict(
context=context,
prediction_length=prediction_length,
num_samples=20,
)
# Prepare forecast data for plotting
forecast_index = range(len(time_series_data), len(time_series_data) + prediction_length)
low, median, high = np.quantile(forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
# Plot the historical and forecasted data
plt.figure(figsize=(8, 4))
plt.plot(time_series_data, color="royalblue", label="Historical data")
plt.plot(forecast_index, median, color="tomato", label="Median forecast")
plt.fill_between(forecast_index, low, high, color="tomato", alpha=0.3, label="80% prediction interval")
plt.legend()
plt.grid()
# Show the plot in the Streamlit app
st.pyplot(plt)
# Note for comments, feedback, or questions
st.write("### Notes")
st.write("For comments, feedback, or any questions, please reach out to me on [LinkedIn](https://www.linkedin.com/in/mjdarvishi/).")