afghan_forecast / app.py
Kimosh's picture
updated to aug31 model and data
4bcee90
raw
history blame contribute delete
No virus
12.5 kB
import os
import gradio as gr
import torch
from darts import TimeSeries, concatenate
from darts.dataprocessing.transformers import Scaler
from darts.utils.timeseries_generation import datetime_attribute_timeseries
from darts.models.forecasting.tft_model import TFTModel
from darts.metrics import mape
from dateutil.relativedelta import relativedelta
import warnings
warnings.filterwarnings("ignore")
import logging
logging.disable(logging.CRITICAL)
import pandas as pd
import numpy as np
from typing import Any, List, Optional
import plotly.graph_objects as go
df_final = pd.read_csv('data/all_afghan.csv',parse_dates=['Date'])
df_comtrade_flour = pd.read_csv('data/comtrade_flour.csv',parse_dates=['Date'])
df_comtrade_grain = pd.read_csv('data/comtrade_grain.csv',parse_dates=['Date'])
series = TimeSeries.from_dataframe(df_final,
time_col='Date',
value_cols=['price', 'usdprice', 'wheat_grain', 'exchange_rate','common_unit_price','black_sea']
)
six_months = df_final['Date'].max() + relativedelta(months=-6)
data_series = series['common_unit_price']
train, val = data_series.split_after(six_months)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(data_series)
# create year, month and integer index covariate series
covariates = datetime_attribute_timeseries(series_transformed, attribute="year", one_hot=False)
covariates = covariates.stack(
datetime_attribute_timeseries(series_transformed, attribute="month", one_hot=True)
)
covariates = covariates.stack(
TimeSeries.from_times_and_values(
times=series_transformed.time_index,
values=np.arange(len(series_transformed)),
)
)
covariates = covariates.add_holidays(country_code="ES")
covariates = covariates.astype(np.float32)
scaler_covs = Scaler()
cov_train, cov_val = covariates.split_after(six_months)
cov_train = scaler_covs.fit_transform(cov_train)
cov_val = scaler_covs.transform(cov_val)
covariates_transformed = scaler_covs.transform(covariates)
grain_series = series['wheat_grain']
grain_scaler = Scaler()
grain_train, grain_val = grain_series.split_after(six_months)
grain_train = grain_scaler.fit_transform(grain_train)
grain_val = grain_scaler.transform(grain_val)
grain_series_scaled = grain_scaler.transform(grain_series)
pakistan_series = series["price"]
pakistan_scaler = Scaler()
pakistan_train, pakistan_val = pakistan_series.split_after(six_months)
pakistan_train = pakistan_scaler.fit_transform(pakistan_train)
pakistan_val = pakistan_scaler.transform(pakistan_val)
pakistan_series_scaled = pakistan_scaler.transform(pakistan_series)
usd_series = series['usdprice']
usd_scaler = Scaler()
usd_train, usd_val = usd_series.split_after(six_months)
usd_train = usd_scaler.fit_transform(usd_train)
usd_val = usd_scaler.transform(usd_val)
usd_series_scaled = usd_scaler.transform(usd_series)
erate_series = series['exchange_rate']
erate_scaler = Scaler()
erate_train, erate_val = erate_series.split_after(six_months)
erate_train_transformed = erate_scaler.fit_transform(erate_train)
erate_val_transformed = erate_scaler.transform(erate_val)
erate_series_scaled = erate_scaler.transform(erate_series)
black_sea = series['black_sea']
black_sea_scaler = Scaler()
black_train,black_val = black_sea.split_after(six_months)
black_train_transformed = black_sea_scaler.fit_transform(black_train)
black_val_transformed = black_sea_scaler.transform(black_val)
black_sea_series = black_sea_scaler.transform(black_sea)
comtrade_flour_series = TimeSeries.from_dataframe(df_comtrade_flour,
time_col="Date")
comtrade_grain_series = TimeSeries.from_dataframe(df_comtrade_grain,
time_col="Date")
from darts import concatenate
my_multivariate_series = concatenate(
[
grain_series_scaled,
pakistan_series_scaled,
# usd_series_scaled,
erate_series_scaled,
black_sea_series,
comtrade_flour_series,
comtrade_grain_series,
covariates_transformed,
],
axis=1)
multivariate_series_train = concatenate(
[
grain_train,
pakistan_train,
# usd_train,
erate_train,
#russian_train_transformed,
# black_train_transformed,
cov_train,
],
axis=1)
class FlaggingHandler(gr.FlaggingCallback):
def __init__(self):
self._csv_logger = gr.CSVLogger()
def setup(self, components: List[gr.components.Component], flagging_dir: str):
"""Called by Gradio at the beginning of the `Interface.launch()` method.
Parameters:
components: Set of components that will provide flagged data.
flagging_dir: A string, typically containing the path to the directory where
the flagging file should be storied (provided as an argument to Interface.__init__()).
"""
self.components = components
self._csv_logger.setup(components=components, flagging_dir=flagging_dir)
def flag(
self,
flag_data: List[Any],
flag_option: Optional[str] = None,
# flag_index: Optional[int] = None,
username: Optional[str] = None,
) -> int:
"""Called by Gradio whenver one of the <flag> buttons is clicked.
Parameters:
interface: The Interface object that is being used to launch the flagging interface.
flag_data: The data to be flagged.
flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
flag_index (optional): The index of the sample that is being flagged.
username (optional): The username of the user that is flagging the data, if logged in.
Returns:
(int) The total number of samples that have been flagged.
"""
for item in flag_data:
print(f"Flagging: {item}")
if flag_option:
print(f"Flag option: {flag_option}")
# if flag_index:
# print(f"Flag index: {flag_index}")
flagged_count = self._csv_logger.flag(
flag_data=flag_data,
flag_option=flag_option,
# flag_index=flag_index,
# username=username,
)
return flagged_count
def get_forecast(period_: str, pred_model: str):
# Let the prediction service do its magic.
period = int(period_[0])
afgh_model = TFTModel.load("Afghan_w_blacksea_allcomtrade_aug31.pt",map_location=torch.device('cpu'))
### afgh model###
pred_series = afgh_model.predict(n=period,num_samples=1)
preds = transformer.inverse_transform(pred_series)
# creating a Dataframe
df_= preds.pd_dataframe()
df_.rename(columns={'common_unit_price': 'Wheat_Forecast'},inplace=True)
# error intervals:
# Calculate the 90% and 110% forecast values
forecast_90 = preds * 0.9
forecast_110 = preds * 1.1
df_90 = forecast_90.pd_dataframe()
df_90.rename(columns={'common_unit_price': 'Lower_Limit'},inplace=True)
df_110 = forecast_110.pd_dataframe()
df_110.rename(columns={'common_unit_price': 'Upper_Limit'},inplace=True)
merged_df = pd.merge(df_90,df_, on=['Date']).merge(df_110, on=['Date'])
merged_df = merged_df.reset_index()
merged_df.to_csv('data/afghan_wheatfcasts.csv',index=False)
start=pd.Timestamp("20180131")
backtest_series_ = afgh_model.historical_forecasts(
series_transformed,
past_covariates=my_multivariate_series,
start=start,
forecast_horizon=period,
retrain=False,
verbose=False,
)
series_time = series_transformed[-len(backtest_series_):].time_index
series_vals = (transformer.inverse_transform(series_transformed[-len(backtest_series_):])).values()
df_series = pd.DataFrame(data={'Date': series_time, 'actual_prices': series_vals.ravel() })
vals = (transformer.inverse_transform(backtest_series_)).values()
df_backtest = pd.DataFrame(data={'Date': backtest_series_.time_index, 'historical_forecasts': vals.ravel() })
# df_backtest_wheat = pd.DataFrame(data={'Date': backtest_series_.time_index, 'historical_wheat_forecasts': vals.ravel() })
df_wheat_output = pd.merge(df_series,df_backtest[['Date',"historical_forecasts"]],on=['Date'],how='left')
df_wheat_output.to_csv('data/aghanwheat_allhistorical.csv',index=False)
# Create figure
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=list(df_backtest.Date),
y=list(df_backtest.historical_forecasts),
name='historical forecasts'
# x=list(df.Date), y=list(df.High)
))
fig.add_trace(
go.Scatter(
x=list(df_series.Date),
y=list(df_series.actual_prices),
name="actual prices",
))
fig.add_trace(go.Scatter(
x = list(merged_df.Date),
y=list(merged_df.Upper_Limit),
name="Upper limit"
))
fig.add_trace(go.Scatter(
x = list(merged_df.Date),
y=list(merged_df.Lower_Limit),
name="Lower limit"
))
fig.add_trace(go.Scatter(
x = list(merged_df.Date),
y=list(merged_df.Wheat_Forecast),
name=" Wheat Forecast"
))
# Set title
fig.update_layout(
title_text=f"\n Mean Absolute Percentage Error {mape(transformer.inverse_transform(series_transformed), transformer.inverse_transform(backtest_series_)):.2f}%"
)
# Add range slider
fig.update_layout(
xaxis=dict(
rangeselector=dict(
buttons=list([
dict(count=1,
label="1m",
step="month",
stepmode="backward"),
dict(count=6,
label="6m",
step="month",
stepmode="todate"),
dict(count=1,
label="YTD",
step="year",
stepmode="todate"),
# dict(count=1,
# label="1y",
# step="year",
# stepmode="backward"),
# dict(step="all")
])
),
rangeslider=dict(
visible=True
),
type="date"
)
)
return merged_df,fig
def main():
flagging_handler = FlaggingHandler()
# example_url = "" # noqa: E501
with gr.Blocks() as iface:
gr.Markdown(
"""
**Timeseries Forecasting model Temporal Fusion Transformer(TFT) built on Darts library**.
""")
commodity = gr.Radio(["Wheat Price Forecasting"],label="Commodity to Forecast")
period = gr.Radio(['3 months',"6 months"],label="Forecast horizon")
# with gr.Row():
# lib = gr.Dropdown(["pandas", "scikit-learn", "torch", "prophet"], label="Library", value="torch")
# time = gr.Dropdown(["3 months", "6 months",], label="Downloads over the last...", value="6 months")
with gr.Row():
btn = gr.Button("Forecast.")
feedback = gr.Textbox(label="Give feedback")
gr.CSVLogger()
data_points = gr.Textbox(label=f"Forecast values. Lower and upper values include a 10% error rate")
plt = gr.Plot(label="Backtesting plot, from 2018").style()
btn.click(
get_forecast,
inputs=[period,commodity],
outputs = [data_points,plt]
)
with gr.Row():
btn_incorrect = gr.Button("Flag as incorrect")
btn_other = gr.Button("Flag as other")
flagging_handler.setup(
components=[commodity, period],
flagging_dir="data/flagged",
)
btn_incorrect.click(
lambda *args: flagging_handler.flag(
flag_data=args, flag_option="Incorrect"
),
[commodity, data_points, period,feedback],
None,
preprocess=False,
)
btn_other.click(
lambda *args: flagging_handler.flag(flag_data=args, flag_option="Other"),
[commodity, data_points, period,feedback],
None,
preprocess=False,
)
iface.launch(debug=True, inline=False)
main()