import open_clip import torch import gradio as gr from PIL import Image from datasets import load_dataset import random from datasets import load_from_disk dataset = load_from_disk("./train") from collections import OrderedDict FRUITS30_CLASSES = OrderedDict( { "0" : "acerolas", "1" : "apples", "2" : "apricots", "3" : "avocados", "4" : "bananas", "5" : "blackberries", "6" : "blueberries", "7" : "cantaloupes", "8" : "cherries", "9" : "coconuts", "10" : "figs", "11" : "grapefruits", "12" : "grapes", "13" : "guava", "14" : "kiwifruit", "15" : "lemons", "16" : "limes", "17" : "mangos", "18" : "olives", "19" : "oranges", "20" : "passionfruit", "21" : "peaches", "22" : "pears", "23" : "pineapples", "24" : "plums", "25" : "pomegranates", "26" : "raspberries", "27" : "strawberries", "28" : "tomatoes", "29" : "watermelons" } ) labels = list(FRUITS30_CLASSES.values()) model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k') model.eval() # model in train mode by default, impacts some models with BatchNorm or stochastic depth active tokenizer = open_clip.get_tokenizer('ViT-B-32') def create_interface(): # Store current correct labels in a mutable container current_correct_labels = [] def get_image(): indices = random.sample(range(len(dataset)), 1) selected_images = [dataset[i]['image'] for i in indices] return selected_images[0] def on_submit(img1,label1): image = preprocess(img1).unsqueeze(0) text = tokenizer(labels+[label1,"not a fruit"]) with torch.no_grad(), torch.cuda.amp.autocast(): image_features = model.encode_image(image) text_features = model.encode_text(text) image_features /= image_features.norm(dim=-1, keepdim=True) text_features /= text_features.norm(dim=-1, keepdim=True) text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) correct_label = FRUITS30_CLASSES[str(text_probs.argmax().item())] return correct_label with gr.Blocks() as demo: # Create components with gr.Row(): img1 = gr.Image(type="pil", label="Fruit",height = 256,width = 256) with gr.Row(): label1 = gr.Textbox(label="Name this fruit") submit_btn = gr.Button("Submit") refresh_btn = gr.Button("Refresh") result = gr.Textbox(label="Answer") # Update images, labels, and correct labels on refresh button click refresh_btn.click( fn=get_image, outputs=[img1] ) # Evaluate user input on submit button click submit_btn.click( fn=on_submit, inputs=[img1,label1], outputs=result ) demo.launch(debug = True) # Run the game create_interface()