alexander-lazarin
commited on
Commit
•
05f1e1c
1
Parent(s):
2955021
switch to clickhouse
Browse files- app.py +46 -12
- requirements.txt +2 -2
app.py
CHANGED
@@ -9,12 +9,16 @@ import os
|
|
9 |
import torch
|
10 |
from chronos import ChronosPipeline
|
11 |
import numpy as np
|
|
|
|
|
12 |
|
13 |
try:
|
14 |
from google.colab import userdata
|
15 |
PG_PASSWORD = userdata.get('FASHION_PG_PASS')
|
|
|
16 |
except:
|
17 |
PG_PASSWORD = os.environ['FASHION_PG_PASS']
|
|
|
18 |
|
19 |
logging.getLogger("prophet").setLevel(logging.WARNING)
|
20 |
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
|
@@ -65,16 +69,33 @@ def read_and_process_file(file):
|
|
65 |
return data, period_type, period_col
|
66 |
|
67 |
def get_data_from_db(query):
|
68 |
-
conn = psycopg2.connect(
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
)
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
return data
|
79 |
|
80 |
def forecast_time_series(file, product_name, wb, ozon, model_choice):
|
@@ -86,11 +107,23 @@ def forecast_time_series(file, product_name, wb, ozon, model_choice):
|
|
86 |
if ozon:
|
87 |
marketplaces.append('ozon')
|
88 |
mp_filter = "', '".join(marketplaces)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
query = f"""
|
90 |
select
|
91 |
-
|
92 |
1.0*sum(turnover) / (max(sum(turnover)) over ()) as y
|
93 |
-
from
|
|
|
|
|
94 |
where {product_name}
|
95 |
and mp in ('{mp_filter}')
|
96 |
group by ds
|
@@ -98,6 +131,7 @@ def forecast_time_series(file, product_name, wb, ozon, model_choice):
|
|
98 |
"""
|
99 |
print(query)
|
100 |
data = get_data_from_db(query)
|
|
|
101 |
period_type = "Week"
|
102 |
period_col = "ds"
|
103 |
|
|
|
9 |
import torch
|
10 |
from chronos import ChronosPipeline
|
11 |
import numpy as np
|
12 |
+
import requests
|
13 |
+
import tempfile
|
14 |
|
15 |
try:
|
16 |
from google.colab import userdata
|
17 |
PG_PASSWORD = userdata.get('FASHION_PG_PASS')
|
18 |
+
CH_PASSWORD = userdata.get('FASHION_CH_PASS')
|
19 |
except:
|
20 |
PG_PASSWORD = os.environ['FASHION_PG_PASS']
|
21 |
+
CH_PASSWORD = os.environ['FASHION_CH_PASS']
|
22 |
|
23 |
logging.getLogger("prophet").setLevel(logging.WARNING)
|
24 |
logging.getLogger("cmdstanpy").setLevel(logging.WARNING)
|
|
|
69 |
return data, period_type, period_col
|
70 |
|
71 |
def get_data_from_db(query):
|
72 |
+
# conn = psycopg2.connect(
|
73 |
+
# dbname="kroyscappingdb",
|
74 |
+
# user="read_only",
|
75 |
+
# password=PG_PASSWORD,
|
76 |
+
# host="rc1d-vbh2dw5ha0gpsazk.mdb.yandexcloud.net",
|
77 |
+
# port="6432",
|
78 |
+
# sslmode="require"
|
79 |
+
# )
|
80 |
+
cert_data = requests.get('https://storage.yandexcloud.net/cloud-certs/RootCA.pem').text
|
81 |
+
|
82 |
+
with tempfile.NamedTemporaryFile(delete=False) as temp_cert_file:
|
83 |
+
temp_cert_file.write(cert_data.encode())
|
84 |
+
cert_file_path = temp_cert_file.name
|
85 |
+
|
86 |
+
client = Client(host='rc1d-a93v7vf0pjfr6e2o.mdb.yandexcloud.net',
|
87 |
+
port = 9440,
|
88 |
+
user='user1',
|
89 |
+
password=CH_PASSWORD,
|
90 |
+
database='db1',
|
91 |
+
secure=True,
|
92 |
+
ca_certs=cert_file_path)
|
93 |
+
|
94 |
+
# data = pd.read_sql_query(query, conn)
|
95 |
+
result, columns = client.execute(query, with_column_types=True)
|
96 |
+
column_names = [col[0] for col in columns]
|
97 |
+
data = pd.DataFrame(result, columns=column_names)
|
98 |
+
# conn.close()
|
99 |
return data
|
100 |
|
101 |
def forecast_time_series(file, product_name, wb, ozon, model_choice):
|
|
|
107 |
if ozon:
|
108 |
marketplaces.append('ozon')
|
109 |
mp_filter = "', '".join(marketplaces)
|
110 |
+
# query = f"""
|
111 |
+
# select
|
112 |
+
# to_char(dm.end_date, 'yyyy-mm-dd') as ds,
|
113 |
+
# 1.0*sum(turnover) / (max(sum(turnover)) over ()) as y
|
114 |
+
# from v_datamart dm
|
115 |
+
# where {product_name}
|
116 |
+
# and mp in ('{mp_filter}')
|
117 |
+
# group by ds
|
118 |
+
# order by ds
|
119 |
+
# """
|
120 |
query = f"""
|
121 |
select
|
122 |
+
cast(start_date as date) as ds,
|
123 |
1.0*sum(turnover) / (max(sum(turnover)) over ()) as y
|
124 |
+
from datamart_all_1
|
125 |
+
join week_data
|
126 |
+
using (id_week)
|
127 |
where {product_name}
|
128 |
and mp in ('{mp_filter}')
|
129 |
group by ds
|
|
|
131 |
"""
|
132 |
print(query)
|
133 |
data = get_data_from_db(query)
|
134 |
+
print(data)
|
135 |
period_type = "Week"
|
136 |
period_col = "ds"
|
137 |
|
requirements.txt
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
numpy == 1.26.4
|
2 |
gradio
|
3 |
-
pandas
|
4 |
plotly
|
5 |
prophet
|
6 |
-
|
7 |
git+https://github.com/amazon-science/chronos-forecasting.git
|
|
|
1 |
numpy == 1.26.4
|
2 |
gradio
|
3 |
+
pandas == 2.1.4
|
4 |
plotly
|
5 |
prophet
|
6 |
+
clickhouse-driver
|
7 |
git+https://github.com/amazon-science/chronos-forecasting.git
|