import gradio as gr import timm import torch import torch.nn as nn from torchvision import datasets, transforms from PIL import Image from torch.utils.mobile_optimizer import optimize_for_mobile model = timm.create_model('resnet50', pretrained=True) model.fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=5) path = "epoch_4_Resnet50-0.5contrast.pth" model.load_state_dict(torch.load(path)) model.eval() def transform_image(img_sample): transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize to 224x224 transforms.ToTensor(), # Convert PIL image to tensor transforms.ColorJitter(contrast=0.5), # Contrast #transforms.RandomAdjustSharpness(sharpness_factor=0.5), #transforms.RandomSolarize(threshold=0.75), #transforms.RandomAutocontrast(p=1), ]) transformed_img = transform(img_sample) return transformed_img def predict(Image): tranformed_img = transform_image(Image) model.eval() img = transform_image(Image) img = img.reshape(1,3,224,224) #img = torch.from_numpy(tranformed_img) #outputs = model(img) #class_out = outputs.argmax(dim=1) with torch.no_grad(): grade = torch.softmax(model(img.float()), dim=1)[0] category = ["0 - Normal", "1 - Mild", "2 - Moderate", "3 - Severe", "4 - Proliferative"] output_dict = {} for cat, value in zip(category, grade): output_dict[cat] = value.item() return output_dict image = gr.Image(type="pil")#shape=(224, 224), image_mode="RGB") label = gr.Label(label="Level") demo = gr.Interface( fn=predict, inputs=image, outputs=label, #examples=["examples/0.png", "examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png"] examples=["0.png", "2.png", "4.png"] ) if __name__ == "__main__": demo.launch(debug=True)