Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
from transformers import DeiTFeatureExtractor, DeiTForImageClassification | |
from hugsvision.inference.VisionClassifierInference import VisionClassifierInference | |
from hugsvision.inference.TorchVisionClassifierInference import TorchVisionClassifierInference | |
models_name = [ | |
"VGG16", | |
"DeiT", | |
"ShuffleNetV2", | |
"MobileNetV2", | |
"DenseNet121", | |
] | |
radio = gr.inputs.Radio(models_name, default="DenseNet121", type="value") | |
def predict_image(image, model_name): | |
image = Image.fromarray(np.uint8(image)).convert('RGB') | |
model_path = "./models/" + model_name | |
if model_name == "DeiT": | |
model = VisionClassifierInference( | |
feature_extractor = DeiTFeatureExtractor.from_pretrained(model_path), | |
model = DeiTForImageClassification.from_pretrained(model_path), | |
) | |
else: | |
model = TorchVisionClassifierInference( | |
model_path = model_path | |
) | |
pred = model.predict_image(img=image, return_str=False) | |
for key in pred.keys(): | |
pred[key] = pred[key]/100 | |
return pred | |
id2label = ["akiec", "bcc", "bkl", "df", "mel", "nv", "vasc"] | |
samples = [["images/" + p + ".jpg"] for p in id2label] | |
print(samples) | |
image = gr.inputs.Image(shape=(224, 224), label="Upload Your Image Here") | |
label = gr.outputs.Label(num_top_classes=len(id2label)) | |
interface = gr.Interface( | |
fn=predict_image, | |
inputs=[image,radio], | |
outputs=label, | |
capture_session=True, | |
allow_flagging=False, | |
thumbnail="ressources/thumbnail.png", | |
article=""" | |
<html style="color: white;"> | |
<style type="text/css"> | |
.tg {border-collapse:collapse;border-spacing:0;} | |
.tg td{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px; | |
overflow:hidden;padding:10px 5px;word-break:normal;} | |
.tg th{border-color:black;border-style:solid;border-width:1px;font-family:Arial, sans-serif;font-size:14px; | |
font-weight:normal;overflow:hidden;padding:10px 5px;word-break:normal;} | |
.tg .tg-v0zy{background-color:#efefef;color:#000000;font-weight:bold;text-align:center;vertical-align:top} | |
.tg .tg-4jb6{background-color:#ffffff;color:#333333;text-align:center;vertical-align:top} | |
</style> | |
<table class="tg"> | |
<thead> | |
<tr> | |
<th class="tg-v0zy">Model</th> | |
<th class="tg-v0zy">Accuracy</th> | |
<th class="tg-v0zy">Size</th> | |
</tr> | |
</thead> | |
<tbody> | |
<tr> | |
<td class="tg-4jb6">VGG16</td> | |
<td class="tg-4jb6">38.27%</td> | |
<td class="tg-4jb6">512.0 MB</td> | |
</tr> | |
<tr> | |
<td class="tg-4jb6">DeiT</td> | |
<td class="tg-4jb6">71.60%</td> | |
<td class="tg-4jb6">327.0 MB</td> | |
</tr> | |
<tr> | |
<td class="tg-4jb6">DenseNet121</td> | |
<td class="tg-4jb6">77.78%</td> | |
<td class="tg-4jb6">27.1 MB</td> | |
</tr> | |
<tr> | |
<td class="tg-4jb6">MobileNetV2</td> | |
<td class="tg-4jb6">75.31%</td> | |
<td class="tg-4jb6">8.77 MB</td> | |
</tr> | |
<tr> | |
<td class="tg-4jb6">ShuffleNetV2</td> | |
<td class="tg-4jb6">76.54%</td> | |
<td class="tg-4jb6">4.99 MB</td> | |
</tr> | |
</tbody> | |
</table> | |
</html> | |
""", | |
theme="darkhuggingface", | |
title="HAM10000: Training and using a TorchVision Image Classifier in 5 min to identify skin cancer", | |
description="A fast and easy tutorial to train a TorchVision Image Classifier that can help dermatologist in their identification procedures Melanoma cases with HugsVision and HAM10000 dataset.", | |
allow_screenshot=True, | |
show_tips=False, | |
encrypt=True, | |
examples=samples, | |
) | |
interface.launch() |