Shreyz-max commited on
Commit
6672bfb
1 Parent(s): 28fb98f

Add application file

Browse files
app.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+
3
+ import streamlit as st
4
+ from streamlit_drawable_canvas import st_canvas
5
+ import test
6
+ from PIL import Image
7
+ import gdown
8
+
9
+
10
+ st.set_page_config(layout="wide")
11
+
12
+ # Specify canvas parameters in application
13
+ drawing_object = st.sidebar.selectbox(
14
+ "Object:", ("sea", "cloud", "bush", "grass", "mountain", "sky", "snow",
15
+ "tree", "flower", "road")
16
+ )
17
+ drawing_object_dict = {"sea": "rgb(56,79,131)", "cloud": "rgb(239,239,239)",
18
+ "bush": "rgb(93,110,50)", "grass": "rgb(183,210,78)",
19
+ "mountain": "rgb(60,59,75)", "snow": "rgb(250,250,250)",
20
+ "sky": "rgb(117,158,223)", "tree": "rgb(53, 38, 19)",
21
+ "flower": "rgb(230,112,182)",
22
+ "road": "rgb(152, 126, 106)"}
23
+
24
+ stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 3)
25
+
26
+ stroke_color = drawing_object_dict[drawing_object]
27
+
28
+
29
+ col1, col2 = st.columns(2)
30
+ with col1:
31
+ # Create a canvas component with different parameters
32
+ canvas_result = st_canvas(
33
+ fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
34
+ stroke_width=stroke_width,
35
+ stroke_color=stroke_color,
36
+ background_color="rgb(117,158,223)",
37
+ background_image=None,
38
+ height=512,
39
+ width=512,
40
+ drawing_mode="freedraw",
41
+ point_display_radius=0,
42
+ key="canvas",
43
+ )
44
+ if canvas_result.image_data is not None:
45
+ pass
46
+
47
+
48
+ @st.cache
49
+ def download_model():
50
+ f_checkpoint = os.path.join("latest_net_G.pth")
51
+ if not os.path.exists(f_checkpoint):
52
+ with st.spinner("Downloading model... this may take awhile! \n Don't stop it!"):
53
+ url = 'https://drive.google.com/uc?id=15VSa2m2F6Ch0NpewDR7mkKAcXlMgDi5F'
54
+ output = 'latest_net_G.pth'
55
+ gdown.download(url, output, quiet=False)
56
+
57
+
58
+ if st.button('generate'):
59
+ download_model()
60
+ image = Image.fromarray(canvas_result.image_data)
61
+ s = test.semantic(image)
62
+ image = test.evaluate(s)
63
+ image = test.to_image(image)
64
+ with col2:
65
+ st.image(image, clamp=True, width=512)
66
+
67
+
68
+ st.markdown(
69
+ """
70
+ <style>
71
+ [data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
72
+ width: 120px;
73
+ }
74
+ [data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
75
+ width: 500px;
76
+ margin-left: -500px;
77
+ }
78
+ </style>
79
+ """,
80
+ unsafe_allow_html=True,
81
+ )
label_colors.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ colorMap = [
2
+ {"color": (56, 79, 131), "id": 154, "label": "sea"},
3
+ {"color": (239, 239, 239), "id": 105, "label": "cloud"},
4
+ {"color": (93, 110, 50), "id": 96, "label": "bush"},
5
+ {"color": (183, 210, 78), "id": 123, "label": "grass"},
6
+ {"color": (60, 59, 75), "id": 134, "label": "mountain"},
7
+ {"color": (117, 158, 223), "id": 156, "label": "sky"},
8
+ {"color": (250, 250, 250), "id": 158, "label": "snow"},
9
+ {"color": (53, 38, 19), "id": 168, "label": "tree"},
10
+ {"color": (230, 112, 182), "id": 118, "label": "flower"},
11
+ {"color": (152, 126, 106), "id": 148, "label": "road"}
12
+ ]
spade/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2018 Jiayuan MAO
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
spade/README.md ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Synchronized-BatchNorm-PyTorch
2
+
3
+ **IMPORTANT: Please read the "Implementation details and highlights" section before use.**
4
+
5
+ Synchronized Batch Normalization implementation in PyTorch.
6
+
7
+ This module differs from the built-in PyTorch BatchNorm as the mean and
8
+ standard-deviation are reduced across all devices during training.
9
+
10
+ For example, when one uses `nn.DataParallel` to wrap the network during
11
+ training, PyTorch's implementation normalize the tensor on each device using
12
+ the statistics only on that device, which accelerated the computation and
13
+ is also easy to implement, but the statistics might be inaccurate.
14
+ Instead, in this synchronized version, the statistics will be computed
15
+ over all training samples distributed on multiple devices.
16
+
17
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
18
+ as the built-in PyTorch implementation.
19
+
20
+ This module is currently only a prototype version for research usages. As mentioned below,
21
+ it has its limitations and may even suffer from some design problems. If you have any
22
+ questions or suggestions, please feel free to
23
+ [open an issue](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues) or
24
+ [submit a pull request](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues).
25
+
26
+ ## Why Synchronized BatchNorm?
27
+
28
+ Although the typical implementation of BatchNorm working on multiple devices (GPUs)
29
+ is fast (with no communication overhead), it inevitably reduces the size of batch size,
30
+ which potentially degenerates the performance. This is not a significant issue in some
31
+ standard vision tasks such as ImageNet classification (as the batch size per device
32
+ is usually large enough to obtain good statistics). However, it will hurt the performance
33
+ in some tasks that the batch size is usually very small (e.g., 1 per GPU).
34
+
35
+ For example, the importance of synchronized batch normalization in object detection has been recently proved with a
36
+ an extensive analysis in the paper [MegDet: A Large Mini-Batch Object Detector](https://arxiv.org/abs/1711.07240).
37
+
38
+ ## Usage
39
+
40
+ To use the Synchronized Batch Normalization, we add a data parallel replication callback. This introduces a slight
41
+ difference with typical usage of the `nn.DataParallel`.
42
+
43
+ Use it with a provided, customized data parallel wrapper:
44
+
45
+ ```python
46
+ from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback
47
+
48
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
49
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
50
+ ```
51
+
52
+ Or, if you are using a customized data parallel module, you can use this library as a monkey patching.
53
+
54
+ ```python
55
+ from torch.nn import DataParallel # or your customized DataParallel module
56
+ from sync_batchnorm import SynchronizedBatchNorm1d, patch_replication_callback
57
+
58
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
59
+ sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
60
+ patch_replication_callback(sync_bn) # monkey-patching
61
+ ```
62
+
63
+ You can use `convert_model` to convert your model to use Synchronized BatchNorm easily.
64
+
65
+ ```python
66
+ import torch.nn as nn
67
+ from torchvision import models
68
+ from sync_batchnorm import convert_model
69
+ # m is a standard pytorch model
70
+ m = models.resnet18(True)
71
+ m = nn.DataParallel(m)
72
+ # after convert, m is using SyncBN
73
+ m = convert_model(m)
74
+ ```
75
+
76
+ See also `tests/test_sync_batchnorm.py` for numeric result comparison.
77
+
78
+ ## Implementation details and highlights
79
+
80
+ If you are interested in how batch statistics are reduced and broadcasted among multiple devices, please take a look
81
+ at the code with detailed comments. Here we only emphasize some highlights of the implementation:
82
+
83
+ - This implementation is in pure-python. No C++ extra extension libs.
84
+ - Easy to use as demonstrated above.
85
+ - It uses unbiased variance to update the moving average, and use `sqrt(max(var, eps))` instead of `sqrt(var + eps)`.
86
+ - The implementation requires that each module on different devices should invoke the `batchnorm` for exactly SAME
87
+ amount of times in each forward pass. For example, you can not only call `batchnorm` on GPU0 but not on GPU1. The `#i
88
+ (i = 1, 2, 3, ...)` calls of the `batchnorm` on each device will be viewed as a whole and the statistics will be reduced.
89
+ This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this
90
+ will usually not be the issue for most of the models.
91
+
92
+ ## Known issues
93
+
94
+ #### Runtime error on backward pass.
95
+
96
+ Due to a [PyTorch Bug](https://github.com/pytorch/pytorch/issues/3883), using old PyTorch libraries will trigger an `RuntimeError` with messages like:
97
+
98
+ ```
99
+ Assertion `pos >= 0 && pos < buffer.size()` failed.
100
+ ```
101
+
102
+ This has already been solved in the newest PyTorch repo, which, unfortunately, has not been pushed to the official and anaconda binary release. Thus, you are required to build the PyTorch package from the source according to the
103
+ instructions [here](https://github.com/pytorch/pytorch#from-source).
104
+
105
+ #### Numeric error.
106
+
107
+ Because this library does not fuse the normalization and statistics operations in C++ (nor CUDA), it is less
108
+ numerically stable compared to the original PyTorch implementation. Detailed analysis can be found in
109
+ `tests/test_sync_batchnorm.py`.
110
+
111
+ ## Authors and License:
112
+
113
+ Copyright (c) 2018-, [Jiayuan Mao](https://vccy.xyz).
114
+
115
+ **Contributors**: [Tete Xiao](https://tetexiao.com), [DTennant](https://github.com/DTennant).
116
+
117
+ Distributed under **MIT License** (See LICENSE)
118
+
spade/Synchronized-BatchNorm-PyTorch ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit dcfae91cbc3767a3c5cd28d46ab78503a22b0fe7
spade/__pycache__/dataset.cpython-310.pyc ADDED
Binary file (1.1 kB). View file
 
spade/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (1.1 kB). View file
 
spade/__pycache__/generator.cpython-310.pyc ADDED
Binary file (3.33 kB). View file
 
spade/__pycache__/generator.cpython-38.pyc ADDED
Binary file (3.31 kB). View file
 
spade/__pycache__/model.cpython-310.pyc ADDED
Binary file (3.52 kB). View file
 
spade/__pycache__/model.cpython-38.pyc ADDED
Binary file (3.48 kB). View file
 
spade/__pycache__/normalizer.cpython-310.pyc ADDED
Binary file (1.47 kB). View file
 
spade/__pycache__/normalizer.cpython-38.pyc ADDED
Binary file (1.47 kB). View file
 
spade/dataset.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image
8
+
9
+ def __scale_width(img, target_width, method=Image.BICUBIC):
10
+ ow, oh = img.size
11
+ if (ow == target_width):
12
+ return img
13
+ w = target_width
14
+ h = int(target_width * oh / ow)
15
+ return img.resize((w, h), method)
16
+
17
+ def get_transform(opt, method=Image.BICUBIC, normalize=True):
18
+ transform_list = []
19
+ transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt['load_size'], method)))
20
+ transform_list += [transforms.ToTensor()]
21
+ if normalize:
22
+ transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
23
+
24
+ return transforms.Compose(transform_list)
spade/generator.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from spade.normalizer import SPADE
11
+
12
+ class SPADEGenerator(nn.Module):
13
+ def __init__(self, opt):
14
+ super().__init__()
15
+
16
+ # nf: # of gen filters in first conv layer
17
+ nf = 64
18
+
19
+ self.sw, self.sh = self.compute_latent_vector_size(opt['crop_size'], opt['aspect_ratio'])
20
+
21
+ self.fc = nn.Conv2d(opt['label_nc'], 16 * nf, 3, padding=1)
22
+
23
+ self.head_0 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)
24
+
25
+ self.G_middle_0 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)
26
+ self.G_middle_1 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)
27
+
28
+ self.up_0 = SPADEResnetBlock(opt, 16 * nf, 8 * nf)
29
+ self.up_1 = SPADEResnetBlock(opt, 8 * nf, 4 * nf)
30
+ self.up_2 = SPADEResnetBlock(opt, 4 * nf, 2 * nf)
31
+ self.up_3 = SPADEResnetBlock(opt, 2 * nf, 1 * nf)
32
+
33
+ self.conv_img = nn.Conv2d(1 * nf, 3, 3, padding=1)
34
+
35
+ self.up = nn.Upsample(scale_factor=2)
36
+
37
+ def compute_latent_vector_size(self, crop_size, aspect_ratio):
38
+ num_up_layers = 5
39
+
40
+ sw = crop_size // (2**num_up_layers)
41
+ sh = round(sw / aspect_ratio)
42
+
43
+ return sw, sh
44
+
45
+ def forward(self, seg):
46
+ # we downsample segmap and run convolution
47
+ x = F.interpolate(seg, size=(self.sh, self.sw))
48
+ x = self.fc(x)
49
+
50
+ x = self.head_0(x, seg)
51
+
52
+ x = self.up(x)
53
+ x = self.G_middle_0(x, seg)
54
+ x = self.G_middle_1(x, seg)
55
+
56
+ x = self.up(x)
57
+ x = self.up_0(x, seg)
58
+ x = self.up(x)
59
+ x = self.up_1(x, seg)
60
+ x = self.up(x)
61
+ x = self.up_2(x, seg)
62
+ x = self.up(x)
63
+ x = self.up_3(x, seg)
64
+
65
+ x = self.conv_img(F.leaky_relu(x, 2e-1))
66
+ x = torch.tanh(x)
67
+
68
+ return x
69
+
70
+ import torch.nn.utils.spectral_norm as spectral_norm
71
+
72
+ # label_nc: the #channels of the input semantic map, hence the input dim of SPADE
73
+ # label_nc: also equivalent to the # of input label classes
74
+ class SPADEResnetBlock(nn.Module):
75
+ def __init__(self, opt, fin, fout):
76
+ super().__init__()
77
+
78
+ self.learned_shortcut = (fin != fout)
79
+ fmiddle = min(fin, fout)
80
+
81
+ self.conv_0 = spectral_norm(nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1))
82
+ self.conv_1 = spectral_norm(nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1))
83
+ if self.learned_shortcut:
84
+ self.conv_s = spectral_norm(nn.Conv2d(fin, fout, kernel_size=1, bias=False))
85
+
86
+ # define normalization layers
87
+ self.norm_0 = SPADE(opt, fin)
88
+ self.norm_1 = SPADE(opt, fmiddle)
89
+ if self.learned_shortcut:
90
+ self.norm_s = SPADE(opt, fin)
91
+
92
+ # note the resnet block with SPADE also takes in |seg|,
93
+ # the semantic segmentation map as input
94
+ def forward(self, x, seg):
95
+ x_s = self.shortcut(x, seg)
96
+
97
+ dx = self.conv_0(self.relu(self.norm_0(x, seg)))
98
+ dx = self.conv_1(self.relu(self.norm_1(dx, seg)))
99
+
100
+ out = x_s + dx
101
+ return out
102
+
103
+ def shortcut(self, x, seg):
104
+ if self.learned_shortcut:
105
+ x_s = self.conv_s(self.norm_s(x, seg))
106
+ else:
107
+ x_s = x
108
+ return x_s
109
+
110
+ def relu(self, x):
111
+ return F.leaky_relu(x, 2e-1)
112
+
113
+
spade/model.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import os
7
+ import torch
8
+ from torch.nn import init
9
+
10
+ from spade.generator import SPADEGenerator
11
+
12
+
13
+ class Pix2PixModel(torch.nn.Module):
14
+ def __init__(self, opt):
15
+ super().__init__()
16
+ self.opt = opt
17
+ self.FloatTensor = torch.cuda.FloatTensor if opt['use_gpu'] \
18
+ else torch.FloatTensor
19
+
20
+ self.netG = self.initialize_networks(opt)
21
+
22
+ def forward(self, data, mode):
23
+ input_semantics, real_image = self.preprocess_input(data)
24
+
25
+ if mode == 'inference':
26
+ with torch.no_grad():
27
+ fake_image = self.generate_fake(input_semantics)
28
+ return fake_image
29
+ else:
30
+ raise ValueError("|mode| is invalid")
31
+
32
+ def preprocess_input(self, data):
33
+ data['label'] = data['label'].long()
34
+
35
+ # move to GPU and change data types
36
+ if self.opt['use_gpu']:
37
+ data['label'] = data['label'].cuda()
38
+ data['instance'] = data['instance'].cuda()
39
+ data['image'] = data['image'].cuda()
40
+
41
+ # create one-hot label map
42
+ label_map = data['label']
43
+ bs, _, h, w = label_map.size()
44
+ input_label = self.FloatTensor(bs, self.opt['label_nc'], h, w).zero_()
45
+ # one whole label map -> to one label map per class
46
+ input_semantics = input_label.scatter_(1, label_map, 1.0)
47
+
48
+ return input_semantics, data['image']
49
+
50
+ def generate_fake(self, input_semantics):
51
+ fake_image = self.netG(input_semantics)
52
+ return fake_image
53
+
54
+ def create_network(self, cls, opt):
55
+ net = cls(opt)
56
+ if self.opt['use_gpu']:
57
+ net.cuda()
58
+
59
+ gain = 0.02
60
+
61
+ def init_weights(m):
62
+ classname = m.__class__.__name__
63
+ if classname.find('BatchNorm2d') != -1:
64
+ if hasattr(m, 'weight') and m.weight is not None:
65
+ init.normal_(m.weight.data, 1.0, gain)
66
+ if hasattr(m, 'bias') and m.bias is not None:
67
+ init.constant_(m.bias.data, 0.0)
68
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
69
+ init.xavier_normal_(m.weight.data, gain=gain)
70
+ if hasattr(m, 'bias') and m.bias is not None:
71
+ init.constant_(m.bias.data, 0.0)
72
+
73
+ # Applies fn recursively to every submodule (as returned by .children()) as well as self
74
+ net.apply(init_weights)
75
+
76
+ return net
77
+
78
+ def load_network(self, net, label, epoch, opt):
79
+ save_filename = '%s_net_%s.pth' % (epoch, label)
80
+ save_path = os.path.join( save_filename)
81
+ weights = torch.load(save_path)
82
+ net.load_state_dict(weights)
83
+ return net
84
+
85
+ def initialize_networks(self, opt):
86
+ netG = self.create_network(SPADEGenerator, opt)
87
+
88
+ if not opt['isTrain']:
89
+ netG = self.load_network(netG, 'G', opt['which_epoch'], opt)
90
+
91
+ # self.print_network(netG)
92
+
93
+ return netG
94
+
95
+ def print_network(self, net):
96
+ num_params = 0
97
+ for param in net.parameters():
98
+ num_params += param.numel()
99
+ print('Network [%s] was created. Total number of parameters: %.1f million. '
100
+ % (type(net).__name__, num_params / 1000000))
101
+ print(net)
spade/normalizer.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3
+ Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
10
+
11
+ # norm_nc: the #channels of the normalized activations, hence the output dim of SPADE
12
+ # label_nc: the #channels of the input semantic map, hence the input dim of SPADE
13
+ # label_nc: also equivalent to the # of input label classes
14
+ class SPADE(nn.Module):
15
+ def __init__(self, opt, norm_nc):
16
+ super().__init__()
17
+
18
+ self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
19
+
20
+ # number of internal filters for generating scale/bias
21
+ nhidden = 128
22
+ # size of kernels
23
+ kernal_size = 3
24
+ # padding size
25
+ padding = kernal_size // 2
26
+
27
+ self.mlp_shared = nn.Sequential(
28
+ nn.Conv2d(opt['label_nc'], nhidden, kernel_size=kernal_size, padding=padding),
29
+ nn.ReLU()
30
+ )
31
+ self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=kernal_size, padding=padding)
32
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernal_size, padding=padding)
33
+
34
+ def forward(self, x, segmap):
35
+ # Part 1. generate parameter-free normalized activations
36
+ normalized = self.param_free_norm(x)
37
+
38
+ # Part 2. produce scaling and bias conditioned on semantic map
39
+ # resize input segmentation map to match x.size() using nearest interpolation
40
+ # N, C, H, W = x.size()
41
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
42
+ actv = self.mlp_shared(segmap)
43
+ gamma = self.mlp_gamma(actv)
44
+ beta = self.mlp_beta(actv)
45
+
46
+ # apply scale and bias
47
+ out = normalized * (1 + gamma) + beta
48
+
49
+ return out
spade/tests/test_numeric_batchnorm.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_numeric_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm.unittest import TorchTestCase
16
+
17
+
18
+ def handy_var(a, unbias=True):
19
+ n = a.size(0)
20
+ asum = a.sum(dim=0)
21
+ as_sum = (a ** 2).sum(dim=0) # a square sum
22
+ sumvar = as_sum - asum * asum / n
23
+ if unbias:
24
+ return sumvar / (n - 1)
25
+ else:
26
+ return sumvar / n
27
+
28
+
29
+ class NumericTestCase(TorchTestCase):
30
+ def testNumericBatchNorm(self):
31
+ a = torch.rand(16, 10)
32
+ bn = nn.BatchNorm1d(10, momentum=1, eps=1e-5, affine=False)
33
+ bn.train()
34
+
35
+ a_var1 = Variable(a, requires_grad=True)
36
+ b_var1 = bn(a_var1)
37
+ loss1 = b_var1.sum()
38
+ loss1.backward()
39
+
40
+ a_var2 = Variable(a, requires_grad=True)
41
+ a_mean2 = a_var2.mean(dim=0, keepdim=True)
42
+ a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43
+ # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44
+ b_var2 = (a_var2 - a_mean2) / a_std2
45
+ loss2 = b_var2.sum()
46
+ loss2.backward()
47
+
48
+ self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49
+ self.assertTensorClose(bn.running_var, handy_var(a))
50
+ self.assertTensorClose(a_var1.data, a_var2.data)
51
+ self.assertTensorClose(b_var1.data, b_var2.data)
52
+ self.assertTensorClose(a_var1.grad, a_var2.grad)
53
+
54
+
55
+ if __name__ == '__main__':
56
+ unittest.main()
spade/tests/test_numeric_batchnorm_v2.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : test_numeric_batchnorm_v2.py
4
+ # Author : Jiayuan Mao
5
+ # Email : [email protected]
6
+ # Date : 11/01/2018
7
+ #
8
+ # Distributed under terms of the MIT license.
9
+
10
+ """
11
+ Test the numerical implementation of batch normalization.
12
+
13
+ Author: acgtyrant.
14
+ See also: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
15
+ """
16
+
17
+ import unittest
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.optim as optim
22
+
23
+ from sync_batchnorm.unittest import TorchTestCase
24
+ from sync_batchnorm.batchnorm_reimpl import BatchNorm2dReimpl
25
+
26
+
27
+ class NumericTestCasev2(TorchTestCase):
28
+ def testNumericBatchNorm(self):
29
+ CHANNELS = 16
30
+ batchnorm1 = nn.BatchNorm2d(CHANNELS, momentum=1)
31
+ optimizer1 = optim.SGD(batchnorm1.parameters(), lr=0.01)
32
+
33
+ batchnorm2 = BatchNorm2dReimpl(CHANNELS, momentum=1)
34
+ batchnorm2.weight.data.copy_(batchnorm1.weight.data)
35
+ batchnorm2.bias.data.copy_(batchnorm1.bias.data)
36
+ optimizer2 = optim.SGD(batchnorm2.parameters(), lr=0.01)
37
+
38
+ for _ in range(100):
39
+ input_ = torch.rand(16, CHANNELS, 16, 16)
40
+
41
+ input1 = input_.clone().requires_grad_(True)
42
+ output1 = batchnorm1(input1)
43
+ output1.sum().backward()
44
+ optimizer1.step()
45
+
46
+ input2 = input_.clone().requires_grad_(True)
47
+ output2 = batchnorm2(input2)
48
+ output2.sum().backward()
49
+ optimizer2.step()
50
+
51
+ self.assertTensorClose(input1, input2)
52
+ self.assertTensorClose(output1, output2)
53
+ self.assertTensorClose(input1.grad, input2.grad)
54
+ self.assertTensorClose(batchnorm1.weight.grad, batchnorm2.weight.grad)
55
+ self.assertTensorClose(batchnorm1.bias.grad, batchnorm2.bias.grad)
56
+ self.assertTensorClose(batchnorm1.running_mean, batchnorm2.running_mean)
57
+ self.assertTensorClose(batchnorm2.running_mean, batchnorm2.running_mean)
58
+
59
+
60
+ if __name__ == '__main__':
61
+ unittest.main()
62
+
spade/tests/test_sync_batchnorm.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : test_sync_batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+
9
+ import unittest
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.autograd import Variable
14
+
15
+ from sync_batchnorm import set_sbn_eps_mode
16
+ from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
17
+ from sync_batchnorm.unittest import TorchTestCase
18
+
19
+ set_sbn_eps_mode('plus')
20
+
21
+
22
+ def handy_var(a, unbias=True):
23
+ n = a.size(0)
24
+ asum = a.sum(dim=0)
25
+ as_sum = (a ** 2).sum(dim=0) # a square sum
26
+ sumvar = as_sum - asum * asum / n
27
+ if unbias:
28
+ return sumvar / (n - 1)
29
+ else:
30
+ return sumvar / n
31
+
32
+
33
+ def _find_bn(module):
34
+ for m in module.modules():
35
+ if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
36
+ return m
37
+
38
+
39
+ class SyncTestCase(TorchTestCase):
40
+ def _syncParameters(self, bn1, bn2):
41
+ bn1.reset_parameters()
42
+ bn2.reset_parameters()
43
+ if bn1.affine and bn2.affine:
44
+ bn2.weight.data.copy_(bn1.weight.data)
45
+ bn2.bias.data.copy_(bn1.bias.data)
46
+
47
+ def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
48
+ """Check the forward and backward for the customized batch normalization."""
49
+ bn1.train(mode=is_train)
50
+ bn2.train(mode=is_train)
51
+
52
+ if cuda:
53
+ input = input.cuda()
54
+
55
+ self._syncParameters(_find_bn(bn1), _find_bn(bn2))
56
+
57
+ input1 = Variable(input, requires_grad=True)
58
+ output1 = bn1(input1)
59
+ output1.sum().backward()
60
+ input2 = Variable(input, requires_grad=True)
61
+ output2 = bn2(input2)
62
+ output2.sum().backward()
63
+
64
+ self.assertTensorClose(input1.data, input2.data)
65
+ self.assertTensorClose(output1.data, output2.data)
66
+ self.assertTensorClose(input1.grad, input2.grad)
67
+ self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
68
+ self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
69
+
70
+ def testSyncBatchNormNormalTrain(self):
71
+ bn = nn.BatchNorm1d(10)
72
+ sync_bn = SynchronizedBatchNorm1d(10)
73
+
74
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
75
+
76
+ def testSyncBatchNormNormalEval(self):
77
+ bn = nn.BatchNorm1d(10)
78
+ sync_bn = SynchronizedBatchNorm1d(10)
79
+
80
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
81
+
82
+ def testSyncBatchNormSyncTrain(self):
83
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
84
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
85
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
86
+
87
+ bn.cuda()
88
+ sync_bn.cuda()
89
+
90
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
91
+
92
+ def testSyncBatchNormSyncEval(self):
93
+ bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
94
+ sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
95
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
96
+
97
+ bn.cuda()
98
+ sync_bn.cuda()
99
+
100
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
101
+
102
+ def testSyncBatchNorm2DSyncTrain(self):
103
+ bn = nn.BatchNorm2d(10)
104
+ sync_bn = SynchronizedBatchNorm2d(10)
105
+ sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
106
+
107
+ bn.cuda()
108
+ sync_bn.cuda()
109
+
110
+ self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
111
+
112
+
113
+ if __name__ == '__main__':
114
+ unittest.main()
sync_batchnorm/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : __init__.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ from .batchnorm import set_sbn_eps_mode
12
+ from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
13
+ from .batchnorm import patch_sync_batchnorm, convert_model
14
+ from .replicate import DataParallelWithCallback, patch_replication_callback
sync_batchnorm/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (475 Bytes). View file
 
