import gradio as gr import re import torch from transformers.utils import logging from transformers import ViltProcessor, ViltForQuestionAnswering, T5Tokenizer, T5ForConditionalGeneration import httpcore setattr(httpcore, 'SyncHTTPTransport', 'AsyncHTTPProxy') # set SyncHTTPTransport attribute for googletrans dependency from googletrans import Translator from googletrans import LANGCODES # List of acceptable languages acceptable_languages = set(L.split()[0] for L in LANGCODES) acceptable_languages.add("mandarin") acceptable_languages.add("cantonese") logging.set_verbosity_info() logger = logging.get_logger("transformers") # Translation def google_translate(question, dest): translator = Translator() translation = translator.translate(question, dest=dest) logger.info("Translation text: " + translation.text) logger.info("Translation src: " + translation.src) return (translation.text, translation.src) # Lang to lang_code mapping def lang_code_match(accaptable_lang): # Exception for chinese langs if accaptable_lang == 'mandarin': return 'zh-cn' elif accaptable_lang == 'cantonese' or accaptable_lang == 'chinese': return 'zh-tw' # Default else: return LANGCODES[accaptable_lang] # Find destination language def find_dest_language(sentence, src_lang): pattern = r'\b(' + '|'.join(acceptable_languages) + r')\b' match = re.search(pattern, sentence, flags=re.IGNORECASE) if match: lang_code = lang_code_match(match.group(0).lower()) logger.info("Destination lang: " + lang_code) return lang_code else: logger.info("Destination lang:" + src_lang) return src_lang # Remove destination language context def remove_language_phrase(sentence): # Bremove "in [acceptable_languages]" or "[acceptable_languages]" and any non-closing punctuation around it pattern = r'(\b(in\s)?(' + '|'.join(acceptable_languages) + r')\b)[\s,;:.!?]*' cleaned_sentence = re.sub(pattern, '', sentence, flags=re.IGNORECASE).strip() logger.info("Language Phrase Removed: " + cleaned_sentence) return cleaned_sentence # Load Vilt vilt_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") def vilt_vqa(image, question): inputs = vilt_processor(image, question, return_tensors="pt") with torch.no_grad(): outputs = vilt_model(**inputs) logits = outputs.logits idx = logits.argmax(-1).item() answer = vilt_model.config.id2label[idx] logger.info("ViLT: " + answer) # Get the top 10 scores and their indices topk_values, topk_indices = torch.topk(logits, 10, dim=-1) topk_answers = [vilt_model.config.id2label[idx.item()] for idx in topk_indices[0]] logger.info("ViLT top 10 answers: " + str(topk_answers)) return answer # Load FLAN-T5 t5_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large") t5_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto") def flan_t5_complete_sentence(question, answer): # #input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information." # input_text = f"What language is this question asking about: {question}" # logger.info("T5 input: " + input_text) # inputs = t5_tokenizer(input_text, return_tensors="pt") # outputs = t5_model.generate(**inputs, max_length=50) # result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) # logger.info("T5 output1: " + result_sentence) # input_text = f"Translate to {str(result_sentence)}: {answer}" # logger.info("T5 input: " + input_text) # inputs = t5_tokenizer(input_text, return_tensors="pt") # outputs = t5_model.generate(**inputs, max_length=50) # result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) # logger.info("T5 output2: " + result_sentence) input_text = f"A question: {question} An answer: {answer}. Based on these, answer the question with a complete sentence without extra information." logger.info("T5 input: " + input_text) inputs = t5_tokenizer(input_text, return_tensors="pt") outputs = t5_model.generate(**inputs, max_length=50) result_sentence = ''.join(t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) logger.info("T5 output: " + result_sentence) return result_sentence # Main function def vqa_main(image, question): en_question, question_src_lang = google_translate(question, dest='en') dest_lang = find_dest_language(en_question, question_src_lang) cleaned_question = remove_language_phrase(en_question) vqa_answer = vilt_vqa(image, cleaned_question) llm_answer = flan_t5_complete_sentence(cleaned_question, vqa_answer) final_answer, answer_src_lang = google_translate(llm_answer, dest=dest_lang) logger.info("Final Answer: " + final_answer) return final_answer # Home page text title = "Interactive demo: Cross-Lingual VQA" description = """ Upload an image, type a question, click 'submit', or click one of the examples to load them. Note: This web demo is running on a CPU thus, may take a few minutes for completing output at times. For better performance, please consider migrating to your own space and upgrading to a GPU runtime. """ article = """ Supported 107 Languages: Afrikaans, Albanian, Amharic, Arabic, Armenian, Azerbaijani, Basque, Belarusian, Bengali, Bosnian, Bulgarian, Catalan, Cebuano, Chichewa, Chinese (Simplified), Chinese (Traditional), Corsican, Croatian, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino, Finnish, French, Frisian, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hmong, Hungarian, Icelandic, Igbo, Indonesian, Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish (Kurmanji), Kyrgyz, Lao, Latin, Latvian, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Mongolian, Myanmar (Burmese), Nepali, Norwegian, Odia, Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Samoan, Scots Gaelic, Serbian, Sesotho, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Spanish, Sundanese, Swahili, Swedish, Tajik, Tamil, Telugu, Thai, Turkish, Ukrainian, Urdu, Uyghur, Uzbek, Vietnamese, Welsh, Xhosa, Yiddish, Yoruba, Zulu """ # Load example images torch.hub.download_url_to_file('http://farm3.staticflickr.com/2710/4520550856_7a9f9ea59d_z.jpg', 'apple.jpg') torch.hub.download_url_to_file('http://images.cocodataset.org/val2017/000000039769.jpg', 'cats.jpg') torch.hub.download_url_to_file('https://media.istockphoto.com/id/1174602891/photo/two-monkeys-mom-and-cub-eat-bananas.jpg?s=612x612&w=0&k=20&c=r7VXi9d1wHhyq3iAk9D2Z3yTZiOJMlLNtjdVRBEjG7g=', 'monkey.jpg') # Define home page variables image = gr.Image(type="pil") question = gr.Textbox(label="Question") answer = gr.Textbox(label="Predicted answer") examples = [ ["apple.jpg", "Qu'est-ce que j'ai dans la main en anglais?"], ["cats.jpg", "How many cats are here?"], ["monkey.jpg", "In Korean, what are these animals called?"], ["apple.jpg", "What color is this? Answer in Uyghur."], ["cats.jpg", "What are the cats doing in German?"], ["monkey.jpg", "Maymunlar ne yiyor, Çince cevap ver."] ] demo = gr.Interface(fn=vqa_main, inputs=[image, question], outputs="text", examples=examples, title=title, description=description, article=article) demo.launch(debug=True, show_error = True)