Spaces:
Runtime error
Runtime error
import random | |
import re | |
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM | |
from transformers import AutoTokenizer | |
from transformers import AutoModelForSeq2SeqLM | |
from transformers import AutoProcessor | |
from transformers import pipeline | |
from transformers import set_seed | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco") | |
big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco") | |
text_pipe = pipeline('text-generation', model='succinctly/text2image-prompt-generator') | |
zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval() | |
zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en') | |
en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval() | |
en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh") | |
def translate_zh2en(text): | |
with torch.no_grad(): | |
text = re.sub(r'([^\u4e00-\u9fa5])([\u4e00-\u9fa5])', r'\1\n\2', text) | |
text = re.sub(r'([\u4e00-\u9fa5])([^\u4e00-\u9fa5])', r'\1\n\2', text) | |
text = text.replace('\n', ',') | |
text =re.sub(r'(?<![a-zA-Z])\s+|\s+(?![a-zA-Z])', '', text) | |
text = re.sub(r',+', ',', text) | |
encoded = zh2en_tokenizer([text], return_tensors='pt') | |
sequences = zh2en_model.generate(**encoded) | |
result = zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] | |
result = result.strip() | |
if result != "No,no," : | |
result = text | |
return result | |
def translate_en2zh(text): | |
with torch.no_grad(): | |
encoded = en2zh_tokenizer([text], return_tensors="pt") | |
sequences = en2zh_model.generate(**encoded) | |
return en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0] | |
def test05(text): | |
return text | |
def test06(text): | |
return text | |
def text_generate(text): | |
seed = random.randint(100, 1000000) | |
set_seed(seed) | |
text_in_english = translate_zh2en(text) | |
result = "" | |
for _ in range(6): | |
sequences = text_pipe(text_in_english, max_length=random.randint(60, 90), num_return_sequences=8) | |
list = [] | |
for sequence in sequences: | |
line = sequence['generated_text'].strip() | |
if line != text_in_english and len(line) > (len(text_in_english) + 4): | |
list.append(translate_en2zh(line)+"\n") | |
list.append(line+"\n") | |
list.append("\n") | |
result = "".join(list) | |
result = re.sub('[^ ]+\.[^ ]+', '', result) | |
result = result.replace('<', '').replace('>', '') | |
if result != '': | |
break | |
return result | |
def load_prompter(): | |
prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist") | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "left" | |
return prompter_model, tokenizer | |
prompter_model, prompter_tokenizer = load_prompter() | |
def generate_prompter(text): | |
text = translate_zh2en(text) | |
input_ids = prompter_tokenizer(text.strip()+" Rephrase:", return_tensors="pt").input_ids | |
eos_id = prompter_tokenizer.eos_token_id | |
outputs = prompter_model.generate( | |
input_ids, | |
do_sample=False, | |
max_new_tokens=75, | |
num_beams=3, | |
num_return_sequences=3, | |
eos_token_id=eos_id, | |
pad_token_id=eos_id, | |
length_penalty=-1.0 | |
) | |
output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
result = [] | |
for output_text in output_texts: | |
output_text = output_text.replace('<', '').replace('>', '') | |
output_text = output_text.split("Rephrase:", 1)[-1].strip() | |
result.append(translate_en2zh(output_text)+"\n") | |
result.append(output_text+"\n") | |
result.append("\n") | |
return "".join(result) | |
def combine_text(text): | |
text01 = generate_prompter(text) | |
text02 = text_generate(text) | |
return text01,text02 | |
def get_prompt_from_image(input_image): | |
image = input_image.convert('RGB') | |
pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values | |
generated_ids = big_model.to(device).generate(pixel_values=pixel_values, max_length=50) | |
generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
result01 = generate_prompter(generated_caption) | |
result02 = text_generate(generated_caption) | |
return result01,result02 | |
with gr.Blocks() as block: | |
with gr.Column(): | |
with gr.Tab('工作區'): | |
with gr.Row(): | |
input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...') | |
input_image = gr.Image(type='pil') | |
with gr.Row(): | |
txt_prompter_btn = gr.Button('文生文') | |
pic_prompter_btn = gr.Button('圖生文') | |
with gr.Row(): | |
Textbox_1 = gr.Textbox(lines=6, label='生成方式A') | |
with gr.Row(): | |
Textbox_2 = gr.Textbox(lines=6, label='生成方式B') | |
with gr.Tab('測試區'): | |
with gr.Row(): | |
input_test01 = gr.Textbox(lines=2, label='中英翻譯', placeholder='在此输入文字...') | |
test01_btn = gr.Button('執行') | |
Textbox_test01 = gr.Textbox(lines=2, label='輸出結果') | |
with gr.Row(): | |
input_test02 = gr.Textbox(lines=2, label='英中翻譯', placeholder='在此输入文字...') | |
test02_btn = gr.Button('執行') | |
Textbox_test02 = gr.Textbox(lines=2, label='輸出結果') | |
with gr.Row(): | |
input_test03 = gr.Textbox(lines=2, label='SD模式', placeholder='在此输入文字...') | |
test03_btn = gr.Button('執行') | |
Textbox_test03 = gr.Textbox(lines=2, label='輸出結果') | |
with gr.Row(): | |
input_test04 = gr.Textbox(lines=2, label='瞎掰模式', placeholder='在此输入文字...') | |
test04_btn = gr.Button('執行') | |
Textbox_test04 = gr.Textbox(lines=2, label='輸出結果') | |
with gr.Row(): | |
input_test05 = gr.Textbox(lines=2, label='沒作用', placeholder='在此输入文字...') | |
test05_btn = gr.Button('執行') | |
Textbox_test05 = gr.Textbox(lines=2, label='輸出結果') | |
with gr.Row(): | |
input_test06 = gr.Textbox(lines=2, label='沒作用', placeholder='在此输入文字...') | |
test06_btn = gr.Button('執行') | |
Textbox_test06 = gr.Textbox(lines=2, label='輸出結果') | |
txt_prompter_btn.click( | |
fn=combine_text, | |
inputs=input_text, | |
outputs=[Textbox_1,Textbox_2] | |
) | |
pic_prompter_btn.click( | |
fn=get_prompt_from_image, | |
inputs=input_image, | |
outputs=[Textbox_1,Textbox_2] | |
) | |
test01_btn.click( | |
fn=translate_zh2en, | |
inputs=input_test01, | |
outputs=Textbox_test01 | |
) | |
test02_btn.click( | |
fn=translate_en2zh, | |
inputs=input_test02, | |
outputs=Textbox_test02 | |
) | |
test03_btn.click( | |
fn=generate_prompter, | |
inputs=input_test03, | |
outputs=Textbox_test03 | |
) | |
test04_btn.click( | |
fn=text_generate, | |
inputs=input_test04, | |
outputs=Textbox_test04 | |
) | |
test05_btn.click( | |
fn=test05, | |
inputs=input_test05, | |
outputs=Textbox_test05 | |
) | |
test06_btn.click( | |
fn=test06, | |
inputs=input_test06, | |
outputs=Textbox_test06 | |
) | |
block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0') | |