Spaces:
Runtime error
Runtime error
import gradio as gr | |
import timm | |
import torch | |
import torch.nn as nn | |
from torchvision import datasets, transforms | |
from PIL import Image | |
from torch.utils.mobile_optimizer import optimize_for_mobile | |
model = timm.create_model('resnet50', pretrained=True) | |
model.fc = torch.nn.Linear(in_features=model.fc.in_features, out_features=5) | |
path = "epoch_4_Resnet50-0.5contrast.pth" | |
model.load_state_dict(torch.load(path)) | |
model.eval() | |
def transform_image(img_sample): | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to 224x224 | |
transforms.ToTensor(), # Convert PIL image to tensor | |
transforms.ColorJitter(contrast=0.5), # Contrast | |
#transforms.RandomAdjustSharpness(sharpness_factor=0.5), | |
#transforms.RandomSolarize(threshold=0.75), | |
#transforms.RandomAutocontrast(p=1), | |
]) | |
transformed_img = transform(img_sample) | |
return transformed_img | |
def predict(Image): | |
tranformed_img = transform_image(Image) | |
model.eval() | |
img = transform_image(Image) | |
img = img.reshape(1,3,224,224) | |
#img = torch.from_numpy(tranformed_img) | |
#outputs = model(img) | |
#class_out = outputs.argmax(dim=1) | |
with torch.no_grad(): | |
grade = torch.softmax(model(img.float()), dim=1)[0] | |
category = ["0 - Normal", "1 - Mild", "2 - Moderate", "3 - Severe", "4 - Proliferative"] | |
output_dict = {} | |
for cat, value in zip(category, grade): | |
output_dict[cat] = value.item() | |
return output_dict | |
image = gr.Image(type="pil")#shape=(224, 224), image_mode="RGB") | |
label = gr.Label(label="Level") | |
demo = gr.Interface( | |
fn=predict, | |
inputs=image, | |
outputs=label, | |
#examples=["examples/0.png", "examples/1.png", "examples/2.png", "examples/3.png", "examples/4.png"] | |
examples=["0.png", "2.png", "4.png"] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |