Spaces:
Sleeping
Sleeping
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from transformers import AutoImageProcessor, AutoModelForImageClassification | |
import os | |
model_names = [ | |
"0-ma/beit-geometric-shapes-base", | |
"0-ma/vit-geometric-shapes-tiny", | |
"0-ma/vit-geometric-shapes-base", | |
"0-ma/swin-geometric-shapes-tiny", | |
"0-ma/mobilenet-v2-geometric-shapes", | |
"0-ma/focalnet-geometric-shapes-tiny", | |
"0-ma/efficientnet-b2-geometric-shapes", | |
"0-ma/mit-b0-geometric-shapes", | |
"0-ma/resnet-geometric-shapes", | |
] | |
labels = [ | |
'None', | |
'Circle', | |
'Triangle', | |
'Square', | |
'Pentagone', | |
'Hexagone' | |
] | |
example_dir = "./example" | |
example_images = [os.path.join(example_dir,example_image) for example_image in os.listdir(example_dir)] | |
feature_extractors = {model_name: AutoImageProcessor.from_pretrained(model_name) for model_name in model_names} | |
classification_models = {model_name: AutoModelForImageClassification.from_pretrained(model_name) for model_name in model_names} | |
def predict(image, selected_model): | |
if image is None: | |
return None | |
feature_extractor = feature_extractors[selected_model] | |
model = classification_models[selected_model] | |
inputs = feature_extractor(images=[image], return_tensors="pt") | |
logits = model(**inputs)['logits'].cpu().detach().numpy()[0] | |
logits_positive = logits | |
logits_positive[logits < 0] = 0 | |
logits_positive = logits_positive/np.sum(logits_positive) | |
confidences = {} | |
for i in range(len(labels)): | |
if logits[i] > 0: | |
confidences[labels[i]] = float(logits_positive[i]) | |
return confidences | |
title = "Geometric Shape Classifier" | |
description = "Select a model and upload an image to classify geometric shapes." | |
with gr.Blocks() as demo: | |
gr.Markdown(f"# {title}") | |
gr.Markdown(description) | |
model_dropdown = gr.Dropdown(choices=model_names, label="Select Model", value=model_names[0]) | |
image_input = gr.Image(type="pil") | |
gr.Examples( | |
examples=example_images, | |
inputs=image_input, | |
label="Click on an example image to test", | |
) | |
output = gr.Label(label="Classification Result") | |
image_input.change(fn=predict, inputs=[image_input, model_dropdown], outputs=output) | |
model_dropdown.change(fn=predict, inputs=[image_input, model_dropdown], outputs=output) | |
demo.launch() |