import numpy as np import torch from pathlib import Path import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import transforms import gradio as gr transform = transforms.Compose([ transforms.Resize((28, 28)), transforms.Grayscale(), transforms.ToTensor() ]) labels = ["๐ (ศูนย์)", "๑ (หนึ่ง)", "๒ (สอง)", "๓ (สาม)", "๔ (สี่)", "๕ (ห้า)", "๖ (หก)", "๗ (เจ็ด)", "๘ (แปด)", "๙ (เก้า)"] LABELS = {i:k for i, k in enumerate(labels)} # dictionary of index and label # Load model using DropoutThaiDigit instead class DropoutThaiDigit(nn.Module): def __init__(self): super(DropoutThaiDigit, self).__init__() self.fc1 = nn.Linear(28 * 28, 392) self.fc2 = nn.Linear(392, 196) self.fc3 = nn.Linear(196, 98) self.fc4 = nn.Linear(98, 10) self.dropout = nn.Dropout(0.1) def forward(self, x): x = x.view(-1, 28 * 28) x = self.fc1(x) x = F.relu(x) x = self.dropout(x) x = self.fc2(x) x = F.relu(x) x = self.dropout(x) x = self.fc3(x) x = F.relu(x) x = self.dropout(x) x = self.fc4(x) return x model = DropoutThaiDigit() model.load_state_dict(torch.load("thai_digit_net.pth")) model.eval() def predict(img): """ Predict function takes image and return top 5 predictions as a dictionary: {label: confidence, label: confidence, ...} """ if img.get("composite") is not None: if img["composite"].sum() == 0: return {"No input sketch": 0.0} img_data = img['composite'] img_gray = Image.fromarray(img_data).convert('L').resize((28, 28)) img_tensor = transforms.ToTensor()(img_gray).unsqueeze(0) # Make prediction with torch.no_grad(): probs = model(img_tensor).softmax(dim=1).squeeze() probs, indices = torch.topk(probs, 5) # select top 5 probs, indices = probs.tolist(), indices.tolist() # transform to list return {LABELS[i]: float(v) for i, v in zip(indices, probs)} js_func = """ function refresh() { const url = new URL(window.location); if (url.searchParams.get('__theme') !== 'dark') { url.searchParams.set('__theme', 'dark'); window.location.href = url.href; } } """ with gr.Blocks(js=js_func) as demo: gr.Interface( fn=predict, inputs=gr.Sketchpad( label="Draw Here", brush=gr.Brush(default_size=14, default_color="#FFFFFF", colors=["#FFFFFF"]), image_mode="L", layers=False, eraser=None, width=400, height=350 ), outputs=gr.Label(label="Guess"), title="Thai Digit Handwritten Classification", description="ทดลองวาดภาพตัวอักษรเลขไทยลงใน Sketchpad ด้านล่างเพื่อทำนายผลตัวเลข ตั้งแต่ ๐ (ศูนย์) ๑ (หนึ่ง) ๒ (สอง) ๓ (สาม) ๔ (สี่) ๕ (ห้า) ๖ (หก) ๗ (เจ็ด) ๘ (แปด) จนถึง ๙ (เก้า)", live=True ) if __name__ == "__main__": demo.launch()