Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tempfile | |
from transformers import MT5ForConditionalGeneration, MT5Tokenizer,ViltProcessor, ViltForQuestionAnswering, AutoTokenizer | |
import torch | |
from PIL import Image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# English to Persian Translation model | |
fa_en_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1") | |
fa_en_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/Persian-to-English-Translation-mT5-V1").to(device) | |
def run_fa_en_transaltion_model(input_string, **generator_args): | |
input_ids = fa_en_translation_tokenizer.encode(input_string, return_tensors="pt") | |
res = fa_en_translation_model.generate(input_ids, **generator_args) | |
output = fa_en_translation_tokenizer.batch_decode(res, skip_special_tokens=True) | |
return output | |
# Persian to English Translation model | |
en_fa_translation_tokenizer = MT5Tokenizer.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1") | |
en_fa_translation_model = MT5ForConditionalGeneration.from_pretrained("SeyedAli/English-to-Persian-Translation-mT5-V1").to(device) | |
def run_en_fa_transaltion_model(input_string, **generator_args): | |
input_ids = en_fa_translation_tokenizer.encode(input_string, return_tensors="pt") | |
res = en_fa_translation_model.generate(input_ids, **generator_args) | |
output = en_fa_translation_tokenizer.batch_decode(res, skip_special_tokens=True) | |
return output | |
# Visual Question Answering model | |
VQA_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
VQA_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa").to(device) | |
image_input =gr.Image(label="عکس ورودی") | |
text_input = gr.TextArea(label="سوال فارسی",text_align="right",rtl=True,type="text") | |
text_output = gr.TextArea(label="پاسخ",text_align="right",rtl=True,type="text") | |
def VQA(image,text): | |
with tempfile.NamedTemporaryFile(suffix=".png") as temp_image_file: | |
# Copy the contents of the uploaded image file to the temporary file | |
Image.fromarray(image).save(temp_image_file.name) | |
# Load the image file using Pillow | |
image = Image.open(temp_image_file.name) | |
# prepare inputs | |
encoding = VQA_processor(image, run_fa_en_transaltion_model(text), return_tensors="pt") | |
# forward pass | |
outputs = VQA_model(**encoding) | |
logits = outputs.logits | |
idx = logits.argmax(-1).item() | |
return run_en_fa_transaltion_model(VQA_model.config.id2label[idx])[0] | |
iface = gr.Interface(fn=VQA, inputs=[image_input,text_input], outputs=text_output) | |
iface.launch(share=False) |