[bug]
Browse files- .gitignore +2 -0
- main.py +12 -5
.gitignore
CHANGED
@@ -4,3 +4,5 @@
|
|
4 |
|
5 |
**/flagged/
|
6 |
**/__pycache__/
|
|
|
|
|
|
4 |
|
5 |
**/flagged/
|
6 |
**/__pycache__/
|
7 |
+
|
8 |
+
trained_models/
|
main.py
CHANGED
@@ -3,6 +3,7 @@
|
|
3 |
import argparse
|
4 |
from collections import defaultdict
|
5 |
import os
|
|
|
6 |
|
7 |
import gradio as gr
|
8 |
from threading import Thread
|
@@ -71,7 +72,6 @@ def main():
|
|
71 |
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
72 |
input_ids = input_ids.to(device)
|
73 |
|
74 |
-
output: str = ""
|
75 |
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
76 |
|
77 |
generation_kwargs = dict(
|
@@ -88,17 +88,24 @@ def main():
|
|
88 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
89 |
thread.start()
|
90 |
|
|
|
|
|
91 |
for output_ in streamer:
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
94 |
output_ = output_.replace("[SEP]", "\n")
|
95 |
output_ = output_.replace("[UNK]", "")
|
96 |
-
output_ = output_.replace(
|
97 |
|
98 |
output += output_.strip()
|
99 |
output_text_box.value += output
|
100 |
yield output
|
101 |
|
|
|
|
|
102 |
demo = gr.Interface(
|
103 |
fn=fn_stream,
|
104 |
inputs=[
|
@@ -107,7 +114,7 @@ def main():
|
|
107 |
gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
|
108 |
gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
|
109 |
gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
|
110 |
-
gr.Dropdown(choices=
|
111 |
gr.Checkbox(value=True, label="is_chat")
|
112 |
],
|
113 |
outputs=[output_text_box],
|
|
|
3 |
import argparse
|
4 |
from collections import defaultdict
|
5 |
import os
|
6 |
+
import platform
|
7 |
|
8 |
import gradio as gr
|
9 |
from threading import Thread
|
|
|
72 |
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
73 |
input_ids = input_ids.to(device)
|
74 |
|
|
|
75 |
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
76 |
|
77 |
generation_kwargs = dict(
|
|
|
88 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
89 |
thread.start()
|
90 |
|
91 |
+
output: str = ""
|
92 |
+
first_answer = True
|
93 |
for output_ in streamer:
|
94 |
+
if first_answer:
|
95 |
+
first_answer = False
|
96 |
+
continue
|
97 |
+
# output_ = output_.replace(text, "")
|
98 |
+
# output_ = output_.replace("[CLS]", "")
|
99 |
output_ = output_.replace("[SEP]", "\n")
|
100 |
output_ = output_.replace("[UNK]", "")
|
101 |
+
output_ = output_.replace(" ", "")
|
102 |
|
103 |
output += output_.strip()
|
104 |
output_text_box.value += output
|
105 |
yield output
|
106 |
|
107 |
+
model_name_choices = ["trained_models/lib_service_4chan"] \
|
108 |
+
if platform.system() == "Windows" else ["qgyd2021/lib_service_4chan"]
|
109 |
demo = gr.Interface(
|
110 |
fn=fn_stream,
|
111 |
inputs=[
|
|
|
114 |
gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"),
|
115 |
gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"),
|
116 |
gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"),
|
117 |
+
gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"),
|
118 |
gr.Checkbox(value=True, label="is_chat")
|
119 |
],
|
120 |
outputs=[output_text_box],
|