sync_batchnorm/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (473 Bytes). View file
 
sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc ADDED
Binary file (15.3 kB). View file
 
sync_batchnorm/__pycache__/comm.cpython-310.pyc ADDED
Binary file (4.84 kB). View file
 
sync_batchnorm/__pycache__/comm.cpython-38.pyc ADDED
Binary file (4.8 kB). View file
 
sync_batchnorm/__pycache__/replicate.cpython-310.pyc ADDED
Binary file (3.46 kB). View file
 
sync_batchnorm/__pycache__/replicate.cpython-38.pyc ADDED
Binary file (3.45 kB). View file
 
sync_batchnorm/batchnorm.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : batchnorm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import collections
12
+ import contextlib
13
+
14
+ import torch
15
+ import torch.nn.functional as F
16
+
17
+ from torch.nn.modules.batchnorm import _BatchNorm
18
+
19
+ try:
20
+ from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
21
+ except ImportError:
22
+ ReduceAddCoalesced = Broadcast = None
23
+
24
+ try:
25
+ from jactorch.parallel.comm import SyncMaster
26
+ from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
27
+ except ImportError:
28
+ from .comm import SyncMaster
29
+ from .replicate import DataParallelWithCallback
30
+
31
+ __all__ = [
32
+ 'set_sbn_eps_mode',
33
+ 'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
34
+ 'patch_sync_batchnorm', 'convert_model'
35
+ ]
36
+
37
+
38
+ SBN_EPS_MODE = 'clamp'
39
+
40
+
41
+ def set_sbn_eps_mode(mode):
42
+ global SBN_EPS_MODE
43
+ assert mode in ('clamp', 'plus')
44
+ SBN_EPS_MODE = mode
45
+
46
+
47
+ def _sum_ft(tensor):
48
+ """sum over the first and last dimention"""
49
+ return tensor.sum(dim=0).sum(dim=-1)
50
+
51
+
52
+ def _unsqueeze_ft(tensor):
53
+ """add new dimensions at the front and the tail"""
54
+ return tensor.unsqueeze(0).unsqueeze(-1)
55
+
56
+
57
+ _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
58
+ _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
59
+
60
+
61
+ class _SynchronizedBatchNorm(_BatchNorm):
62
+ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
63
+ assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
64
+
65
+ super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
66
+ track_running_stats=track_running_stats)
67
+
68
+ if not self.track_running_stats:
69
+ import warnings
70
+ warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
71
+
72
+ self._sync_master = SyncMaster(self._data_parallel_master)
73
+
74
+ self._is_parallel = False
75
+ self._parallel_id = None
76
+ self._slave_pipe = None
77
+
78
+ def forward(self, input):
79
+ # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
80
+ if not (self._is_parallel and self.training):
81
+ return F.batch_norm(
82
+ input, self.running_mean, self.running_var, self.weight, self.bias,
83
+ self.training, self.momentum, self.eps)
84
+
85
+ # Resize the input to (B, C, -1).
86
+ input_shape = input.size()
87
+ assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
88
+ input = input.view(input.size(0), self.num_features, -1)
89
+
90
+ # Compute the sum and square-sum.
91
+ sum_size = input.size(0) * input.size(2)
92
+ input_sum = _sum_ft(input)
93
+ input_ssum = _sum_ft(input ** 2)
94
+
95
+ # Reduce-and-broadcast the statistics.
96
+ if self._parallel_id == 0:
97
+ mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
98
+ else:
99
+ mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
100
+
101
+ # Compute the output.
102
+ if self.affine:
103
+ # MJY:: Fuse the multiplication for speed.
104
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
105
+ else:
106
+ output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
107
+
108
+ # Reshape it.
109
+ return output.view(input_shape)
110
+
111
+ def __data_parallel_replicate__(self, ctx, copy_id):
112
+ self._is_parallel = True
113
+ self._parallel_id = copy_id
114
+
115
+ # parallel_id == 0 means master device.
116
+ if self._parallel_id == 0:
117
+ ctx.sync_master = self._sync_master
118
+ else:
119
+ self._slave_pipe = ctx.sync_master.register_slave(copy_id)
120
+
121
+ def _data_parallel_master(self, intermediates):
122
+ """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
123
+
124
+ # Always using same "device order" makes the ReduceAdd operation faster.
125
+ # Thanks to:: Tete Xiao (http://tetexiao.com/)
126
+ intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
127
+
128
+ to_reduce = [i[1][:2] for i in intermediates]
129
+ to_reduce = [j for i in to_reduce for j in i] # flatten
130
+ target_gpus = [i[1].sum.get_device() for i in intermediates]
131
+
132
+ sum_size = sum([i[1].sum_size for i in intermediates])
133
+ sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
134
+ mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
135
+
136
+ broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
137
+
138
+ outputs = []
139
+ for i, rec in enumerate(intermediates):
140
+ outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
141
+
142
+ return outputs
143
+
144
+ def _compute_mean_std(self, sum_, ssum, size):
145
+ """Compute the mean and standard-deviation with sum and square-sum. This method
146
+ also maintains the moving average on the master device."""
147
+ assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
148
+ mean = sum_ / size
149
+ sumvar = ssum - sum_ * mean
150
+ unbias_var = sumvar / (size - 1)
151
+ bias_var = sumvar / size
152
+
153
+ if hasattr(torch, 'no_grad'):
154
+ with torch.no_grad():
155
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
156
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
157
+ else:
158
+ self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
159
+ self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
160
+
161
+ if SBN_EPS_MODE == 'clamp':
162
+ return mean, bias_var.clamp(self.eps) ** -0.5
163
+ elif SBN_EPS_MODE == 'plus':
164
+ return mean, (bias_var + self.eps) ** -0.5
165
+ else:
166
+ raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
167
+
168
+
169
+ class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
170
+ r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
171
+ mini-batch.
172
+
173
+ .. math::
174
+
175
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
176
+
177
+ This module differs from the built-in PyTorch BatchNorm1d as the mean and
178
+ standard-deviation are reduced across all devices during training.
179
+
180
+ For example, when one uses `nn.DataParallel` to wrap the network during
181
+ training, PyTorch's implementation normalize the tensor on each device using
182
+ the statistics only on that device, which accelerated the computation and
183
+ is also easy to implement, but the statistics might be inaccurate.
184
+ Instead, in this synchronized version, the statistics will be computed
185
+ over all training samples distributed on multiple devices.
186
+
187
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
188
+ as the built-in PyTorch implementation.
189
+
190
+ The mean and standard-deviation are calculated per-dimension over
191
+ the mini-batches and gamma and beta are learnable parameter vectors
192
+ of size C (where C is the input size).
193
+
194
+ During training, this layer keeps a running estimate of its computed mean
195
+ and variance. The running sum is kept with a default momentum of 0.1.
196
+
197
+ During evaluation, this running mean/variance is used for normalization.
198
+
199
+ Because the BatchNorm is done over the `C` dimension, computing statistics
200
+ on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
201
+
202
+ Args:
203
+ num_features: num_features from an expected input of size
204
+ `batch_size x num_features [x width]`
205
+ eps: a value added to the denominator for numerical stability.
206
+ Default: 1e-5
207
+ momentum: the value used for the running_mean and running_var
208
+ computation. Default: 0.1
209
+ affine: a boolean value that when set to ``True``, gives the layer learnable
210
+ affine parameters. Default: ``True``
211
+
212
+ Shape::
213
+ - Input: :math:`(N, C)` or :math:`(N, C, L)`
214
+ - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
215
+
216
+ Examples:
217
+ >>> # With Learnable Parameters
218
+ >>> m = SynchronizedBatchNorm1d(100)
219
+ >>> # Without Learnable Parameters
220
+ >>> m = SynchronizedBatchNorm1d(100, affine=False)
221
+ >>> input = torch.autograd.Variable(torch.randn(20, 100))
222
+ >>> output = m(input)
223
+ """
224
+
225
+ def _check_input_dim(self, input):
226
+ if input.dim() != 2 and input.dim() != 3:
227
+ raise ValueError('expected 2D or 3D input (got {}D input)'
228
+ .format(input.dim()))
229
+
230
+
231
+ class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
232
+ r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
233
+ of 3d inputs
234
+
235
+ .. math::
236
+
237
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
238
+
239
+ This module differs from the built-in PyTorch BatchNorm2d as the mean and
240
+ standard-deviation are reduced across all devices during training.
241
+
242
+ For example, when one uses `nn.DataParallel` to wrap the network during
243
+ training, PyTorch's implementation normalize the tensor on each device using
244
+ the statistics only on that device, which accelerated the computation and
245
+ is also easy to implement, but the statistics might be inaccurate.
246
+ Instead, in this synchronized version, the statistics will be computed
247
+ over all training samples distributed on multiple devices.
248
+
249
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
250
+ as the built-in PyTorch implementation.
251
+
252
+ The mean and standard-deviation are calculated per-dimension over
253
+ the mini-batches and gamma and beta are learnable parameter vectors
254
+ of size C (where C is the input size).
255
+
256
+ During training, this layer keeps a running estimate of its computed mean
257
+ and variance. The running sum is kept with a default momentum of 0.1.
258
+
259
+ During evaluation, this running mean/variance is used for normalization.
260
+
261
+ Because the BatchNorm is done over the `C` dimension, computing statistics
262
+ on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
263
+
264
+ Args:
265
+ num_features: num_features from an expected input of
266
+ size batch_size x num_features x height x width
267
+ eps: a value added to the denominator for numerical stability.
268
+ Default: 1e-5
269
+ momentum: the value used for the running_mean and running_var
270
+ computation. Default: 0.1
271
+ affine: a boolean value that when set to ``True``, gives the layer learnable
272
+ affine parameters. Default: ``True``
273
+
274
+ Shape::
275
+ - Input: :math:`(N, C, H, W)`
276
+ - Output: :math:`(N, C, H, W)` (same shape as input)
277
+
278
+ Examples:
279
+ >>> # With Learnable Parameters
280
+ >>> m = SynchronizedBatchNorm2d(100)
281
+ >>> # Without Learnable Parameters
282
+ >>> m = SynchronizedBatchNorm2d(100, affine=False)
283
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
284
+ >>> output = m(input)
285
+ """
286
+
287
+ def _check_input_dim(self, input):
288
+ if input.dim() != 4:
289
+ raise ValueError('expected 4D input (got {}D input)'
290
+ .format(input.dim()))
291
+
292
+
293
+ class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
294
+ r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
295
+ of 4d inputs
296
+
297
+ .. math::
298
+
299
+ y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
300
+
301
+ This module differs from the built-in PyTorch BatchNorm3d as the mean and
302
+ standard-deviation are reduced across all devices during training.
303
+
304
+ For example, when one uses `nn.DataParallel` to wrap the network during
305
+ training, PyTorch's implementation normalize the tensor on each device using
306
+ the statistics only on that device, which accelerated the computation and
307
+ is also easy to implement, but the statistics might be inaccurate.
308
+ Instead, in this synchronized version, the statistics will be computed
309
+ over all training samples distributed on multiple devices.
310
+
311
+ Note that, for one-GPU or CPU-only case, this module behaves exactly same
312
+ as the built-in PyTorch implementation.
313
+
314
+ The mean and standard-deviation are calculated per-dimension over
315
+ the mini-batches and gamma and beta are learnable parameter vectors
316
+ of size C (where C is the input size).
317
+
318
+ During training, this layer keeps a running estimate of its computed mean
319
+ and variance. The running sum is kept with a default momentum of 0.1.
320
+
321
+ During evaluation, this running mean/variance is used for normalization.
322
+
323
+ Because the BatchNorm is done over the `C` dimension, computing statistics
324
+ on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
325
+ or Spatio-temporal BatchNorm
326
+
327
+ Args:
328
+ num_features: num_features from an expected input of
329
+ size batch_size x num_features x depth x height x width
330
+ eps: a value added to the denominator for numerical stability.
331
+ Default: 1e-5
332
+ momentum: the value used for the running_mean and running_var
333
+ computation. Default: 0.1
334
+ affine: a boolean value that when set to ``True``, gives the layer learnable
335
+ affine parameters. Default: ``True``
336
+
337
+ Shape::
338
+ - Input: :math:`(N, C, D, H, W)`
339
+ - Output: :math:`(N, C, D, H, W)` (same shape as input)
340
+
341
+ Examples:
342
+ >>> # With Learnable Parameters
343
+ >>> m = SynchronizedBatchNorm3d(100)
344
+ >>> # Without Learnable Parameters
345
+ >>> m = SynchronizedBatchNorm3d(100, affine=False)
346
+ >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
347
+ >>> output = m(input)
348
+ """
349
+
350
+ def _check_input_dim(self, input):
351
+ if input.dim() != 5:
352
+ raise ValueError('expected 5D input (got {}D input)'
353
+ .format(input.dim()))
354
+
355
+
356
+ @contextlib.contextmanager
357
+ def patch_sync_batchnorm():
358
+ import torch.nn as nn
359
+
360
+ backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
361
+
362
+ nn.BatchNorm1d = SynchronizedBatchNorm1d
363
+ nn.BatchNorm2d = SynchronizedBatchNorm2d
364
+ nn.BatchNorm3d = SynchronizedBatchNorm3d
365
+
366
+ yield
367
+
368
+ nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
369
+
370
+
371
+ def convert_model(module):
372
+ """Traverse the input module and its child recursively
373
+ and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
374
+ to SynchronizedBatchNorm*N*d
375
+
376
+ Args:
377
+ module: the input module needs to be convert to SyncBN model
378
+
379
+ Examples:
380
+ >>> import torch.nn as nn
381
+ >>> import torchvision
382
+ >>> # m is a standard pytorch model
383
+ >>> m = torchvision.models.resnet18(True)
384
+ >>> m = nn.DataParallel(m)
385
+ >>> # after convert, m is using SyncBN
386
+ >>> m = convert_model(m)
387
+ """
388
+ if isinstance(module, torch.nn.DataParallel):
389
+ mod = module.module
390
+ mod = convert_model(mod)
391
+ mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
392
+ return mod
393
+
394
+ mod = module
395
+ for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
396
+ torch.nn.modules.batchnorm.BatchNorm2d,
397
+ torch.nn.modules.batchnorm.BatchNorm3d],
398
+ [SynchronizedBatchNorm1d,
399
+ SynchronizedBatchNorm2d,
400
+ SynchronizedBatchNorm3d]):
401
+ if isinstance(module, pth_module):
402
+ mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
403
+ mod.running_mean = module.running_mean
404
+ mod.running_var = module.running_var
405
+ if module.affine:
406
+ mod.weight.data = module.weight.data.clone().detach()
407
+ mod.bias.data = module.bias.data.clone().detach()
408
+
409
+ for name, child in module.named_children():
410
+ mod.add_module(name, convert_model(child))
411
+
412
+ return mod
sync_batchnorm/batchnorm_reimpl.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ # File : batchnorm_reimpl.py
4
+ # Author : acgtyrant
5
+ # Date : 11/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.init as init
14
+
15
+ __all__ = ['BatchNorm2dReimpl']
16
+
17
+
18
+ class BatchNorm2dReimpl(nn.Module):
19
+ """
20
+ A re-implementation of batch normalization, used for testing the numerical
21
+ stability.
22
+
23
+ Author: acgtyrant
24
+ See also:
25
+ https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26
+ """
27
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
28
+ super().__init__()
29
+
30
+ self.num_features = num_features
31
+ self.eps = eps
32
+ self.momentum = momentum
33
+ self.weight = nn.Parameter(torch.empty(num_features))
34
+ self.bias = nn.Parameter(torch.empty(num_features))
35
+ self.register_buffer('running_mean', torch.zeros(num_features))
36
+ self.register_buffer('running_var', torch.ones(num_features))
37
+ self.reset_parameters()
38
+
39
+ def reset_running_stats(self):
40
+ self.running_mean.zero_()
41
+ self.running_var.fill_(1)
42
+
43
+ def reset_parameters(self):
44
+ self.reset_running_stats()
45
+ init.uniform_(self.weight)
46
+ init.zeros_(self.bias)
47
+
48
+ def forward(self, input_):
49
+ batchsize, channels, height, width = input_.size()
50
+ numel = batchsize * height * width
51
+ input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52
+ sum_ = input_.sum(1)
53
+ sum_of_square = input_.pow(2).sum(1)
54
+ mean = sum_ / numel
55
+ sumvar = sum_of_square - sum_ * mean
56
+
57
+ self.running_mean = (
58
+ (1 - self.momentum) * self.running_mean
59
+ + self.momentum * mean.detach()
60
+ )
61
+ unbias_var = sumvar / (numel - 1)
62
+ self.running_var = (
63
+ (1 - self.momentum) * self.running_var
64
+ + self.momentum * unbias_var.detach()
65
+ )
66
+
67
+ bias_var = sumvar / numel
68
+ inv_std = 1 / (bias_var + self.eps).pow(0.5)
69
+ output = (
70
+ (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71
+ self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72
+
73
+ return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74
+
sync_batchnorm/comm.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : comm.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import queue
12
+ import collections
13
+ import threading
14
+
15
+ __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16
+
17
+
18
+ class FutureResult(object):
19
+ """A thread-safe future implementation. Used only as one-to-one pipe."""
20
+
21
+ def __init__(self):
22
+ self._result = None
23
+ self._lock = threading.Lock()
24
+ self._cond = threading.Condition(self._lock)
25
+
26
+ def put(self, result):
27
+ with self._lock:
28
+ assert self._result is None, 'Previous result has\'t been fetched.'
29
+ self._result = result
30
+ self._cond.notify()
31
+
32
+ def get(self):
33
+ with self._lock:
34
+ if self._result is None:
35
+ self._cond.wait()
36
+
37
+ res = self._result
38
+ self._result = None
39
+ return res
40
+
41
+
42
+ _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43
+ _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44
+
45
+
46
+ class SlavePipe(_SlavePipeBase):
47
+ """Pipe for master-slave communication."""
48
+
49
+ def run_slave(self, msg):
50
+ self.queue.put((self.identifier, msg))
51
+ ret = self.result.get()
52
+ self.queue.put(True)
53
+ return ret
54
+
55
+
56
+ class SyncMaster(object):
57
+ """An abstract `SyncMaster` object.
58
+
59
+ - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60
+ call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61
+ - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62
+ and passed to a registered callback.
63
+ - After receiving the messages, the master device should gather the information and determine to message passed
64
+ back to each slave devices.
65
+ """
66
+
67
+ def __init__(self, master_callback):
68
+ """
69
+
70
+ Args:
71
+ master_callback: a callback to be invoked after having collected messages from slave devices.
72
+ """
73
+ self._master_callback = master_callback
74
+ self._queue = queue.Queue()
75
+ self._registry = collections.OrderedDict()
76
+ self._activated = False
77
+
78
+ def __getstate__(self):
79
+ return {'master_callback': self._master_callback}
80
+
81
+ def __setstate__(self, state):
82
+ self.__init__(state['master_callback'])
83
+
84
+ def register_slave(self, identifier):
85
+ """
86
+ Register an slave device.
87
+
88
+ Args:
89
+ identifier: an identifier, usually is the device id.
90
+
91
+ Returns: a `SlavePipe` object which can be used to communicate with the master device.
92
+
93
+ """
94
+ if self._activated:
95
+ assert self._queue.empty(), 'Queue is not clean before next initialization.'
96
+ self._activated = False
97
+ self._registry.clear()
98
+ future = FutureResult()
99
+ self._registry[identifier] = _MasterRegistry(future)
100
+ return SlavePipe(identifier, self._queue, future)
101
+
102
+ def run_master(self, master_msg):
103
+ """
104
+ Main entry for the master device in each forward pass.
105
+ The messages were first collected from each devices (including the master device), and then
106
+ an callback will be invoked to compute the message to be sent back to each devices
107
+ (including the master device).
108
+
109
+ Args:
110
+ master_msg: the message that the master want to send to itself. This will be placed as the first
111
+ message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112
+
113
+ Returns: the message to be sent back to the master device.
114
+
115
+ """
116
+ self._activated = True
117
+
118
+ intermediates = [(0, master_msg)]
119
+ for i in range(self.nr_slaves):
120
+ intermediates.append(self._queue.get())
121
+
122
+ results = self._master_callback(intermediates)
123
+ assert results[0][0] == 0, 'The first result should belongs to the master.'
124
+
125
+ for i, res in results:
126
+ if i == 0:
127
+ continue
128
+ self._registry[i].result.put(res)
129
+
130
+ for i in range(self.nr_slaves):
131
+ assert self._queue.get() is True
132
+
133
+ return results[0][1]
134
+
135
+ @property
136
+ def nr_slaves(self):
137
+ return len(self._registry)
sync_batchnorm/replicate.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : replicate.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import functools
12
+
13
+ from torch.nn.parallel.data_parallel import DataParallel
14
+
15
+ __all__ = [
16
+ 'CallbackContext',
17
+ 'execute_replication_callbacks',
18
+ 'DataParallelWithCallback',
19
+ 'patch_replication_callback'
20
+ ]
21
+
22
+
23
+ class CallbackContext(object):
24
+ pass
25
+
26
+
27
+ def execute_replication_callbacks(modules):
28
+ """
29
+ Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30
+
31
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32
+
33
+ Note that, as all modules are isomorphism, we assign each sub-module with a context
34
+ (shared among multiple copies of this module on different devices).
35
+ Through this context, different copies can share some information.
36
+
37
+ We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38
+ of any slave copies.
39
+ """
40
+ master_copy = modules[0]
41
+ nr_modules = len(list(master_copy.modules()))
42
+ ctxs = [CallbackContext() for _ in range(nr_modules)]
43
+
44
+ for i, module in enumerate(modules):
45
+ for j, m in enumerate(module.modules()):
46
+ if hasattr(m, '__data_parallel_replicate__'):
47
+ m.__data_parallel_replicate__(ctxs[j], i)
48
+
49
+
50
+ class DataParallelWithCallback(DataParallel):
51
+ """
52
+ Data Parallel with a replication callback.
53
+
54
+ An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55
+ original `replicate` function.
56
+ The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57
+
58
+ Examples:
59
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61
+ # sync_bn.__data_parallel_replicate__ will be invoked.
62
+ """
63
+
64
+ def replicate(self, module, device_ids):
65
+ modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66
+ execute_replication_callbacks(modules)
67
+ return modules
68
+
69
+
70
+ def patch_replication_callback(data_parallel):
71
+ """
72
+ Monkey-patch an existing `DataParallel` object. Add the replication callback.
73
+ Useful when you have customized `DataParallel` implementation.
74
+
75
+ Examples:
76
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77
+ > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78
+ > patch_replication_callback(sync_bn)
79
+ # this is equivalent to
80
+ > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81
+ > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82
+ """
83
+
84
+ assert isinstance(data_parallel, DataParallel)
85
+
86
+ old_replicate = data_parallel.replicate
87
+
88
+ @functools.wraps(old_replicate)
89
+ def new_replicate(module, device_ids):
90
+ modules = old_replicate(module, device_ids)
91
+ execute_replication_callbacks(modules)
92
+ return modules
93
+
94
+ data_parallel.replicate = new_replicate
sync_batchnorm/unittest.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # File : unittest.py
3
+ # Author : Jiayuan Mao
4
+ # Email : [email protected]
5
+ # Date : 27/01/2018
6
+ #
7
+ # This file is part of Synchronized-BatchNorm-PyTorch.
8
+ # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9
+ # Distributed under MIT License.
10
+
11
+ import unittest
12
+ import torch
13
+
14
+
15
+ class TorchTestCase(unittest.TestCase):
16
+ def assertTensorClose(self, x, y):
17
+ adiff = float((x - y).abs().max())
18
+ if (y == 0).all():
19
+ rdiff = 'NaN'
20
+ else:
21
+ rdiff = float((adiff / y).abs().max())
22
+
23
+ message = (
24
+ 'Tensor close check failed\n'
25
+ 'adiff={}\n'
26
+ 'rdiff={}\n'
27
+ ).format(adiff, rdiff)
28
+ self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
29
+
test.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from label_colors import colorMap
3
+ from PIL import Image
4
+ from spade.model import Pix2PixModel
5
+ from spade.dataset import get_transform
6
+ from torchvision.transforms import ToPILImage
7
+
8
+ '''colors = np.array([[56, 79, 131], [239, 239, 239],
9
+ [93, 110, 50], [183, 210, 78],
10
+ [60, 59, 75], [250, 250, 250]])'''
11
+ colors = [key['color'] for key in colorMap]
12
+ id_list = [key['id'] for key in colorMap]
13
+
14
+
15
+ def semantic(img):
16
+ print("semantic", type(img))
17
+ h, w = img.size
18
+ imrgb = img.convert("RGB")
19
+ pix = list(imrgb.getdata())
20
+ mask = [id_list[colors.index(i)] if i in colors else 156 for i in pix]
21
+ return np.array(mask).reshape(h, w)
22
+
23
+
24
+ def evaluate(labelmap):
25
+ opt = {
26
+ 'label_nc': 182, # num classes in coco model
27
+ 'crop_size': 512,
28
+ 'load_size': 512,
29
+ 'aspect_ratio': 1.0,
30
+ 'isTrain': False,
31
+ 'checkpoints_dir': 'app',
32
+ 'which_epoch': 'latest',
33
+ 'use_gpu': False
34
+ }
35
+ model = Pix2PixModel(opt)
36
+ model.eval()
37
+ image = Image.fromarray(np.array(labelmap).astype(np.uint8))
38
+ transform_label = get_transform(opt, method=Image.NEAREST, normalize=False)
39
+ # transforms.ToTensor in transform_label rescales image from [0,255] to [0.0,1.0]
40
+ # lets rescale it back to [0,255] to match our label ids
41
+ label_tensor = transform_label(image) * 255.0
42
+ label_tensor[label_tensor == 255] = opt['label_nc'] # 'unknown' is opt.label_nc
43
+ print("label_tensor:", label_tensor.shape)
44
+
45
+ # not using encoder, so creating a blank image...
46
+ transform_image = get_transform(opt)
47
+ image_tensor = transform_image(Image.new('RGB', (500, 500)))
48
+
49
+ data = {
50
+ 'label': label_tensor.unsqueeze(0),
51
+ 'instance': label_tensor.unsqueeze(0),
52
+ 'image': image_tensor.unsqueeze(0)
53
+ }
54
+ generated = model(data, mode='inference')
55
+ print("generated_image:", generated.shape)
56
+
57
+ return generated
58
+
59
+
60
+ def to_image(generated):
61
+ to_img = ToPILImage()
62
+ normalized_img = ((generated.reshape([3, 512, 512]) + 1) / 2.0) * 255.0
63
+ return to_img(normalized_img.byte().cpu())