Kimosh commited on
Commit
2d3ad00
1 Parent(s): 7991a21

first app file,model and pre-reqs

Browse files
Afghan_w_blacksea_allcomtrade_jun06.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a072d516af02292b59bb28671a3895661ac2bbf5c24e5a95047055d86f5ec258
3
+ size 155563
Afghan_w_blacksea_allcomtrade_jun06.pt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c47af6f9f678a61390ee7ba9882752e6ce2855a229b05039a0b546e345a4fc9e
3
+ size 18414743
app.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ import torch
4
+ from darts import TimeSeries, concatenate
5
+ from darts.dataprocessing.transformers import Scaler
6
+ from darts.utils.timeseries_generation import datetime_attribute_timeseries
7
+ from darts.models.forecasting.tft_model import TFTModel
8
+ from darts.metrics import mape
9
+
10
+ from dateutil.relativedelta import relativedelta
11
+
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+ import logging
15
+
16
+ logging.disable(logging.CRITICAL)
17
+
18
+ import pandas as pd
19
+ import numpy as np
20
+
21
+ from typing import Any, List, Optional
22
+ import plotly.graph_objects as go
23
+
24
+ df_final = pd.read_csv('data/finalwheat_maize-forecasting.csv',parse_dates=['Date'])
25
+ series = TimeSeries.from_dataframe(df_final,
26
+ time_col='Date',
27
+ value_cols=['wheat_price', 'Hard red winter', 'Dollar value', 'CPI_value','maize-price','crude oil price']
28
+ )
29
+
30
+ six_months = df_final['Date'].max() + relativedelta(months=-6)
31
+ data_series = series['wheat_price']
32
+ train, val = data_series.split_after(six_months)
33
+ transformer = Scaler()
34
+ train_transformed = transformer.fit_transform(train)
35
+ val_transformed = transformer.transform(val)
36
+ series_transformed = transformer.transform(data_series)
37
+ # create year, month and integer index covariate series
38
+ covariates = datetime_attribute_timeseries(series_transformed, attribute="year", one_hot=False)
39
+ covariates = covariates.stack(
40
+ datetime_attribute_timeseries(series_transformed, attribute="month", one_hot=True)
41
+ )
42
+ covariates = covariates.stack(
43
+ TimeSeries.from_times_and_values(
44
+ times=series_transformed.time_index,
45
+ values=np.arange(len(series_transformed)),
46
+ # columns=["linear_increase"],
47
+ )
48
+ )
49
+ covariates = covariates.add_holidays(country_code="ES")
50
+
51
+ covariates = covariates.astype(np.float32)
52
+
53
+ scaler_covs = Scaler()
54
+ cov_train, cov_val = covariates.split_after(six_months)
55
+ cov_train = scaler_covs.fit_transform(cov_train)
56
+ cov_val = scaler_covs.transform(cov_val)
57
+ covariates_transformed = scaler_covs.transform(covariates)
58
+
59
+ dxy_series = series['Dollar value']
60
+ dxy_scaler = Scaler()
61
+ dxy_train, dxy_val = dxy_series.split_after(six_months)
62
+ dxy_train = dxy_scaler.fit_transform(dxy_train)
63
+ dxy_val = dxy_scaler.transform(dxy_val)
64
+ dxy_series_scaled = dxy_scaler.transform(dxy_series)
65
+
66
+ hard_series = series["Hard red winter"]
67
+ hard_scaler = Scaler()
68
+ hard_train, hard_val = hard_series.split_after(six_months)
69
+ hard_train = hard_scaler.fit_transform(hard_train)
70
+ hard_val = hard_scaler.transform(hard_val)
71
+ hard_series_scaled = hard_scaler.transform(hard_series)
72
+
73
+ cpi_series = series['CPI_value']
74
+ cpi_scaler = Scaler()
75
+ cpi_train, cpi_val = cpi_series.split_after(six_months)
76
+ cpi_train = cpi_scaler.fit_transform(cpi_train)
77
+ cpi_val = cpi_scaler.transform(cpi_val)
78
+ cpi_series_scaled = cpi_scaler.transform(cpi_series)
79
+
80
+ maize_series = series['maize-price']
81
+ maize_scaler = Scaler()
82
+ maize_train, maize_val = maize_series.split_after(six_months)
83
+ maize_train_transformed = maize_scaler.fit_transform(maize_train)
84
+ maize_val_transformed = maize_scaler.transform(maize_val)
85
+ maize_series_scaled = maize_scaler.transform(maize_series)
86
+
87
+ crude_series = series['crude oil price']
88
+ crude_scaler = Scaler()
89
+ crude_train, crude_val = crude_series.split_after(six_months)
90
+ crude_train = crude_scaler.fit_transform(crude_train)
91
+ crude_val = crude_scaler.transform(crude_val)
92
+ crude_series_scaled = crude_scaler.transform(crude_series)
93
+
94
+ from darts import concatenate
95
+ my_multivariate_series = concatenate(
96
+ [
97
+ dxy_series_scaled,
98
+ cpi_series_scaled,
99
+ hard_series_scaled,
100
+ crude_series_scaled,
101
+ covariates_transformed,
102
+ ],
103
+ axis=1)
104
+ multivariate_series_train = concatenate(
105
+ [
106
+ dxy_train,
107
+ cpi_train,
108
+ hard_train,
109
+ crude_train,
110
+ cov_train,
111
+ ],
112
+ axis=1)
113
+
114
+
115
+
116
+ class FlaggingHandler(gr.FlaggingCallback):
117
+ def __init__(self):
118
+ self._csv_logger = gr.CSVLogger()
119
+
120
+ def setup(self, components: List[gr.components.Component], flagging_dir: str):
121
+ """Called by Gradio at the beginning of the `Interface.launch()` method.
122
+ Parameters:
123
+ components: Set of components that will provide flagged data.
124
+ flagging_dir: A string, typically containing the path to the directory where
125
+ the flagging file should be storied (provided as an argument to Interface.__init__()).
126
+ """
127
+ self.components = components
128
+ self._csv_logger.setup(components=components, flagging_dir=flagging_dir)
129
+
130
+ def flag(
131
+ self,
132
+ flag_data: List[Any],
133
+ flag_option: Optional[str] = None,
134
+ # flag_index: Optional[int] = None,
135
+ username: Optional[str] = None,
136
+ ) -> int:
137
+ """Called by Gradio whenver one of the <flag> buttons is clicked.
138
+ Parameters:
139
+ interface: The Interface object that is being used to launch the flagging interface.
140
+ flag_data: The data to be flagged.
141
+ flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
142
+ flag_index (optional): The index of the sample that is being flagged.
143
+ username (optional): The username of the user that is flagging the data, if logged in.
144
+ Returns:
145
+ (int) The total number of samples that have been flagged.
146
+ """
147
+ for item in flag_data:
148
+ print(f"Flagging: {item}")
149
+ if flag_option:
150
+ print(f"Flag option: {flag_option}")
151
+
152
+ # if flag_index:
153
+ # print(f"Flag index: {flag_index}")
154
+
155
+ flagged_count = self._csv_logger.flag(
156
+ flag_data=flag_data,
157
+ flag_option=flag_option,
158
+ # flag_index=flag_index,
159
+ # username=username,
160
+ )
161
+ return flagged_count
162
+
163
+
164
+ def get_forecast(period_: str, pred_model: str):
165
+ # Let the prediction service do its magic.
166
+ period = int(period_[0])
167
+ afgh_model = TFTModel.load("Afghan_w_blacksea_allcomtrade_jun06.pt",map_location=torch.device('cpu'))
168
+
169
+ ### afgh model###
170
+ pred_series = afgh_model.predict(n=period,num_samples=1)
171
+ preds = transformer.inverse_transform(pred_series)
172
+ # creating a Dataframe
173
+ df_= preds.pd_dataframe()
174
+ df_.rename(columns={'common_unit_price': 'Wheat_Forecast'},inplace=True)
175
+
176
+ # error intervals:
177
+ # Calculate the 90% and 110% forecast values
178
+ forecast_90 = preds * 0.9
179
+ forecast_110 = preds * 1.1
180
+ df_90 = forecast_90.pd_dataframe()
181
+ df_90.rename(columns={'common_unit_price': 'Lower_Limit'},inplace=True)
182
+
183
+ df_110 = forecast_110.pd_dataframe()
184
+ df_110.rename(columns={'common_unit_price': 'Upper_Limit'},inplace=True)
185
+ merged_df = pd.merge(df_90,df_, on=['Date']).merge(df_110, on=['Date'])
186
+ merged_df = merged_df.reset_index()
187
+
188
+
189
+ start=pd.Timestamp("20180131")
190
+
191
+ backtest_series_ = afgh_model.historical_forecasts(
192
+ series_transformed,
193
+ past_covariates=my_multivariate_series,
194
+ start=start,
195
+ forecast_horizon=period,
196
+ retrain=False,
197
+ verbose=False,
198
+ )
199
+ series_time = series_transformed[-len(backtest_series_):].time_index
200
+ series_vals = (transformer.inverse_transform(series_transformed[-len(backtest_series_):])).values()
201
+ df_series = pd.DataFrame(data={'date': series_time, 'actual_prices': series_vals.ravel() })
202
+ vals = (transformer.inverse_transform(backtest_series_)).values()
203
+ df_backtest = pd.DataFrame(data={'date': backtest_series_.time_index, 'historical_forecasts': vals.ravel() })
204
+
205
+
206
+ # Create figure
207
+ fig = go.Figure()
208
+
209
+ fig.add_trace(
210
+ go.Scatter(
211
+ x=list(df_backtest.date),
212
+ y=list(df_backtest.historical_forecasts),
213
+ name='historical forecasts'
214
+ # x=list(df.Date), y=list(df.High)
215
+ ))
216
+
217
+ fig.add_trace(
218
+ go.Scatter(
219
+ x=list(df_series.date),
220
+ y=list(df_series.actual_prices),
221
+ name="actual prices",
222
+ ))
223
+
224
+ fig.add_trace(go.Scatter(
225
+ x = list(merged_df.Date),
226
+ y=list(merged_df.Upper_Limit),
227
+ name="Upper limit"
228
+ ))
229
+
230
+ fig.add_trace(go.Scatter(
231
+ x = list(merged_df.Date),
232
+ y=list(merged_df.Lower_Limit),
233
+ name="Lower limit"
234
+ ))
235
+ fig.add_trace(go.Scatter(
236
+ x = list(merged_df.Date),
237
+ y=list(merged_df.Wheat_Forecast),
238
+ name=" Wheat Forecast"
239
+ ))
240
+
241
+ # Set title
242
+ fig.update_layout(
243
+ title_text=f"\n Mean Absolute Percentage Error {mape(transformer.inverse_transform(series_transformed), transformer.inverse_transform(backtest_series_)):.2f}%"
244
+ )
245
+
246
+ # Add range slider
247
+ fig.update_layout(
248
+ xaxis=dict(
249
+ rangeselector=dict(
250
+ buttons=list([
251
+ dict(count=1,
252
+ label="1m",
253
+ step="month",
254
+ stepmode="backward"),
255
+ dict(count=6,
256
+ label="6m",
257
+ step="month",
258
+ stepmode="todate"),
259
+ dict(count=1,
260
+ label="YTD",
261
+ step="year",
262
+ stepmode="todate"),
263
+ # dict(count=1,
264
+ # label="1y",
265
+ # step="year",
266
+ # stepmode="backward"),
267
+ # dict(step="all")
268
+ ])
269
+ ),
270
+ rangeslider=dict(
271
+ visible=True
272
+ ),
273
+ type="date"
274
+ )
275
+ )
276
+
277
+ return merged_df,fig
278
+
279
+ def main():
280
+ flagging_handler = FlaggingHandler()
281
+
282
+ # example_url = "" # noqa: E501
283
+ with gr.Blocks() as iface:
284
+ gr.Markdown(
285
+ """
286
+ **Timeseries Forecasting model Temporal Fusion Transformer(TFT) built on Darts library**.
287
+ """)
288
+ commodity = gr.Radio(["Wheat Price Forecasting","Maize Price Forecasting"],label="Commodity to Forecast")
289
+ period = gr.Radio(['3 months',"6 months"],label="Forecast horizon")
290
+
291
+ # with gr.Row():
292
+ # lib = gr.Dropdown(["pandas", "scikit-learn", "torch", "prophet"], label="Library", value="torch")
293
+ # time = gr.Dropdown(["3 months", "6 months",], label="Downloads over the last...", value="6 months")
294
+
295
+ with gr.Row():
296
+ btn = gr.Button("Forecast.")
297
+ feedback = gr.Textbox(label="Give feedback")
298
+ gr.CSVLogger()
299
+
300
+
301
+
302
+ data_points = gr.Textbox(label=f"Forecast values. Lower and upper values include a 10% error rate")
303
+ plt = gr.Plot(label="Backtesting plot, from 2018").style()
304
+
305
+
306
+ btn.click(
307
+ get_forecast,
308
+ inputs=[period,commodity],
309
+ outputs = [data_points,plt]
310
+ )
311
+ with gr.Row():
312
+ btn_incorrect = gr.Button("Flag as incorrect")
313
+ btn_other = gr.Button("Flag as other")
314
+ flagging_handler.setup(
315
+ components=[commodity, period],
316
+ flagging_dir="data/flagged",
317
+ )
318
+
319
+ with gr.Row():
320
+ current_wheat = gr.Image('wheat_prices.png')
321
+ current_maize = gr.Image('maize_prices.png')
322
+ btn_incorrect.click(
323
+ lambda *args: flagging_handler.flag(
324
+ flag_data=args, flag_option="Incorrect"
325
+ ),
326
+ [commodity, data_points, period,feedback],
327
+ None,
328
+ preprocess=False,
329
+ )
330
+ btn_other.click(
331
+ lambda *args: flagging_handler.flag(flag_data=args, flag_option="Other"),
332
+ [commodity, data_points, period,feedback],
333
+ None,
334
+ preprocess=False,
335
+ )
336
+
337
+ iface.launch(debug=True, inline=False)
338
+
339
+ main()
pre-requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ darts==0.24.0
2
+ gradio==3.28.3
3
+ numpy==1.23.5
4
+ pandas==1.5.3
5
+ plotly==5.13.1
6
+ python_dateutil==2.8.2
7
+ torch==2.0.0
8
+ lightning==2.0.2