ClassCat commited on
Commit
8187330
1 Parent(s): a08923c

add app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -0
app.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # common
3
+ import os, sys
4
+ import math
5
+ #import numpy as np
6
+
7
+ #from random import randrange
8
+
9
+ # torch
10
+ import torch
11
+ from torch import nn
12
+ #from torch import einsum
13
+
14
+ import torch.nn.functional as F
15
+
16
+ #from torch import optim
17
+ #from torch.optim import lr_scheduler
18
+
19
+ #from torch.utils.data import DataLoader
20
+ #from torch.utils.data.sampler import SubsetRandomSampler
21
+
22
+ # torchVision
23
+ import torchvision
24
+ from torchvision import transforms
25
+ #from torchvision import models
26
+ #from torchvision.datasets import CIFAR10, CIFAR100
27
+
28
+ # torchinfo
29
+ #from torchinfo import summary
30
+
31
+ # Define model
32
+ class WideBasic(nn.Module):
33
+ def __init__(self, in_channels, out_channels, stride=1):
34
+ super().__init__()
35
+ self.residual = nn.Sequential(
36
+ nn.BatchNorm2d(in_channels),
37
+ nn.ReLU(inplace=True),
38
+ nn.Conv2d(
39
+ in_channels,
40
+ out_channels,
41
+ kernel_size=3,
42
+ stride=stride,
43
+ padding=1
44
+ ),
45
+ nn.BatchNorm2d(out_channels),
46
+ nn.ReLU(inplace=True),
47
+ nn.Dropout(),
48
+ nn.Conv2d(
49
+ out_channels,
50
+ out_channels,
51
+ kernel_size=3,
52
+ stride=1,
53
+ padding=1
54
+ )
55
+ )
56
+
57
+ self.shortcut = nn.Sequential()
58
+
59
+ if in_channels != out_channels or stride != 1:
60
+ self.shortcut = nn.Sequential(
61
+ nn.Conv2d(in_channels, out_channels, 1, stride=stride)
62
+ )
63
+
64
+ def forward(self, x):
65
+ residual = self.residual(x)
66
+ shortcut = self.shortcut(x)
67
+
68
+ return residual + shortcut
69
+
70
+ class WideResNet(nn.Module):
71
+ def __init__(self, num_classes, block, depth=50, widen_factor=1):
72
+ super().__init__()
73
+
74
+ self.depth = depth
75
+ k = widen_factor
76
+ l = int((depth - 4) / 6)
77
+ self.in_channels = 16
78
+ self.init_conv = nn.Conv2d(3, self.in_channels, 3, 1, padding=1)
79
+ self.conv2 = self._make_layer(block, 16 * k, l, 1)
80
+ self.conv3 = self._make_layer(block, 32 * k, l, 2)
81
+ self.conv4 = self._make_layer(block, 64 * k, l, 2)
82
+ self.bn = nn.BatchNorm2d(64 * k)
83
+ self.relu = nn.ReLU(inplace=True)
84
+ self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
85
+ self.linear = nn.Linear(64 * k, num_classes)
86
+
87
+ def forward(self, x):
88
+ x = self.init_conv(x)
89
+ x = self.conv2(x)
90
+ x = self.conv3(x)
91
+ x = self.conv4(x)
92
+ x = self.bn(x)
93
+ x = self.relu(x)
94
+ x = self.avg_pool(x)
95
+ x = x.view(x.size(0), -1)
96
+ x = self.linear(x)
97
+
98
+ return x
99
+
100
+ def _make_layer(self, block, out_channels, num_blocks, stride):
101
+ strides = [stride] + [1] * (num_blocks - 1)
102
+ layers = []
103
+ for stride in strides:
104
+ layers.append(block(self.in_channels, out_channels, stride))
105
+ self.in_channels = out_channels
106
+
107
+ return nn.Sequential(*layers)
108
+
109
+
110
+ model = WideResNet(10, WideBasic, depth=40, widen_factor=10)
111
+ model.load_state_dict(
112
+ torch.load("weights/cifar10_wide_resnet_model.pt",
113
+ map_location=torch.device('cpu'))
114
+ )
115
+
116
+ model.eval()
117
+
118
+ import gradio as gr
119
+ from torchvision import transforms
120
+
121
+ import os
122
+ import glob
123
+
124
+ examples_dir = './examples'
125
+ example_files = glob.glob(os.path.join(examples_dir, '*.png'))
126
+
127
+ normalize = transforms.Normalize(
128
+ mean=[0.4914, 0.4822, 0.4465],
129
+ std=[0.2470, 0.2435, 0.2616],
130
+ )
131
+
132
+ transform = transforms.Compose([
133
+ transforms.ToTensor(),
134
+ normalize,
135
+ ])
136
+
137
+ classes = [
138
+ "airplane",
139
+ "automobile",
140
+ "bird",
141
+ "cat",
142
+ "deer",
143
+ "dog",
144
+ "frog",
145
+ "horse",
146
+ "ship",
147
+ "truck",
148
+ ]
149
+
150
+ def predict(image):
151
+ tsr_image = transform(image).unsqueeze(dim=0)
152
+
153
+ model.eval()
154
+ with torch.no_grad():
155
+ pred = model(tsr_image)
156
+ prob = torch.nn.functional.softmax(pred[0], dim=0)
157
+
158
+ confidences = {classes[i]: float(prob[i]) for i in range(10)}
159
+ return confidences
160
+
161
+
162
+ with gr.Blocks(css=".gradio-container {background:honeydew;}", title="WideResNet - CIFAR10 Classification"
163
+ ) as demo:
164
+ gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">WideResNet - CIFAR10 Classification</div>""")
165
+
166
+ with gr.Row():
167
+ input_image = gr.Image(type="pil", image_mode="RGB", shape=(32, 32))
168
+
169
+ output_label=gr.Label(label="Probabilities", num_top_classes=3)
170
+
171
+ send_btn = gr.Button("Infer")
172
+
173
+ gr.Examples(example_files, inputs=input_image)
174
+ #gr.Examples(['examples/sample02.png', 'examples/sample04.png'], inputs=input_image2)
175
+
176
+ send_btn.click(fn=predict, inputs=input_image, outputs=output_label)
177
+
178
+ # demo.queue(concurrency_count=3)
179
+ demo.launch()
180
+
181
+
182
+ ### EOF ###