import torch import torch.nn as nn import numpy as np from torchvision import models, transforms import time import os import copy import pickle from PIL import Image import datetime import gdown import urllib.request import gradio as gr url = 'https://drive.google.com/uc?id=1VMLpE5ojF9fq0GtBKaqcMVWUIfJUfKbc' path_class_names = "./class_names_restnet_catsVSdogs.pkl" gdown.download(url, path_class_names, quiet=False) url = 'https://drive.google.com/uc?id=1jorQB1mpPCLH097M8paxut3v5XwVlKqp' path_model = "./model_state_restnet_catsVSdogs.pth" gdown.download(url, path_model, quiet=False) url = "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg" path_input = "./cat.jpg" urllib.request.urlretrieve(url, filename=path_input) url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg" path_input = "./dog.jpg" urllib.request.urlretrieve(url, filename=path_input) data_transforms_val = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) class_names = pickle.load(open(path_class_names, "rb")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model_ft = models.resnet18(pretrained=True) num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, len(class_names)) model_ft = model_ft.to(device) model_ft.load_state_dict(copy.deepcopy(torch.load(path_model,device))) def do_inference(img): img_t = data_transforms_val(img) batch_t = torch.unsqueeze(img_t, 0) model_ft.eval() # We don't need gradients for test, so wrap in # no_grad to save memory with torch.no_grad(): batch_t = batch_t.to(device) # forward propagation output = model_ft( batch_t) # get prediction probs = torch.nn.functional.softmax(output, dim=1) output = torch.argsort(probs, dim=1, descending=True).cpu().numpy()[0].astype(int) probs = probs.cpu().numpy()[0] probs = probs[output] labels = np.array(class_names)[output] return {labels[i]: round(float(probs[i]),2) for i in range(len(labels))} im = gr.inputs.Image(shape=(512, 512), image_mode='RGB', invert_colors=False, source="upload", type="pil") title = "CatsVsDogs Classifier" description = "Playground: Inferernce of Object Classification (Binary) using ResNet model and CatsVsDogs dataset. Tools: Pytorch, Gradio." examples = [['./cat.jpg'],['./dog.jpg']] iface = gr.Interface( do_inference, im, gr.outputs.Label(num_top_classes=2), live=False, interpretation="default", title=title, description=description, examples=examples ) iface.test_launch() iface.launch(share=True)