File size: 3,172 Bytes
b8561a2
 
 
 
 
 
1b92e75
82db71c
b8561a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa3f38a
 
 
b8561a2
 
 
 
 
c20c779
 
b8561a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0bdd7f
b8561a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0bdd7f
b8561a2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import open_clip
import torch
import gradio as gr
from PIL import Image
from datasets import load_dataset
import random
from datasets import load_from_disk
dataset = load_from_disk("./train")
from collections import OrderedDict
FRUITS30_CLASSES = OrderedDict(
    {
           "0" : "acerolas",
           "1" : "apples",
           "2" : "apricots",
           "3" : "avocados",
           "4" : "bananas",
           "5" : "blackberries",
           "6" : "blueberries",
           "7" : "cantaloupes",
           "8" : "cherries",
           "9" : "coconuts",
           "10" : "figs",
           "11" : "grapefruits",
           "12" : "grapes",
           "13" : "guava",
           "14" : "kiwifruit",
           "15" : "lemons",
           "16" : "limes",
           "17" : "mangos",
           "18" : "olives",
           "19" : "oranges",
           "20" : "passionfruit",
           "21" : "peaches",
           "22" : "pears",
           "23" : "pineapples",
           "24" : "plums",
           "25" : "pomegranates",
           "26" : "raspberries",
           "27" : "strawberries",
           "28" : "tomatoes",
           "29" : "watermelons"
    }
)
labels = list(FRUITS30_CLASSES.values())
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
model.eval()  # model in train mode by default, impacts some models with BatchNorm or stochastic depth active
tokenizer = open_clip.get_tokenizer('ViT-B-32')

def create_interface():
    # Store current correct labels in a mutable container
    current_correct_labels = []
    def get_image():
        indices = random.sample(range(len(dataset)), 1)
        selected_images = [dataset[i]['image'] for i in indices]
        return selected_images[0]

    def on_submit(img1,label1):
        image = preprocess(img1).unsqueeze(0)
        text = tokenizer(labels+[label1,"not a fruit"])

        with torch.no_grad(), torch.cuda.amp.autocast():
            image_features = model.encode_image(image)
            text_features = model.encode_text(text)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)

            text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
        correct_label = FRUITS30_CLASSES[str(text_probs.argmax().item())]
        return correct_label

    with gr.Blocks() as demo:
        # Create components
        with gr.Row():
          img1 = gr.Image(type="pil", label="Fruit",height = 256,width = 256)
        with gr.Row():
          label1 = gr.Textbox(label="Name this fruit")

        submit_btn = gr.Button("Submit")
        refresh_btn = gr.Button("Refresh")

        result = gr.Textbox(label="Answer")

        # Update images, labels, and correct labels on refresh button click
        refresh_btn.click(
            fn=get_image, 
            outputs=[img1]
        )
        # Evaluate user input on submit button click
        submit_btn.click(
            fn=on_submit, 
            inputs=[img1,label1], 
            outputs=result
        )

    demo.launch(debug = True)

# Run the game
create_interface()