File size: 10,218 Bytes
62d6afe 73bcabb 03fdb99 2955021 05f1e1c 3bc4b76 03fdb99 2955021 05f1e1c 03fdb99 2955021 05f1e1c 73bcabb 62d6afe 2955021 62d6afe 73bcabb 62d6afe 73bcabb 62d6afe 73bcabb 62d6afe 73bcabb 62d6afe 03fdb99 05f1e1c 03fdb99 2955021 73bcabb 03fdb99 05f1e1c 03fdb99 05f1e1c 03fdb99 05f1e1c 03fdb99 05f1e1c 03fdb99 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 2955021 62d6afe 73bcabb 62d6afe 73bcabb 03fdb99 dd74596 03fdb99 2955021 03fdb99 62d6afe 73bcabb 62d6afe 73bcabb 62d6afe 73bcabb 2955021 62d6afe 2955021 |
|
import gradio as gr
import pandas as pd
from prophet import Prophet
import plotly.graph_objs as go
import re
import logging
import os
import torch
from chronos import ChronosPipeline
import numpy as np
import requests
import tempfile
from clickhouse_driver import Client
try:
from google.colab import userdata
PG_PASSWORD = userdata.get('FASHION_PG_PASS')
CH_PASSWORD = userdata.get('FASHION_CH_PASS')
except:
PG_PASSWORD = os.environ['FASHION_PG_PASS']
CH_PASSWORD = os.environ['FASHION_CH_PASS']
logging.getLogger("prophet").setLevel(logging.WARNING)
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
# Dictionary to map Russian month names to month numbers
russian_months = {
"январь": "01", "февраль": "02", "март": "03", "апрель": "04",
"май": "05", "июнь": "06", "июль": "07", "август": "08",
"сентябрь": "09", "октябрь": "10", "ноябрь": "11", "декабрь": "12"
}
def read_and_process_file(file):
# Read the first three lines as a single text string
with open(file.name, 'r') as f:
first_three_lines = ''.join([next(f) for _ in range(3)])
# Check for "Неделя" or "Week" (case-insensitive)
if not any(word in first_three_lines.lower() for word in ["неделя", "week"]):
period_type = "Month"
else:
period_type = "Week"
# Read the file again to process it
with open(file.name, 'r') as f:
lines = f.readlines()
# Check if the second line is empty
if lines[1].strip() == '':
source = 'Google'
data = pd.read_csv(file.name, skiprows=2)
# Replace any occurrences of "<1" with 0
else:
source = 'Yandex'
data = pd.read_csv(file.name, sep=';', skiprows=0, usecols=[0, 2])
if period_type == "Month":
# Replace Russian months with yyyy-MM format
data.iloc[:, 0] = data.iloc[:, 0].apply(lambda x: re.sub(r'(\w+)\s(\d{4})', lambda m: f'{m.group(2)}-{russian_months[m.group(1).lower()]}', x) + '-01')
if period_type == "Week":
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format="%d.%m.%Y")
# Replace any occurrences of "<1" with 0
data.iloc[:, 1] = data.iloc[:, 1].apply(str).str.replace('<1', '0').str.replace(' ', '').str.replace(',', '.').astype(float)
# Process the date column and set it as the index
period_col = data.columns[0]
data[period_col] = pd.to_datetime(data[period_col])
data.set_index(period_col, inplace=True)
return data, period_type, period_col
def get_data_from_db(query):
# conn = psycopg2.connect(
# dbname="kroyscappingdb",
# user="read_only",
# password=PG_PASSWORD,
# host="rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net",
# port="6432",
# sslmode="require"
# )
cert_data = requests.get('https://storage.yandexcloud.net/cloud-certs/RootCA.pem').text
with tempfile.NamedTemporaryFile(delete=False) as temp_cert_file:
temp_cert_file.write(cert_data.encode())
cert_file_path = temp_cert_file.name
client = Client(host='rc1d-a93v7vf0pjfr6e2o.mdb.yandexcloud.net',
port = 9440,
user='user1',
password=CH_PASSWORD,
database='db1',
secure=True,
ca_certs=cert_file_path)
# data = pd.read_sql_query(query, conn)
result, columns = client.execute(query, with_column_types=True)
column_names = [col[0] for col in columns]
data = pd.DataFrame(result, columns=column_names)
# conn.close()
return data
def forecast_time_series(file, product_name, wb, ozon, model_choice):
if file is None:
# Construct the query
marketplaces = []
if wb:
marketplaces.append('wildberries')
if ozon:
marketplaces.append('ozon')
mp_filter = "', '".join(marketplaces)
# query = f"""
# select
# to_char(dm.end_date, 'yyyy-mm-dd') as ds,
# 1.0*sum(turnover) / (max(sum(turnover)) over ()) as y
# from v_datamart dm
# where {product_name}
# and mp in ('{mp_filter}')
# group by ds
# order by ds
# """
query = f"""
select
cast(start_date as date) as ds,
1.0*sum(turnover) / (max(sum(turnover)) over ()) as y
from datamart_all_1
join week_data
using (id_week)
where {product_name}
and mp in ('{mp_filter}')
group by ds
order by ds
"""
print(query)
data = get_data_from_db(query)
print(data)
period_type = "Week"
period_col = "ds"
if len(data)==0:
raise gr.Error("No data found in database. Please adjust filters")
data.iloc[:, 0] = pd.to_datetime(data.iloc[:, 0], format='%Y-%m-%d')
data.set_index('ds', inplace=True)
else:
data, period_type, period_col = read_and_process_file(file)
if period_type == "Month":
year = 12
n_periods = 24
freq = "MS"
else:
year = 52
n_periods = year * 2
freq = "W"
df = data.reset_index().rename(columns={period_col: 'ds', data.columns[0]: 'y'})
if model_choice == "Prophet":
forecast, yoy_change = forecast_prophet(df, n_periods, freq, year)
elif model_choice == "Chronos":
forecast, yoy_change = forecast_chronos(df, n_periods, freq, year)
else:
raise ValueError("Invalid model choice")
# Create Plotly figure (common for both models)
fig = create_plot(data, forecast)
# Combine original data and forecast
combined_df = pd.concat([data, forecast.set_index('ds')], axis=1)
# Save combined data
combined_file = 'combined_data.csv'
combined_df.to_csv(combined_file)
return fig, f'Year-over-Year Change in Sum of Values: {yoy_change:.2%}', combined_file
def forecast_prophet(df, n_periods, freq, year):
model = Prophet()
model.fit(df)
future = model.make_future_dataframe(periods=n_periods, freq=freq)
forecast = model.predict(future)
sum_last_year_original = df['y'].iloc[-year:].sum()
sum_first_year_forecast = forecast['yhat'].iloc[-n_periods:-n_periods + year].sum()
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
return forecast, yoy_change
def forecast_chronos(df, n_periods, freq, year):
pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-mini",
device_map="cpu",
torch_dtype=torch.bfloat16,
)
# Check for non-numeric values
if not pd.api.types.is_numeric_dtype(df['y']):
non_numeric = df[pd.to_numeric(df['y'], errors='coerce').isna()]
if not non_numeric.empty:
error_message = f"Non-numeric values found in 'y' column. First few problematic rows:\n{non_numeric.head().to_string()}"
raise ValueError(error_message)
try:
y_values = df['y'].values.astype(np.float32)
except ValueError as e:
raise ValueError(f"Unable to convert 'y' column to float32: {str(e)}")
chronos_forecast = pipeline.predict(
context=torch.tensor(y_values),
prediction_length=n_periods,
num_samples=20,
limit_prediction_length=False
)
forecast_index = pd.date_range(start=df['ds'].iloc[-1], periods=n_periods+1, freq=freq)[1:]
low, median, high = np.quantile(chronos_forecast[0].numpy(), [0.1, 0.5, 0.9], axis=0)
forecast = pd.DataFrame({
'ds': forecast_index,
'yhat': median,
'yhat_lower': low,
'yhat_upper': high
})
sum_last_year_original = df['y'].iloc[-year:].sum()
sum_first_year_forecast = median[:year].sum()
yoy_change = (sum_first_year_forecast - sum_last_year_original) / sum_last_year_original
return forecast, yoy_change
def create_plot(data, forecast):
fig = go.Figure()
fig.add_trace(go.Scatter(x=data.index, y=data.iloc[:, 0], mode='lines', name='Observed'))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], mode='lines', name='Forecast', line=dict(color='red')))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], fill=None, mode='lines', line=dict(color='pink'), name='Lower CI'))
fig.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], fill='tonexty', mode='lines', line=dict(color='pink'), name='Upper CI'))
fig.update_layout(
title='Observed Time Series and Forecast with Confidence Intervals',
xaxis_title='Date',
yaxis_title='Values',
legend=dict(orientation='h', yanchor='bottom', y=1.02, xanchor='right', x=1),
hovermode='x unified'
)
return fig
# Create Gradio interface using Blocks
with gr.Blocks(theme=gr.themes.Monochrome()) as interface:
gr.Markdown("# Time Series Forecasting")
gr.Markdown("Upload a CSV file with a time series to forecast the next 2 years and see the YoY % change. Download the combined original and forecast data.")
with gr.Row():
file_input = gr.File(label="Upload Time Series CSV")
with gr.Row():
wb_checkbox = gr.Checkbox(label="Wildberries", value=True)
ozon_checkbox = gr.Checkbox(label="Ozon", value=True)
with gr.Row():
product_name_input = gr.Textbox(label="Product Name Filter", value="name like '%пуховик%'")
with gr.Row():
model_choice = gr.Radio(["Prophet", "Chronos"], label="Choose Model", value="Prophet")
with gr.Row():
compute_button = gr.Button("Compute")
with gr.Row():
plot_output = gr.Plot(label="Time Series + Forecast Chart")
with gr.Row():
yoy_output = gr.Text(label="YoY % Change")
with gr.Row():
csv_output = gr.File(label="Download Combined Data CSV")
compute_button.click(
forecast_time_series,
inputs=[file_input, product_name_input, wb_checkbox, ozon_checkbox, model_choice],
outputs=[plot_output, yoy_output, csv_output]
)
# Launch the interface
interface.launch(debug=True) |