Spaces:
Sleeping
Sleeping
import gradio as gr | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel | |
sample_images = [f"./sample_images/{i}.jpg" for i in range(5)] | |
prediction_image = None | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
def upload_file(files): | |
file_paths = [file.name for file in files] | |
return file_paths | |
def read_image(path): | |
img = Image.open(path) | |
return img | |
def set_prediction_image(evt: gr.SelectData, gallery): | |
global prediction_image | |
if isinstance(gallery[evt.index], dict): | |
prediction_image = gallery[evt.index]["name"] | |
else: | |
prediction_image = gallery[evt.index][0]["name"] | |
def predict(text): | |
text_classes = text.split(",") | |
text_classes = [sentence.strip() for sentence in text_classes] | |
image = read_image(prediction_image) | |
inputs = clip_processor( | |
text=text_classes, | |
images=image, | |
return_tensors="pt", | |
padding=True, | |
) | |
outputs = clip_model(**inputs) | |
logits_per_image = ( | |
outputs.logits_per_image | |
) # this is the image-text similarity score | |
probs = logits_per_image.softmax(dim=1)[0] | |
results = {text_class: prob.item() for text_class, prob in zip(text_classes, probs)} | |
return {output: gr.update(value=results)} | |
with gr.Blocks() as app: | |
gr.Markdown("## ERA Session19 - Zero Shot Classification with CLIP") | |
gr.Markdown( | |
"Please an image or select one of the sample images. Type some classification labels separated by comma. For ex: dog, cat" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Box(): | |
with gr.Group(): | |
upload_gallery = gr.Gallery( | |
value=None, | |
label="Uploaded images", | |
show_label=False, | |
elem_id="gallery_upload", | |
columns=5, | |
rows=2, | |
height="auto", | |
object_fit="contain", | |
) | |
upload_button = gr.UploadButton( | |
"Click to Upload images", | |
file_types=["image"], | |
file_count="multiple", | |
) | |
upload_button.upload(upload_file, upload_button, upload_gallery) | |
with gr.Group(): | |
sample_gallery = gr.Gallery( | |
value=sample_images, | |
label="Sample images", | |
show_label=False, | |
elem_id="gallery_sample", | |
columns=3, | |
rows=2, | |
height="auto", | |
object_fit="contain", | |
) | |
upload_gallery.select(set_prediction_image, inputs=[upload_gallery]) | |
sample_gallery.select(set_prediction_image, inputs=[sample_gallery]) | |
with gr.Box(): | |
input_text = gr.TextArea( | |
label="Classification Text", | |
placeholder="Please enter comma separated text", | |
interactive=True, | |
) | |
submit_btn = gr.Button(value="Submit") | |
with gr.Column(): | |
with gr.Box(): | |
output = gr.Label(value=None, label="Classification Results") | |
submit_btn.click(predict, inputs=[input_text], outputs=[output]) | |
app.launch() | |