Jeril Sebastian commited on
Commit
0b7cb96
1 Parent(s): 2f156d5

some fixes

Browse files
Files changed (4) hide show
  1. .gitignore +3 -0
  2. app.py +29 -23
  3. model.pth +0 -3
  4. requirements.txt +3 -2
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.pth
2
+ .venv
3
+ __pycache__
app.py CHANGED
@@ -4,21 +4,19 @@ import torchvision.transforms as transforms
4
  import torch.nn.functional as F
5
  from pathlib import Path
6
  import gradio as gr
7
- from PIL import Image
8
- import numpy as np
9
 
10
-
11
- LABELS = Path('classes.txt').read_text().splitlines()
12
  num_classes = len(LABELS)
13
 
14
  model = nn.Sequential(
15
- nn.Conv2d(1, 64, 3, padding='same'),
16
  nn.ReLU(),
17
  nn.MaxPool2d(2),
18
- nn.Conv2d(64, 128, 3, padding='same'),
19
  nn.ReLU(),
20
  nn.MaxPool2d(2),
21
- nn.Conv2d(128, 256, 3, padding='same'),
22
  nn.ReLU(),
23
  nn.MaxPool2d(2),
24
  nn.Flatten(),
@@ -27,32 +25,40 @@ model = nn.Sequential(
27
  nn.Linear(512, num_classes),
28
  )
29
 
30
- state_dict = torch.load('model.pth', map_location='cpu')
31
- model.load_state_dict(state_dict, strict=False)
 
32
  model.eval()
33
 
34
- transform = transforms.Compose([
35
- transforms.Resize((28, 28)),
36
- transforms.ToTensor(),
37
- transforms.Normalize((0.5,), (0.5,))
38
- ])
39
 
 
 
 
 
 
 
 
40
 
41
  def predict(image):
42
  image = image['composite']
43
- image = Image.fromarray(image).convert('L')
44
- input = transform(image).unsqueeze(0)
45
-
46
  with torch.no_grad():
47
- out = model(input)
48
 
49
- print(out.shape)
50
  probabilities = F.softmax(out[0], dim=0)
51
  values, indices = torch.topk(probabilities, 5)
52
- print(values, indices)
53
-
54
  return {LABELS[i]: v.item() for i, v in zip(indices, values)}
55
 
56
 
57
- interface = gr.Interface(predict, inputs='sketchpad', outputs='label', live=True)
58
- interface.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
4
  import torch.nn.functional as F
5
  from pathlib import Path
6
  import gradio as gr
7
+ from huggingface_hub import hf_hub_download
 
8
 
9
+ LABELS = Path("classes.txt").read_text().splitlines()
 
10
  num_classes = len(LABELS)
11
 
12
  model = nn.Sequential(
13
+ nn.Conv2d(1, 64, 3, padding="same"),
14
  nn.ReLU(),
15
  nn.MaxPool2d(2),
16
+ nn.Conv2d(64, 128, 3, padding="same"),
17
  nn.ReLU(),
18
  nn.MaxPool2d(2),
19
+ nn.Conv2d(128, 256, 3, padding="same"),
20
  nn.ReLU(),
21
  nn.MaxPool2d(2),
22
  nn.Flatten(),
 
25
  nn.Linear(512, num_classes),
26
  )
27
 
28
+ model_path = hf_hub_download(repo_id="jerilseb/quickdraw-small", filename="model.pth")
29
+ state_dict = torch.load(model_path, map_location="cpu")
30
+ model.load_state_dict(state_dict)
31
  model.eval()
32
 
 
 
 
 
 
33
 
34
+ transform = transforms.Compose(
35
+ [
36
+ transforms.Resize((28, 28)),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize((0.5,), (0.5,)),
39
+ ]
40
+ )
41
 
42
  def predict(image):
43
  image = image['composite']
44
+ tensor = transform(image).unsqueeze(0)
 
 
45
  with torch.no_grad():
46
+ out = model(tensor)
47
 
 
48
  probabilities = F.softmax(out[0], dim=0)
49
  values, indices = torch.topk(probabilities, 5)
 
 
50
  return {LABELS[i]: v.item() for i, v in zip(indices, values)}
51
 
52
 
53
+ inputs = gr.ImageEditor(
54
+ type="pil",
55
+ height=720,
56
+ width=720,
57
+ layers=False,
58
+ image_mode="L",
59
+ brush=gr.Brush(default_color="white", default_size=20),
60
+ sources=[],
61
+ label="Draw a shape",
62
+ )
63
+ demo = gr.Interface(predict, inputs=inputs, outputs="label", live=True)
64
+ demo.launch(debug=True)
model.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:217879299736265e793d5c72f21acc9b8646fa926f51fb3c31c46400e6bfe32c
3
- size 6910492
 
 
 
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- torch
2
  torchvision
3
- Pillow
 
 
1
+ torch==2.3.1
2
  torchvision
3
+ gradio
4
+ huggingface_hub