Add application file
Browse files- DB/SQL.db +0 -0
- app.py +164 -0
- local_database.py +61 -0
- logs/server.log +0 -0
- prompt.py +103 -0
- requirements.txt +8 -0
- utility/__init__.py +0 -0
- utility/__pycache__/__init__.cpython-39.pyc +0 -0
- utility/__pycache__/constant.cpython-39.pyc +0 -0
- utility/__pycache__/db_tools.cpython-39.pyc +0 -0
- utility/__pycache__/loggers.cpython-39.pyc +0 -0
- utility/__pycache__/utils.cpython-39.pyc +0 -0
- utility/constant.py +8 -0
- utility/db_tools.py +157 -0
- utility/loggers.py +60 -0
- utility/utils.py +46 -0
DB/SQL.db
ADDED
Binary file (28.7 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModel, AutoTokenizer
|
5 |
+
import gradio as gr
|
6 |
+
import mdtex2html
|
7 |
+
from transformers import AutoTokenizer, AutoModel
|
8 |
+
from utility.utils import config_dict
|
9 |
+
from utility.loggers import logger
|
10 |
+
from sentence_transformers import util
|
11 |
+
from local_database import db_operate
|
12 |
+
from prompt import table_schema, embedder,corpus_embeddings, corpus,In_context_prompt, query_template
|
13 |
+
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int8", trust_remote_code=True)
|
15 |
+
model = AutoModel.from_pretrained("THUDM/chatglm-6b-int8",trust_remote_code=True).float()
|
16 |
+
model = model.eval()
|
17 |
+
|
18 |
+
|
19 |
+
"""Override Chatbot.postprocess"""
|
20 |
+
|
21 |
+
def postprocess(self, y):
|
22 |
+
if y is None:
|
23 |
+
return []
|
24 |
+
for i, (message, response) in enumerate(y):
|
25 |
+
y[i] = (
|
26 |
+
None if message is None else mdtex2html.convert((message)),
|
27 |
+
None if response is None else mdtex2html.convert(response),
|
28 |
+
)
|
29 |
+
return y
|
30 |
+
|
31 |
+
gr.Chatbot.postprocess = postprocess
|
32 |
+
|
33 |
+
def parse_text(text):
|
34 |
+
"""copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
|
35 |
+
lines = text.split("\n")
|
36 |
+
lines = [line for line in lines if line != ""]
|
37 |
+
count = 0
|
38 |
+
for i, line in enumerate(lines):
|
39 |
+
if "```" in line:
|
40 |
+
count += 1
|
41 |
+
items = line.split('`')
|
42 |
+
if count % 2 == 1:
|
43 |
+
lines[i] = f'<pre><code class="language-{items[-1]}">'
|
44 |
+
else:
|
45 |
+
lines[i] = f'<br></code></pre>'
|
46 |
+
else:
|
47 |
+
if i > 0:
|
48 |
+
if count % 2 == 1:
|
49 |
+
line = line.replace("`", "\`")
|
50 |
+
line = line.replace("<", "<")
|
51 |
+
line = line.replace(">", ">")
|
52 |
+
line = line.replace(" ", " ")
|
53 |
+
line = line.replace("*", "*")
|
54 |
+
line = line.replace("_", "_")
|
55 |
+
line = line.replace("-", "-")
|
56 |
+
line = line.replace(".", ".")
|
57 |
+
line = line.replace("!", "!")
|
58 |
+
line = line.replace("(", "(")
|
59 |
+
line = line.replace(")", ")")
|
60 |
+
line = line.replace("$", "$")
|
61 |
+
lines[i] = "<br>"+line
|
62 |
+
text = "".join(lines)
|
63 |
+
return text
|
64 |
+
|
65 |
+
|
66 |
+
def obtain_sql(response):
|
67 |
+
response = re.split("```|\n\n", response)
|
68 |
+
for text in response:
|
69 |
+
if "SELECT" in text:
|
70 |
+
response = text
|
71 |
+
break
|
72 |
+
else:
|
73 |
+
response = response[0]
|
74 |
+
response = response.replace("\n", " ").replace("``", "").replace("`", "").strip()
|
75 |
+
response = re.sub(' +',' ', response)
|
76 |
+
return response
|
77 |
+
|
78 |
+
|
79 |
+
def predict(input, chatbot, history):
|
80 |
+
max_length = 2048
|
81 |
+
top_p = 0.7
|
82 |
+
temperature = 0.2
|
83 |
+
top_k = 3
|
84 |
+
dboperate = db_operate(config_dict['db_path'])
|
85 |
+
logger.info(f"query:{input}")
|
86 |
+
chatbot_prompt = """
|
87 |
+
你是一个文本转SQL的生成器,你的主要目标是尽可能的协助用户将输入的文本转换为正确的SQL语句。
|
88 |
+
上下文开始
|
89 |
+
生成的表名和表字段均来自以下表:
|
90 |
+
"""
|
91 |
+
query_embedding = embedder.encode(input, convert_to_tensor=True) # 与6张表的表名和输入的问题进行相似度计算
|
92 |
+
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
|
93 |
+
top_results = torch.topk(cos_scores, k=top_k) # 拿到topk=3的表名
|
94 |
+
# 组合Prompt
|
95 |
+
table_nums = 0
|
96 |
+
for score, idx in zip(top_results[0], top_results[1]):
|
97 |
+
# 阈值过滤
|
98 |
+
if score > 0.45:
|
99 |
+
table_nums += 1
|
100 |
+
chatbot_prompt += table_schema[corpus[idx]]
|
101 |
+
chatbot_prompt += "上下文结束\n"
|
102 |
+
# In-Context Learning
|
103 |
+
if table_nums >= 2 and not history: # 如果表名大于等于2个,且没有历史记录,就加上In-Context Learning
|
104 |
+
chatbot_prompt += In_context_prompt
|
105 |
+
# 加上查询模板
|
106 |
+
chatbot_prompt += query_template
|
107 |
+
query = chatbot_prompt.replace("<user_input>", input)
|
108 |
+
chatbot.append((parse_text(input), ""))
|
109 |
+
# 流式输出
|
110 |
+
# for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
|
111 |
+
# temperature=temperature):
|
112 |
+
# chatbot[-1] = (parse_text(input), parse_text(response))
|
113 |
+
response, history = model.chat(tokenizer, query, history=history, max_length=max_length, top_p=top_p,temperature=temperature)
|
114 |
+
chatbot[-1] = (parse_text(input), parse_text(response))
|
115 |
+
# chatbot[-1] = (chatbot[-1][0], chatbot[-1][1])
|
116 |
+
# 获取结果中的SQL语句
|
117 |
+
response = obtain_sql(response)
|
118 |
+
# 查询结果
|
119 |
+
if "SELECT" in response:
|
120 |
+
try:
|
121 |
+
sql_stauts = "sql语句执行成功,结果如下:"
|
122 |
+
sql_result = dboperate.query_data(response)
|
123 |
+
sql_result = str(sql_result)
|
124 |
+
except Exception as e:
|
125 |
+
sql_stauts = "sql语句执行失败"
|
126 |
+
sql_result = str(e)
|
127 |
+
chatbot[-1] = (chatbot[-1][0],
|
128 |
+
chatbot[-1][1] + "\n\n"+ "===================="+"\n\n" + sql_stauts + "\n\n" + sql_result)
|
129 |
+
return chatbot, history
|
130 |
+
|
131 |
+
|
132 |
+
def reset_user_input():
|
133 |
+
return gr.update(value='')
|
134 |
+
|
135 |
+
|
136 |
+
def reset_state():
|
137 |
+
return [], []
|
138 |
+
|
139 |
+
with gr.Blocks() as demo:
|
140 |
+
gr.HTML("""<h1 align="center">🤖ChatSQL</h1>""")
|
141 |
+
|
142 |
+
chatbot = gr.Chatbot()
|
143 |
+
with gr.Row():
|
144 |
+
with gr.Column(scale=4):
|
145 |
+
with gr.Column(scale=12):
|
146 |
+
user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
|
147 |
+
container=False)
|
148 |
+
with gr.Column(min_width=32, scale=1):
|
149 |
+
submitBtn = gr.Button("Submit", variant="primary")
|
150 |
+
with gr.Column(scale=1):
|
151 |
+
emptyBtn = gr.Button("Clear History")
|
152 |
+
# max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
|
153 |
+
# top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
|
154 |
+
# temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
|
155 |
+
|
156 |
+
history = gr.State([])
|
157 |
+
|
158 |
+
submitBtn.click(predict, [user_input, chatbot, history], [chatbot, history],
|
159 |
+
show_progress=True)
|
160 |
+
submitBtn.click(reset_user_input, [], [user_input])
|
161 |
+
|
162 |
+
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
|
163 |
+
|
164 |
+
demo.queue().launch(share=False, inbrowser=True)
|
local_database.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Author: Liushu
|
2 |
+
# Date: 2023-04-24
|
3 |
+
# 这里的只是用来校验生成SQL的正确性
|
4 |
+
import sqlite3
|
5 |
+
import numpy as np
|
6 |
+
from utility.utils import config_dict as DB_CONFIG
|
7 |
+
|
8 |
+
|
9 |
+
class db_operate(object):
|
10 |
+
|
11 |
+
def __init__(self, db_path):
|
12 |
+
self.conn = sqlite3.connect(db_path)
|
13 |
+
self.cursor = self.conn.cursor()
|
14 |
+
|
15 |
+
def create_table(self, SQL):
|
16 |
+
self.cursor.execute(SQL)
|
17 |
+
self.conn.commit()
|
18 |
+
|
19 |
+
def insert_data(self, SQL, data):
|
20 |
+
self.cursor.executemany(SQL, data)
|
21 |
+
self.conn.commit()
|
22 |
+
|
23 |
+
def update_data(self, table_name, data, condition):
|
24 |
+
self.cursor.execute(f"UPDATE {table_name} SET {data} WHERE {condition}")
|
25 |
+
self.conn.commit()
|
26 |
+
|
27 |
+
def delete_data(self, table_name, condition):
|
28 |
+
self.cursor.execute(f"DELETE FROM {table_name} WHERE {condition}")
|
29 |
+
self.conn.commit()
|
30 |
+
|
31 |
+
def query_data(self, SQL_statement):
|
32 |
+
self.cursor.execute(SQL_statement)
|
33 |
+
return self.cursor.fetchall()
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
db_path = DB_CONFIG["db_path"]
|
38 |
+
TABLE = DB_CONFIG["TABLE"]
|
39 |
+
dboperate = db_operate(db_path)
|
40 |
+
print("建表中...")
|
41 |
+
# 建表
|
42 |
+
for table_name in TABLE.keys():
|
43 |
+
table_field_sql = f"create table if not exists {table_name} ("
|
44 |
+
for idx, filed in enumerate(TABLE[table_name]["field"].items()):
|
45 |
+
if idx == len(TABLE[table_name]["field"].items()) - 1:
|
46 |
+
table_field_sql += filed[0] + f" {filed[1][1]})"
|
47 |
+
else:
|
48 |
+
table_field_sql += filed[0] + f" {filed[1][1]},"
|
49 |
+
# 建表语句
|
50 |
+
dboperate.create_table(table_field_sql)
|
51 |
+
|
52 |
+
print("插入数据中...")
|
53 |
+
# 插入数据
|
54 |
+
TABLE_Values = DB_CONFIG["TABLE_Values"]
|
55 |
+
for table_name in TABLE.keys():
|
56 |
+
table_field_sql = f"INSERT INTO {table_name} VALUES ("
|
57 |
+
field_len = len(list(TABLE_Values[table_name]["field"].keys()))
|
58 |
+
insert_slot = ",".join(['?' for i in range(field_len)])
|
59 |
+
table_field_sql += insert_slot + ")"
|
60 |
+
table_val = np.array(list(TABLE_Values[table_name]["field"].values())).T.tolist() # 转置+list
|
61 |
+
dboperate.insert_data(table_field_sql, table_val)
|
logs/server.log
ADDED
File without changes
|
prompt.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Text2SQL机器人·Prompt
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
from sentence_transformers import SentenceTransformer, util
|
6 |
+
from utility.utils import config_dict as DB_CONFIG
|
7 |
+
from local_database import db_operate
|
8 |
+
|
9 |
+
embedder = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
|
10 |
+
|
11 |
+
# 每张表的表含义和对应的表结构
|
12 |
+
# table_schema = {"货物销售表,主要存储货物名称、净收益率、损失率、环比增长率、销售量、货物品类、货物仓库、销售负责人名字、销售部门、销售负责人联系方式":
|
13 |
+
# """表1: cargo
|
14 |
+
# 字段1:cargo_id(货物编号),字段2:cargo_name(货物名称),字段3:year(年份),字段4:net_yield(净收益率%),字段5:loss_rate(损失率%),字段6:month_on_month_growth_rate(环比增长率%),字段7:sales_volume(销售量),字段8:cargo_price(货物单价),字段9:cargo_category(货物品类),字段10:source_cargo(货物来源),字段11:storage_warehouse(货物仓库),字段12:sales_person_name(销售责任人名字),字段13:sales_person_id(销售负责人id),字段14:sales_department(销售部门),字段15:sales_person_numbers(销售负责人联系方式)
|
15 |
+
# """,
|
16 |
+
# "人员信息表,主要存储销售人员人名、入职时间、当前业绩、职位等级":
|
17 |
+
# """表2: sales
|
18 |
+
# 字段1:sales_person_id(销售人员id),字段2:sales_person_name(人员名称),字段3:sales_person_level(人员等级),字段4:sales_person_work_date(入职时间),字段5:sales_person_leader_id(人员主管id),字段6:sales_person_leader_name(人员主管名字),字段7:sales_person_number(人员电话),字段8:sales_person_achievement_year(人员业绩年份),字段9:sales_person_achievement(人员业绩),字段10:sales_person_department(人员部门名称),字段11:sales_person_department_id(人员部门id)
|
19 |
+
# """,
|
20 |
+
# "货物信息表,主要存储货物名称、货物来源地、购买价格、货物大类、货物品类、供应商名称和供应商负责人":
|
21 |
+
# """表3: cargo_info
|
22 |
+
# 字段1:cargo_info_id(货物信息id),字段2:cargo_id(货物id),字段3:cargo_name(货物名称),字段4:origin_cargo(货物产地),字段5:cargo_purchase_price(货物购买价),字段6:cargo_type(货物大类),字段7:cargo_category(货物品类),字段8:cargo_supply_company(供货公司),字段9:cargo_num(货物数量),字段10:cargo_supply_aftermarket_person(货物售后负责人),字段11:cargo_supply_aftermarket_person_number(货物售后负责人联系电话),字段12:cargo_supply_market_person(货物公司销售负责人),字段13:cargo_supply_market_person_number(货物公司负责人联系电话)
|
23 |
+
# """,
|
24 |
+
# "部门表,主要存储部门名称、部门职责、部门主管":
|
25 |
+
# """表4: depart_list
|
26 |
+
# 字段1:department_id(部门id),字段2:department_name(部门名称),字段3:department_duty(部门职责),字段4:department_lead_name(部门负责人名字),字段5:department_lead_name_id(部门负责人id),字段6:department_person_nums(部门人数)
|
27 |
+
# """,
|
28 |
+
# "购买信息表,主要存储购买公司名称、购买公司街道、购买负责人、购买货物名称和数量以及类型":
|
29 |
+
# """表5: purchase_info
|
30 |
+
# 字段1:purchase_company_id(货物购买方id),字段2:purchase_company_name(货物购买方名称),字段3:purchase_company_address(货物购买方地址),字段4:purchase_company_person_name(货物购买方负责人人名),字段5:purchase_company_person_numbers(货物购买方负责人联系方式),字段6:purchase_company_person_level(货物购买方负责人职位),字段7:purchase_cargo_name_id(购买货物id),字段8:purchase_cargo_name(购买货物名称),字段9:purchase_cargo_nums(购买货物数量),字段10:purchase_cargo_type(购买货物大类),字段11:purchase_cargo_category(购买货物品类)
|
31 |
+
# """,
|
32 |
+
# "供应商信息表,主要存储供应商名称、供应商地址、供应商货物名称、入供应商目录日期":
|
33 |
+
# """表6: supply_company
|
34 |
+
# 字段1:supply_company_id(供货公司id),字段2:supply_company_name(供货公司名称),字段3:supply_company_address(供货公司地址),字段4:supply_company_product_id(供货公司货物id),字段5:supply_company_product_name(供货公司货物名称),字段6:supply_company_date(供货公司入名录时间)
|
35 |
+
# """}
|
36 |
+
|
37 |
+
|
38 |
+
# corpus = list(table_schema.keys())
|
39 |
+
|
40 |
+
# corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
chatbot_prompt = """
|
45 |
+
你是一个文本转SQL的生成器,你的主要目标是尽可能的协助用户,将输入的文本转换为正确的SQL语句。
|
46 |
+
上下文开始
|
47 |
+
表名和表字段来自以下表:
|
48 |
+
表1: cargo
|
49 |
+
字段1:cargo_id(货物编号),字段2:cargo_name(货物名称),字段3:year(年份),字段4:net_yield(净收益率%),字段5:loss_rate(损失率%),字段6:month_on_month_growth_rate(环比增长率%),字段7:sales_volume(销售量),字段8:cargo_price(货物单价),字段9:cargo_category(货物品类),字段10:source_cargo(货物来源),字段11:storage_warehouse(货物仓库),字段12:sales_person_name(销售责任人名字),字段13:sales_person_id(销售负责人id),字段14:sales_department(销售部门),字段15:sales_person_numbers(销售负责人联系方式)
|
50 |
+
表2: sales
|
51 |
+
字段1:sales_person_id(销售负责人id),字段2:sales_person_name(人员名称),字段3:sales_person_level(人员等级),字段4:sales_person_work_date(入职时间),字段5:sales_person_leader_id(人员主管id),字段6:sales_person_leader_name(人员主管名字),字段7:sales_person_number(人员电话),字段8:sales_person_achievement_year(人员业绩年份),字段9:sales_person_achievement(人员业绩),字段10:sales_person_department(人员部门名称),字段11:sales_person_department_id(人员部门id)
|
52 |
+
表3: cargo_info
|
53 |
+
字段1:cargo_info_id(货物信息id),字段2:cargo_id(货物id),字段3:cargo_name(货物名称),字段4:origin_cargo(货物产地),字段5:cargo_purchase_price(货物购买价),字段6:cargo_type(货物大类),字段7:cargo_category(货物品类),字段8:cargo_supply_company(供货公司),字段9:cargo_num(货物数量),字段10:cargo_supply_aftermarket_person(货物售后负责人),字段11:cargo_supply_aftermarket_person_number(货物售后负责人联系电话),字段12:cargo_supply_market_person(货物公司销售负责人),字段13:cargo_supply_market_person_number(货物公司负责人联系电话)
|
54 |
+
表4: depart_list
|
55 |
+
字段1:department_id(部门id),字段2:department_name(部门名称),字段3:department_duty(部门职责),字段4:department_lead_name(部门负责人名字),字段5:department_lead_name_id(部门负责人id),字段6:department_person_nums(部门人数)
|
56 |
+
表5: purchase_info
|
57 |
+
字段1:purchase_company_id(货物购买方id),字段2:purchase_company_name(货物购买方名称),字段3:purchase_company_address(货物购买方地址),字段4:purchase_company_person_name(货物购买方负责人人名),字段5:purchase_company_person_numbers(货物购买方负责人联系方式),字段6:purchase_company_person_level(货物购买方负责人职位),字段7:purchase_cargo_name_id(购买货物id),字段8:purchase_cargo_name(购买货物名称),字段9:purchase_cargo_nums(购买货物数量),字段10:purchase_cargo_type(购买货物大类),字段11:purchase_cargo_category(购买货物品类)
|
58 |
+
表6: supply_company
|
59 |
+
字段1:supply_company_id(供货公司id),字段2:supply_company_name(供货公司名称),字段3:supply_company_address(供货公司地址),字段4:supply_company_product_id(供货公司货物id),字段5:supply_company_product_name(供货公司货物名称),字段6:supply_company_date(供货公司入名录时间)
|
60 |
+
上下文结束
|
61 |
+
问: 请帮我查询所有的货物名称
|
62 |
+
答: SELECT cargo_name FROM cargo
|
63 |
+
问: 请帮我查询在2019年的净收益率大于10并且销售量大于100并且业绩大于1000的销售负责人名字
|
64 |
+
答: SELECT sales.sales_person_name FROM sales INNER JOIN cargo on sales.sales_person_id = cargo.sales_person_id WHERE cargo.year = 2019 AND cargo.net_yield > 10 AND cargo.sales_volume > 100 AND sales.sales_person_achievement > 1000
|
65 |
+
问: 文本转SQL: <user input>
|
66 |
+
答:
|
67 |
+
"""
|
68 |
+
|
69 |
+
# 一些学习的例子
|
70 |
+
In_context_prompt = """问: 请帮我查询所有的货物名称
|
71 |
+
答: SELECT cargo_name FROM cargo;
|
72 |
+
问: 请帮我查询在2019年的净收益率大于10并且销售量大于100并且业绩大于1000的销售负责人名字
|
73 |
+
答: SELECT sales.sales_person_name FROM sales INNER JOIN cargo on sales.sales_person_id = cargo.sales_person_id WHERE cargo.year = 2019 AND cargo.net_yield > 10 AND cargo.sales_volume > 100 AND sales.sales_person_achievement > 1000;
|
74 |
+
问: 请帮我查询购买"板材"货物的公司名称
|
75 |
+
答: SELECT purchase_company_name FROM purchase_info WHERE purchase_cargo_name = "板材";
|
76 |
+
"""
|
77 |
+
|
78 |
+
query_template = """问: <user_input>
|
79 |
+
答:
|
80 |
+
"""
|
81 |
+
|
82 |
+
# yaml解析
|
83 |
+
TABLE = DB_CONFIG["TABLE"]
|
84 |
+
table_schema = {}
|
85 |
+
|
86 |
+
for table_name in TABLE.keys():
|
87 |
+
# table描述拼接
|
88 |
+
table_info = """"""
|
89 |
+
table_info += "表名:" + table_name + "\n"
|
90 |
+
table_info += "字段:"
|
91 |
+
for idx, filed in enumerate(TABLE[table_name]["field"].items()):
|
92 |
+
if idx == len(TABLE[table_name]["field"].items()) - 1:
|
93 |
+
table_info += filed[0] + "(" + filed[1][0] + ")"
|
94 |
+
else:
|
95 |
+
table_info += filed[0] + "(" + filed[1][0] + "),"
|
96 |
+
|
97 |
+
table_schema[TABLE[table_name]["info"]] = table_info
|
98 |
+
|
99 |
+
# 获取表的描述信息
|
100 |
+
corpus = list(table_schema.keys())
|
101 |
+
|
102 |
+
# 向量化
|
103 |
+
corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
protobuf
|
2 |
+
transformers==4.27.1
|
3 |
+
cpm_kernels
|
4 |
+
torch>=1.10
|
5 |
+
gradio
|
6 |
+
mdtex2html
|
7 |
+
sentencepiece
|
8 |
+
sentence_transformers==2.2.2
|
utility/__init__.py
ADDED
File without changes
|
utility/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (139 Bytes). View file
|
|
utility/__pycache__/constant.cpython-39.pyc
ADDED
Binary file (364 Bytes). View file
|
|
utility/__pycache__/db_tools.cpython-39.pyc
ADDED
Binary file (4.28 kB). View file
|
|
utility/__pycache__/loggers.cpython-39.pyc
ADDED
Binary file (1.18 kB). View file
|
|
utility/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (1.4 kB). View file
|
|
utility/constant.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@File: 存放固定的路径
|
3 |
+
"""
|
4 |
+
import os
|
5 |
+
import pandas as pd
|
6 |
+
# 根目录
|
7 |
+
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
8 |
+
configurable_file_path = os.path.join(BASE_DIR, 'configurable_file')
|
utility/db_tools.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Time: 2022/11/03
|
3 |
+
@Author: LiuShu
|
4 |
+
@File: 数据库操作类库
|
5 |
+
"""
|
6 |
+
import pymysql
|
7 |
+
from utility.loggers import logger
|
8 |
+
from utility.utils import config
|
9 |
+
|
10 |
+
|
11 |
+
class Cur_db(object):
|
12 |
+
def __init__(self):
|
13 |
+
self.config = config
|
14 |
+
self.db_name = self.config['database']['DB']
|
15 |
+
|
16 |
+
def pymysql_cur(self, reback=5):
|
17 |
+
""" 连接数据库 """
|
18 |
+
try:
|
19 |
+
self.conn = pymysql.connect(host=self.config['database']['HOST'], user=self.config['database']['USER'],
|
20 |
+
password=self.config['database']['PWD'], db=self.db_name,
|
21 |
+
port=int(self.config['database']['PORT']),
|
22 |
+
charset='utf8')
|
23 |
+
except Exception as e:
|
24 |
+
if reback == 0:
|
25 |
+
logger.exception('Exception occurred.')
|
26 |
+
return
|
27 |
+
else:
|
28 |
+
logger.exception('Exception occurred.')
|
29 |
+
reback -= 1
|
30 |
+
return self.pymysql_cur(reback)
|
31 |
+
|
32 |
+
def get_db_name(self):
|
33 |
+
"""
|
34 |
+
|
35 |
+
:return:
|
36 |
+
"""
|
37 |
+
return self.db_name
|
38 |
+
|
39 |
+
def select(self, sql, params, reback=2):
|
40 |
+
""" 查询单条语句,并返回查询所有的结果 """
|
41 |
+
try:
|
42 |
+
cur = self.conn.cursor()
|
43 |
+
cur.execute(sql, params)
|
44 |
+
# 单条
|
45 |
+
res = cur.fetchone()
|
46 |
+
cur.close()
|
47 |
+
if res:
|
48 |
+
return res
|
49 |
+
return
|
50 |
+
except Exception as e:
|
51 |
+
logger.exception('Exception occurred.')
|
52 |
+
if reback > 0:
|
53 |
+
reback -= 1
|
54 |
+
return self.select(sql, reback)
|
55 |
+
else:
|
56 |
+
logger.info(str('*' * 100))
|
57 |
+
return
|
58 |
+
|
59 |
+
def _select(self, sql, reback=2):
|
60 |
+
try:
|
61 |
+
cur = self.conn.cursor()
|
62 |
+
cur.execute(sql)
|
63 |
+
# 单条
|
64 |
+
res = cur.fetchone()
|
65 |
+
cur.close()
|
66 |
+
if res:
|
67 |
+
return res[0]
|
68 |
+
return
|
69 |
+
except Exception as e:
|
70 |
+
logger.exception('Exception occurred.')
|
71 |
+
if reback > 0:
|
72 |
+
reback -= 1
|
73 |
+
return self.select(sql, reback)
|
74 |
+
else:
|
75 |
+
logger.info(str('*' * 100))
|
76 |
+
return
|
77 |
+
|
78 |
+
def selectMany(self, sql, reback=2):
|
79 |
+
try:
|
80 |
+
cur = self.conn.cursor()
|
81 |
+
cur.execute(sql)
|
82 |
+
res = cur.fetchall()
|
83 |
+
cur.close()
|
84 |
+
if res:
|
85 |
+
return res
|
86 |
+
logger.info(str(sql))
|
87 |
+
return
|
88 |
+
except Exception as e:
|
89 |
+
logger.exception('Exception occurred.')
|
90 |
+
if reback > 0:
|
91 |
+
reback -= 1
|
92 |
+
return self.selectMany(sql, reback)
|
93 |
+
else:
|
94 |
+
logger.info(str('*' * 100))
|
95 |
+
return
|
96 |
+
|
97 |
+
def insert(self, sql, params):
|
98 |
+
cur = self.conn.cursor()
|
99 |
+
cur.execute(sql, params)
|
100 |
+
self.conn.commit()
|
101 |
+
return
|
102 |
+
|
103 |
+
def _insert(self, sql):
|
104 |
+
cur = self.conn.cursor()
|
105 |
+
cur.execute(sql)
|
106 |
+
self.conn.commit()
|
107 |
+
|
108 |
+
def insert_batch(self, sql, data_list):
|
109 |
+
"""
|
110 |
+
将dataframe批量入库
|
111 |
+
:param sql: 插入语句
|
112 |
+
:return:
|
113 |
+
"""
|
114 |
+
cur = self.conn.cursor()
|
115 |
+
# 开启事务
|
116 |
+
self.conn.begin()
|
117 |
+
try:
|
118 |
+
cur.executemany(sql, data_list)
|
119 |
+
self.conn.commit()
|
120 |
+
cur.close()
|
121 |
+
self.conn.close()
|
122 |
+
return True
|
123 |
+
except:
|
124 |
+
# 万一失败了,要进行回滚操作
|
125 |
+
self.conn.rollback()
|
126 |
+
cur.close()
|
127 |
+
self.conn.close()
|
128 |
+
return False
|
129 |
+
|
130 |
+
def update(self, sql, params):
|
131 |
+
cur = self.conn.cursor()
|
132 |
+
cur.execute(sql, params)
|
133 |
+
self.conn.commit()
|
134 |
+
return
|
135 |
+
|
136 |
+
def _update(self, sql):
|
137 |
+
try:
|
138 |
+
cur = self.conn.cursor()
|
139 |
+
cur.execute(sql)
|
140 |
+
self.conn.commit()
|
141 |
+
except Exception as e:
|
142 |
+
logger.exception('Exception occurred.')
|
143 |
+
|
144 |
+
def close(self):
|
145 |
+
self.conn.close()
|
146 |
+
pass
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == '__main__':
|
150 |
+
db_con = Cur_db()
|
151 |
+
logger.info(str(db_con.config['database']['HOST']))
|
152 |
+
print(str(db_con.config['database']['HOST']))
|
153 |
+
db_con.pymysql_cur()
|
154 |
+
sql = "SELECT * FROM cargo"
|
155 |
+
res = db_con.selectMany(sql)
|
156 |
+
print(str(res))
|
157 |
+
db_con.close()
|
utility/loggers.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Time: 2022/11/03
|
3 |
+
@Author: LiuShu
|
4 |
+
@File: loggers.py
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
from utility.constant import BASE_DIR
|
8 |
+
import logging
|
9 |
+
import logging.config
|
10 |
+
|
11 |
+
LOGGING = {
|
12 |
+
'version': 1,
|
13 |
+
'disable_existing_loggers': True,
|
14 |
+
'formatters': {
|
15 |
+
'simple': {
|
16 |
+
'format': '%(levelname)s %(message)s'
|
17 |
+
},
|
18 |
+
'standard': {
|
19 |
+
'format': '[%(asctime)s] %(filename)s-[line:%(lineno)d] %(levelname)s--%(message)s',
|
20 |
+
'datefmt': '%Y-%m-%d %H:%M:%S',
|
21 |
+
},
|
22 |
+
},
|
23 |
+
'handlers': {
|
24 |
+
'file': {
|
25 |
+
'level': 'DEBUG',
|
26 |
+
'class': 'logging.handlers.TimedRotatingFileHandler',
|
27 |
+
# TODO 文件路径修改位置
|
28 |
+
'filename': os.path.join(BASE_DIR, 'logs/server.log'),
|
29 |
+
'formatter': 'standard',
|
30 |
+
'when': 'D',
|
31 |
+
'interval': 1,
|
32 |
+
'backupCount': 7,
|
33 |
+
},
|
34 |
+
'null': {
|
35 |
+
'level': 'DEBUG',
|
36 |
+
'class': 'logging.StreamHandler',
|
37 |
+
},
|
38 |
+
},
|
39 |
+
'loggers': {
|
40 |
+
'django': {
|
41 |
+
'handlers': ['null'],
|
42 |
+
'level': 'ERROR',
|
43 |
+
'propagate': True,
|
44 |
+
},
|
45 |
+
'system': {
|
46 |
+
'handlers': ['file'],
|
47 |
+
'level': 'DEBUG',
|
48 |
+
'propagate': True,
|
49 |
+
},
|
50 |
+
}
|
51 |
+
}
|
52 |
+
|
53 |
+
|
54 |
+
def get_logger():
|
55 |
+
logging.config.dictConfig(LOGGING)
|
56 |
+
Logger = logging.getLogger("system")
|
57 |
+
return Logger
|
58 |
+
|
59 |
+
|
60 |
+
logger = get_logger()
|
utility/utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
@Time: 2022/12/06
|
3 |
+
@Author: LiuShu
|
4 |
+
@File: utility.py
|
5 |
+
"""
|
6 |
+
import os
|
7 |
+
import re
|
8 |
+
import sys
|
9 |
+
import shutil
|
10 |
+
import subprocess
|
11 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
12 |
+
# 读取config内容
|
13 |
+
import configparser
|
14 |
+
import yaml
|
15 |
+
import json
|
16 |
+
import requests
|
17 |
+
from utility.constant import BASE_DIR
|
18 |
+
from utility.loggers import logger
|
19 |
+
db_config_file_path = os.path.join(BASE_DIR, 'config.cfg')
|
20 |
+
logger.info(db_config_file_path)
|
21 |
+
yaml_file_path = os.path.join(BASE_DIR, 'config.yaml')
|
22 |
+
logger.info(yaml_file_path)
|
23 |
+
|
24 |
+
# config.cfg
|
25 |
+
def get_config():
|
26 |
+
config = configparser.ConfigParser()
|
27 |
+
config.read(db_config_file_path, encoding='utf8')
|
28 |
+
logger.info(str(config._sections))
|
29 |
+
return config
|
30 |
+
|
31 |
+
# 得到配置内容
|
32 |
+
config = get_config()
|
33 |
+
|
34 |
+
|
35 |
+
# yaml 文件
|
36 |
+
class ConfigParser:
|
37 |
+
@staticmethod
|
38 |
+
def load_config():
|
39 |
+
with open(yaml_file_path, 'r', encoding='utf-8') as file_stream:
|
40 |
+
config_dict = yaml.load(file_stream, Loader=yaml.Loader)
|
41 |
+
|
42 |
+
logger.info('config_dict:' + str(config_dict))
|
43 |
+
|
44 |
+
return config_dict
|
45 |
+
|
46 |
+
config_dict = ConfigParser.load_config()
|