ls291 commited on
Commit
dc82c71
1 Parent(s): 7e2b584

Add application file

Browse files
DB/SQL.db ADDED
Binary file (28.7 kB). View file
 
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ from transformers import AutoModel, AutoTokenizer
5
+ import gradio as gr
6
+ import mdtex2html
7
+ from transformers import AutoTokenizer, AutoModel
8
+ from utility.utils import config_dict
9
+ from utility.loggers import logger
10
+ from sentence_transformers import util
11
+ from local_database import db_operate
12
+ from prompt import table_schema, embedder,corpus_embeddings, corpus,In_context_prompt, query_template
13
+
14
+ tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b-int8", trust_remote_code=True)
15
+ model = AutoModel.from_pretrained("THUDM/chatglm-6b-int8",trust_remote_code=True).float()
16
+ model = model.eval()
17
+
18
+
19
+ """Override Chatbot.postprocess"""
20
+
21
+ def postprocess(self, y):
22
+ if y is None:
23
+ return []
24
+ for i, (message, response) in enumerate(y):
25
+ y[i] = (
26
+ None if message is None else mdtex2html.convert((message)),
27
+ None if response is None else mdtex2html.convert(response),
28
+ )
29
+ return y
30
+
31
+ gr.Chatbot.postprocess = postprocess
32
+
33
+ def parse_text(text):
34
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
35
+ lines = text.split("\n")
36
+ lines = [line for line in lines if line != ""]
37
+ count = 0
38
+ for i, line in enumerate(lines):
39
+ if "```" in line:
40
+ count += 1
41
+ items = line.split('`')
42
+ if count % 2 == 1:
43
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
44
+ else:
45
+ lines[i] = f'<br></code></pre>'
46
+ else:
47
+ if i > 0:
48
+ if count % 2 == 1:
49
+ line = line.replace("`", "\`")
50
+ line = line.replace("<", "&lt;")
51
+ line = line.replace(">", "&gt;")
52
+ line = line.replace(" ", "&nbsp;")
53
+ line = line.replace("*", "&ast;")
54
+ line = line.replace("_", "&lowbar;")
55
+ line = line.replace("-", "&#45;")
56
+ line = line.replace(".", "&#46;")
57
+ line = line.replace("!", "&#33;")
58
+ line = line.replace("(", "&#40;")
59
+ line = line.replace(")", "&#41;")
60
+ line = line.replace("$", "&#36;")
61
+ lines[i] = "<br>"+line
62
+ text = "".join(lines)
63
+ return text
64
+
65
+
66
+ def obtain_sql(response):
67
+ response = re.split("```|\n\n", response)
68
+ for text in response:
69
+ if "SELECT" in text:
70
+ response = text
71
+ break
72
+ else:
73
+ response = response[0]
74
+ response = response.replace("\n", " ").replace("``", "").replace("`", "").strip()
75
+ response = re.sub(' +',' ', response)
76
+ return response
77
+
78
+
79
+ def predict(input, chatbot, history):
80
+ max_length = 2048
81
+ top_p = 0.7
82
+ temperature = 0.2
83
+ top_k = 3
84
+ dboperate = db_operate(config_dict['db_path'])
85
+ logger.info(f"query:{input}")
86
+ chatbot_prompt = """
87
+ 你是一个文本转SQL的生成器,你的主要目标是尽可能的协助用户将输入的文本转换为正确的SQL语句。
88
+ 上下文开始
89
+ 生成的表名和表字段均来自以下表:
90
+ """
91
+ query_embedding = embedder.encode(input, convert_to_tensor=True) # 与6张表的表名和输入的问题进行相似度计算
92
+ cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
93
+ top_results = torch.topk(cos_scores, k=top_k) # 拿到topk=3的表名
94
+ # 组合Prompt
95
+ table_nums = 0
96
+ for score, idx in zip(top_results[0], top_results[1]):
97
+ # 阈值过滤
98
+ if score > 0.45:
99
+ table_nums += 1
100
+ chatbot_prompt += table_schema[corpus[idx]]
101
+ chatbot_prompt += "上下文结束\n"
102
+ # In-Context Learning
103
+ if table_nums >= 2 and not history: # 如果表名大于等于2个,且没有历史记录,就加上In-Context Learning
104
+ chatbot_prompt += In_context_prompt
105
+ # 加上查询模板
106
+ chatbot_prompt += query_template
107
+ query = chatbot_prompt.replace("<user_input>", input)
108
+ chatbot.append((parse_text(input), ""))
109
+ # 流式输出
110
+ # for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p,
111
+ # temperature=temperature):
112
+ # chatbot[-1] = (parse_text(input), parse_text(response))
113
+ response, history = model.chat(tokenizer, query, history=history, max_length=max_length, top_p=top_p,temperature=temperature)
114
+ chatbot[-1] = (parse_text(input), parse_text(response))
115
+ # chatbot[-1] = (chatbot[-1][0], chatbot[-1][1])
116
+ # 获取结果中的SQL语句
117
+ response = obtain_sql(response)
118
+ # 查询结果
119
+ if "SELECT" in response:
120
+ try:
121
+ sql_stauts = "sql语句执行成功,结果如下:"
122
+ sql_result = dboperate.query_data(response)
123
+ sql_result = str(sql_result)
124
+ except Exception as e:
125
+ sql_stauts = "sql语句执行失败"
126
+ sql_result = str(e)
127
+ chatbot[-1] = (chatbot[-1][0],
128
+ chatbot[-1][1] + "\n\n"+ "===================="+"\n\n" + sql_stauts + "\n\n" + sql_result)
129
+ return chatbot, history
130
+
131
+
132
+ def reset_user_input():
133
+ return gr.update(value='')
134
+
135
+
136
+ def reset_state():
137
+ return [], []
138
+
139
+ with gr.Blocks() as demo:
140
+ gr.HTML("""<h1 align="center">🤖ChatSQL</h1>""")
141
+
142
+ chatbot = gr.Chatbot()
143
+ with gr.Row():
144
+ with gr.Column(scale=4):
145
+ with gr.Column(scale=12):
146
+ user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
147
+ container=False)
148
+ with gr.Column(min_width=32, scale=1):
149
+ submitBtn = gr.Button("Submit", variant="primary")
150
+ with gr.Column(scale=1):
151
+ emptyBtn = gr.Button("Clear History")
152
+ # max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
153
+ # top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
154
+ # temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)
155
+
156
+ history = gr.State([])
157
+
158
+ submitBtn.click(predict, [user_input, chatbot, history], [chatbot, history],
159
+ show_progress=True)
160
+ submitBtn.click(reset_user_input, [], [user_input])
161
+
162
+ emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
163
+
164
+ demo.queue().launch(share=False, inbrowser=True)
local_database.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Liushu
2
+ # Date: 2023-04-24
3
+ # 这里的只是用来校验生成SQL的正确性
4
+ import sqlite3
5
+ import numpy as np
6
+ from utility.utils import config_dict as DB_CONFIG
7
+
8
+
9
+ class db_operate(object):
10
+
11
+ def __init__(self, db_path):
12
+ self.conn = sqlite3.connect(db_path)
13
+ self.cursor = self.conn.cursor()
14
+
15
+ def create_table(self, SQL):
16
+ self.cursor.execute(SQL)
17
+ self.conn.commit()
18
+
19
+ def insert_data(self, SQL, data):
20
+ self.cursor.executemany(SQL, data)
21
+ self.conn.commit()
22
+
23
+ def update_data(self, table_name, data, condition):
24
+ self.cursor.execute(f"UPDATE {table_name} SET {data} WHERE {condition}")
25
+ self.conn.commit()
26
+
27
+ def delete_data(self, table_name, condition):
28
+ self.cursor.execute(f"DELETE FROM {table_name} WHERE {condition}")
29
+ self.conn.commit()
30
+
31
+ def query_data(self, SQL_statement):
32
+ self.cursor.execute(SQL_statement)
33
+ return self.cursor.fetchall()
34
+
35
+
36
+ if __name__ == "__main__":
37
+ db_path = DB_CONFIG["db_path"]
38
+ TABLE = DB_CONFIG["TABLE"]
39
+ dboperate = db_operate(db_path)
40
+ print("建表中...")
41
+ # 建表
42
+ for table_name in TABLE.keys():
43
+ table_field_sql = f"create table if not exists {table_name} ("
44
+ for idx, filed in enumerate(TABLE[table_name]["field"].items()):
45
+ if idx == len(TABLE[table_name]["field"].items()) - 1:
46
+ table_field_sql += filed[0] + f" {filed[1][1]})"
47
+ else:
48
+ table_field_sql += filed[0] + f" {filed[1][1]},"
49
+ # 建表语句
50
+ dboperate.create_table(table_field_sql)
51
+
52
+ print("插入数据中...")
53
+ # 插入数据
54
+ TABLE_Values = DB_CONFIG["TABLE_Values"]
55
+ for table_name in TABLE.keys():
56
+ table_field_sql = f"INSERT INTO {table_name} VALUES ("
57
+ field_len = len(list(TABLE_Values[table_name]["field"].keys()))
58
+ insert_slot = ",".join(['?' for i in range(field_len)])
59
+ table_field_sql += insert_slot + ")"
60
+ table_val = np.array(list(TABLE_Values[table_name]["field"].values())).T.tolist() # 转置+list
61
+ dboperate.insert_data(table_field_sql, table_val)
logs/server.log ADDED
File without changes
prompt.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text2SQL机器人·Prompt
3
+ """
4
+ import numpy as np
5
+ from sentence_transformers import SentenceTransformer, util
6
+ from utility.utils import config_dict as DB_CONFIG
7
+ from local_database import db_operate
8
+
9
+ embedder = SentenceTransformer('paraphrase-multilingual-MiniLM-L12-v2')
10
+
11
+ # 每张表的表含义和对应的表结构
12
+ # table_schema = {"货物销售表,主要存储货物名称、净收益率、损失率、环比增长率、销售量、货物品类、货物仓库、销售负责人名字、销售部门、销售负责人联系方式":
13
+ # """表1: cargo
14
+ # 字段1:cargo_id(货物编号),字段2:cargo_name(货物名称),字段3:year(年份),字段4:net_yield(净收益率%),字段5:loss_rate(损失率%),字段6:month_on_month_growth_rate(环比增长率%),字段7:sales_volume(销售量),字段8:cargo_price(货物单价),字段9:cargo_category(货物品类),字段10:source_cargo(货物来源),字段11:storage_warehouse(货物仓库),字段12:sales_person_name(销售责任人名字),字段13:sales_person_id(销售负责人id),字段14:sales_department(销售部门),字段15:sales_person_numbers(销售负责人联系方式)
15
+ # """,
16
+ # "人员信息表,主要存储销售人员人名、入职时间、当前业绩、职位等级":
17
+ # """表2: sales
18
+ # 字段1:sales_person_id(销售人员id),字段2:sales_person_name(人员名称),字段3:sales_person_level(人员等级),字段4:sales_person_work_date(入职时间),字段5:sales_person_leader_id(人员主管id),字段6:sales_person_leader_name(人员主管名字),字段7:sales_person_number(人员电话),字段8:sales_person_achievement_year(人员业绩年份),字段9:sales_person_achievement(人员业绩),字段10:sales_person_department(人员部门名称),字段11:sales_person_department_id(人员部门id)
19
+ # """,
20
+ # "货物信息表,主要存储货物名称、货物来源地、购买价格、货物大类、货物品类、供应商名称和供应商负责人":
21
+ # """表3: cargo_info
22
+ # 字段1:cargo_info_id(货物信息id),字段2:cargo_id(货物id),字段3:cargo_name(货物名称),字段4:origin_cargo(货物产地),字段5:cargo_purchase_price(货物购买价),字段6:cargo_type(货物大类),字段7:cargo_category(货物品类),字段8:cargo_supply_company(供货公司),字段9:cargo_num(货物数量),字段10:cargo_supply_aftermarket_person(货物售后负责人),字段11:cargo_supply_aftermarket_person_number(货物售后负责人联系电话),字段12:cargo_supply_market_person(货物公司销售负责人),字段13:cargo_supply_market_person_number(货物公司负责人联系电话)
23
+ # """,
24
+ # "部门表,主要存储部门名称、部门职责、部门主管":
25
+ # """表4: depart_list
26
+ # 字段1:department_id(部门id),字段2:department_name(部门名称),字段3:department_duty(部门职责),字段4:department_lead_name(部门负责人名字),字段5:department_lead_name_id(部门负责人id),字段6:department_person_nums(部门人数)
27
+ # """,
28
+ # "购买信息表,主要存储购买公司名称、购买公司街道、购买负责人、购买货物名称和数量以及类型":
29
+ # """表5: purchase_info
30
+ # 字段1:purchase_company_id(货物购买方id),字段2:purchase_company_name(货物购买方名称),字段3:purchase_company_address(货物购买方地址),字段4:purchase_company_person_name(货物购买方负责人人名),字段5:purchase_company_person_numbers(货物购买方负责人联系方式),字段6:purchase_company_person_level(货物购买方负责人职位),字段7:purchase_cargo_name_id(购买货物id),字段8:purchase_cargo_name(购买货物名称),字段9:purchase_cargo_nums(购买货物数量),字段10:purchase_cargo_type(购买货物大类),字段11:purchase_cargo_category(购买货物品类)
31
+ # """,
32
+ # "供应商信息表,主要存储供应商名称、供应商地址、供应商货物名称、入供应商目录日期":
33
+ # """表6: supply_company
34
+ # 字段1:supply_company_id(供货公司id),字段2:supply_company_name(供货公司名称),字段3:supply_company_address(供货公司地址),字段4:supply_company_product_id(供货公司货物id),字段5:supply_company_product_name(供货公司货物名称),字段6:supply_company_date(供货公司入名录时间)
35
+ # """}
36
+
37
+
38
+ # corpus = list(table_schema.keys())
39
+
40
+ # corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
41
+
42
+
43
+
44
+ chatbot_prompt = """
45
+ 你是一个文本转SQL的生成器,你的主要目标是尽可能的协助用户,将输入的文本转换为正确的SQL语句。
46
+ 上下文开始
47
+ 表名和表字段来自以下表:
48
+ 表1: cargo
49
+ 字段1:cargo_id(货物编号),字段2:cargo_name(货物名称),字段3:year(年份),字段4:net_yield(净收益率%),字段5:loss_rate(损失率%),字段6:month_on_month_growth_rate(环比增长率%),字段7:sales_volume(销售量),字段8:cargo_price(货物单价),字段9:cargo_category(货物品类),字段10:source_cargo(货物来源),字段11:storage_warehouse(货物仓库),字段12:sales_person_name(销售责任人名字),字段13:sales_person_id(销售负责人id),字段14:sales_department(销售部门),字段15:sales_person_numbers(销售负责人联系方式)
50
+ 表2: sales
51
+ 字段1:sales_person_id(销售负责人id),字段2:sales_person_name(人员名称),字段3:sales_person_level(人员等级),字段4:sales_person_work_date(入职时间),字段5:sales_person_leader_id(人员主管id),字段6:sales_person_leader_name(人员主管名字),字段7:sales_person_number(人员电话),字段8:sales_person_achievement_year(人员业绩年份),字段9:sales_person_achievement(人员业绩),字段10:sales_person_department(人员部门名称),字段11:sales_person_department_id(人员部门id)
52
+ 表3: cargo_info
53
+ 字段1:cargo_info_id(货物信息id),字段2:cargo_id(货物id),字段3:cargo_name(货物名称),字段4:origin_cargo(货物产地),字段5:cargo_purchase_price(货物购买价),字段6:cargo_type(货物大类),字段7:cargo_category(货物品类),字段8:cargo_supply_company(供货公司),字段9:cargo_num(货物数量),字段10:cargo_supply_aftermarket_person(货物售后负责人),字段11:cargo_supply_aftermarket_person_number(货物售后负责人联系电话),字段12:cargo_supply_market_person(货物公司销售负责人),字段13:cargo_supply_market_person_number(货物公司负责人联系电话)
54
+ 表4: depart_list
55
+ 字段1:department_id(部门id),字段2:department_name(部门名称),字段3:department_duty(部门职责),字段4:department_lead_name(部门负责人名字),字段5:department_lead_name_id(部门负责人id),字段6:department_person_nums(部门人数)
56
+ 表5: purchase_info
57
+ 字段1:purchase_company_id(货物购买方id),字段2:purchase_company_name(货物购买方名称),字段3:purchase_company_address(货物购买方地址),字段4:purchase_company_person_name(货物购买方负责人人名),字段5:purchase_company_person_numbers(货物购买方负责人联系方式),字段6:purchase_company_person_level(货物购买方负责人职位),字段7:purchase_cargo_name_id(购买货物id),字段8:purchase_cargo_name(购买货物名称),字段9:purchase_cargo_nums(购买货物数量),字段10:purchase_cargo_type(购买货物大类),字段11:purchase_cargo_category(购买货物品类)
58
+ 表6: supply_company
59
+ 字段1:supply_company_id(供货公司id),字段2:supply_company_name(供货公司名称),字段3:supply_company_address(供货公司地址),字段4:supply_company_product_id(供货公司货物id),字段5:supply_company_product_name(供货公司货物名称),字段6:supply_company_date(供货公司入名录时间)
60
+ 上下文结束
61
+ 问: 请帮我查询所有的货物名称
62
+ 答: SELECT cargo_name FROM cargo
63
+ 问: 请帮我查询在2019年的净收益率大于10并且销售量大于100并且业绩大于1000的销售负责人名字
64
+ 答: SELECT sales.sales_person_name FROM sales INNER JOIN cargo on sales.sales_person_id = cargo.sales_person_id WHERE cargo.year = 2019 AND cargo.net_yield > 10 AND cargo.sales_volume > 100 AND sales.sales_person_achievement > 1000
65
+ 问: 文本转SQL: <user input>
66
+ 答:
67
+ """
68
+
69
+ # 一些学习的例子
70
+ In_context_prompt = """问: 请帮我查询所有的货物名称
71
+ 答: SELECT cargo_name FROM cargo;
72
+ 问: 请帮我查询在2019年的净收益率大于10并且销售量大于100并且业绩大于1000的销售负责人名字
73
+ 答: SELECT sales.sales_person_name FROM sales INNER JOIN cargo on sales.sales_person_id = cargo.sales_person_id WHERE cargo.year = 2019 AND cargo.net_yield > 10 AND cargo.sales_volume > 100 AND sales.sales_person_achievement > 1000;
74
+ 问: 请帮我查询购买"板材"货物的公司名称
75
+ 答: SELECT purchase_company_name FROM purchase_info WHERE purchase_cargo_name = "板材";
76
+ """
77
+
78
+ query_template = """问: <user_input>
79
+ 答:
80
+ """
81
+
82
+ # yaml解析
83
+ TABLE = DB_CONFIG["TABLE"]
84
+ table_schema = {}
85
+
86
+ for table_name in TABLE.keys():
87
+ # table描述拼接
88
+ table_info = """"""
89
+ table_info += "表名:" + table_name + "\n"
90
+ table_info += "字段:"
91
+ for idx, filed in enumerate(TABLE[table_name]["field"].items()):
92
+ if idx == len(TABLE[table_name]["field"].items()) - 1:
93
+ table_info += filed[0] + "(" + filed[1][0] + ")"
94
+ else:
95
+ table_info += filed[0] + "(" + filed[1][0] + "),"
96
+
97
+ table_schema[TABLE[table_name]["info"]] = table_info
98
+
99
+ # 获取表的描述信息
100
+ corpus = list(table_schema.keys())
101
+
102
+ # 向量化
103
+ corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ protobuf
2
+ transformers==4.27.1
3
+ cpm_kernels
4
+ torch>=1.10
5
+ gradio
6
+ mdtex2html
7
+ sentencepiece
8
+ sentence_transformers==2.2.2
utility/__init__.py ADDED
File without changes
utility/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (139 Bytes). View file
 
