Kartheekb7's picture
Update app.py
f02ccdd verified
import open_clip
import torch
import gradio as gr
from PIL import Image
from datasets import load_dataset
import random
from datasets import load_from_disk
dataset = load_from_disk("./train")
from collections import OrderedDict
FRUITS30_CLASSES = OrderedDict(
{
"0" : "acerolas",
"1" : "apples",
"2" : "apricots",
"3" : "avocados",
"4" : "bananas",
"5" : "blackberries",
"6" : "blueberries",
"7" : "cantaloupes",
"8" : "cherries",
"9" : "coconuts",
"10" : "figs",
"11" : "grapefruits",
"12" : "grapes",
"13" : "guava",
"14" : "kiwifruit",
"15" : "lemons",
"16" : "limes",
"17" : "mangos",
"18" : "olives",
"19" : "oranges",
"20" : "passionfruit",
"21" : "peaches",
"22" : "pears",
"23" : "pineapples",
"24" : "plums",
"25" : "pomegranates",
"26" : "raspberries",
"27" : "strawberries",
"28" : "tomatoes",
"29" : "watermelons"
}
)
labels = list(FRUITS30_CLASSES.values())
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer('ViT-B-32')
def create_interface():
# Store current correct labels in a mutable container
current_correct_labels = []
def get_image():
indices = random.sample(range(len(dataset)), 1)
selected_images = [dataset[i]['image'] for i in indices]
return selected_images[0]
def on_submit(img1,label1):
image = preprocess(img1).unsqueeze(0)
text = tokenizer(labels+[label1,"not a fruit"])
with torch.no_grad():
image_features = model.encode_image(image)
text_features = model.encode_text(text)
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
labels1 = labels+[label1,"not a fruit"]
correct_label = labels1[(text_probs.argmax().item())]
return correct_label
with gr.Blocks() as demo:
# Create components
with gr.Row():
img1 = gr.Image(type="pil", label="Fruit",height = 300,width = 300)
label1 = gr.Textbox(label="Name this fruit")
submit_btn = gr.Button("Submit")
refresh_btn = gr.Button("Refresh")
result = gr.Textbox(label="Answer")
# Update images, labels, and correct labels on refresh button click
refresh_btn.click(
fn=get_image,
outputs=[img1]
)
# Evaluate user input on submit button click
submit_btn.click(
fn=on_submit,
inputs=[img1,label1],
outputs=result
)
demo.launch(debug = True)
# Run the game
create_interface()