Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import clip | |
from PIL import Image | |
print("Getting device...") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Loading model...") | |
model, preprocess = clip.load("ViT-B/32", device=device) | |
print("Loaded model.") | |
def process(image, prompt): | |
print("Inferring...") | |
image = preprocess(image).unsqueeze(0).to(device) | |
print(image) | |
prompts = prompt.split("\n") | |
text = clip.tokenize(prompts).to(device) | |
print(text) | |
with torch.no_grad(): | |
logits_per_image, logits_per_text = model(image, text) | |
probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
print(probs) | |
return dict(zip(prompts, probs[0])) | |
iface = gr.Interface( | |
fn=process, | |
inputs=[ | |
gr.Image(type="pil"), | |
gr.Textbox(lines=5, label="Prompts (newline-separated)"), | |
], | |
outputs="label", | |
) | |
iface.launch() | |