Spaces:
Runtime error
Runtime error
Shreyz-max
commited on
Commit
•
6672bfb
1
Parent(s):
28fb98f
Add application file
Browse files- app.py +81 -0
- label_colors.py +12 -0
- spade/LICENSE +21 -0
- spade/README.md +118 -0
- spade/Synchronized-BatchNorm-PyTorch +1 -0
- spade/__pycache__/dataset.cpython-310.pyc +0 -0
- spade/__pycache__/dataset.cpython-38.pyc +0 -0
- spade/__pycache__/generator.cpython-310.pyc +0 -0
- spade/__pycache__/generator.cpython-38.pyc +0 -0
- spade/__pycache__/model.cpython-310.pyc +0 -0
- spade/__pycache__/model.cpython-38.pyc +0 -0
- spade/__pycache__/normalizer.cpython-310.pyc +0 -0
- spade/__pycache__/normalizer.cpython-38.pyc +0 -0
- spade/dataset.py +24 -0
- spade/generator.py +113 -0
- spade/model.py +101 -0
- spade/normalizer.py +49 -0
- spade/tests/test_numeric_batchnorm.py +56 -0
- spade/tests/test_numeric_batchnorm_v2.py +62 -0
- spade/tests/test_sync_batchnorm.py +114 -0
- sync_batchnorm/__init__.py +14 -0
- sync_batchnorm/__pycache__/__init__.cpython-310.pyc +0 -0
- sync_batchnorm/__pycache__/__init__.cpython-38.pyc +0 -0
- sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc +0 -0
- sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc +0 -0
- sync_batchnorm/__pycache__/comm.cpython-310.pyc +0 -0
- sync_batchnorm/__pycache__/comm.cpython-38.pyc +0 -0
- sync_batchnorm/__pycache__/replicate.cpython-310.pyc +0 -0
- sync_batchnorm/__pycache__/replicate.cpython-38.pyc +0 -0
- sync_batchnorm/batchnorm.py +412 -0
- sync_batchnorm/batchnorm_reimpl.py +74 -0
- sync_batchnorm/comm.py +137 -0
- sync_batchnorm/replicate.py +94 -0
- sync_batchnorm/unittest.py +29 -0
- test.py +63 -0
app.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from streamlit_drawable_canvas import st_canvas
|
5 |
+
import test
|
6 |
+
from PIL import Image
|
7 |
+
import gdown
|
8 |
+
|
9 |
+
|
10 |
+
st.set_page_config(layout="wide")
|
11 |
+
|
12 |
+
# Specify canvas parameters in application
|
13 |
+
drawing_object = st.sidebar.selectbox(
|
14 |
+
"Object:", ("sea", "cloud", "bush", "grass", "mountain", "sky", "snow",
|
15 |
+
"tree", "flower", "road")
|
16 |
+
)
|
17 |
+
drawing_object_dict = {"sea": "rgb(56,79,131)", "cloud": "rgb(239,239,239)",
|
18 |
+
"bush": "rgb(93,110,50)", "grass": "rgb(183,210,78)",
|
19 |
+
"mountain": "rgb(60,59,75)", "snow": "rgb(250,250,250)",
|
20 |
+
"sky": "rgb(117,158,223)", "tree": "rgb(53, 38, 19)",
|
21 |
+
"flower": "rgb(230,112,182)",
|
22 |
+
"road": "rgb(152, 126, 106)"}
|
23 |
+
|
24 |
+
stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 3)
|
25 |
+
|
26 |
+
stroke_color = drawing_object_dict[drawing_object]
|
27 |
+
|
28 |
+
|
29 |
+
col1, col2 = st.columns(2)
|
30 |
+
with col1:
|
31 |
+
# Create a canvas component with different parameters
|
32 |
+
canvas_result = st_canvas(
|
33 |
+
fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity
|
34 |
+
stroke_width=stroke_width,
|
35 |
+
stroke_color=stroke_color,
|
36 |
+
background_color="rgb(117,158,223)",
|
37 |
+
background_image=None,
|
38 |
+
height=512,
|
39 |
+
width=512,
|
40 |
+
drawing_mode="freedraw",
|
41 |
+
point_display_radius=0,
|
42 |
+
key="canvas",
|
43 |
+
)
|
44 |
+
if canvas_result.image_data is not None:
|
45 |
+
pass
|
46 |
+
|
47 |
+
|
48 |
+
@st.cache
|
49 |
+
def download_model():
|
50 |
+
f_checkpoint = os.path.join("latest_net_G.pth")
|
51 |
+
if not os.path.exists(f_checkpoint):
|
52 |
+
with st.spinner("Downloading model... this may take awhile! \n Don't stop it!"):
|
53 |
+
url = 'https://drive.google.com/uc?id=15VSa2m2F6Ch0NpewDR7mkKAcXlMgDi5F'
|
54 |
+
output = 'latest_net_G.pth'
|
55 |
+
gdown.download(url, output, quiet=False)
|
56 |
+
|
57 |
+
|
58 |
+
if st.button('generate'):
|
59 |
+
download_model()
|
60 |
+
image = Image.fromarray(canvas_result.image_data)
|
61 |
+
s = test.semantic(image)
|
62 |
+
image = test.evaluate(s)
|
63 |
+
image = test.to_image(image)
|
64 |
+
with col2:
|
65 |
+
st.image(image, clamp=True, width=512)
|
66 |
+
|
67 |
+
|
68 |
+
st.markdown(
|
69 |
+
"""
|
70 |
+
<style>
|
71 |
+
[data-testid="stSidebar"][aria-expanded="true"] > div:first-child {
|
72 |
+
width: 120px;
|
73 |
+
}
|
74 |
+
[data-testid="stSidebar"][aria-expanded="false"] > div:first-child {
|
75 |
+
width: 500px;
|
76 |
+
margin-left: -500px;
|
77 |
+
}
|
78 |
+
</style>
|
79 |
+
""",
|
80 |
+
unsafe_allow_html=True,
|
81 |
+
)
|
label_colors.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
colorMap = [
|
2 |
+
{"color": (56, 79, 131), "id": 154, "label": "sea"},
|
3 |
+
{"color": (239, 239, 239), "id": 105, "label": "cloud"},
|
4 |
+
{"color": (93, 110, 50), "id": 96, "label": "bush"},
|
5 |
+
{"color": (183, 210, 78), "id": 123, "label": "grass"},
|
6 |
+
{"color": (60, 59, 75), "id": 134, "label": "mountain"},
|
7 |
+
{"color": (117, 158, 223), "id": 156, "label": "sky"},
|
8 |
+
{"color": (250, 250, 250), "id": 158, "label": "snow"},
|
9 |
+
{"color": (53, 38, 19), "id": 168, "label": "tree"},
|
10 |
+
{"color": (230, 112, 182), "id": 118, "label": "flower"},
|
11 |
+
{"color": (152, 126, 106), "id": 148, "label": "road"}
|
12 |
+
]
|
spade/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2018 Jiayuan MAO
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
spade/README.md
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Synchronized-BatchNorm-PyTorch
|
2 |
+
|
3 |
+
**IMPORTANT: Please read the "Implementation details and highlights" section before use.**
|
4 |
+
|
5 |
+
Synchronized Batch Normalization implementation in PyTorch.
|
6 |
+
|
7 |
+
This module differs from the built-in PyTorch BatchNorm as the mean and
|
8 |
+
standard-deviation are reduced across all devices during training.
|
9 |
+
|
10 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
11 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
12 |
+
the statistics only on that device, which accelerated the computation and
|
13 |
+
is also easy to implement, but the statistics might be inaccurate.
|
14 |
+
Instead, in this synchronized version, the statistics will be computed
|
15 |
+
over all training samples distributed on multiple devices.
|
16 |
+
|
17 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
18 |
+
as the built-in PyTorch implementation.
|
19 |
+
|
20 |
+
This module is currently only a prototype version for research usages. As mentioned below,
|
21 |
+
it has its limitations and may even suffer from some design problems. If you have any
|
22 |
+
questions or suggestions, please feel free to
|
23 |
+
[open an issue](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues) or
|
24 |
+
[submit a pull request](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues).
|
25 |
+
|
26 |
+
## Why Synchronized BatchNorm?
|
27 |
+
|
28 |
+
Although the typical implementation of BatchNorm working on multiple devices (GPUs)
|
29 |
+
is fast (with no communication overhead), it inevitably reduces the size of batch size,
|
30 |
+
which potentially degenerates the performance. This is not a significant issue in some
|
31 |
+
standard vision tasks such as ImageNet classification (as the batch size per device
|
32 |
+
is usually large enough to obtain good statistics). However, it will hurt the performance
|
33 |
+
in some tasks that the batch size is usually very small (e.g., 1 per GPU).
|
34 |
+
|
35 |
+
For example, the importance of synchronized batch normalization in object detection has been recently proved with a
|
36 |
+
an extensive analysis in the paper [MegDet: A Large Mini-Batch Object Detector](https://arxiv.org/abs/1711.07240).
|
37 |
+
|
38 |
+
## Usage
|
39 |
+
|
40 |
+
To use the Synchronized Batch Normalization, we add a data parallel replication callback. This introduces a slight
|
41 |
+
difference with typical usage of the `nn.DataParallel`.
|
42 |
+
|
43 |
+
Use it with a provided, customized data parallel wrapper:
|
44 |
+
|
45 |
+
```python
|
46 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, DataParallelWithCallback
|
47 |
+
|
48 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
49 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
50 |
+
```
|
51 |
+
|
52 |
+
Or, if you are using a customized data parallel module, you can use this library as a monkey patching.
|
53 |
+
|
54 |
+
```python
|
55 |
+
from torch.nn import DataParallel # or your customized DataParallel module
|
56 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, patch_replication_callback
|
57 |
+
|
58 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
59 |
+
sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
60 |
+
patch_replication_callback(sync_bn) # monkey-patching
|
61 |
+
```
|
62 |
+
|
63 |
+
You can use `convert_model` to convert your model to use Synchronized BatchNorm easily.
|
64 |
+
|
65 |
+
```python
|
66 |
+
import torch.nn as nn
|
67 |
+
from torchvision import models
|
68 |
+
from sync_batchnorm import convert_model
|
69 |
+
# m is a standard pytorch model
|
70 |
+
m = models.resnet18(True)
|
71 |
+
m = nn.DataParallel(m)
|
72 |
+
# after convert, m is using SyncBN
|
73 |
+
m = convert_model(m)
|
74 |
+
```
|
75 |
+
|
76 |
+
See also `tests/test_sync_batchnorm.py` for numeric result comparison.
|
77 |
+
|
78 |
+
## Implementation details and highlights
|
79 |
+
|
80 |
+
If you are interested in how batch statistics are reduced and broadcasted among multiple devices, please take a look
|
81 |
+
at the code with detailed comments. Here we only emphasize some highlights of the implementation:
|
82 |
+
|
83 |
+
- This implementation is in pure-python. No C++ extra extension libs.
|
84 |
+
- Easy to use as demonstrated above.
|
85 |
+
- It uses unbiased variance to update the moving average, and use `sqrt(max(var, eps))` instead of `sqrt(var + eps)`.
|
86 |
+
- The implementation requires that each module on different devices should invoke the `batchnorm` for exactly SAME
|
87 |
+
amount of times in each forward pass. For example, you can not only call `batchnorm` on GPU0 but not on GPU1. The `#i
|
88 |
+
(i = 1, 2, 3, ...)` calls of the `batchnorm` on each device will be viewed as a whole and the statistics will be reduced.
|
89 |
+
This is tricky but is a good way to handle PyTorch's dynamic computation graph. Although sounds complicated, this
|
90 |
+
will usually not be the issue for most of the models.
|
91 |
+
|
92 |
+
## Known issues
|
93 |
+
|
94 |
+
#### Runtime error on backward pass.
|
95 |
+
|
96 |
+
Due to a [PyTorch Bug](https://github.com/pytorch/pytorch/issues/3883), using old PyTorch libraries will trigger an `RuntimeError` with messages like:
|
97 |
+
|
98 |
+
```
|
99 |
+
Assertion `pos >= 0 && pos < buffer.size()` failed.
|
100 |
+
```
|
101 |
+
|
102 |
+
This has already been solved in the newest PyTorch repo, which, unfortunately, has not been pushed to the official and anaconda binary release. Thus, you are required to build the PyTorch package from the source according to the
|
103 |
+
instructions [here](https://github.com/pytorch/pytorch#from-source).
|
104 |
+
|
105 |
+
#### Numeric error.
|
106 |
+
|
107 |
+
Because this library does not fuse the normalization and statistics operations in C++ (nor CUDA), it is less
|
108 |
+
numerically stable compared to the original PyTorch implementation. Detailed analysis can be found in
|
109 |
+
`tests/test_sync_batchnorm.py`.
|
110 |
+
|
111 |
+
## Authors and License:
|
112 |
+
|
113 |
+
Copyright (c) 2018-, [Jiayuan Mao](https://vccy.xyz).
|
114 |
+
|
115 |
+
**Contributors**: [Tete Xiao](https://tetexiao.com), [DTennant](https://github.com/DTennant).
|
116 |
+
|
117 |
+
Distributed under **MIT License** (See LICENSE)
|
118 |
+
|
spade/Synchronized-BatchNorm-PyTorch
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit dcfae91cbc3767a3c5cd28d46ab78503a22b0fe7
|
spade/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (1.1 kB). View file
|
|
spade/__pycache__/dataset.cpython-38.pyc
ADDED
Binary file (1.1 kB). View file
|
|
spade/__pycache__/generator.cpython-310.pyc
ADDED
Binary file (3.33 kB). View file
|
|
spade/__pycache__/generator.cpython-38.pyc
ADDED
Binary file (3.31 kB). View file
|
|
spade/__pycache__/model.cpython-310.pyc
ADDED
Binary file (3.52 kB). View file
|
|
spade/__pycache__/model.cpython-38.pyc
ADDED
Binary file (3.48 kB). View file
|
|
spade/__pycache__/normalizer.cpython-310.pyc
ADDED
Binary file (1.47 kB). View file
|
|
spade/__pycache__/normalizer.cpython-38.pyc
ADDED
Binary file (1.47 kB). View file
|
|
spade/dataset.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
def __scale_width(img, target_width, method=Image.BICUBIC):
|
10 |
+
ow, oh = img.size
|
11 |
+
if (ow == target_width):
|
12 |
+
return img
|
13 |
+
w = target_width
|
14 |
+
h = int(target_width * oh / ow)
|
15 |
+
return img.resize((w, h), method)
|
16 |
+
|
17 |
+
def get_transform(opt, method=Image.BICUBIC, normalize=True):
|
18 |
+
transform_list = []
|
19 |
+
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt['load_size'], method)))
|
20 |
+
transform_list += [transforms.ToTensor()]
|
21 |
+
if normalize:
|
22 |
+
transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
23 |
+
|
24 |
+
return transforms.Compose(transform_list)
|
spade/generator.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from spade.normalizer import SPADE
|
11 |
+
|
12 |
+
class SPADEGenerator(nn.Module):
|
13 |
+
def __init__(self, opt):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
# nf: # of gen filters in first conv layer
|
17 |
+
nf = 64
|
18 |
+
|
19 |
+
self.sw, self.sh = self.compute_latent_vector_size(opt['crop_size'], opt['aspect_ratio'])
|
20 |
+
|
21 |
+
self.fc = nn.Conv2d(opt['label_nc'], 16 * nf, 3, padding=1)
|
22 |
+
|
23 |
+
self.head_0 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)
|
24 |
+
|
25 |
+
self.G_middle_0 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)
|
26 |
+
self.G_middle_1 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)
|
27 |
+
|
28 |
+
self.up_0 = SPADEResnetBlock(opt, 16 * nf, 8 * nf)
|
29 |
+
self.up_1 = SPADEResnetBlock(opt, 8 * nf, 4 * nf)
|
30 |
+
self.up_2 = SPADEResnetBlock(opt, 4 * nf, 2 * nf)
|
31 |
+
self.up_3 = SPADEResnetBlock(opt, 2 * nf, 1 * nf)
|
32 |
+
|
33 |
+
self.conv_img = nn.Conv2d(1 * nf, 3, 3, padding=1)
|
34 |
+
|
35 |
+
self.up = nn.Upsample(scale_factor=2)
|
36 |
+
|
37 |
+
def compute_latent_vector_size(self, crop_size, aspect_ratio):
|
38 |
+
num_up_layers = 5
|
39 |
+
|
40 |
+
sw = crop_size // (2**num_up_layers)
|
41 |
+
sh = round(sw / aspect_ratio)
|
42 |
+
|
43 |
+
return sw, sh
|
44 |
+
|
45 |
+
def forward(self, seg):
|
46 |
+
# we downsample segmap and run convolution
|
47 |
+
x = F.interpolate(seg, size=(self.sh, self.sw))
|
48 |
+
x = self.fc(x)
|
49 |
+
|
50 |
+
x = self.head_0(x, seg)
|
51 |
+
|
52 |
+
x = self.up(x)
|
53 |
+
x = self.G_middle_0(x, seg)
|
54 |
+
x = self.G_middle_1(x, seg)
|
55 |
+
|
56 |
+
x = self.up(x)
|
57 |
+
x = self.up_0(x, seg)
|
58 |
+
x = self.up(x)
|
59 |
+
x = self.up_1(x, seg)
|
60 |
+
x = self.up(x)
|
61 |
+
x = self.up_2(x, seg)
|
62 |
+
x = self.up(x)
|
63 |
+
x = self.up_3(x, seg)
|
64 |
+
|
65 |
+
x = self.conv_img(F.leaky_relu(x, 2e-1))
|
66 |
+
x = torch.tanh(x)
|
67 |
+
|
68 |
+
return x
|
69 |
+
|
70 |
+
import torch.nn.utils.spectral_norm as spectral_norm
|
71 |
+
|
72 |
+
# label_nc: the #channels of the input semantic map, hence the input dim of SPADE
|
73 |
+
# label_nc: also equivalent to the # of input label classes
|
74 |
+
class SPADEResnetBlock(nn.Module):
|
75 |
+
def __init__(self, opt, fin, fout):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.learned_shortcut = (fin != fout)
|
79 |
+
fmiddle = min(fin, fout)
|
80 |
+
|
81 |
+
self.conv_0 = spectral_norm(nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1))
|
82 |
+
self.conv_1 = spectral_norm(nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1))
|
83 |
+
if self.learned_shortcut:
|
84 |
+
self.conv_s = spectral_norm(nn.Conv2d(fin, fout, kernel_size=1, bias=False))
|
85 |
+
|
86 |
+
# define normalization layers
|
87 |
+
self.norm_0 = SPADE(opt, fin)
|
88 |
+
self.norm_1 = SPADE(opt, fmiddle)
|
89 |
+
if self.learned_shortcut:
|
90 |
+
self.norm_s = SPADE(opt, fin)
|
91 |
+
|
92 |
+
# note the resnet block with SPADE also takes in |seg|,
|
93 |
+
# the semantic segmentation map as input
|
94 |
+
def forward(self, x, seg):
|
95 |
+
x_s = self.shortcut(x, seg)
|
96 |
+
|
97 |
+
dx = self.conv_0(self.relu(self.norm_0(x, seg)))
|
98 |
+
dx = self.conv_1(self.relu(self.norm_1(dx, seg)))
|
99 |
+
|
100 |
+
out = x_s + dx
|
101 |
+
return out
|
102 |
+
|
103 |
+
def shortcut(self, x, seg):
|
104 |
+
if self.learned_shortcut:
|
105 |
+
x_s = self.conv_s(self.norm_s(x, seg))
|
106 |
+
else:
|
107 |
+
x_s = x
|
108 |
+
return x_s
|
109 |
+
|
110 |
+
def relu(self, x):
|
111 |
+
return F.leaky_relu(x, 2e-1)
|
112 |
+
|
113 |
+
|
spade/model.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
from torch.nn import init
|
9 |
+
|
10 |
+
from spade.generator import SPADEGenerator
|
11 |
+
|
12 |
+
|
13 |
+
class Pix2PixModel(torch.nn.Module):
|
14 |
+
def __init__(self, opt):
|
15 |
+
super().__init__()
|
16 |
+
self.opt = opt
|
17 |
+
self.FloatTensor = torch.cuda.FloatTensor if opt['use_gpu'] \
|
18 |
+
else torch.FloatTensor
|
19 |
+
|
20 |
+
self.netG = self.initialize_networks(opt)
|
21 |
+
|
22 |
+
def forward(self, data, mode):
|
23 |
+
input_semantics, real_image = self.preprocess_input(data)
|
24 |
+
|
25 |
+
if mode == 'inference':
|
26 |
+
with torch.no_grad():
|
27 |
+
fake_image = self.generate_fake(input_semantics)
|
28 |
+
return fake_image
|
29 |
+
else:
|
30 |
+
raise ValueError("|mode| is invalid")
|
31 |
+
|
32 |
+
def preprocess_input(self, data):
|
33 |
+
data['label'] = data['label'].long()
|
34 |
+
|
35 |
+
# move to GPU and change data types
|
36 |
+
if self.opt['use_gpu']:
|
37 |
+
data['label'] = data['label'].cuda()
|
38 |
+
data['instance'] = data['instance'].cuda()
|
39 |
+
data['image'] = data['image'].cuda()
|
40 |
+
|
41 |
+
# create one-hot label map
|
42 |
+
label_map = data['label']
|
43 |
+
bs, _, h, w = label_map.size()
|
44 |
+
input_label = self.FloatTensor(bs, self.opt['label_nc'], h, w).zero_()
|
45 |
+
# one whole label map -> to one label map per class
|
46 |
+
input_semantics = input_label.scatter_(1, label_map, 1.0)
|
47 |
+
|
48 |
+
return input_semantics, data['image']
|
49 |
+
|
50 |
+
def generate_fake(self, input_semantics):
|
51 |
+
fake_image = self.netG(input_semantics)
|
52 |
+
return fake_image
|
53 |
+
|
54 |
+
def create_network(self, cls, opt):
|
55 |
+
net = cls(opt)
|
56 |
+
if self.opt['use_gpu']:
|
57 |
+
net.cuda()
|
58 |
+
|
59 |
+
gain = 0.02
|
60 |
+
|
61 |
+
def init_weights(m):
|
62 |
+
classname = m.__class__.__name__
|
63 |
+
if classname.find('BatchNorm2d') != -1:
|
64 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
65 |
+
init.normal_(m.weight.data, 1.0, gain)
|
66 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
67 |
+
init.constant_(m.bias.data, 0.0)
|
68 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
69 |
+
init.xavier_normal_(m.weight.data, gain=gain)
|
70 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
71 |
+
init.constant_(m.bias.data, 0.0)
|
72 |
+
|
73 |
+
# Applies fn recursively to every submodule (as returned by .children()) as well as self
|
74 |
+
net.apply(init_weights)
|
75 |
+
|
76 |
+
return net
|
77 |
+
|
78 |
+
def load_network(self, net, label, epoch, opt):
|
79 |
+
save_filename = '%s_net_%s.pth' % (epoch, label)
|
80 |
+
save_path = os.path.join( save_filename)
|
81 |
+
weights = torch.load(save_path)
|
82 |
+
net.load_state_dict(weights)
|
83 |
+
return net
|
84 |
+
|
85 |
+
def initialize_networks(self, opt):
|
86 |
+
netG = self.create_network(SPADEGenerator, opt)
|
87 |
+
|
88 |
+
if not opt['isTrain']:
|
89 |
+
netG = self.load_network(netG, 'G', opt['which_epoch'], opt)
|
90 |
+
|
91 |
+
# self.print_network(netG)
|
92 |
+
|
93 |
+
return netG
|
94 |
+
|
95 |
+
def print_network(self, net):
|
96 |
+
num_params = 0
|
97 |
+
for param in net.parameters():
|
98 |
+
num_params += param.numel()
|
99 |
+
print('Network [%s] was created. Total number of parameters: %.1f million. '
|
100 |
+
% (type(net).__name__, num_params / 1000000))
|
101 |
+
print(net)
|
spade/normalizer.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
|
10 |
+
|
11 |
+
# norm_nc: the #channels of the normalized activations, hence the output dim of SPADE
|
12 |
+
# label_nc: the #channels of the input semantic map, hence the input dim of SPADE
|
13 |
+
# label_nc: also equivalent to the # of input label classes
|
14 |
+
class SPADE(nn.Module):
|
15 |
+
def __init__(self, opt, norm_nc):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
|
19 |
+
|
20 |
+
# number of internal filters for generating scale/bias
|
21 |
+
nhidden = 128
|
22 |
+
# size of kernels
|
23 |
+
kernal_size = 3
|
24 |
+
# padding size
|
25 |
+
padding = kernal_size // 2
|
26 |
+
|
27 |
+
self.mlp_shared = nn.Sequential(
|
28 |
+
nn.Conv2d(opt['label_nc'], nhidden, kernel_size=kernal_size, padding=padding),
|
29 |
+
nn.ReLU()
|
30 |
+
)
|
31 |
+
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=kernal_size, padding=padding)
|
32 |
+
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernal_size, padding=padding)
|
33 |
+
|
34 |
+
def forward(self, x, segmap):
|
35 |
+
# Part 1. generate parameter-free normalized activations
|
36 |
+
normalized = self.param_free_norm(x)
|
37 |
+
|
38 |
+
# Part 2. produce scaling and bias conditioned on semantic map
|
39 |
+
# resize input segmentation map to match x.size() using nearest interpolation
|
40 |
+
# N, C, H, W = x.size()
|
41 |
+
segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
|
42 |
+
actv = self.mlp_shared(segmap)
|
43 |
+
gamma = self.mlp_gamma(actv)
|
44 |
+
beta = self.mlp_beta(actv)
|
45 |
+
|
46 |
+
# apply scale and bias
|
47 |
+
out = normalized * (1 + gamma) + beta
|
48 |
+
|
49 |
+
return out
|
spade/tests/test_numeric_batchnorm.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : test_numeric_batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
from sync_batchnorm.unittest import TorchTestCase
|
16 |
+
|
17 |
+
|
18 |
+
def handy_var(a, unbias=True):
|
19 |
+
n = a.size(0)
|
20 |
+
asum = a.sum(dim=0)
|
21 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
22 |
+
sumvar = as_sum - asum * asum / n
|
23 |
+
if unbias:
|
24 |
+
return sumvar / (n - 1)
|
25 |
+
else:
|
26 |
+
return sumvar / n
|
27 |
+
|
28 |
+
|
29 |
+
class NumericTestCase(TorchTestCase):
|
30 |
+
def testNumericBatchNorm(self):
|
31 |
+
a = torch.rand(16, 10)
|
32 |
+
bn = nn.BatchNorm1d(10, momentum=1, eps=1e-5, affine=False)
|
33 |
+
bn.train()
|
34 |
+
|
35 |
+
a_var1 = Variable(a, requires_grad=True)
|
36 |
+
b_var1 = bn(a_var1)
|
37 |
+
loss1 = b_var1.sum()
|
38 |
+
loss1.backward()
|
39 |
+
|
40 |
+
a_var2 = Variable(a, requires_grad=True)
|
41 |
+
a_mean2 = a_var2.mean(dim=0, keepdim=True)
|
42 |
+
a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
|
43 |
+
# a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
|
44 |
+
b_var2 = (a_var2 - a_mean2) / a_std2
|
45 |
+
loss2 = b_var2.sum()
|
46 |
+
loss2.backward()
|
47 |
+
|
48 |
+
self.assertTensorClose(bn.running_mean, a.mean(dim=0))
|
49 |
+
self.assertTensorClose(bn.running_var, handy_var(a))
|
50 |
+
self.assertTensorClose(a_var1.data, a_var2.data)
|
51 |
+
self.assertTensorClose(b_var1.data, b_var2.data)
|
52 |
+
self.assertTensorClose(a_var1.grad, a_var2.grad)
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == '__main__':
|
56 |
+
unittest.main()
|
spade/tests/test_numeric_batchnorm_v2.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : test_numeric_batchnorm_v2.py
|
4 |
+
# Author : Jiayuan Mao
|
5 |
+
# Email : [email protected]
|
6 |
+
# Date : 11/01/2018
|
7 |
+
#
|
8 |
+
# Distributed under terms of the MIT license.
|
9 |
+
|
10 |
+
"""
|
11 |
+
Test the numerical implementation of batch normalization.
|
12 |
+
|
13 |
+
Author: acgtyrant.
|
14 |
+
See also: https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
15 |
+
"""
|
16 |
+
|
17 |
+
import unittest
|
18 |
+
|
19 |
+
import torch
|
20 |
+
import torch.nn as nn
|
21 |
+
import torch.optim as optim
|
22 |
+
|
23 |
+
from sync_batchnorm.unittest import TorchTestCase
|
24 |
+
from sync_batchnorm.batchnorm_reimpl import BatchNorm2dReimpl
|
25 |
+
|
26 |
+
|
27 |
+
class NumericTestCasev2(TorchTestCase):
|
28 |
+
def testNumericBatchNorm(self):
|
29 |
+
CHANNELS = 16
|
30 |
+
batchnorm1 = nn.BatchNorm2d(CHANNELS, momentum=1)
|
31 |
+
optimizer1 = optim.SGD(batchnorm1.parameters(), lr=0.01)
|
32 |
+
|
33 |
+
batchnorm2 = BatchNorm2dReimpl(CHANNELS, momentum=1)
|
34 |
+
batchnorm2.weight.data.copy_(batchnorm1.weight.data)
|
35 |
+
batchnorm2.bias.data.copy_(batchnorm1.bias.data)
|
36 |
+
optimizer2 = optim.SGD(batchnorm2.parameters(), lr=0.01)
|
37 |
+
|
38 |
+
for _ in range(100):
|
39 |
+
input_ = torch.rand(16, CHANNELS, 16, 16)
|
40 |
+
|
41 |
+
input1 = input_.clone().requires_grad_(True)
|
42 |
+
output1 = batchnorm1(input1)
|
43 |
+
output1.sum().backward()
|
44 |
+
optimizer1.step()
|
45 |
+
|
46 |
+
input2 = input_.clone().requires_grad_(True)
|
47 |
+
output2 = batchnorm2(input2)
|
48 |
+
output2.sum().backward()
|
49 |
+
optimizer2.step()
|
50 |
+
|
51 |
+
self.assertTensorClose(input1, input2)
|
52 |
+
self.assertTensorClose(output1, output2)
|
53 |
+
self.assertTensorClose(input1.grad, input2.grad)
|
54 |
+
self.assertTensorClose(batchnorm1.weight.grad, batchnorm2.weight.grad)
|
55 |
+
self.assertTensorClose(batchnorm1.bias.grad, batchnorm2.bias.grad)
|
56 |
+
self.assertTensorClose(batchnorm1.running_mean, batchnorm2.running_mean)
|
57 |
+
self.assertTensorClose(batchnorm2.running_mean, batchnorm2.running_mean)
|
58 |
+
|
59 |
+
|
60 |
+
if __name__ == '__main__':
|
61 |
+
unittest.main()
|
62 |
+
|
spade/tests/test_sync_batchnorm.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : test_sync_batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
|
9 |
+
import unittest
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.autograd import Variable
|
14 |
+
|
15 |
+
from sync_batchnorm import set_sbn_eps_mode
|
16 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
|
17 |
+
from sync_batchnorm.unittest import TorchTestCase
|
18 |
+
|
19 |
+
set_sbn_eps_mode('plus')
|
20 |
+
|
21 |
+
|
22 |
+
def handy_var(a, unbias=True):
|
23 |
+
n = a.size(0)
|
24 |
+
asum = a.sum(dim=0)
|
25 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
26 |
+
sumvar = as_sum - asum * asum / n
|
27 |
+
if unbias:
|
28 |
+
return sumvar / (n - 1)
|
29 |
+
else:
|
30 |
+
return sumvar / n
|
31 |
+
|
32 |
+
|
33 |
+
def _find_bn(module):
|
34 |
+
for m in module.modules():
|
35 |
+
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
|
36 |
+
return m
|
37 |
+
|
38 |
+
|
39 |
+
class SyncTestCase(TorchTestCase):
|
40 |
+
def _syncParameters(self, bn1, bn2):
|
41 |
+
bn1.reset_parameters()
|
42 |
+
bn2.reset_parameters()
|
43 |
+
if bn1.affine and bn2.affine:
|
44 |
+
bn2.weight.data.copy_(bn1.weight.data)
|
45 |
+
bn2.bias.data.copy_(bn1.bias.data)
|
46 |
+
|
47 |
+
def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
|
48 |
+
"""Check the forward and backward for the customized batch normalization."""
|
49 |
+
bn1.train(mode=is_train)
|
50 |
+
bn2.train(mode=is_train)
|
51 |
+
|
52 |
+
if cuda:
|
53 |
+
input = input.cuda()
|
54 |
+
|
55 |
+
self._syncParameters(_find_bn(bn1), _find_bn(bn2))
|
56 |
+
|
57 |
+
input1 = Variable(input, requires_grad=True)
|
58 |
+
output1 = bn1(input1)
|
59 |
+
output1.sum().backward()
|
60 |
+
input2 = Variable(input, requires_grad=True)
|
61 |
+
output2 = bn2(input2)
|
62 |
+
output2.sum().backward()
|
63 |
+
|
64 |
+
self.assertTensorClose(input1.data, input2.data)
|
65 |
+
self.assertTensorClose(output1.data, output2.data)
|
66 |
+
self.assertTensorClose(input1.grad, input2.grad)
|
67 |
+
self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
|
68 |
+
self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
|
69 |
+
|
70 |
+
def testSyncBatchNormNormalTrain(self):
|
71 |
+
bn = nn.BatchNorm1d(10)
|
72 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
73 |
+
|
74 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
|
75 |
+
|
76 |
+
def testSyncBatchNormNormalEval(self):
|
77 |
+
bn = nn.BatchNorm1d(10)
|
78 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
79 |
+
|
80 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
|
81 |
+
|
82 |
+
def testSyncBatchNormSyncTrain(self):
|
83 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
84 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
85 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
86 |
+
|
87 |
+
bn.cuda()
|
88 |
+
sync_bn.cuda()
|
89 |
+
|
90 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
|
91 |
+
|
92 |
+
def testSyncBatchNormSyncEval(self):
|
93 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
94 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
95 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
96 |
+
|
97 |
+
bn.cuda()
|
98 |
+
sync_bn.cuda()
|
99 |
+
|
100 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
|
101 |
+
|
102 |
+
def testSyncBatchNorm2DSyncTrain(self):
|
103 |
+
bn = nn.BatchNorm2d(10)
|
104 |
+
sync_bn = SynchronizedBatchNorm2d(10)
|
105 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
106 |
+
|
107 |
+
bn.cuda()
|
108 |
+
sync_bn.cuda()
|
109 |
+
|
110 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == '__main__':
|
114 |
+
unittest.main()
|
sync_batchnorm/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : __init__.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
from .batchnorm import set_sbn_eps_mode
|
12 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
13 |
+
from .batchnorm import patch_sync_batchnorm, convert_model
|
14 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
sync_batchnorm/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (475 Bytes). View file
|
|
sync_batchnorm/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (473 Bytes). View file
|
|
sync_batchnorm/__pycache__/batchnorm.cpython-310.pyc
ADDED
Binary file (15.2 kB). View file
|
|
sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc
ADDED
Binary file (15.3 kB). View file
|
|
sync_batchnorm/__pycache__/comm.cpython-310.pyc
ADDED
Binary file (4.84 kB). View file
|
|
sync_batchnorm/__pycache__/comm.cpython-38.pyc
ADDED
Binary file (4.8 kB). View file
|
|
sync_batchnorm/__pycache__/replicate.cpython-310.pyc
ADDED
Binary file (3.46 kB). View file
|
|
sync_batchnorm/__pycache__/replicate.cpython-38.pyc
ADDED
Binary file (3.45 kB). View file
|
|
sync_batchnorm/batchnorm.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : batchnorm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import collections
|
12 |
+
import contextlib
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
|
17 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
18 |
+
|
19 |
+
try:
|
20 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
21 |
+
except ImportError:
|
22 |
+
ReduceAddCoalesced = Broadcast = None
|
23 |
+
|
24 |
+
try:
|
25 |
+
from jactorch.parallel.comm import SyncMaster
|
26 |
+
from jactorch.parallel.data_parallel import JacDataParallel as DataParallelWithCallback
|
27 |
+
except ImportError:
|
28 |
+
from .comm import SyncMaster
|
29 |
+
from .replicate import DataParallelWithCallback
|
30 |
+
|
31 |
+
__all__ = [
|
32 |
+
'set_sbn_eps_mode',
|
33 |
+
'SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d',
|
34 |
+
'patch_sync_batchnorm', 'convert_model'
|
35 |
+
]
|
36 |
+
|
37 |
+
|
38 |
+
SBN_EPS_MODE = 'clamp'
|
39 |
+
|
40 |
+
|
41 |
+
def set_sbn_eps_mode(mode):
|
42 |
+
global SBN_EPS_MODE
|
43 |
+
assert mode in ('clamp', 'plus')
|
44 |
+
SBN_EPS_MODE = mode
|
45 |
+
|
46 |
+
|
47 |
+
def _sum_ft(tensor):
|
48 |
+
"""sum over the first and last dimention"""
|
49 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
50 |
+
|
51 |
+
|
52 |
+
def _unsqueeze_ft(tensor):
|
53 |
+
"""add new dimensions at the front and the tail"""
|
54 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
55 |
+
|
56 |
+
|
57 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
58 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
59 |
+
|
60 |
+
|
61 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
62 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
|
63 |
+
assert ReduceAddCoalesced is not None, 'Can not use Synchronized Batch Normalization without CUDA support.'
|
64 |
+
|
65 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine,
|
66 |
+
track_running_stats=track_running_stats)
|
67 |
+
|
68 |
+
if not self.track_running_stats:
|
69 |
+
import warnings
|
70 |
+
warnings.warn('track_running_stats=False is not supported by the SynchronizedBatchNorm.')
|
71 |
+
|
72 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
73 |
+
|
74 |
+
self._is_parallel = False
|
75 |
+
self._parallel_id = None
|
76 |
+
self._slave_pipe = None
|
77 |
+
|
78 |
+
def forward(self, input):
|
79 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
80 |
+
if not (self._is_parallel and self.training):
|
81 |
+
return F.batch_norm(
|
82 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
83 |
+
self.training, self.momentum, self.eps)
|
84 |
+
|
85 |
+
# Resize the input to (B, C, -1).
|
86 |
+
input_shape = input.size()
|
87 |
+
assert input.size(1) == self.num_features, 'Channel size mismatch: got {}, expect {}.'.format(input.size(1), self.num_features)
|
88 |
+
input = input.view(input.size(0), self.num_features, -1)
|
89 |
+
|
90 |
+
# Compute the sum and square-sum.
|
91 |
+
sum_size = input.size(0) * input.size(2)
|
92 |
+
input_sum = _sum_ft(input)
|
93 |
+
input_ssum = _sum_ft(input ** 2)
|
94 |
+
|
95 |
+
# Reduce-and-broadcast the statistics.
|
96 |
+
if self._parallel_id == 0:
|
97 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
98 |
+
else:
|
99 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
100 |
+
|
101 |
+
# Compute the output.
|
102 |
+
if self.affine:
|
103 |
+
# MJY:: Fuse the multiplication for speed.
|
104 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
105 |
+
else:
|
106 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
107 |
+
|
108 |
+
# Reshape it.
|
109 |
+
return output.view(input_shape)
|
110 |
+
|
111 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
112 |
+
self._is_parallel = True
|
113 |
+
self._parallel_id = copy_id
|
114 |
+
|
115 |
+
# parallel_id == 0 means master device.
|
116 |
+
if self._parallel_id == 0:
|
117 |
+
ctx.sync_master = self._sync_master
|
118 |
+
else:
|
119 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
120 |
+
|
121 |
+
def _data_parallel_master(self, intermediates):
|
122 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
123 |
+
|
124 |
+
# Always using same "device order" makes the ReduceAdd operation faster.
|
125 |
+
# Thanks to:: Tete Xiao (http://tetexiao.com/)
|
126 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
127 |
+
|
128 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
129 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
130 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
131 |
+
|
132 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
133 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
134 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
135 |
+
|
136 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
137 |
+
|
138 |
+
outputs = []
|
139 |
+
for i, rec in enumerate(intermediates):
|
140 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
141 |
+
|
142 |
+
return outputs
|
143 |
+
|
144 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
145 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
146 |
+
also maintains the moving average on the master device."""
|
147 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
148 |
+
mean = sum_ / size
|
149 |
+
sumvar = ssum - sum_ * mean
|
150 |
+
unbias_var = sumvar / (size - 1)
|
151 |
+
bias_var = sumvar / size
|
152 |
+
|
153 |
+
if hasattr(torch, 'no_grad'):
|
154 |
+
with torch.no_grad():
|
155 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
156 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
157 |
+
else:
|
158 |
+
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
|
159 |
+
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
|
160 |
+
|
161 |
+
if SBN_EPS_MODE == 'clamp':
|
162 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
163 |
+
elif SBN_EPS_MODE == 'plus':
|
164 |
+
return mean, (bias_var + self.eps) ** -0.5
|
165 |
+
else:
|
166 |
+
raise ValueError('Unknown EPS mode: {}.'.format(SBN_EPS_MODE))
|
167 |
+
|
168 |
+
|
169 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
170 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
171 |
+
mini-batch.
|
172 |
+
|
173 |
+
.. math::
|
174 |
+
|
175 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
176 |
+
|
177 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
178 |
+
standard-deviation are reduced across all devices during training.
|
179 |
+
|
180 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
181 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
182 |
+
the statistics only on that device, which accelerated the computation and
|
183 |
+
is also easy to implement, but the statistics might be inaccurate.
|
184 |
+
Instead, in this synchronized version, the statistics will be computed
|
185 |
+
over all training samples distributed on multiple devices.
|
186 |
+
|
187 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
188 |
+
as the built-in PyTorch implementation.
|
189 |
+
|
190 |
+
The mean and standard-deviation are calculated per-dimension over
|
191 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
192 |
+
of size C (where C is the input size).
|
193 |
+
|
194 |
+
During training, this layer keeps a running estimate of its computed mean
|
195 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
196 |
+
|
197 |
+
During evaluation, this running mean/variance is used for normalization.
|
198 |
+
|
199 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
200 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
201 |
+
|
202 |
+
Args:
|
203 |
+
num_features: num_features from an expected input of size
|
204 |
+
`batch_size x num_features [x width]`
|
205 |
+
eps: a value added to the denominator for numerical stability.
|
206 |
+
Default: 1e-5
|
207 |
+
momentum: the value used for the running_mean and running_var
|
208 |
+
computation. Default: 0.1
|
209 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
210 |
+
affine parameters. Default: ``True``
|
211 |
+
|
212 |
+
Shape::
|
213 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
214 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
215 |
+
|
216 |
+
Examples:
|
217 |
+
>>> # With Learnable Parameters
|
218 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
219 |
+
>>> # Without Learnable Parameters
|
220 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
221 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
222 |
+
>>> output = m(input)
|
223 |
+
"""
|
224 |
+
|
225 |
+
def _check_input_dim(self, input):
|
226 |
+
if input.dim() != 2 and input.dim() != 3:
|
227 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
228 |
+
.format(input.dim()))
|
229 |
+
|
230 |
+
|
231 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
232 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
233 |
+
of 3d inputs
|
234 |
+
|
235 |
+
.. math::
|
236 |
+
|
237 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
238 |
+
|
239 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
240 |
+
standard-deviation are reduced across all devices during training.
|
241 |
+
|
242 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
243 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
244 |
+
the statistics only on that device, which accelerated the computation and
|
245 |
+
is also easy to implement, but the statistics might be inaccurate.
|
246 |
+
Instead, in this synchronized version, the statistics will be computed
|
247 |
+
over all training samples distributed on multiple devices.
|
248 |
+
|
249 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
250 |
+
as the built-in PyTorch implementation.
|
251 |
+
|
252 |
+
The mean and standard-deviation are calculated per-dimension over
|
253 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
254 |
+
of size C (where C is the input size).
|
255 |
+
|
256 |
+
During training, this layer keeps a running estimate of its computed mean
|
257 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
258 |
+
|
259 |
+
During evaluation, this running mean/variance is used for normalization.
|
260 |
+
|
261 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
262 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
263 |
+
|
264 |
+
Args:
|
265 |
+
num_features: num_features from an expected input of
|
266 |
+
size batch_size x num_features x height x width
|
267 |
+
eps: a value added to the denominator for numerical stability.
|
268 |
+
Default: 1e-5
|
269 |
+
momentum: the value used for the running_mean and running_var
|
270 |
+
computation. Default: 0.1
|
271 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
272 |
+
affine parameters. Default: ``True``
|
273 |
+
|
274 |
+
Shape::
|
275 |
+
- Input: :math:`(N, C, H, W)`
|
276 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
277 |
+
|
278 |
+
Examples:
|
279 |
+
>>> # With Learnable Parameters
|
280 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
281 |
+
>>> # Without Learnable Parameters
|
282 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
283 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
284 |
+
>>> output = m(input)
|
285 |
+
"""
|
286 |
+
|
287 |
+
def _check_input_dim(self, input):
|
288 |
+
if input.dim() != 4:
|
289 |
+
raise ValueError('expected 4D input (got {}D input)'
|
290 |
+
.format(input.dim()))
|
291 |
+
|
292 |
+
|
293 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
294 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
295 |
+
of 4d inputs
|
296 |
+
|
297 |
+
.. math::
|
298 |
+
|
299 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
300 |
+
|
301 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
302 |
+
standard-deviation are reduced across all devices during training.
|
303 |
+
|
304 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
305 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
306 |
+
the statistics only on that device, which accelerated the computation and
|
307 |
+
is also easy to implement, but the statistics might be inaccurate.
|
308 |
+
Instead, in this synchronized version, the statistics will be computed
|
309 |
+
over all training samples distributed on multiple devices.
|
310 |
+
|
311 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
312 |
+
as the built-in PyTorch implementation.
|
313 |
+
|
314 |
+
The mean and standard-deviation are calculated per-dimension over
|
315 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
316 |
+
of size C (where C is the input size).
|
317 |
+
|
318 |
+
During training, this layer keeps a running estimate of its computed mean
|
319 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
320 |
+
|
321 |
+
During evaluation, this running mean/variance is used for normalization.
|
322 |
+
|
323 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
324 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
325 |
+
or Spatio-temporal BatchNorm
|
326 |
+
|
327 |
+
Args:
|
328 |
+
num_features: num_features from an expected input of
|
329 |
+
size batch_size x num_features x depth x height x width
|
330 |
+
eps: a value added to the denominator for numerical stability.
|
331 |
+
Default: 1e-5
|
332 |
+
momentum: the value used for the running_mean and running_var
|
333 |
+
computation. Default: 0.1
|
334 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
335 |
+
affine parameters. Default: ``True``
|
336 |
+
|
337 |
+
Shape::
|
338 |
+
- Input: :math:`(N, C, D, H, W)`
|
339 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
340 |
+
|
341 |
+
Examples:
|
342 |
+
>>> # With Learnable Parameters
|
343 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
344 |
+
>>> # Without Learnable Parameters
|
345 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
346 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
347 |
+
>>> output = m(input)
|
348 |
+
"""
|
349 |
+
|
350 |
+
def _check_input_dim(self, input):
|
351 |
+
if input.dim() != 5:
|
352 |
+
raise ValueError('expected 5D input (got {}D input)'
|
353 |
+
.format(input.dim()))
|
354 |
+
|
355 |
+
|
356 |
+
@contextlib.contextmanager
|
357 |
+
def patch_sync_batchnorm():
|
358 |
+
import torch.nn as nn
|
359 |
+
|
360 |
+
backup = nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d
|
361 |
+
|
362 |
+
nn.BatchNorm1d = SynchronizedBatchNorm1d
|
363 |
+
nn.BatchNorm2d = SynchronizedBatchNorm2d
|
364 |
+
nn.BatchNorm3d = SynchronizedBatchNorm3d
|
365 |
+
|
366 |
+
yield
|
367 |
+
|
368 |
+
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d = backup
|
369 |
+
|
370 |
+
|
371 |
+
def convert_model(module):
|
372 |
+
"""Traverse the input module and its child recursively
|
373 |
+
and replace all instance of torch.nn.modules.batchnorm.BatchNorm*N*d
|
374 |
+
to SynchronizedBatchNorm*N*d
|
375 |
+
|
376 |
+
Args:
|
377 |
+
module: the input module needs to be convert to SyncBN model
|
378 |
+
|
379 |
+
Examples:
|
380 |
+
>>> import torch.nn as nn
|
381 |
+
>>> import torchvision
|
382 |
+
>>> # m is a standard pytorch model
|
383 |
+
>>> m = torchvision.models.resnet18(True)
|
384 |
+
>>> m = nn.DataParallel(m)
|
385 |
+
>>> # after convert, m is using SyncBN
|
386 |
+
>>> m = convert_model(m)
|
387 |
+
"""
|
388 |
+
if isinstance(module, torch.nn.DataParallel):
|
389 |
+
mod = module.module
|
390 |
+
mod = convert_model(mod)
|
391 |
+
mod = DataParallelWithCallback(mod, device_ids=module.device_ids)
|
392 |
+
return mod
|
393 |
+
|
394 |
+
mod = module
|
395 |
+
for pth_module, sync_module in zip([torch.nn.modules.batchnorm.BatchNorm1d,
|
396 |
+
torch.nn.modules.batchnorm.BatchNorm2d,
|
397 |
+
torch.nn.modules.batchnorm.BatchNorm3d],
|
398 |
+
[SynchronizedBatchNorm1d,
|
399 |
+
SynchronizedBatchNorm2d,
|
400 |
+
SynchronizedBatchNorm3d]):
|
401 |
+
if isinstance(module, pth_module):
|
402 |
+
mod = sync_module(module.num_features, module.eps, module.momentum, module.affine)
|
403 |
+
mod.running_mean = module.running_mean
|
404 |
+
mod.running_var = module.running_var
|
405 |
+
if module.affine:
|
406 |
+
mod.weight.data = module.weight.data.clone().detach()
|
407 |
+
mod.bias.data = module.bias.data.clone().detach()
|
408 |
+
|
409 |
+
for name, child in module.named_children():
|
410 |
+
mod.add_module(name, convert_model(child))
|
411 |
+
|
412 |
+
return mod
|
sync_batchnorm/batchnorm_reimpl.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! /usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
# File : batchnorm_reimpl.py
|
4 |
+
# Author : acgtyrant
|
5 |
+
# Date : 11/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.nn.init as init
|
14 |
+
|
15 |
+
__all__ = ['BatchNorm2dReimpl']
|
16 |
+
|
17 |
+
|
18 |
+
class BatchNorm2dReimpl(nn.Module):
|
19 |
+
"""
|
20 |
+
A re-implementation of batch normalization, used for testing the numerical
|
21 |
+
stability.
|
22 |
+
|
23 |
+
Author: acgtyrant
|
24 |
+
See also:
|
25 |
+
https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
|
26 |
+
"""
|
27 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
28 |
+
super().__init__()
|
29 |
+
|
30 |
+
self.num_features = num_features
|
31 |
+
self.eps = eps
|
32 |
+
self.momentum = momentum
|
33 |
+
self.weight = nn.Parameter(torch.empty(num_features))
|
34 |
+
self.bias = nn.Parameter(torch.empty(num_features))
|
35 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
36 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
37 |
+
self.reset_parameters()
|
38 |
+
|
39 |
+
def reset_running_stats(self):
|
40 |
+
self.running_mean.zero_()
|
41 |
+
self.running_var.fill_(1)
|
42 |
+
|
43 |
+
def reset_parameters(self):
|
44 |
+
self.reset_running_stats()
|
45 |
+
init.uniform_(self.weight)
|
46 |
+
init.zeros_(self.bias)
|
47 |
+
|
48 |
+
def forward(self, input_):
|
49 |
+
batchsize, channels, height, width = input_.size()
|
50 |
+
numel = batchsize * height * width
|
51 |
+
input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
|
52 |
+
sum_ = input_.sum(1)
|
53 |
+
sum_of_square = input_.pow(2).sum(1)
|
54 |
+
mean = sum_ / numel
|
55 |
+
sumvar = sum_of_square - sum_ * mean
|
56 |
+
|
57 |
+
self.running_mean = (
|
58 |
+
(1 - self.momentum) * self.running_mean
|
59 |
+
+ self.momentum * mean.detach()
|
60 |
+
)
|
61 |
+
unbias_var = sumvar / (numel - 1)
|
62 |
+
self.running_var = (
|
63 |
+
(1 - self.momentum) * self.running_var
|
64 |
+
+ self.momentum * unbias_var.detach()
|
65 |
+
)
|
66 |
+
|
67 |
+
bias_var = sumvar / numel
|
68 |
+
inv_std = 1 / (bias_var + self.eps).pow(0.5)
|
69 |
+
output = (
|
70 |
+
(input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
|
71 |
+
self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
|
72 |
+
|
73 |
+
return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
|
74 |
+
|
sync_batchnorm/comm.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : comm.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import queue
|
12 |
+
import collections
|
13 |
+
import threading
|
14 |
+
|
15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
16 |
+
|
17 |
+
|
18 |
+
class FutureResult(object):
|
19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
self._result = None
|
23 |
+
self._lock = threading.Lock()
|
24 |
+
self._cond = threading.Condition(self._lock)
|
25 |
+
|
26 |
+
def put(self, result):
|
27 |
+
with self._lock:
|
28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
29 |
+
self._result = result
|
30 |
+
self._cond.notify()
|
31 |
+
|
32 |
+
def get(self):
|
33 |
+
with self._lock:
|
34 |
+
if self._result is None:
|
35 |
+
self._cond.wait()
|
36 |
+
|
37 |
+
res = self._result
|
38 |
+
self._result = None
|
39 |
+
return res
|
40 |
+
|
41 |
+
|
42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
44 |
+
|
45 |
+
|
46 |
+
class SlavePipe(_SlavePipeBase):
|
47 |
+
"""Pipe for master-slave communication."""
|
48 |
+
|
49 |
+
def run_slave(self, msg):
|
50 |
+
self.queue.put((self.identifier, msg))
|
51 |
+
ret = self.result.get()
|
52 |
+
self.queue.put(True)
|
53 |
+
return ret
|
54 |
+
|
55 |
+
|
56 |
+
class SyncMaster(object):
|
57 |
+
"""An abstract `SyncMaster` object.
|
58 |
+
|
59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
62 |
+
and passed to a registered callback.
|
63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
64 |
+
back to each slave devices.
|
65 |
+
"""
|
66 |
+
|
67 |
+
def __init__(self, master_callback):
|
68 |
+
"""
|
69 |
+
|
70 |
+
Args:
|
71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
72 |
+
"""
|
73 |
+
self._master_callback = master_callback
|
74 |
+
self._queue = queue.Queue()
|
75 |
+
self._registry = collections.OrderedDict()
|
76 |
+
self._activated = False
|
77 |
+
|
78 |
+
def __getstate__(self):
|
79 |
+
return {'master_callback': self._master_callback}
|
80 |
+
|
81 |
+
def __setstate__(self, state):
|
82 |
+
self.__init__(state['master_callback'])
|
83 |
+
|
84 |
+
def register_slave(self, identifier):
|
85 |
+
"""
|
86 |
+
Register an slave device.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
identifier: an identifier, usually is the device id.
|
90 |
+
|
91 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
92 |
+
|
93 |
+
"""
|
94 |
+
if self._activated:
|
95 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
96 |
+
self._activated = False
|
97 |
+
self._registry.clear()
|
98 |
+
future = FutureResult()
|
99 |
+
self._registry[identifier] = _MasterRegistry(future)
|
100 |
+
return SlavePipe(identifier, self._queue, future)
|
101 |
+
|
102 |
+
def run_master(self, master_msg):
|
103 |
+
"""
|
104 |
+
Main entry for the master device in each forward pass.
|
105 |
+
The messages were first collected from each devices (including the master device), and then
|
106 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
107 |
+
(including the master device).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
111 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
112 |
+
|
113 |
+
Returns: the message to be sent back to the master device.
|
114 |
+
|
115 |
+
"""
|
116 |
+
self._activated = True
|
117 |
+
|
118 |
+
intermediates = [(0, master_msg)]
|
119 |
+
for i in range(self.nr_slaves):
|
120 |
+
intermediates.append(self._queue.get())
|
121 |
+
|
122 |
+
results = self._master_callback(intermediates)
|
123 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
124 |
+
|
125 |
+
for i, res in results:
|
126 |
+
if i == 0:
|
127 |
+
continue
|
128 |
+
self._registry[i].result.put(res)
|
129 |
+
|
130 |
+
for i in range(self.nr_slaves):
|
131 |
+
assert self._queue.get() is True
|
132 |
+
|
133 |
+
return results[0][1]
|
134 |
+
|
135 |
+
@property
|
136 |
+
def nr_slaves(self):
|
137 |
+
return len(self._registry)
|
sync_batchnorm/replicate.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : replicate.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import functools
|
12 |
+
|
13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'CallbackContext',
|
17 |
+
'execute_replication_callbacks',
|
18 |
+
'DataParallelWithCallback',
|
19 |
+
'patch_replication_callback'
|
20 |
+
]
|
21 |
+
|
22 |
+
|
23 |
+
class CallbackContext(object):
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
def execute_replication_callbacks(modules):
|
28 |
+
"""
|
29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
30 |
+
|
31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
32 |
+
|
33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
34 |
+
(shared among multiple copies of this module on different devices).
|
35 |
+
Through this context, different copies can share some information.
|
36 |
+
|
37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
38 |
+
of any slave copies.
|
39 |
+
"""
|
40 |
+
master_copy = modules[0]
|
41 |
+
nr_modules = len(list(master_copy.modules()))
|
42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
43 |
+
|
44 |
+
for i, module in enumerate(modules):
|
45 |
+
for j, m in enumerate(module.modules()):
|
46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
48 |
+
|
49 |
+
|
50 |
+
class DataParallelWithCallback(DataParallel):
|
51 |
+
"""
|
52 |
+
Data Parallel with a replication callback.
|
53 |
+
|
54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
55 |
+
original `replicate` function.
|
56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
57 |
+
|
58 |
+
Examples:
|
59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
62 |
+
"""
|
63 |
+
|
64 |
+
def replicate(self, module, device_ids):
|
65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
66 |
+
execute_replication_callbacks(modules)
|
67 |
+
return modules
|
68 |
+
|
69 |
+
|
70 |
+
def patch_replication_callback(data_parallel):
|
71 |
+
"""
|
72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
73 |
+
Useful when you have customized `DataParallel` implementation.
|
74 |
+
|
75 |
+
Examples:
|
76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
78 |
+
> patch_replication_callback(sync_bn)
|
79 |
+
# this is equivalent to
|
80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
82 |
+
"""
|
83 |
+
|
84 |
+
assert isinstance(data_parallel, DataParallel)
|
85 |
+
|
86 |
+
old_replicate = data_parallel.replicate
|
87 |
+
|
88 |
+
@functools.wraps(old_replicate)
|
89 |
+
def new_replicate(module, device_ids):
|
90 |
+
modules = old_replicate(module, device_ids)
|
91 |
+
execute_replication_callbacks(modules)
|
92 |
+
return modules
|
93 |
+
|
94 |
+
data_parallel.replicate = new_replicate
|
sync_batchnorm/unittest.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# File : unittest.py
|
3 |
+
# Author : Jiayuan Mao
|
4 |
+
# Email : [email protected]
|
5 |
+
# Date : 27/01/2018
|
6 |
+
#
|
7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
9 |
+
# Distributed under MIT License.
|
10 |
+
|
11 |
+
import unittest
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
class TorchTestCase(unittest.TestCase):
|
16 |
+
def assertTensorClose(self, x, y):
|
17 |
+
adiff = float((x - y).abs().max())
|
18 |
+
if (y == 0).all():
|
19 |
+
rdiff = 'NaN'
|
20 |
+
else:
|
21 |
+
rdiff = float((adiff / y).abs().max())
|
22 |
+
|
23 |
+
message = (
|
24 |
+
'Tensor close check failed\n'
|
25 |
+
'adiff={}\n'
|
26 |
+
'rdiff={}\n'
|
27 |
+
).format(adiff, rdiff)
|
28 |
+
self.assertTrue(torch.allclose(x, y, atol=1e-5, rtol=1e-3), message)
|
29 |
+
|
test.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from label_colors import colorMap
|
3 |
+
from PIL import Image
|
4 |
+
from spade.model import Pix2PixModel
|
5 |
+
from spade.dataset import get_transform
|
6 |
+
from torchvision.transforms import ToPILImage
|
7 |
+
|
8 |
+
'''colors = np.array([[56, 79, 131], [239, 239, 239],
|
9 |
+
[93, 110, 50], [183, 210, 78],
|
10 |
+
[60, 59, 75], [250, 250, 250]])'''
|
11 |
+
colors = [key['color'] for key in colorMap]
|
12 |
+
id_list = [key['id'] for key in colorMap]
|
13 |
+
|
14 |
+
|
15 |
+
def semantic(img):
|
16 |
+
print("semantic", type(img))
|
17 |
+
h, w = img.size
|
18 |
+
imrgb = img.convert("RGB")
|
19 |
+
pix = list(imrgb.getdata())
|
20 |
+
mask = [id_list[colors.index(i)] if i in colors else 156 for i in pix]
|
21 |
+
return np.array(mask).reshape(h, w)
|
22 |
+
|
23 |
+
|
24 |
+
def evaluate(labelmap):
|
25 |
+
opt = {
|
26 |
+
'label_nc': 182, # num classes in coco model
|
27 |
+
'crop_size': 512,
|
28 |
+
'load_size': 512,
|
29 |
+
'aspect_ratio': 1.0,
|
30 |
+
'isTrain': False,
|
31 |
+
'checkpoints_dir': 'app',
|
32 |
+
'which_epoch': 'latest',
|
33 |
+
'use_gpu': False
|
34 |
+
}
|
35 |
+
model = Pix2PixModel(opt)
|
36 |
+
model.eval()
|
37 |
+
image = Image.fromarray(np.array(labelmap).astype(np.uint8))
|
38 |
+
transform_label = get_transform(opt, method=Image.NEAREST, normalize=False)
|
39 |
+
# transforms.ToTensor in transform_label rescales image from [0,255] to [0.0,1.0]
|
40 |
+
# lets rescale it back to [0,255] to match our label ids
|
41 |
+
label_tensor = transform_label(image) * 255.0
|
42 |
+
label_tensor[label_tensor == 255] = opt['label_nc'] # 'unknown' is opt.label_nc
|
43 |
+
print("label_tensor:", label_tensor.shape)
|
44 |
+
|
45 |
+
# not using encoder, so creating a blank image...
|
46 |
+
transform_image = get_transform(opt)
|
47 |
+
image_tensor = transform_image(Image.new('RGB', (500, 500)))
|
48 |
+
|
49 |
+
data = {
|
50 |
+
'label': label_tensor.unsqueeze(0),
|
51 |
+
'instance': label_tensor.unsqueeze(0),
|
52 |
+
'image': image_tensor.unsqueeze(0)
|
53 |
+
}
|
54 |
+
generated = model(data, mode='inference')
|
55 |
+
print("generated_image:", generated.shape)
|
56 |
+
|
57 |
+
return generated
|
58 |
+
|
59 |
+
|
60 |
+
def to_image(generated):
|
61 |
+
to_img = ToPILImage()
|
62 |
+
normalized_img = ((generated.reshape([3, 512, 512]) + 1) / 2.0) * 255.0
|
63 |
+
return to_img(normalized_img.byte().cpu())
|