File size: 3,561 Bytes
3b41a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import gradio as gr
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

sample_images = [f"./sample_images/{i}.jpg" for i in range(5)]
prediction_image = None

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")


def upload_file(files):
    file_paths = [file.name for file in files]
    return file_paths


def read_image(path):
    img = Image.open(path)
    return img


def set_prediction_image(evt: gr.SelectData, gallery):
    global prediction_image
    if isinstance(gallery[evt.index], dict):
        prediction_image = gallery[evt.index]["name"]
    else:
        prediction_image = gallery[evt.index][0]["name"]


def predict(text):
    text_classes = text.split(",")
    text_classes = [sentence.strip() for sentence in text_classes]

    image = read_image(prediction_image)

    inputs = clip_processor(
        text=text_classes,
        images=image,
        return_tensors="pt",
        padding=True,
    )
    outputs = clip_model(**inputs)
    logits_per_image = (
        outputs.logits_per_image
    )  # this is the image-text similarity score
    probs = logits_per_image.softmax(dim=1)[0]
    results = {text_class: prob.item() for text_class, prob in zip(text_classes, probs)}
    return {output: gr.update(value=results)}


with gr.Blocks() as app:
    gr.Markdown("## ERA Session19 - Zero Shot Classification with CLIP")
    gr.Markdown(
        "Please an image or select one of the sample images. Type some classification labels separated by comma. For ex: dog, cat"
    )
    with gr.Row():
        with gr.Column():
            with gr.Box():
                with gr.Group():
                    upload_gallery = gr.Gallery(
                        value=None,
                        label="Uploaded images",
                        show_label=False,
                        elem_id="gallery_upload",
                        columns=5,
                        rows=2,
                        height="auto",
                        object_fit="contain",
                    )
                    upload_button = gr.UploadButton(
                        "Click to Upload images",
                        file_types=["image"],
                        file_count="multiple",
                    )
                    upload_button.upload(upload_file, upload_button, upload_gallery)

                with gr.Group():
                    sample_gallery = gr.Gallery(
                        value=sample_images,
                        label="Sample images",
                        show_label=False,
                        elem_id="gallery_sample",
                        columns=3,
                        rows=2,
                        height="auto",
                        object_fit="contain",
                    )

                upload_gallery.select(set_prediction_image, inputs=[upload_gallery])
                sample_gallery.select(set_prediction_image, inputs=[sample_gallery])
            with gr.Box():
                input_text = gr.TextArea(
                    label="Classification Text",
                    placeholder="Please enter comma separated text",
                    interactive=True,
                )

            submit_btn = gr.Button(value="Submit")
        with gr.Column():
            with gr.Box():
                output = gr.Label(value=None, label="Classification Results")

            submit_btn.click(predict, inputs=[input_text], outputs=[output])


app.launch()