DR-classifier / app.py
ipd's picture
init
5f0e5c0
raw
history blame
1.67 kB
import gradio as gr
import timm
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from PIL import Image
from torch.utils.mobile_optimizer import optimize_for_mobile
model = timm.create_model('vit_base_patch16_224', pretrained=True)
model.head = torch.nn.Linear(in_features=model.head.in_features, out_features=5)
path = "opt_model.pt"
model = model.jit.load(path)
model.eval()
def transform_image(img_sample):
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize to 224x224
transforms.ToTensor(), # Convert PIL image to tensor
transforms.ColorJitter(contrast=0.5), # Contrast
transforms.RandomAdjustSharpness(sharpness_factor=0.5),
transforms.RandomSolarize(threshold=0.75),
transforms.RandomAutocontrast(p=1),
])
img = Image.open(img_sample)
transformed_img = transform(img)
return transformed_img
def predict(Image):
model.eval()
tranformed_img = transform_image(Image)
img = torch.from_numpy(tranformed_img)
with torch.no_grad():
grade = torch.softmax(model(img.float()), dim=1)[0]
category = ["None", "Mild", "Moderate", "Severe", "Proliferative"]
output_dict = {}
for cat, value in zip(category, grade):
output_dict[cat] = value.item()
return output_dict
image = gr.Image(shape=(224, 224), image_mode="RGB")
label = gr.Label(label="Grade")
demo = gr.Interface(
fn=predict,
inputs=image,
outputs=label,
examples=["examples/0.png", "examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png"]
)
if __name__ == "__main__":
demo.launch(debug=True)