import os import gradio as gr import torch from darts import TimeSeries, concatenate from darts.dataprocessing.transformers import Scaler from darts.utils.timeseries_generation import datetime_attribute_timeseries from darts.models.forecasting.tft_model import TFTModel from darts.metrics import mape from dateutil.relativedelta import relativedelta import warnings warnings.filterwarnings("ignore") import logging logging.disable(logging.CRITICAL) import pandas as pd import numpy as np from typing import Any, List, Optional import plotly.graph_objects as go df_final = pd.read_csv('data/all_afghan.csv',parse_dates=['Date']) df_comtrade_flour = pd.read_csv('data/comtrade_flour.csv',parse_dates=['Date']) df_comtrade_grain = pd.read_csv('data/comtrade_grain.csv',parse_dates=['Date']) series = TimeSeries.from_dataframe(df_final, time_col='Date', value_cols=['price', 'usdprice', 'wheat_grain', 'exchange_rate','common_unit_price','black_sea'] ) six_months = df_final['Date'].max() + relativedelta(months=-6) data_series = series['common_unit_price'] train, val = data_series.split_after(six_months) transformer = Scaler() train_transformed = transformer.fit_transform(train) val_transformed = transformer.transform(val) series_transformed = transformer.transform(data_series) # create year, month and integer index covariate series covariates = datetime_attribute_timeseries(series_transformed, attribute="year", one_hot=False) covariates = covariates.stack( datetime_attribute_timeseries(series_transformed, attribute="month", one_hot=True) ) covariates = covariates.stack( TimeSeries.from_times_and_values( times=series_transformed.time_index, values=np.arange(len(series_transformed)), ) ) covariates = covariates.add_holidays(country_code="ES") covariates = covariates.astype(np.float32) scaler_covs = Scaler() cov_train, cov_val = covariates.split_after(six_months) cov_train = scaler_covs.fit_transform(cov_train) cov_val = scaler_covs.transform(cov_val) covariates_transformed = scaler_covs.transform(covariates) grain_series = series['wheat_grain'] grain_scaler = Scaler() grain_train, grain_val = grain_series.split_after(six_months) grain_train = grain_scaler.fit_transform(grain_train) grain_val = grain_scaler.transform(grain_val) grain_series_scaled = grain_scaler.transform(grain_series) pakistan_series = series["price"] pakistan_scaler = Scaler() pakistan_train, pakistan_val = pakistan_series.split_after(six_months) pakistan_train = pakistan_scaler.fit_transform(pakistan_train) pakistan_val = pakistan_scaler.transform(pakistan_val) pakistan_series_scaled = pakistan_scaler.transform(pakistan_series) usd_series = series['usdprice'] usd_scaler = Scaler() usd_train, usd_val = usd_series.split_after(six_months) usd_train = usd_scaler.fit_transform(usd_train) usd_val = usd_scaler.transform(usd_val) usd_series_scaled = usd_scaler.transform(usd_series) erate_series = series['exchange_rate'] erate_scaler = Scaler() erate_train, erate_val = erate_series.split_after(six_months) erate_train_transformed = erate_scaler.fit_transform(erate_train) erate_val_transformed = erate_scaler.transform(erate_val) erate_series_scaled = erate_scaler.transform(erate_series) black_sea = series['black_sea'] black_sea_scaler = Scaler() black_train,black_val = black_sea.split_after(six_months) black_train_transformed = black_sea_scaler.fit_transform(black_train) black_val_transformed = black_sea_scaler.transform(black_val) black_sea_series = black_sea_scaler.transform(black_sea) comtrade_flour_series = TimeSeries.from_dataframe(df_comtrade_flour, time_col="Date") comtrade_grain_series = TimeSeries.from_dataframe(df_comtrade_grain, time_col="Date") from darts import concatenate my_multivariate_series = concatenate( [ grain_series_scaled, pakistan_series_scaled, # usd_series_scaled, erate_series_scaled, black_sea_series, comtrade_flour_series, comtrade_grain_series, covariates_transformed, ], axis=1) multivariate_series_train = concatenate( [ grain_train, pakistan_train, # usd_train, erate_train, #russian_train_transformed, # black_train_transformed, cov_train, ], axis=1) class FlaggingHandler(gr.FlaggingCallback): def __init__(self): self._csv_logger = gr.CSVLogger() def setup(self, components: List[gr.components.Component], flagging_dir: str): """Called by Gradio at the beginning of the `Interface.launch()` method. Parameters: components: Set of components that will provide flagged data. flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()). """ self.components = components self._csv_logger.setup(components=components, flagging_dir=flagging_dir) def flag( self, flag_data: List[Any], flag_option: Optional[str] = None, # flag_index: Optional[int] = None, username: Optional[str] = None, ) -> int: """Called by Gradio whenver one of the buttons is clicked. Parameters: interface: The Interface object that is being used to launch the flagging interface. flag_data: The data to be flagged. flag_option (optional): In the case that flagging_options are provided, the flag option that is being used. flag_index (optional): The index of the sample that is being flagged. username (optional): The username of the user that is flagging the data, if logged in. Returns: (int) The total number of samples that have been flagged. """ for item in flag_data: print(f"Flagging: {item}") if flag_option: print(f"Flag option: {flag_option}") # if flag_index: # print(f"Flag index: {flag_index}") flagged_count = self._csv_logger.flag( flag_data=flag_data, flag_option=flag_option, # flag_index=flag_index, # username=username, ) return flagged_count def get_forecast(period_: str, pred_model: str): # Let the prediction service do its magic. period = int(period_[0]) afgh_model = TFTModel.load("Afghan_w_blacksea_allcomtrade_aug31.pt",map_location=torch.device('cpu')) ### afgh model### pred_series = afgh_model.predict(n=period,num_samples=1) preds = transformer.inverse_transform(pred_series) # creating a Dataframe df_= preds.pd_dataframe() df_.rename(columns={'common_unit_price': 'Wheat_Forecast'},inplace=True) # error intervals: # Calculate the 90% and 110% forecast values forecast_90 = preds * 0.9 forecast_110 = preds * 1.1 df_90 = forecast_90.pd_dataframe() df_90.rename(columns={'common_unit_price': 'Lower_Limit'},inplace=True) df_110 = forecast_110.pd_dataframe() df_110.rename(columns={'common_unit_price': 'Upper_Limit'},inplace=True) merged_df = pd.merge(df_90,df_, on=['Date']).merge(df_110, on=['Date']) merged_df = merged_df.reset_index() merged_df.to_csv('data/afghan_wheatfcasts.csv',index=False) start=pd.Timestamp("20180131") backtest_series_ = afgh_model.historical_forecasts( series_transformed, past_covariates=my_multivariate_series, start=start, forecast_horizon=period, retrain=False, verbose=False, ) series_time = series_transformed[-len(backtest_series_):].time_index series_vals = (transformer.inverse_transform(series_transformed[-len(backtest_series_):])).values() df_series = pd.DataFrame(data={'Date': series_time, 'actual_prices': series_vals.ravel() }) vals = (transformer.inverse_transform(backtest_series_)).values() df_backtest = pd.DataFrame(data={'Date': backtest_series_.time_index, 'historical_forecasts': vals.ravel() }) # df_backtest_wheat = pd.DataFrame(data={'Date': backtest_series_.time_index, 'historical_wheat_forecasts': vals.ravel() }) df_wheat_output = pd.merge(df_series,df_backtest[['Date',"historical_forecasts"]],on=['Date'],how='left') df_wheat_output.to_csv('data/aghanwheat_allhistorical.csv',index=False) # Create figure fig = go.Figure() fig.add_trace( go.Scatter( x=list(df_backtest.Date), y=list(df_backtest.historical_forecasts), name='historical forecasts' # x=list(df.Date), y=list(df.High) )) fig.add_trace( go.Scatter( x=list(df_series.Date), y=list(df_series.actual_prices), name="actual prices", )) fig.add_trace(go.Scatter( x = list(merged_df.Date), y=list(merged_df.Upper_Limit), name="Upper limit" )) fig.add_trace(go.Scatter( x = list(merged_df.Date), y=list(merged_df.Lower_Limit), name="Lower limit" )) fig.add_trace(go.Scatter( x = list(merged_df.Date), y=list(merged_df.Wheat_Forecast), name=" Wheat Forecast" )) # Set title fig.update_layout( title_text=f"\n Mean Absolute Percentage Error {mape(transformer.inverse_transform(series_transformed), transformer.inverse_transform(backtest_series_)):.2f}%" ) # Add range slider fig.update_layout( xaxis=dict( rangeselector=dict( buttons=list([ dict(count=1, label="1m", step="month", stepmode="backward"), dict(count=6, label="6m", step="month", stepmode="todate"), dict(count=1, label="YTD", step="year", stepmode="todate"), # dict(count=1, # label="1y", # step="year", # stepmode="backward"), # dict(step="all") ]) ), rangeslider=dict( visible=True ), type="date" ) ) return merged_df,fig def main(): flagging_handler = FlaggingHandler() # example_url = "" # noqa: E501 with gr.Blocks() as iface: gr.Markdown( """ **Timeseries Forecasting model Temporal Fusion Transformer(TFT) built on Darts library**. """) commodity = gr.Radio(["Wheat Price Forecasting"],label="Commodity to Forecast") period = gr.Radio(['3 months',"6 months"],label="Forecast horizon") # with gr.Row(): # lib = gr.Dropdown(["pandas", "scikit-learn", "torch", "prophet"], label="Library", value="torch") # time = gr.Dropdown(["3 months", "6 months",], label="Downloads over the last...", value="6 months") with gr.Row(): btn = gr.Button("Forecast.") feedback = gr.Textbox(label="Give feedback") gr.CSVLogger() data_points = gr.Textbox(label=f"Forecast values. Lower and upper values include a 10% error rate") plt = gr.Plot(label="Backtesting plot, from 2018").style() btn.click( get_forecast, inputs=[period,commodity], outputs = [data_points,plt] ) with gr.Row(): btn_incorrect = gr.Button("Flag as incorrect") btn_other = gr.Button("Flag as other") flagging_handler.setup( components=[commodity, period], flagging_dir="data/flagged", ) btn_incorrect.click( lambda *args: flagging_handler.flag( flag_data=args, flag_option="Incorrect" ), [commodity, data_points, period,feedback], None, preprocess=False, ) btn_other.click( lambda *args: flagging_handler.flag(flag_data=args, flag_option="Other"), [commodity, data_points, period,feedback], None, preprocess=False, ) iface.launch(debug=True, inline=False) main()