Spaces:
Running
Running
File size: 3,172 Bytes
b8561a2 1b92e75 82db71c b8561a2 aa3f38a b8561a2 c20c779 b8561a2 e0bdd7f b8561a2 e0bdd7f b8561a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import open_clip
import torch
import gradio as gr
from PIL import Image
from datasets import load_dataset
import random
from datasets import load_from_disk
dataset = load_from_disk("./train")
from collections import OrderedDict
FRUITS30_CLASSES = OrderedDict(
{
"0" : "acerolas",
"1" : "apples",
"2" : "apricots",
"3" : "avocados",
"4" : "bananas",
"5" : "blackberries",
"6" : "blueberries",
"7" : "cantaloupes",
"8" : "cherries",
"9" : "coconuts",
"10" : "figs",
"11" : "grapefruits",
"12" : "grapes",
"13" : "guava",
"14" : "kiwifruit",
"15" : "lemons",
"16" : "limes",
"17" : "mangos",
"18" : "olives",
"19" : "oranges",
"20" : "passionfruit",
"21" : "peaches",
"22" : "pears",
"23" : "pineapples",
"24" : "plums",
"25" : "pomegranates",
"26" : "raspberries",
"27" : "strawberries",
"28" : "tomatoes",
"29" : "watermelons"
}
)
labels = list(FRUITS30_CLASSES.values())
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer('ViT-B-32')
def create_interface():
# Store current correct labels in a mutable container
current_correct_labels = []
def get_image():
indices = random.sample(range(len(dataset)), 1)
selected_images = [dataset[i]['image'] for i in indices]
return selected_images[0]
def on_submit(img1,label1):
image = preprocess(img1).unsqueeze(0)
text = tokenizer(labels+[label1,"not a fruit"])
with torch.no_grad(), torch.cuda.amp.autocast():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
correct_label = FRUITS30_CLASSES[str(text_probs.argmax().item())]
return correct_label
with gr.Blocks() as demo:
# Create components
with gr.Row():
img1 = gr.Image(type="pil", label="Fruit",height = 256,width = 256)
with gr.Row():
label1 = gr.Textbox(label="Name this fruit")
submit_btn = gr.Button("Submit")
refresh_btn = gr.Button("Refresh")
result = gr.Textbox(label="Answer")
# Update images, labels, and correct labels on refresh button click
refresh_btn.click(
fn=get_image,
outputs=[img1]
)
# Evaluate user input on submit button click
submit_btn.click(
fn=on_submit,
inputs=[img1,label1],
outputs=result
)
demo.launch(debug = True)
# Run the game
create_interface() |