Rounak28's picture
jyfgyjsgf
f20e7b8
raw
history blame
853 Bytes
import gradio as gr
import torch
from torch.nn import Softmax
from torchvision.io import read_image
from torchvision.transforms.functional import rgb_to_grayscale
from model import NeuralNet
device = "cuda" if torch.cuda.is_available() else "cpu"
model = NeuralNet().to(device)
model.load_state_dict(torch.load("models/model.pth"))
labels = [
"zero",
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine"
]
def predict(image):
img = read_image(image).to(device)
img = rgb_to_grayscale(img) / 255
pred = model(img)[0]
prob = Softmax(dim=0)(pred).tolist()
return {l:p for l, p in zip(labels, prob)}
iface = gr.Interface(
fn=predict,
inputs=gr.Image(label="Upload", type="filepath", shape=(28, 28)),
outputs=gr.Label(num_top_classes=10),
)
iface.launch()