Spaces:
Runtime error
Runtime error
import gradio as gr | |
import timm | |
import torch | |
import torch.nn as nn | |
from torchvision import datasets, transforms | |
from PIL import Image | |
from torch.utils.mobile_optimizer import optimize_for_mobile | |
model = timm.create_model('vit_base_patch16_224', pretrained=True) | |
model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=5) | |
path = "opt_model.pt" | |
model = model.jit.load(path) | |
model.eval() | |
def transform_image(img_sample): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to 224x224 | |
transforms.ToTensor(), # Convert PIL image to tensor | |
transforms.ColorJitter(contrast=0.5), # Contrast | |
transforms.RandomAdjustSharpness(sharpness_factor=0.5), | |
transforms.RandomSolarize(threshold=0.75), | |
transforms.RandomAutocontrast(p=1), | |
]) | |
img = Image.open(img_sample) | |
transformed_img = transform(img) | |
return transformed_img | |
def predict(Image): | |
model.eval() | |
tranformed_img = transform_image(Image) | |
img = torch.from_numpy(tranformed_img) | |
with torch.no_grad(): | |
grade = torch.softmax(model(img.float()), dim=1)[0] | |
category = ["None", "Mild", "Moderate", "Severe", "Proliferative"] | |
output_dict = {} | |
for cat, value in zip(category, grade): | |
output_dict[cat] = value.item() | |
return output_dict | |
image = gr.Image(shape=(224, 224), image_mode="RGB") | |
label = gr.Label(label="Grade") | |
demo = gr.Interface( | |
fn=predict, | |
inputs=image, | |
outputs=label, | |
examples=["examples/0.png", "examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png"] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |