ERA-SESSION19 / app.py
ravi.naik
Added clip gradio web interface
3b41a3f
raw
history blame
No virus
3.56 kB
import gradio as gr
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
sample_images = [f"./sample_images/{i}.jpg" for i in range(5)]
prediction_image = None
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
def upload_file(files):
file_paths = [file.name for file in files]
return file_paths
def read_image(path):
img = Image.open(path)
return img
def set_prediction_image(evt: gr.SelectData, gallery):
global prediction_image
if isinstance(gallery[evt.index], dict):
prediction_image = gallery[evt.index]["name"]
else:
prediction_image = gallery[evt.index][0]["name"]
def predict(text):
text_classes = text.split(",")
text_classes = [sentence.strip() for sentence in text_classes]
image = read_image(prediction_image)
inputs = clip_processor(
text=text_classes,
images=image,
return_tensors="pt",
padding=True,
)
outputs = clip_model(**inputs)
logits_per_image = (
outputs.logits_per_image
) # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1)[0]
results = {text_class: prob.item() for text_class, prob in zip(text_classes, probs)}
return {output: gr.update(value=results)}
with gr.Blocks() as app:
gr.Markdown("## ERA Session19 - Zero Shot Classification with CLIP")
gr.Markdown(
"Please an image or select one of the sample images. Type some classification labels separated by comma. For ex: dog, cat"
)
with gr.Row():
with gr.Column():
with gr.Box():
with gr.Group():
upload_gallery = gr.Gallery(
value=None,
label="Uploaded images",
show_label=False,
elem_id="gallery_upload",
columns=5,
rows=2,
height="auto",
object_fit="contain",
)
upload_button = gr.UploadButton(
"Click to Upload images",
file_types=["image"],
file_count="multiple",
)
upload_button.upload(upload_file, upload_button, upload_gallery)
with gr.Group():
sample_gallery = gr.Gallery(
value=sample_images,
label="Sample images",
show_label=False,
elem_id="gallery_sample",
columns=3,
rows=2,
height="auto",
object_fit="contain",
)
upload_gallery.select(set_prediction_image, inputs=[upload_gallery])
sample_gallery.select(set_prediction_image, inputs=[sample_gallery])
with gr.Box():
input_text = gr.TextArea(
label="Classification Text",
placeholder="Please enter comma separated text",
interactive=True,
)
submit_btn = gr.Button(value="Submit")
with gr.Column():
with gr.Box():
output = gr.Label(value=None, label="Classification Results")
submit_btn.click(predict, inputs=[input_text], outputs=[output])
app.launch()