Spaces:
Running
Running
import open_clip | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from datasets import load_dataset | |
import random | |
from datasets import load_from_disk | |
dataset = load_from_disk("./train") | |
from collections import OrderedDict | |
FRUITS30_CLASSES = OrderedDict( | |
{ | |
"0" : "acerolas", | |
"1" : "apples", | |
"2" : "apricots", | |
"3" : "avocados", | |
"4" : "bananas", | |
"5" : "blackberries", | |
"6" : "blueberries", | |
"7" : "cantaloupes", | |
"8" : "cherries", | |
"9" : "coconuts", | |
"10" : "figs", | |
"11" : "grapefruits", | |
"12" : "grapes", | |
"13" : "guava", | |
"14" : "kiwifruit", | |
"15" : "lemons", | |
"16" : "limes", | |
"17" : "mangos", | |
"18" : "olives", | |
"19" : "oranges", | |
"20" : "passionfruit", | |
"21" : "peaches", | |
"22" : "pears", | |
"23" : "pineapples", | |
"24" : "plums", | |
"25" : "pomegranates", | |
"26" : "raspberries", | |
"27" : "strawberries", | |
"28" : "tomatoes", | |
"29" : "watermelons" | |
} | |
) | |
labels = list(FRUITS30_CLASSES.values()) | |
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') | |
model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active | |
tokenizer = open_clip.get_tokenizer('ViT-B-32') | |
def create_interface(): | |
# Store current correct labels in a mutable container | |
current_correct_labels = [] | |
def get_image(): | |
indices = random.sample(range(len(dataset)), 1) | |
selected_images = [dataset[i]['image'] for i in indices] | |
return selected_images[0] | |
def on_submit(img1,label1): | |
image = preprocess(img1).unsqueeze(0) | |
text = tokenizer(labels+[label1,"not a fruit"]) | |
with torch.no_grad(): | |
image_features = model.encode_image(image) | |
text_features = model.encode_text(text) | |
image_features /= image_features.norm(dim=-1, keepdim=True) | |
text_features /= text_features.norm(dim=-1, keepdim=True) | |
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) | |
labels1 = labels+[label1,"not a fruit"] | |
correct_label = labels1[(text_probs.argmax().item())] | |
return correct_label | |
with gr.Blocks() as demo: | |
# Create components | |
with gr.Row(): | |
img1 = gr.Image(type="pil", label="Fruit",height = 300,width = 300) | |
label1 = gr.Textbox(label="Name this fruit") | |
submit_btn = gr.Button("Submit") | |
refresh_btn = gr.Button("Refresh") | |
result = gr.Textbox(label="Answer") | |
# Update images, labels, and correct labels on refresh button click | |
refresh_btn.click( | |
fn=get_image, | |
outputs=[img1] | |
) | |
# Evaluate user input on submit button click | |
submit_btn.click( | |
fn=on_submit, | |
inputs=[img1,label1], | |
outputs=result | |
) | |
demo.launch(debug = True) | |
# Run the game | |
create_interface() |