scutcyr commited on
Commit
a3f1fa9
1 Parent(s): 2973fab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -0
app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # install torch and tf
3
+ os.system('pip install transformers SentencePiece')
4
+ os.system('pip install torch')
5
+ # pip install streamlit-chat
6
+ os.system('pip install streamlit --upgrade')
7
+ os.system('pip install streamlit-chat')
8
+
9
+ from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoTokenizer
10
+ import torch
11
+
12
+ import streamlit as st
13
+ from streamlit_chat import message
14
+
15
+ # 修改colab笔记本设置为gpu,推理更快
16
+ device = torch.device('cpu')
17
+
18
+ def preprocess(text):
19
+ text = text.replace("\n", "\\n").replace("\t", "\\t")
20
+ return text
21
+
22
+ def postprocess(text):
23
+ return text.replace("\\n", "\n").replace("\\t", "\t")
24
+
25
+ def answer(user_history, bot_history, sample=True, top_p=1, temperature=0.7):
26
+ '''sample:是否抽样。生成任务,可以设置为True;
27
+ top_p:0-1之间,生成的内容越多样
28
+ max_new_tokens=512 lost...'''
29
+
30
+ if len(bot_history)>0:
31
+ context = "\n".join([f"病人:{user_history[i]}\n医生:{bot_history[i]}" for i in range(len(bot_history))])
32
+ input_text = context + "\n病人:" + user_history[-1] + "\n医生:"
33
+ else:
34
+ input_text = "病人:" + user_history[-1] + "\n医生:"
35
+ return "我是利用人工智能技术,结合大数据训练得到的智能医疗问答模型扁鹊,你可以向我提问。"
36
+
37
+
38
+ input_text = preprocess(input_text)
39
+ print(input_text)
40
+ encoding = tokenizer(text=input_text, truncation=True, padding=True, max_length=768, return_tensors="pt").to(device)
41
+ if not sample:
42
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, num_beams=1, length_penalty=0.6)
43
+ else:
44
+ out = model.generate(**encoding, return_dict_in_generate=True, output_scores=False, max_new_tokens=512, do_sample=True, top_p=top_p, temperature=temperature, no_repeat_ngram_size=3)
45
+ out_text = tokenizer.batch_decode(out["sequences"], skip_special_tokens=True)
46
+ print('医生: '+postprocess(out_text[0]))
47
+ return postprocess(out_text[0])
48
+
49
+ st.set_page_config(
50
+ page_title="Chinese ChatBot - Demo",
51
+ page_icon=":robot:"
52
+ )
53
+
54
+ st.header("Chinese ChatBot - Demo")
55
+ st.markdown("[Github](https://github.com/scutcyr)")
56
+
57
+
58
+ @st.cache_resource
59
+ def load_model():
60
+ model = T5ForConditionalGeneration.from_pretrained("scutcyr/BianQue-1.0")
61
+ model.to(device)
62
+ print('Model Load done!')
63
+ return model
64
+
65
+ @st.cache_resource
66
+ def load_tokenizer():
67
+ tokenizer = T5Tokenizer.from_pretrained("scutcyr/BianQue-1.0")
68
+ print('Tokenizer Load done!')
69
+ return tokenizer
70
+
71
+ model = load_model()
72
+ tokenizer = load_tokenizer()
73
+
74
+ if 'generated' not in st.session_state:
75
+ st.session_state['generated'] = []
76
+
77
+ if 'past' not in st.session_state:
78
+ st.session_state['past'] = []
79
+
80
+
81
+ def get_text():
82
+ input_text = st.text_input("用户: ","你好!", key="input")
83
+ return input_text
84
+
85
+ #user_history = []
86
+ #bot_history = []
87
+ user_input = get_text()
88
+ #user_history.append(user_input)
89
+
90
+ if user_input:
91
+ st.session_state.past.append(user_input)
92
+ output = answer(st.session_state['past'],st.session_state["generated"])
93
+ st.session_state.generated.append(output)
94
+ #bot_history.append(output)
95
+
96
+ if st.session_state['generated']:
97
+
98
+ #for i in range(len(st.session_state['generated'])-1, -1, -1):
99
+ # message(st.session_state["generated"][i], key=str(i))
100
+ # message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
101
+ for i in range(len(st.session_state['generated'])):
102
+ message(st.session_state['past'][i], is_user=True, key=str(i) + '_user')
103
+ message(st.session_state["generated"][i], key=str(i))
104
+
105
+
106
+ if st.button("清理对话缓存"):
107
+ # Clear values from *all* all in-memory and on-disk data caches:
108
+ # i.e. clear values from both square and cube
109
+ st.session_state['generated'] = []
110
+ st.session_state['past'] = []