Spaces:
Runtime error
Runtime error
File size: 1,931 Bytes
486f4bf ce0ce67 486f4bf ce0ce67 486f4bf ce0ce67 486f4bf 78cea7d 486f4bf 8d07585 486f4bf e5113d0 78cea7d 486f4bf 8d07585 b390d84 458c57f 78cea7d 5bbc196 458c57f 1651e3c 486f4bf 78cea7d 9fd6fc9 cf8abff 9fd6fc9 486f4bf 232d5c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from transformers import IdeficsForVisionText2Text, AutoProcessor
from peft import PeftModel, PeftConfig
import gradio as gr
peft_model_id = "mrm8488/idefics-9b-ft-describe-diffusion-bf16-adapter"
device = "cuda" if torch.cuda.is_available() else "cpu"
config = PeftConfig.from_pretrained(peft_model_id)
model = IdeficsForVisionText2Text.from_pretrained(config.base_model_name_or_path, torch_dtype=torch.bfloat16)
model = PeftModel.from_pretrained(model, peft_model_id)
processor = AutoProcessor.from_pretrained(config.base_model_name_or_path)
model = model.to(device)
model.eval()
#Pre-determined best prompt for this fine-tune
prompt="Describe the following image:"
#Max generated tokens for your prompt
max_length=64
def predict(image):
prompts = [[image, prompt]]
inputs = processor(prompts[0], return_tensors="pt").to(device)
generated_ids = model.generate(**inputs, max_length=max_length)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
generated_text = generated_text.replace(f"{prompt} ","")
return generated_text
title = "Midjourney-like Image Captioning with IDEFICS"
description = "Gradio Demo for generating *Midjourney* like captions (describe functionality) with **IDEFICS**"
examples = [
["1_sTXgMwDUW0pk-1yK4iHYFw.png"],
["0_6as5rHi0sgG4W2Tq.png"],
["zoomout_2-1440x807.jpg"],
["inZdRVn7eafZNvaVre2iW1a538.webp"],
["cute-photos-of-cats-in-grass-1593184777.jpg"],
["llama2-coder-logo.png"]
]
io = gr.Interface(fn=predict,
inputs=[
gr.Image(label="Upload an image", type="pil"),
],
outputs=[
gr.Textbox(label="IDEFICS Description")
],
title=title, description=description, examples=examples,
allow_flagging=False, allow_screenshot=False)
io.launch(debug=True) |