utility/__pycache__/constant.cpython-39.pyc ADDED
Binary file (364 Bytes). View file
 
utility/__pycache__/db_tools.cpython-39.pyc ADDED
Binary file (4.28 kB). View file
 
utility/__pycache__/loggers.cpython-39.pyc ADDED
Binary file (1.18 kB). View file
 
utility/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.4 kB). View file
 
utility/constant.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @File: 存放固定的路径
3
+ """
4
+ import os
5
+ import pandas as pd
6
+ # 根目录
7
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
8
+ configurable_file_path = os.path.join(BASE_DIR, 'configurable_file')
utility/db_tools.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Time: 2022/11/03
3
+ @Author: LiuShu
4
+ @File: 数据库操作类库
5
+ """
6
+ import pymysql
7
+ from utility.loggers import logger
8
+ from utility.utils import config
9
+
10
+
11
+ class Cur_db(object):
12
+ def __init__(self):
13
+ self.config = config
14
+ self.db_name = self.config['database']['DB']
15
+
16
+ def pymysql_cur(self, reback=5):
17
+ """ 连接数据库 """
18
+ try:
19
+ self.conn = pymysql.connect(host=self.config['database']['HOST'], user=self.config['database']['USER'],
20
+ password=self.config['database']['PWD'], db=self.db_name,
21
+ port=int(self.config['database']['PORT']),
22
+ charset='utf8')
23
+ except Exception as e:
24
+ if reback == 0:
25
+ logger.exception('Exception occurred.')
26
+ return
27
+ else:
28
+ logger.exception('Exception occurred.')
29
+ reback -= 1
30
+ return self.pymysql_cur(reback)
31
+
32
+ def get_db_name(self):
33
+ """
34
+
35
+ :return:
36
+ """
37
+ return self.db_name
38
+
39
+ def select(self, sql, params, reback=2):
40
+ """ 查询单条语句,并返回查询所有的结果 """
41
+ try:
42
+ cur = self.conn.cursor()
43
+ cur.execute(sql, params)
44
+ # 单条
45
+ res = cur.fetchone()
46
+ cur.close()
47
+ if res:
48
+ return res
49
+ return
50
+ except Exception as e:
51
+ logger.exception('Exception occurred.')
52
+ if reback > 0:
53
+ reback -= 1
54
+ return self.select(sql, reback)
55
+ else:
56
+ logger.info(str('*' * 100))
57
+ return
58
+
59
+ def _select(self, sql, reback=2):
60
+ try:
61
+ cur = self.conn.cursor()
62
+ cur.execute(sql)
63
+ # 单条
64
+ res = cur.fetchone()
65
+ cur.close()
66
+ if res:
67
+ return res[0]
68
+ return
69
+ except Exception as e:
70
+ logger.exception('Exception occurred.')
71
+ if reback > 0:
72
+ reback -= 1
73
+ return self.select(sql, reback)
74
+ else:
75
+ logger.info(str('*' * 100))
76
+ return
77
+
78
+ def selectMany(self, sql, reback=2):
79
+ try:
80
+ cur = self.conn.cursor()
81
+ cur.execute(sql)
82
+ res = cur.fetchall()
83
+ cur.close()
84
+ if res:
85
+ return res
86
+ logger.info(str(sql))
87
+ return
88
+ except Exception as e:
89
+ logger.exception('Exception occurred.')
90
+ if reback > 0:
91
+ reback -= 1
92
+ return self.selectMany(sql, reback)
93
+ else:
94
+ logger.info(str('*' * 100))
95
+ return
96
+
97
+ def insert(self, sql, params):
98
+ cur = self.conn.cursor()
99
+ cur.execute(sql, params)
100
+ self.conn.commit()
101
+ return
102
+
103
+ def _insert(self, sql):
104
+ cur = self.conn.cursor()
105
+ cur.execute(sql)
106
+ self.conn.commit()
107
+
108
+ def insert_batch(self, sql, data_list):
109
+ """
110
+ 将dataframe批量入库
111
+ :param sql: 插入语句
112
+ :return:
113
+ """
114
+ cur = self.conn.cursor()
115
+ # 开启事务
116
+ self.conn.begin()
117
+ try:
118
+ cur.executemany(sql, data_list)
119
+ self.conn.commit()
120
+ cur.close()
121
+ self.conn.close()
122
+ return True
123
+ except:
124
+ # 万一失败了,要进行回滚操作
125
+ self.conn.rollback()
126
+ cur.close()
127
+ self.conn.close()
128
+ return False
129
+
130
+ def update(self, sql, params):
131
+ cur = self.conn.cursor()
132
+ cur.execute(sql, params)
133
+ self.conn.commit()
134
+ return
135
+
136
+ def _update(self, sql):
137
+ try:
138
+ cur = self.conn.cursor()
139
+ cur.execute(sql)
140
+ self.conn.commit()
141
+ except Exception as e:
142
+ logger.exception('Exception occurred.')
143
+
144
+ def close(self):
145
+ self.conn.close()
146
+ pass
147
+
148
+
149
+ if __name__ == '__main__':
150
+ db_con = Cur_db()
151
+ logger.info(str(db_con.config['database']['HOST']))
152
+ print(str(db_con.config['database']['HOST']))
153
+ db_con.pymysql_cur()
154
+ sql = "SELECT * FROM cargo"
155
+ res = db_con.selectMany(sql)
156
+ print(str(res))
157
+ db_con.close()
utility/loggers.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Time: 2022/11/03
3
+ @Author: LiuShu
4
+ @File: loggers.py
5
+ """
6
+ import os
7
+ from utility.constant import BASE_DIR
8
+ import logging
9
+ import logging.config
10
+
11
+ LOGGING = {
12
+ 'version': 1,
13
+ 'disable_existing_loggers': True,
14
+ 'formatters': {
15
+ 'simple': {
16
+ 'format': '%(levelname)s %(message)s'
17
+ },
18
+ 'standard': {
19
+ 'format': '[%(asctime)s] %(filename)s-[line:%(lineno)d] %(levelname)s--%(message)s',
20
+ 'datefmt': '%Y-%m-%d %H:%M:%S',
21
+ },
22
+ },
23
+ 'handlers': {
24
+ 'file': {
25
+ 'level': 'DEBUG',
26
+ 'class': 'logging.handlers.TimedRotatingFileHandler',
27
+ # TODO 文件路径修改位置
28
+ 'filename': os.path.join(BASE_DIR, 'logs/server.log'),
29
+ 'formatter': 'standard',
30
+ 'when': 'D',
31
+ 'interval': 1,
32
+ 'backupCount': 7,
33
+ },
34
+ 'null': {
35
+ 'level': 'DEBUG',
36
+ 'class': 'logging.StreamHandler',
37
+ },
38
+ },
39
+ 'loggers': {
40
+ 'django': {
41
+ 'handlers': ['null'],
42
+ 'level': 'ERROR',
43
+ 'propagate': True,
44
+ },
45
+ 'system': {
46
+ 'handlers': ['file'],
47
+ 'level': 'DEBUG',
48
+ 'propagate': True,
49
+ },
50
+ }
51
+ }
52
+
53
+
54
+ def get_logger():
55
+ logging.config.dictConfig(LOGGING)
56
+ Logger = logging.getLogger("system")
57
+ return Logger
58
+
59
+
60
+ logger = get_logger()
utility/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ @Time: 2022/12/06
3
+ @Author: LiuShu
4
+ @File: utility.py
5
+ """
6
+ import os
7
+ import re
8
+ import sys
9
+ import shutil
10
+ import subprocess
11
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
12
+ # 读取config内容
13
+ import configparser
14
+ import yaml
15
+ import json
16
+ import requests
17
+ from utility.constant import BASE_DIR
18
+ from utility.loggers import logger
19
+ db_config_file_path = os.path.join(BASE_DIR, 'config.cfg')
20
+ logger.info(db_config_file_path)
21
+ yaml_file_path = os.path.join(BASE_DIR, 'config.yaml')
22
+ logger.info(yaml_file_path)
23
+
24
+ # config.cfg
25
+ def get_config():
26
+ config = configparser.ConfigParser()
27
+ config.read(db_config_file_path, encoding='utf8')
28
+ logger.info(str(config._sections))
29
+ return config
30
+
31
+ # 得到配置内容
32
+ config = get_config()
33
+
34
+
35
+ # yaml 文件
36
+ class ConfigParser:
37
+ @staticmethod
38
+ def load_config():
39
+ with open(yaml_file_path, 'r', encoding='utf-8') as file_stream:
40
+ config_dict = yaml.load(file_stream, Loader=yaml.Loader)
41
+
42
+ logger.info('config_dict:' + str(config_dict))
43
+
44
+ return config_dict
45
+
46
+ config_dict = ConfigParser.load_config()