Winnie / web.py
lewiswu1209's picture
add qa skills
c1e6869
raw
history blame contribute delete
No virus
5.18 kB
import os
import random
import re
import requests
import argparse
import string
from datetime import timedelta
from flask import Flask, session, request, jsonify, render_template
from transformers.models.bert.tokenization_bert import BertTokenizer
from bot.chatbot import ChatBot
from bot.config import special_token_list
app = Flask(__name__)
app.config["SECRET_KEY"] = os.urandom(74)
app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=7)
tokenizer:BertTokenizer = None
history_matrix:dict = {}
def move_history_from_session_to_global_memory() -> None:
global history_matrix
if session.get( "session_hash") and session["history"]:
history_matrix[session["session_hash"]] = session["history"]
def move_history_from_global_memory_to_session() -> None:
global history_matrix
if session.get( "session_hash"):
session["history"] = history_matrix.get( session.get( "session_hash") )
def set_args() -> argparse.Namespace:
parser:argparse.ArgumentParser = argparse.ArgumentParser()
parser.add_argument("--vocab_path", default=None, type=str, required=False, help="选择词库")
parser.add_argument("--model_path", default="lewiswu1209/Winnie", type=str, required=False, help="对话模型路径")
return parser.parse_args()
@app.route("/chitchat/history", methods = ["GET"])
def get_history_list() -> str:
global tokenizer
move_history_from_global_memory_to_session()
history_list:list = session.get("history")
if history_list is None:
history_list = []
history:list = []
for history_ids in history_list:
tokens = tokenizer.convert_ids_to_tokens(history_ids)
fixed_tokens = []
for token in tokens:
if token.startswith("##"):
token = token[2:]
fixed_tokens.append(token)
history.append( "".join( fixed_tokens ) )
return jsonify(history)
@app.route("/chitchat/chat", methods = ["GET"])
def talk() -> str:
global tokenizer
global history_matrix
if request.args.get("hash"):
session["session_hash"] = request.args.get("hash")
move_history_from_global_memory_to_session()
if session.get("session_hash") is None:
session["session_hash"] = "".join( random.sample(string.ascii_lowercase + string.digits, 11) )
if request.args.get("text"):
input_text = request.args.get("text")
history_list = session.get("history")
if input_text.upper()=="HELP":
help_info_list = ["输入任意文字,Winnie会和你对话",
"输入ERASE MEMORY,Winnie会清空记忆",
"输入\"<TAG>=<VALUE>\",Winnie会记录你的角色信息",
"例如:<NAME>=Vicky,Winnie会修改自己的名字",
"可以修改的角色信息有:",
"<NAME>, <GENDER>, <YEAROFBIRTH>, <MONTHOFBIRTH>, <DAYOFBIRTH>, <ZODIAC>, <AGE>",
"输入“上联:XXXXXXX”,Winnie会和你对对联",
"输入“写诗:XXXXXXX”,Winnie会以XXXXXXX为开头写诗",
"以\"请问\"开头并以问号结尾,Winnie会回答该问题"
]
return jsonify(help_info_list)
if history_list is None or len(history_list)==0 or input_text == "ERASE MEMORY":
history_list = []
output_text = requests.post(
url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
json={"data": ["ERASE MEMORY"], "session_hash": session["session_hash"]}
).json()["data"][0]
if input_text != "ERASE MEMORY":
if not re.match( r"^<.+>=.+$", input_text ):
history_list.append( tokenizer.encode(input_text, add_special_tokens=False) )
output_text = requests.post(
url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
json={"data": [input_text], "session_hash": session["session_hash"]}
).json()["data"][0]
if not re.match( r"^<.+>=.+$", input_text ):
history_list.append( tokenizer.encode(output_text, add_special_tokens=False) )
session["history"] = history_list
history_matrix[session["session_hash"]] = history_list
return jsonify([output_text])
else:
return jsonify([""])
@app.route("/")
def index() -> str:
return "Hello world!"
@app.route("/chitchat/hash", methods = ["GET"])
def get_hash() -> str:
global history_matrix
if request.args.get("hash"):
session["session_hash"] = request.args.get("hash")
move_history_from_global_memory_to_session()
hash = session.get("session_hash")
if hash:
return session.get("session_hash")
else:
return " "
@app.route( "/chitchat", methods = ["GET"] )
def chitchat() -> str:
return render_template( "chat_template.html" )
def main() -> None:
global tokenizer
args = set_args()
tokenizer = ChatBot.get_tokenizer(
args.model_path,
vocab_path=args.vocab_path,
special_token_list = special_token_list
)
app.run( host = "127.0.0.1", port = 8080 )
if __name__ == "__main__":
main()