import os import gradio as gr import torch from transformers import Blip2Processor, Blip2ForConditionalGeneration from PIL import Image # PIL should be imported separately for image handling EXAMPLES_DIR = 'examples' DEFAULT_PROMPT = "" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load the BLIP2 model using the AutoModel with trust_remote_code=True model = Blip2ForConditionalGeneration.from_pretrained('Salesforce/blip2-flan-t5-xl', device_map="auto", torch_dtype=torch.float16) model.to(device) model.eval() # Initialize processor processor = Blip2Processor.from_pretrained('Salesforce/blip2-flan-t5-xl') # Setup some example images examples = [] if os.path.isdir(EXAMPLES_DIR): for file in os.listdir(EXAMPLES_DIR): path = EXAMPLES_DIR + "/" + file examples.append([path, DEFAULT_PROMPT]) def predict_caption(image, prompt): assert isinstance(prompt, str) # Convert the PIL image to the format expected by the processor inputs = processor(images=image, text=prompt, return_tensors="pt").to(device) # Generate the caption generated_ids = model.generate(**inputs, max_length=50) caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return caption iface = gr.Interface( fn=predict_caption, inputs=[gr.Image(type="pil"), gr.Textbox(value=DEFAULT_PROMPT, label="Prompt")], examples=examples, outputs="text" ) iface.launch(debug=True)