import gradio as gr import numpy as np from PIL import Image from transformers import DeiTFeatureExtractor, DeiTForImageClassification from hugsvision.inference.VisionClassifierInference import VisionClassifierInference from hugsvision.inference.TorchVisionClassifierInference import TorchVisionClassifierInference models_name = [ "VGG16", "DeiT", "ShuffleNetV2", "MobileNetV2", "DenseNet121", ] radio = gr.inputs.Radio(models_name, default="DenseNet121", type="value") def predict_image(image, model_name): image = Image.fromarray(np.uint8(image)).convert('RGB') model_path = "./models/" + model_name if model_name == "DeiT": model = VisionClassifierInference( feature_extractor = DeiTFeatureExtractor.from_pretrained(model_path), model = DeiTForImageClassification.from_pretrained(model_path), ) else: model = TorchVisionClassifierInference( model_path = model_path ) pred = model.predict_image(img=image, return_str=False) for key in pred.keys(): pred[key] = pred[key]/100 return pred id2label = ["akiec", "bcc", "bkl", "df", "mel", "nv", "vasc"] samples = [["images/" + p + ".jpg"] for p in id2label] print(samples) image = gr.inputs.Image(shape=(224, 224), label="Upload Your Image Here") label = gr.outputs.Label(num_top_classes=len(id2label)) interface = gr.Interface( fn=predict_image, inputs=[image,radio], outputs=label, capture_session=True, allow_flagging=False, thumbnail="ressources/thumbnail.png", article="""
Model | Accuracy | Size |
---|---|---|
VGG16 | 38.27% | 512.0 MB |
DeiT | 71.60% | 327.0 MB |
DenseNet121 | 77.78% | 27.1 MB |
MobileNetV2 | 75.31% | 8.77 MB |
ShuffleNetV2 | 76.54% | 4.99 MB |