File size: 5,179 Bytes
f3c6b77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c1e6869
f3c6b77
 
 
 
 
 
c1e6869
 
f3c6b77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

import os
import random
import re
import requests
import argparse
import string

from datetime import timedelta
from flask import Flask, session, request, jsonify, render_template

from transformers.models.bert.tokenization_bert import BertTokenizer

from bot.chatbot import ChatBot
from bot.config import special_token_list

app = Flask(__name__)
app.config["SECRET_KEY"] = os.urandom(74)
app.config["PERMANENT_SESSION_LIFETIME"] = timedelta(days=7)

tokenizer:BertTokenizer = None

history_matrix:dict = {}

def move_history_from_session_to_global_memory() -> None:
    global history_matrix

    if session.get( "session_hash") and session["history"]:
        history_matrix[session["session_hash"]] = session["history"]

def move_history_from_global_memory_to_session() -> None:
    global history_matrix

    if session.get( "session_hash"):
        session["history"] = history_matrix.get( session.get( "session_hash") )

def set_args() -> argparse.Namespace:
    parser:argparse.ArgumentParser = argparse.ArgumentParser()
    parser.add_argument("--vocab_path", default=None, type=str, required=False, help="选择词库")
    parser.add_argument("--model_path", default="lewiswu1209/Winnie", type=str, required=False, help="对话模型路径")

    return parser.parse_args()

@app.route("/chitchat/history", methods = ["GET"])
def get_history_list() -> str:
    global tokenizer

    move_history_from_global_memory_to_session()

    history_list:list = session.get("history")
    if history_list is None:
        history_list = []

    history:list = []
    for history_ids in history_list:
        tokens = tokenizer.convert_ids_to_tokens(history_ids)
        fixed_tokens = []
        for token in tokens:
            if token.startswith("##"):
                token = token[2:]
            fixed_tokens.append(token)
        history.append( "".join( fixed_tokens ) )

    return jsonify(history)

@app.route("/chitchat/chat", methods = ["GET"])
def talk() -> str:
    global tokenizer
    global history_matrix

    if request.args.get("hash"):
        session["session_hash"] = request.args.get("hash")
        move_history_from_global_memory_to_session()

    if session.get("session_hash") is None:
        session["session_hash"] = "".join( random.sample(string.ascii_lowercase + string.digits, 11) )

    if request.args.get("text"):
        input_text = request.args.get("text")
        history_list = session.get("history")

        if input_text.upper()=="HELP":
            help_info_list = ["输入任意文字,Winnie会和你对话",
            "输入ERASE MEMORY,Winnie会清空记忆",
            "输入\"<TAG>=<VALUE>\",Winnie会记录你的角色信息",
            "例如:<NAME>=Vicky,Winnie会修改自己的名字",
            "可以修改的角色信息有:",
            "<NAME>, <GENDER>, <YEAROFBIRTH>, <MONTHOFBIRTH>, <DAYOFBIRTH>, <ZODIAC>, <AGE>",
            "输入“上联:XXXXXXX”,Winnie会和你对对联",
            "输入“写诗:XXXXXXX”,Winnie会以XXXXXXX为开头写诗",
            "以\"请问\"开头并以问号结尾,Winnie会回答该问题"
            ]
            return jsonify(help_info_list)

        if history_list is None or len(history_list)==0 or input_text == "ERASE MEMORY":
            history_list = []
            output_text = requests.post(
                url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
                json={"data": ["ERASE MEMORY"], "session_hash": session["session_hash"]}
            ).json()["data"][0]

        if input_text != "ERASE MEMORY":
            if not re.match( r"^<.+>=.+$", input_text ):
                history_list.append( tokenizer.encode(input_text, add_special_tokens=False) )
            output_text = requests.post(
                url='https://hf.space/embed/lewiswu1209/Winnie/+/api/predict/',
                json={"data": [input_text], "session_hash": session["session_hash"]}
            ).json()["data"][0]
            if not re.match( r"^<.+>=.+$", input_text ):
                history_list.append( tokenizer.encode(output_text, add_special_tokens=False) )

        session["history"] = history_list
        history_matrix[session["session_hash"]] = history_list
        return jsonify([output_text])
    else:
        return jsonify([""])

@app.route("/")
def index() -> str:
  return "Hello world!"

@app.route("/chitchat/hash", methods = ["GET"])
def get_hash() -> str:
    global history_matrix

    if request.args.get("hash"):
        session["session_hash"] = request.args.get("hash")
        move_history_from_global_memory_to_session()
    hash = session.get("session_hash")
    if hash:
        return session.get("session_hash")
    else:
        return " "

@app.route( "/chitchat", methods = ["GET"] )
def chitchat() -> str:
    return render_template( "chat_template.html" )

def main() -> None:
    global tokenizer

    args = set_args()
    tokenizer = ChatBot.get_tokenizer(
        args.model_path,
        vocab_path=args.vocab_path,
        special_token_list = special_token_list
    )

    app.run( host = "127.0.0.1", port = 8080 )

if __name__ == "__main__":
    main()