File size: 10,218 Bytes
62d6afe 73bcabb 03fdb99 2955021 05f1e1c 3bc4b76 03fdb99 2955021 05f1e1c 03fdb99 2955021 05f1e1c 73bcabb 62d6afe 2955021 62d6afe 73bcabb 62d6afe 73bcabb 62d6afe 73bcabb 62d6afe 73bcabb 62d6afe 03fdb99 05f1e1c 03fdb99 2955021 73bcabb 03fdb99 05f1e1c 03fdb99 05f1e1c 03fdb99 05f1e1c 03fdb99 05f1e1c 03fdb99 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 73bcabb 62d6afe 73bcabb 03fdb99 dd74596 03fdb99 2955021 03fdb99 62d6afe 73bcabb 62d6afe 73bcabb 62d6afe 73bcabb 2955021 62d6afe 2955021 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 |
import gradio as gr
import pandas as pd
from prophet import Prophet
import plotly.graph_objs as go
import re
import logging
import os
import torch
from chronos import ChronosPipeline
import numpy as np
import requests
import tempfile
from clickhouse_driver import Client
try:
from google.colab import userdata
PG_PASSWORD = userdata.get('FASHION_PG_PASS')
CH_PASSWORD = userdata.get('FASHION_CH_PASS')
except:
PG_PASSWORD = os.environ['FASHION_PG_PASS']
CH_PASSWORD = os.environ['FASHION_CH_PASS']
logging.getLogger("prophet").setLevel(logging.WARNING)
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
# Dictionary to map Russian month names to month numbers
russian_months = {
"январь": "01", "февраль": "02", "март": "03", "апрель": "04",
"май": "05", "июнь": "06", "июль": "07", "август": "08",
"сентябрь": "09", "октябрь": "10", "ноябрь": "11", "декабрь": "12"
}
def read_and_process_file(file):
# Read the first three lines as a single text string
with open(file.name, 'r') as f:
first_three_lines = ''.join([next(f) for _ in range(3)])
# Check for "Неделя" or "Week" (case-insensitive)
if not any(word in first_three_lines.lower() for word in ["неделя", "week"]):
period_type = "Month"
else:
period_type = "Week"
# Read the file again to process it
with open(file.name, 'r') as f:
lines = f.readlines()
# Check if the second line is empty
if lines[1].strip() == '':
source = 'Google'
data = pd.read_csv(file.name, skiprows=2)
# Replace any occurrences of "<1" with 0
else:
source = 'Yandex'
data = pd.read_csv(file.name, sep=';', skiprows=0, usecols=[0, 2])
if period_type == "Month":
# Replace Russian months with yyyy-MM format
data.iloc[:, 0] = data.iloc[:, 0].apply(lambda x: re.sub(r'(\w+)\s(\d{4})', lambda m: f'{m.group(2)}-{russian_months[m.group(1).lower()]}', x) + '-01')
if period_type == "Week":
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format="%d.%m.%Y")
# Replace any occurrences of "<1" with 0
data.iloc[:, 1] = data.iloc[:, 1].apply(str).str.replace('<1', '0').str.replace(' ', '').str.replace(',', '.').astype(float)
# Process the date column and set it as the index
period_col = data.columns[0]
data[period_col] = pd.to_datetime(data[period_col])
data.set_index(period_col, inplace=True)
return data, period_type, period_col
def get_data_from_db(query):
# conn = psycopg2.connect(
# dbname="kroyscappingdb",
# user="read_only",
# password=PG_PASSWORD,
# host="rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net",
# port="6432",
# sslmode="require"
# )
cert_data = requests.get('https://storage.yandexcloud.net/cloud-certs/RootCA.pem').text
with tempfile.NamedTemporaryFile(delete=False) as temp_cert_file:
temp_cert_file.write(cert_data.encode())
cert_file_path = temp_cert_file.name
client = Client(host='rc1d-a93v7vf0pjfr6e2o.mdb.yandexcloud.net',
port = 9440,
user='user1',
password=CH_PASSWORD,
database='db1',
secure=True,
ca_certs=cert_file_path)
# data = pd.read_sql_query(query, conn)
result, columns = client.execute(query, with_column_types=True)
column_names = [col[0] for col in columns]
data = pd.DataFrame(result, columns=column_names)
# conn.close()
return data
def forecast_time_series(file, product_name, wb, ozon, model_choice):
if file is None:
# Construct the query
marketplaces = []
if wb:
marketplaces.append('wildberries')
if ozon:
marketplaces.append('ozon')
mp_filter = "', '".join(marketplaces)
# query = f"""
# select
# to_char(dm.end_date, 'yyyy-mm-dd') as ds,
# 1.0*sum(turnover) / (max(sum(turnover)) over ()) as y
# from v_datamart dm
# where {product_name}
# and mp in ('{mp_filter}')
# group by ds
# order by ds
# """
query = f"""
select
cast(start_date as date) as ds,
1.0*sum(turnover) / (max(sum(turnover)) over ()) as y
from datamart_all_1
join week_data
using (id_week)
where {product_name}
and mp in ('{mp_filter}')
group by ds
order by ds
"""
print(query)
data = get_data_from_db(query)
print(data)
period_type = "Week"
period_col = "ds"
if len(data)==0:
raise gr.Error("No data found in database. Please adjust filters")
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d')
data.set_index('ds', inplace=True)
else:
data, period_type, period_col = read_and_process_file(file)
if period_type == "Month":
year = 12
n_periods = 24
freq = "MS"
else:
year = 52
n_periods = year * 2
freq = "W"
df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'})
if model_choice == "Prophet":
forecast, yoy_change = forecast_prophet(df, n_periods, freq, year)
elif model_choice == "Chronos":
forecast, yoy_change = forecast_chronos(df, n_periods, freq, year)
else:
raise ValueError("Invalid model choice")
# Create Plotly figure (common for both models)
fig = create_plot(data, forecast)
# Combine original data and forecast
combined_df = pd.concat([data, forecast.set_index('ds')], axis=1)
# Save combined data
combined_file = 'combined_data.csv'
combined_df.to_csv(combined_file)
return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file
def forecast_prophet(df, n_periods, freq, year):
model = Prophet()
model.fit(df)
future = model.make_future_dataframe(periods=n_periods, freq=freq)
forecast = model.predict(future)
sum_last_year_original = df['y'].iloc[-year:].sum()
sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum()
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
return forecast, yoy_change
def forecast_chronos(df, n_periods, freq, year):
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-mini",
device_map="cpu",
torch_dtype=torch.bfloat16,
)
# Check for non-numeric values
if not pd.api.types.is_numeric_dtype(df['y']):
non_numeric = df[pd.to_numeric(df['y'], errors='coerce').isna()]
if not non_numeric.empty:
error_message = f"Non-numeric values found in 'y' column. First few problematic rows:\n{non_numeric.head().to_string()}"
raise ValueError(error_message)
try:
y_values = df['y'].values.astype(np.float32)
except ValueError as e:
raise ValueError(f"Unable to convert 'y' column to float32: {str(e)}")
chronos_forecast = pipeline.predict(
context=torch.tensor(y_values),
prediction_length=n_periods,
num_samples=20,
limit_prediction_length=False
)
forecast_index = pd.date_range(start=df['ds'].iloc[-1], periods=n_periods+1, freq=freq)[1:]
low, median, high = np.quantile(chronos_forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
forecast = pd.DataFrame({
'ds': forecast_index,
'yhat': median,
'yhat_lower': low,
'yhat_upper': high
})
sum_last_year_original = df['y'].iloc[-year:].sum()
sum_first_year_forecast = median[:year].sum()
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
return forecast, yoy_change
def create_plot(data, forecast):
fig = go.Figure()
fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed'))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red')))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], fill=None, mode='lines', line=dict(color='pink'), name='Lower CI'))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], fill='tonexty', mode='lines', line=dict(color='pink'), name='Upper CI'))
fig.update_layout(
title='Observed Time Series and Forecast with Confidence Intervals',
xaxis_title='Date',
yaxis_title='Values',
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
hovermode='x unified'
)
return fig
# Create Gradio interface using Blocks
with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
gr.Markdown("# Time Series Forecasting")
gr.Markdown("Upload a CSV file with a time series to forecast the next 2 years and see the YoY % change. Download the combined original and forecast data.")
with gr.Row():
file_input = gr.File(label="Upload Time Series CSV")
with gr.Row():
wb_checkbox = gr.Checkbox(label="Wildberries", value=True)
ozon_checkbox = gr.Checkbox(label="Ozon", value=True)
with gr.Row():
product_name_input = gr.Textbox(label="Product Name Filter", value="name like '%пуховик%'")
with gr.Row():
model_choice = gr.Radio(["Prophet", "Chronos"], label="Choose Model", value="Prophet")
with gr.Row():
compute_button = gr.Button("Compute")
with gr.Row():
plot_output = gr.Plot(label="Time Series + Forecast Chart")
with gr.Row():
yoy_output = gr.Text(label="YoY % Change")
with gr.Row():
csv_output = gr.File(label="Download Combined Data CSV")
compute_button.click(
forecast_time_series,
inputs=[file_input, product_name_input, wb_checkbox, ozon_checkbox, model_choice],
outputs=[plot_output, yoy_output, csv_output]
)
# Launch the interface
interface.launch(debug=True) |