tekkonetes brezende commited on
Commit
de9358f
0 Parent(s):

Duplicate from cloudqi/MultisourceChat

Browse files

Co-authored-by: Bruno Rezende <[email protected]>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +321 -0
  4. requirements.txt +2 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: CQI ChatBot Multisource Test
3
+ emoji: 📊
4
+ colorFrom: blue
5
+ colorTo: orange
6
+ sdk: gradio
7
+ sdk_version: 3.20.1
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: cloudqi/MultisourceChat
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+
5
+ from text_generation import Client, InferenceAPIClient
6
+
7
+ openchat_preprompt = (
8
+ "\n<human>: Hi!\n<bot>: Hi!\nMy name is Kleber Assistant, model version is 0.1, part of an open-source kit for "
9
+ "fine-tuning new specialists bots! I was created by Together, LAION, and Ontocord.ai and the open-source "
10
+ "community, this model is from CloudQI using this base. I am not human, not evil and not alive, and thus have no thoughts and feelings, "
11
+ "but I am programmed to be helpful, polite, honest, and friendly.\n"
12
+ )
13
+
14
+
15
+ def get_client(model: str):
16
+ if model == "Rallio67/joi2_20Be_instruct_alpha":
17
+ return Client(os.getenv("JOI_API_URL"))
18
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
19
+ return Client(os.getenv("OPENCHAT_API_URL"))
20
+ return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))
21
+
22
+
23
+ def get_usernames(model: str):
24
+ """
25
+ Returns:
26
+ (str, str, str, str): pre-prompt, username, bot name, separator
27
+ """
28
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
29
+ return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>"
30
+ if model == "Rallio67/joi2_20Be_instruct_alpha":
31
+ return "", "User: ", "Joi: ", "\n\n"
32
+ if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
33
+ return openchat_preprompt, "<human>: ", "<bot>: ", "\n"
34
+ return "", "User: ", "Assistant: ", "\n"
35
+
36
+
37
+ def predict(
38
+ model: str,
39
+ inputs: str,
40
+ typical_p: float,
41
+ top_p: float,
42
+ temperature: float,
43
+ top_k: int,
44
+ repetition_penalty: float,
45
+ watermark: bool,
46
+ chatbot,
47
+ history,
48
+ ):
49
+ client = get_client(model)
50
+ preprompt, user_name, assistant_name, sep = get_usernames(model)
51
+
52
+ history.append(inputs)
53
+
54
+ past = []
55
+ for data in chatbot:
56
+ user_data, model_data = data
57
+
58
+ if not user_data.startswith(user_name):
59
+ user_data = user_name + user_data
60
+ if not model_data.startswith(sep + assistant_name):
61
+ model_data = sep + assistant_name + model_data
62
+
63
+ past.append(user_data + model_data.rstrip() + sep)
64
+
65
+ if not inputs.startswith(user_name):
66
+ inputs = user_name + inputs
67
+
68
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
69
+
70
+ partial_words = ""
71
+
72
+ if model == "OpenAssistant/oasst-sft-1-pythia-12b":
73
+ iterator = client.generate_stream(
74
+ total_inputs,
75
+ typical_p=typical_p,
76
+ truncate=1000,
77
+ watermark=watermark,
78
+ max_new_tokens=500,
79
+ )
80
+ else:
81
+ iterator = client.generate_stream(
82
+ total_inputs,
83
+ top_p=top_p if top_p < 1.0 else None,
84
+ top_k=top_k,
85
+ truncate=1000,
86
+ repetition_penalty=repetition_penalty,
87
+ watermark=watermark,
88
+ temperature=temperature,
89
+ max_new_tokens=500,
90
+ stop_sequences=[user_name.rstrip(), assistant_name.rstrip()],
91
+ )
92
+
93
+ for i, response in enumerate(iterator):
94
+ if response.token.special:
95
+ continue
96
+
97
+ partial_words = partial_words + response.token.text
98
+ if partial_words.endswith(user_name.rstrip()):
99
+ partial_words = partial_words.rstrip(user_name.rstrip())
100
+ if partial_words.endswith(assistant_name.rstrip()):
101
+ partial_words = partial_words.rstrip(assistant_name.rstrip())
102
+
103
+ if i == 0:
104
+ history.append(" " + partial_words)
105
+ elif response.token.text not in user_name:
106
+ history[-1] = partial_words
107
+
108
+ chat = [
109
+ (history[i].strip(), history[i + 1].strip())
110
+ for i in range(0, len(history) - 1, 2)
111
+ ]
112
+ yield chat, history
113
+
114
+
115
+ def reset_textbox():
116
+ return gr.update(value="")
117
+
118
+
119
+ def radio_on_change(
120
+ value: str,
121
+ disclaimer,
122
+ typical_p,
123
+ top_p,
124
+ top_k,
125
+ temperature,
126
+ repetition_penalty,
127
+ watermark,
128
+ ):
129
+ if value == "OpenAssistant/oasst-sft-1-pythia-12b":
130
+ typical_p = typical_p.update(value=0.2, visible=True)
131
+ top_p = top_p.update(visible=False)
132
+ top_k = top_k.update(visible=False)
133
+ temperature = temperature.update(visible=False)
134
+ disclaimer = disclaimer.update(visible=False)
135
+ repetition_penalty = repetition_penalty.update(visible=False)
136
+ watermark = watermark.update(False)
137
+ elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B":
138
+ typical_p = typical_p.update(visible=False)
139
+ top_p = top_p.update(value=0.25, visible=True)
140
+ top_k = top_k.update(value=50, visible=True)
141
+ temperature = temperature.update(value=0.6, visible=True)
142
+ repetition_penalty = repetition_penalty.update(value=1.01, visible=True)
143
+ watermark = watermark.update(False)
144
+ disclaimer = disclaimer.update(visible=True)
145
+ else:
146
+ typical_p = typical_p.update(visible=False)
147
+ top_p = top_p.update(value=0.95, visible=True)
148
+ top_k = top_k.update(value=4, visible=True)
149
+ temperature = temperature.update(value=0.5, visible=True)
150
+ repetition_penalty = repetition_penalty.update(value=1.03, visible=True)
151
+ watermark = watermark.update(True)
152
+ disclaimer = disclaimer.update(visible=False)
153
+ return (
154
+ disclaimer,
155
+ typical_p,
156
+ top_p,
157
+ top_k,
158
+ temperature,
159
+ repetition_penalty,
160
+ watermark,
161
+ )
162
+
163
+
164
+ title = """<h2 align="center">MultiSource ChatBot</h2><h3 align="center"> CloudQI Test Interface </h3>"""
165
+ description = """Os modelos de linguagem podem ser condicionados a agir como agentes de diálogo por meio de um prompt de conversação que normalmente assume a forma:
166
+
167
+ ```
168
+ User: <utterance>
169
+ Assistant: <utterance>
170
+ User: <utterance>
171
+ Assistant: <utterance>
172
+ ...
173
+ ```
174
+ """
175
+
176
+ openchat_disclaimer = """
177
+ <div align="center">Checkout the official <a href=https://huggingface.co/spaces/togethercomputer/OpenChatKit>OpenChatKit feedback app</a> for the full experience.</div>
178
+ """
179
+
180
+ with gr.Blocks(
181
+ css="""#col_container {margin-left: auto; margin-right: auto;}
182
+ #chatbot {height: 520px; overflow: auto;}"""
183
+ ) as demo:
184
+ gr.HTML(title)
185
+ with gr.Column(elem_id="col_container"):
186
+ model = gr.Radio(
187
+ value="OpenAssistant/oasst-sft-1-pythia-12b",
188
+ choices=[
189
+ "OpenAssistant/oasst-sft-1-pythia-12b",
190
+ "togethercomputer/GPT-NeoXT-Chat-Base-20B",
191
+ "Rallio67/joi2_20Be_instruct_alpha",
192
+ "google/flan-t5-xxl",
193
+ "google/flan-ul2",
194
+ "bigscience/bloom",
195
+ "bigscience/bloomz",
196
+ "EleutherAI/gpt-neox-20b",
197
+ ],
198
+ label="Model",
199
+ interactive=True,
200
+ )
201
+
202
+ chatbot = gr.Chatbot(elem_id="chatbot")
203
+ inputs = gr.Textbox(
204
+ placeholder="Olá!", label="Insira seu texto e aperte Enter"
205
+ )
206
+ disclaimer = gr.Markdown(openchat_disclaimer, visible=False)
207
+ state = gr.State([])
208
+ b1 = gr.Button()
209
+
210
+ with gr.Accordion("Parameters", open=False):
211
+ typical_p = gr.Slider(
212
+ minimum=-0,
213
+ maximum=1.0,
214
+ value=0.2,
215
+ step=0.05,
216
+ interactive=True,
217
+ label="Typical P mass",
218
+ )
219
+ top_p = gr.Slider(
220
+ minimum=-0,
221
+ maximum=1.0,
222
+ value=0.25,
223
+ step=0.05,
224
+ interactive=True,
225
+ label="Top-p (nucleus sampling)",
226
+ visible=False,
227
+ )
228
+ temperature = gr.Slider(
229
+ minimum=-0,
230
+ maximum=5.0,
231
+ value=0.6,
232
+ step=0.1,
233
+ interactive=True,
234
+ label="Temperature",
235
+ visible=False,
236
+ )
237
+ top_k = gr.Slider(
238
+ minimum=1,
239
+ maximum=50,
240
+ value=50,
241
+ step=1,
242
+ interactive=True,
243
+ label="Top-k",
244
+ visible=False,
245
+ )
246
+ repetition_penalty = gr.Slider(
247
+ minimum=0.1,
248
+ maximum=3.0,
249
+ value=1.03,
250
+ step=0.01,
251
+ interactive=True,
252
+ label="Repetition Penalty",
253
+ visible=False,
254
+ )
255
+ watermark = gr.Checkbox(value=False, label="Text watermarking")
256
+ hf_token_input = gr.inputs.Textbox(label="HF Token")
257
+ joi_api_url_input = gr.inputs.Textbox(label="JOI API URL")
258
+ openchat_api_url_input = gr.inputs.Textbox(label="OPENCHAT API URL")
259
+
260
+
261
+
262
+ model.change(
263
+ lambda value: radio_on_change(
264
+ value,
265
+ disclaimer,
266
+ typical_p,
267
+ top_p,
268
+ top_k,
269
+ temperature,
270
+ repetition_penalty,
271
+ watermark,
272
+ ),
273
+ inputs=model,
274
+ outputs=[
275
+ disclaimer,
276
+ typical_p,
277
+ top_p,
278
+ top_k,
279
+ temperature,
280
+ repetition_penalty,
281
+ watermark,
282
+ ],
283
+ )
284
+
285
+ inputs.submit(
286
+ predict,
287
+ [
288
+ model,
289
+ inputs,
290
+ typical_p,
291
+ top_p,
292
+ temperature,
293
+ top_k,
294
+ repetition_penalty,
295
+ watermark,
296
+ chatbot,
297
+ state,
298
+ ],
299
+ [chatbot, state],
300
+ )
301
+ b1.click(
302
+ predict,
303
+ [
304
+ model,
305
+ inputs,
306
+ typical_p,
307
+ top_p,
308
+ temperature,
309
+ top_k,
310
+ repetition_penalty,
311
+ watermark,
312
+ chatbot,
313
+ state,
314
+ ],
315
+ [chatbot, state],
316
+ )
317
+ b1.click(reset_textbox, [], [inputs])
318
+ inputs.submit(reset_textbox, [], [inputs])
319
+
320
+ gr.Markdown(description)
321
+ demo.queue(concurrency_count=16).launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ text-generation==0.3.0
2
+ gradio==3.20.1