import torch from torch import nn import torchvision.transforms as transforms import torch.nn.functional as F from pathlib import Path import gradio as gr from huggingface_hub import hf_hub_download LABELS = Path("classes.txt").read_text().splitlines() num_classes = len(LABELS) model = nn.Sequential( nn.Conv2d(1, 64, 3, padding="same"), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(64, 128, 3, padding="same"), nn.ReLU(), nn.MaxPool2d(2), nn.Conv2d(128, 256, 3, padding="same"), nn.ReLU(), nn.MaxPool2d(2), nn.Flatten(), nn.Linear(2304, 512), nn.ReLU(), nn.Linear(512, num_classes), ) model_path = hf_hub_download(repo_id="jerilseb/quickdraw-small", filename="model.pth") state_dict = torch.load(model_path, map_location="cpu") model.load_state_dict(state_dict) model.eval() transform = transforms.Compose( [ transforms.Resize((28, 28)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ] ) def predict(image): image = image['composite'] tensor = transform(image).unsqueeze(0) with torch.no_grad(): out = model(tensor) probabilities = F.softmax(out[0], dim=0) values, indices = torch.topk(probabilities, 5) return {LABELS[i]: v.item() for i, v in zip(indices, values)} inputs = gr.ImageEditor( type="pil", height=720, width=720, layers=False, image_mode="L", brush=gr.Brush(default_color="white", default_size=20), sources=[], label="Draw a shape", ) demo = gr.Interface(predict, inputs=inputs, outputs="label", live=True) demo.launch(debug=True)