import torch import requests from PIL import Image from matplotlib import pyplot as plt import numpy as np import pandas as pd from lavis.common.gradcam import getAttMap from lavis.models import load_model_and_preprocess from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM import gradio as gr def prepare_data(image, question): image = vis_processors["eval"](image).unsqueeze(0).to(device) question = txt_processors["eval"](question) samples = {"image": image, "text_input": [question]} return samples def gradcam_attention(image, question): dst_w = 720 samples = prepare_data(image, question) samples = model.forward_itm(samples=samples) w, h = image.size scaling_factor = dst_w / w resized_img = image.resize((int(w * scaling_factor), int(h * scaling_factor))) norm_img = np.float32(resized_img) / 255 gradcam = samples['gradcams'].reshape(24,24) avg_gradcam = getAttMap(norm_img, gradcam, blur=True) return (avg_gradcam * 255).astype(np.uint8) def generate_cap(image, question, cap_number): samples = prepare_data(image, question) samples = model.forward_itm(samples=samples) samples = model.forward_cap(samples=samples, num_captions=cap_number, num_patches=5) print('Examples of question-guided captions: ') return pd.DataFrame({'Caption': samples['captions'][0][:cap_number]}) def postprocess(text): for i, ans in enumerate(text): for j, w in enumerate(ans): if w == '.' or w == '\n': ans = ans[:j].lower() break return ans def generate_answer(image, question): samples = prepare_data(image, question) samples = model.forward_itm(samples=samples) samples = model.forward_cap(samples=samples, num_captions=5, num_patches=5) samples = model.forward_qa_generation(samples) Img2Prompt = model.prompts_construction(samples) Img2Prompt_input = tokenizer(Img2Prompt, padding='longest', truncation=True, return_tensors="pt").to(device) outputs = llm_model.generate(input_ids=Img2Prompt_input.input_ids, attention_mask=Img2Prompt_input.attention_mask, max_length=20+len(Img2Prompt_input.input_ids[0]), return_dict_in_generate=True, output_scores=True ) pred_answer = tokenizer.batch_decode(outputs.sequences[:, len(Img2Prompt_input.input_ids[0]):]) pred_answer = postprocess(pred_answer) print(pred_answer, type(pred_answer)) return pred_answer # setup device to use device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(device) def load_model(model_selection): model = AutoModelForCausalLM.from_pretrained(model_selection) tokenizer = AutoTokenizer.from_pretrained(model_selection, use_fast=False) return model,tokenizer # Choose LLM to use # weights for OPT-6.7B/OPT-13B/OPT-30B/OPT-66B will download automatically print("Loading Large Language Model (LLM)...") llm_model, tokenizer = load_model('facebook/opt-350m') # ~13G (FP16) llm_model.to(device) model, vis_processors, txt_processors = load_model_and_preprocess(name="img2prompt_vqa", model_type="base", is_eval=True, device=device) # ---- Gradio Layout ----- title = "From Images to Textual Prompts: Zero-shot VQA with Frozen Large Language Models" df_init = pd.DataFrame(columns=['Caption']) raw_image = gr.Image(label='Input image', type="pil") question = gr.Textbox(label="Input question", lines=1, interactive=True) demo = gr.Blocks(title=title) demo.encrypt = False cap_df = gr.DataFrame(value=df_init, label="Caption dataframe", row_count=(0, "dynamic"), max_rows = 20, wrap=True, overflow_row_behaviour='paginate') with demo: with gr.Row(): gr.Markdown(f'''