|
|
|
|
|
import gradio as gr |
|
import torch |
|
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from threading import Thread |
|
import os |
|
|
|
model_id = './model' |
|
|
|
CUDA_AVAILABLE = torch.cuda.is_available() |
|
device = torch.device("cuda" if CUDA_AVAILABLE else "cpu") |
|
|
|
generator = pipeline('text-generation', model=model_id, |
|
tokenizer=model_id, |
|
load_in_8bit=True, |
|
device=device) |
|
|
|
early_stop_pattern = "\n\n\n" |
|
print(f'Early stop pattern = \"{early_stop_pattern}\"') |
|
|
|
model = generator.model |
|
tok = generator.tokenizer |
|
|
|
stop_token = tok.eos_token |
|
print(f'stop_token = \"{stop_token}\"') |
|
|
|
def generate(text = ""): |
|
print("Create streamer") |
|
yield "[ืื ื ืืืชืื ื ืืชืฉืืื]" |
|
streamer = TextIteratorStreamer(tok, timeout=5.) |
|
if len(text) == 0: |
|
text = "\n" |
|
|
|
inputs = tok([text], return_tensors="pt").to(device) |
|
generation_kwargs = dict(inputs, streamer=streamer, repetition_penalty=2.5, do_sample=True, top_k=40, top_p=0.2, temperature=0.4, num_beams = 1 ,max_new_tokens=128, pad_token_id = model.config.eos_token_id, early_stopping=True, no_repeat_ngram_size=4) |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
generated_text = "" |
|
for new_text in streamer: |
|
yield generated_text + new_text |
|
print(new_text, end ="") |
|
generated_text += new_text |
|
if (early_stop_pattern in generated_text) or (stop_token in new_text): |
|
generated_text = generated_text[: generated_text.find(early_stop_pattern) if early_stop_pattern else None] |
|
generated_text = generated_text[: generated_text.find(stop_token) if stop_token else None] |
|
streamer.end() |
|
print("\n--\n") |
|
yield generated_text |
|
return generated_text |
|
|
|
return generated_text |
|
|
|
demo = gr.Interface( |
|
title="Hebrew text generator: Science Fiction and Fantasy (GPT-Neo)", |
|
fn=generate, |
|
inputs=gr.Textbox(label="ืืชืื ืืื ืืช ืืืงืกื ืฉืืื ืื ืืฉืืืจื ืจืืง", text_align = 'right', rtl = True, elem_id="input_text"), |
|
outputs=gr.Textbox(type="text", label="ืคื ืืืคืืข ืืืงืกื ืฉืืืืืื ืืืืื", text_align = 'right', rtl = True, elem_id="output_text"), |
|
css="#output_text{direction: rtl} #input_text{direction: rtl}", |
|
examples = ['ืืฉื ืืืคืืข ืืื','ืงืืื ืฉืืคื ืืช','ืคืขื ืืืช ืืคื ื ืฉื ืื ืจืืืช', 'ืืืจื ืคืืืจ ืืืื ืืืื ื ืืื', 'ืืื ืืคืจืชื ืืช ืื ืืืื ืืืงืก ืืฉ'], |
|
allow_flagging="never", |
|
cache_examples=False |
|
) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|