Update app.py
Browse files
app.py
CHANGED
@@ -13,28 +13,23 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
13 |
n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
|
14 |
model.to(device)
|
15 |
|
16 |
-
early_stop_pattern = tok.eos_token
|
17 |
-
print(f'Early stop pattern = \"{early_stop_pattern}\"')
|
18 |
-
|
19 |
def generate(text = ""):
|
20 |
streamer = TextIteratorStreamer(tok, timeout=10.)
|
21 |
if len(text) == 0:
|
22 |
text = " "
|
23 |
inputs = tok([text], return_tensors="pt").to(device)
|
24 |
-
generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=128)
|
25 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
26 |
thread.start()
|
27 |
generated_text = ""
|
28 |
for new_text in streamer:
|
29 |
-
yield generated_text + new_text
|
30 |
-
#print(new_text, end ="")
|
31 |
generated_text += new_text
|
32 |
if early_stop_pattern in generated_text:
|
33 |
-
generated_text = generated_text[: generated_text.find(
|
34 |
streamer.end()
|
35 |
-
#print("\n--\n")
|
36 |
yield generated_text
|
37 |
-
|
38 |
|
39 |
demo = gr.Interface(
|
40 |
title="TextIteratorStreamer + Gradio demo",
|
|
|
13 |
n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
|
14 |
model.to(device)
|
15 |
|
|
|
|
|
|
|
16 |
def generate(text = ""):
|
17 |
streamer = TextIteratorStreamer(tok, timeout=10.)
|
18 |
if len(text) == 0:
|
19 |
text = " "
|
20 |
inputs = tok([text], return_tensors="pt").to(device)
|
21 |
+
generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.0, do_sample=True, top_k=40, top_p=0.97, max_new_tokens=128, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4)
|
22 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
23 |
thread.start()
|
24 |
generated_text = ""
|
25 |
for new_text in streamer:
|
26 |
+
yield generated_text + new_text
|
|
|
27 |
generated_text += new_text
|
28 |
if early_stop_pattern in generated_text:
|
29 |
+
generated_text = generated_text[: generated_text.find(tok.eos_token) if tok.eos_token else None]
|
30 |
streamer.end()
|
|
|
31 |
yield generated_text
|
32 |
+
return generated_text
|
33 |
|
34 |
demo = gr.Interface(
|
35 |
title="TextIteratorStreamer + Gradio demo",
|