udion commited on
Commit
5bd623f
1 Parent(s): f8c9ab8

BayesCap demo to EuroCrypt

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,26 @@
1
  ---
2
  title: BayesCap
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.0.24
8
  app_file: app.py
9
  pinned: false
10
- license: cc
11
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  ---
2
  title: BayesCap
3
+ emoji: 🔥
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
 
7
  app_file: app.py
8
  pinned: false
 
9
  ---
10
+ # Configuration
11
+ `title`: _string_
12
+ Display title for the Space
13
+ `emoji`: _string_
14
+ Space emoji (emoji-only character allowed)
15
+ `colorFrom`: _string_
16
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
17
+ `colorTo`: _string_
18
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
19
+ `sdk`: _string_
20
+ Can be either `gradio` or `streamlit`
21
+ `app_file`: _string_
22
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
23
+ Path is relative to the root of the repository.
24
 
25
+ `pinned`: _boolean_
26
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib import cm
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.models as models
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode as IMode
13
+
14
+ from PIL import Image
15
+
16
+ from ds import *
17
+ from losses import *
18
+ from networks_SRGAN import *
19
+ from utils import *
20
+
21
+ device = 'cpu'
22
+ if device == 'cuda':
23
+ dtype = torch.cuda.FloatTensor
24
+ else:
25
+ dtype = torch.FloatTensor
26
+
27
+ NetG = Generator()
28
+ model_parameters = filter(lambda p: True, NetG.parameters())
29
+ params = sum([np.prod(p.size()) for p in model_parameters])
30
+ print("Number of Parameters:", params)
31
+ NetC = BayesCap(in_channels=3, out_channels=3)
32
+
33
+ ensure_checkpoint_exists('BayesCap_SRGAN.pth')
34
+ NetG.load_state_dict(torch.load('BayesCap_SRGAN.pth', map_location=device))
35
+ NetG.to(device)
36
+ NetG.eval()
37
+
38
+ ensure_checkpoint_exists('BayesCap_ckpt.pth')
39
+ NetC.load_state_dict(torch.load('BayesCap_ckpt.pth', map_location=device))
40
+ NetC.to(device)
41
+ NetC.eval()
42
+
43
+ def tensor01_to_pil(xt):
44
+ r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
45
+ return r
46
+
47
+ def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
48
+ """Convert ``PIL.Image`` to Tensor.
49
+ Args:
50
+ image (np.ndarray): The image data read by ``PIL.Image``
51
+ range_norm (bool): Scale [0, 1] data to between [-1, 1]
52
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
53
+ Returns:
54
+ Normalized image data
55
+ Examples:
56
+ >>> image = Image.open("image.bmp")
57
+ >>> tensor_image = image2tensor(image, range_norm=False, half=False)
58
+ """
59
+ tensor = F.to_tensor(image)
60
+
61
+ if range_norm:
62
+ tensor = tensor.mul_(2.0).sub_(1.0)
63
+ if half:
64
+ tensor = tensor.half()
65
+
66
+ return tensor
67
+
68
+
69
+ def predict(img):
70
+ """
71
+ img: image
72
+ """
73
+ image_size = (256,256)
74
+ upscale_factor = 4
75
+ # lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
76
+ # to retain aspect ratio
77
+ lr_transforms = transforms.Resize(image_size[0]//upscale_factor, interpolation=IMode.BICUBIC, antialias=True)
78
+ # lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
79
+
80
+ img = Image.fromarray(np.array(img))
81
+ img = lr_transforms(img)
82
+ lr_tensor = image2tensor(img, range_norm=False, half=False)
83
+
84
+ xLR = lr_tensor.to(device).unsqueeze(0)
85
+ xLR = xLR.type(dtype)
86
+ # pass them through the network
87
+ with torch.no_grad():
88
+ xSR = NetG(xLR)
89
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
90
+
91
+ a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
92
+ b_map = xSRC_beta[0].to('cpu').data
93
+ u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
94
+
95
+
96
+ x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
97
+
98
+ x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
99
+
100
+ #im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
101
+
102
+ a_map = torch.clamp(a_map, min=0, max=0.1)
103
+ a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
104
+ x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
105
+
106
+ b_map = torch.clamp(b_map, min=0.45, max=0.75)
107
+ b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
108
+ x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
109
+
110
+ u_map = torch.clamp(u_map, min=0, max=0.15)
111
+ u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
112
+ x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
113
+
114
+ return x_LR, x_mean, x_alpha, x_beta, x_uncer
115
+
116
+ import gradio as gr
117
+
118
+ title = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks"
119
+
120
+ abstract="<b>Abstract.</b> High-quality calibrated uncertainty estimates are crucial for numerous real-world applications, especially for deep learning-based deployed ML systems. While Bayesian deep learning techniques allow uncertainty estimation, training them with large-scale datasets is an expensive process that does not always yield models competitive with non-Bayesian counterparts. Moreover, many of the high-performing deep learning models that are already trained and deployed are non-Bayesian in nature and do not provide uncertainty estimates. To address these issues, we propose BayesCap that learns a Bayesian identity mapping for the frozen model, allowing uncertainty estimation. BayesCap is a memory-efficient method that can be trained on a small fraction of the original dataset, enhancing pretrained non-Bayesian computer vision models by providing calibrated uncertainty estimates for the predictions without (i) hampering the performance of the model and (ii) the need for expensive retraining the model from scratch. The proposed method is agnostic to various architectures and tasks. We show the efficacy of our method on a wide variety of tasks with a diverse set of architectures, including image super-resolution, deblurring, inpainting, and crucial application such as medical image translation. Moreover, we apply the derived uncertainty estimates to detect out-of-distribution samples in critical scenarios like depth estimation in autonomous driving. Code is available <a href='https://github.com/ExplainableML/BayesCap'>here</a>. <br> <br>"
121
+
122
+ method = "In this demo, we show an application of BayesCap on top of SRGAN for the task of super resolution. BayesCap estimates the per-pixel uncertainty of a pretrained computer vision model like SRGAN (used for super-resolution). BayesCap takes the ouput of the pretrained model (in this case SRGAN), and predicts the per-pixel distribution parameters for the output, that can be used to quantify the per-pixel uncertainty. In our work, we model the per-pixel output as a <a href='https://en.wikipedia.org/wiki/Generalized_normal_distribution'>Generalized Gaussian distribution</a> that is parameterized by 3 parameters the mean, scale (alpha), and the shape (beta). As a result our model predicts these three parameters as shown below. From these 3 parameters one can compute the uncertainty as shown in <a href='https://en.wikipedia.org/wiki/Generalized_normal_distribution'>this article</a>. <br><br>"
123
+
124
+ closing = "For more details, please find the <a href='https://arxiv.org/'>ECCV 2022 paper here</a>."
125
+
126
+ description = abstract + method + closing
127
+
128
+ article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
129
+
130
+
131
+ gr.Interface(
132
+ fn=predict,
133
+ inputs=gr.inputs.Image(type='pil', label="Orignal"),
134
+ outputs=[
135
+ gr.outputs.Image(type='pil', label="Low-resolution image (input to SRGAN)"),
136
+ gr.outputs.Image(type='pil', label="Super-resolved image (output of SRGAN)"),
137
+ gr.outputs.Image(type='pil', label="Alpha parameter map characterizing per-pixel distribution (output of BayesCap)"),
138
+ gr.outputs.Image(type='pil', label="Beta parameter map characterizing per-pixel distribution (output of BayesCap)"),
139
+ gr.outputs.Image(type='pil', label="Per-pixel uncertainty map (derived using outputs of BayesCap)")
140
+ ],
141
+ title=title,
142
+ description=description,
143
+ article=article,
144
+ examples=[
145
+ ["./demo_examples/tue.jpeg"],
146
+ ["./demo_examples/baby.png"],
147
+ ["./demo_examples/bird.png"],
148
+ ["./demo_examples/butterfly.png"],
149
+ ["./demo_examples/head.png"],
150
+ ["./demo_examples/woman.png"],
151
+ ]
152
+ ).launch()
demo_examples/baby.png ADDED
demo_examples/bird.png ADDED
demo_examples/butterfly.png ADDED
demo_examples/head.png ADDED
demo_examples/tue.jpeg ADDED
demo_examples/woman.png ADDED
ds.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import random
4
+ import copy
5
+ import io
6
+ import os
7
+ import numpy as np
8
+ from PIL import Image
9
+ import skimage.transform
10
+ from collections import Counter
11
+
12
+
13
+ import torch
14
+ import torch.utils.data as data
15
+ from torch import Tensor
16
+ from torch.utils.data import Dataset
17
+ from torchvision import transforms
18
+ from torchvision.transforms.functional import InterpolationMode as IMode
19
+
20
+ import utils
21
+
22
+ class ImgDset(Dataset):
23
+ """Customize the data set loading function and prepare low/high resolution image data in advance.
24
+
25
+ Args:
26
+ dataroot (str): Training data set address
27
+ image_size (int): High resolution image size
28
+ upscale_factor (int): Image magnification
29
+ mode (str): Data set loading method, the training data set is for data enhancement,
30
+ and the verification data set is not for data enhancement
31
+
32
+ """
33
+
34
+ def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
35
+ super(ImgDset, self).__init__()
36
+ self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
37
+
38
+ if mode == "train":
39
+ self.hr_transforms = transforms.Compose([
40
+ transforms.RandomCrop(image_size),
41
+ transforms.RandomRotation(90),
42
+ transforms.RandomHorizontalFlip(0.5),
43
+ ])
44
+ else:
45
+ self.hr_transforms = transforms.Resize(image_size)
46
+
47
+ self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
48
+
49
+ def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
50
+ # Read a batch of image data
51
+ image = Image.open(self.filenames[batch_index])
52
+
53
+ # Transform image
54
+ hr_image = self.hr_transforms(image)
55
+ lr_image = self.lr_transforms(hr_image)
56
+
57
+ # Convert image data into Tensor stream format (PyTorch).
58
+ # Note: The range of input and output is between [0, 1]
59
+ lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
60
+ hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
61
+
62
+ return lr_tensor, hr_tensor
63
+
64
+ def __len__(self) -> int:
65
+ return len(self.filenames)
66
+
67
+
68
+ class PairedImages_w_nameList(Dataset):
69
+ '''
70
+ can act as supervised or un-supervised based on flists
71
+ '''
72
+ def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
73
+ self.flist1 = flist1
74
+ self.flist2 = flist2
75
+ self.transform1 = transform1
76
+ self.transform2 = transform2
77
+ self.do_aug = do_aug
78
+ def __getitem__(self, index):
79
+ impath1 = self.flist1[index]
80
+ img1 = Image.open(impath1).convert('RGB')
81
+ impath2 = self.flist2[index]
82
+ img2 = Image.open(impath2).convert('RGB')
83
+
84
+ img1 = utils.image2tensor(img1, range_norm=False, half=False)
85
+ img2 = utils.image2tensor(img2, range_norm=False, half=False)
86
+
87
+ if self.transform1 is not None:
88
+ img1 = self.transform1(img1)
89
+ if self.transform2 is not None:
90
+ img2 = self.transform2(img2)
91
+
92
+ return img1, img2
93
+ def __len__(self):
94
+ return len(self.flist1)
95
+
96
+ class PairedImages_w_nameList_npy(Dataset):
97
+ '''
98
+ can act as supervised or un-supervised based on flists
99
+ '''
100
+ def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
101
+ self.flist1 = flist1
102
+ self.flist2 = flist2
103
+ self.transform1 = transform1
104
+ self.transform2 = transform2
105
+ self.do_aug = do_aug
106
+ def __getitem__(self, index):
107
+ impath1 = self.flist1[index]
108
+ img1 = np.load(impath1)
109
+ impath2 = self.flist2[index]
110
+ img2 = np.load(impath2)
111
+
112
+ if self.transform1 is not None:
113
+ img1 = self.transform1(img1)
114
+ if self.transform2 is not None:
115
+ img2 = self.transform2(img2)
116
+
117
+ return img1, img2
118
+ def __len__(self):
119
+ return len(self.flist1)
120
+
121
+ # def call_paired():
122
+ # root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
123
+ # root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
124
+
125
+ # flist1=glob.glob(root1+'/*/*.png')
126
+ # flist2=glob.glob(root2+'/*/*.png')
127
+
128
+ # dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
129
+
130
+ #### KITTI depth
131
+
132
+ def load_velodyne_points(filename):
133
+ """Load 3D point cloud from KITTI file format
134
+ (adapted from https://github.com/hunse/kitti)
135
+ """
136
+ points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
137
+ points[:, 3] = 1.0 # homogeneous
138
+ return points
139
+
140
+
141
+ def read_calib_file(path):
142
+ """Read KITTI calibration file
143
+ (from https://github.com/hunse/kitti)
144
+ """
145
+ float_chars = set("0123456789.e+- ")
146
+ data = {}
147
+ with open(path, 'r') as f:
148
+ for line in f.readlines():
149
+ key, value = line.split(':', 1)
150
+ value = value.strip()
151
+ data[key] = value
152
+ if float_chars.issuperset(value):
153
+ # try to cast to float array
154
+ try:
155
+ data[key] = np.array(list(map(float, value.split(' '))))
156
+ except ValueError:
157
+ # casting error: data[key] already eq. value, so pass
158
+ pass
159
+
160
+ return data
161
+
162
+
163
+ def sub2ind(matrixSize, rowSub, colSub):
164
+ """Convert row, col matrix subscripts to linear indices
165
+ """
166
+ m, n = matrixSize
167
+ return rowSub * (n-1) + colSub - 1
168
+
169
+
170
+ def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
171
+ """Generate a depth map from velodyne data
172
+ """
173
+ # load calibration files
174
+ cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
175
+ velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
176
+ velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
177
+ velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
178
+
179
+ # get image shape
180
+ im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
181
+
182
+ # compute projection matrix velodyne->image plane
183
+ R_cam2rect = np.eye(4)
184
+ R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
185
+ P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
186
+ P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
187
+
188
+ # load velodyne points and remove all behind image plane (approximation)
189
+ # each row of the velodyne data is forward, left, up, reflectance
190
+ velo = load_velodyne_points(velo_filename)
191
+ velo = velo[velo[:, 0] >= 0, :]
192
+
193
+ # project the points to the camera
194
+ velo_pts_im = np.dot(P_velo2im, velo.T).T
195
+ velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
196
+
197
+ if vel_depth:
198
+ velo_pts_im[:, 2] = velo[:, 0]
199
+
200
+ # check if in bounds
201
+ # use minus 1 to get the exact same value as KITTI matlab code
202
+ velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
203
+ velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
204
+ val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
205
+ val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
206
+ velo_pts_im = velo_pts_im[val_inds, :]
207
+
208
+ # project to image
209
+ depth = np.zeros((im_shape[:2]))
210
+ depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
211
+
212
+ # find the duplicate points and choose the closest depth
213
+ inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
214
+ dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
215
+ for dd in dupe_inds:
216
+ pts = np.where(inds == dd)[0]
217
+ x_loc = int(velo_pts_im[pts[0], 0])
218
+ y_loc = int(velo_pts_im[pts[0], 1])
219
+ depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
220
+ depth[depth < 0] = 0
221
+
222
+ return depth
223
+
224
+ def pil_loader(path):
225
+ # open path as file to avoid ResourceWarning
226
+ # (https://github.com/python-pillow/Pillow/issues/835)
227
+ with open(path, 'rb') as f:
228
+ with Image.open(f) as img:
229
+ return img.convert('RGB')
230
+
231
+
232
+ class MonoDataset(data.Dataset):
233
+ """Superclass for monocular dataloaders
234
+
235
+ Args:
236
+ data_path
237
+ filenames
238
+ height
239
+ width
240
+ frame_idxs
241
+ num_scales
242
+ is_train
243
+ img_ext
244
+ """
245
+ def __init__(self,
246
+ data_path,
247
+ filenames,
248
+ height,
249
+ width,
250
+ frame_idxs,
251
+ num_scales,
252
+ is_train=False,
253
+ img_ext='.jpg'):
254
+ super(MonoDataset, self).__init__()
255
+
256
+ self.data_path = data_path
257
+ self.filenames = filenames
258
+ self.height = height
259
+ self.width = width
260
+ self.num_scales = num_scales
261
+ self.interp = Image.ANTIALIAS
262
+
263
+ self.frame_idxs = frame_idxs
264
+
265
+ self.is_train = is_train
266
+ self.img_ext = img_ext
267
+
268
+ self.loader = pil_loader
269
+ self.to_tensor = transforms.ToTensor()
270
+
271
+ # We need to specify augmentations differently in newer versions of torchvision.
272
+ # We first try the newer tuple version; if this fails we fall back to scalars
273
+ try:
274
+ self.brightness = (0.8, 1.2)
275
+ self.contrast = (0.8, 1.2)
276
+ self.saturation = (0.8, 1.2)
277
+ self.hue = (-0.1, 0.1)
278
+ transforms.ColorJitter.get_params(
279
+ self.brightness, self.contrast, self.saturation, self.hue)
280
+ except TypeError:
281
+ self.brightness = 0.2
282
+ self.contrast = 0.2
283
+ self.saturation = 0.2
284
+ self.hue = 0.1
285
+
286
+ self.resize = {}
287
+ for i in range(self.num_scales):
288
+ s = 2 ** i
289
+ self.resize[i] = transforms.Resize((self.height // s, self.width // s),
290
+ interpolation=self.interp)
291
+
292
+ self.load_depth = self.check_depth()
293
+
294
+ def preprocess(self, inputs, color_aug):
295
+ """Resize colour images to the required scales and augment if required
296
+
297
+ We create the color_aug object in advance and apply the same augmentation to all
298
+ images in this item. This ensures that all images input to the pose network receive the
299
+ same augmentation.
300
+ """
301
+ for k in list(inputs):
302
+ frame = inputs[k]
303
+ if "color" in k:
304
+ n, im, i = k
305
+ for i in range(self.num_scales):
306
+ inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
307
+
308
+ for k in list(inputs):
309
+ f = inputs[k]
310
+ if "color" in k:
311
+ n, im, i = k
312
+ inputs[(n, im, i)] = self.to_tensor(f)
313
+ inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
314
+
315
+ def __len__(self):
316
+ return len(self.filenames)
317
+
318
+ def __getitem__(self, index):
319
+ """Returns a single training item from the dataset as a dictionary.
320
+
321
+ Values correspond to torch tensors.
322
+ Keys in the dictionary are either strings or tuples:
323
+
324
+ ("color", <frame_id>, <scale>) for raw colour images,
325
+ ("color_aug", <frame_id>, <scale>) for augmented colour images,
326
+ ("K", scale) or ("inv_K", scale) for camera intrinsics,
327
+ "stereo_T" for camera extrinsics, and
328
+ "depth_gt" for ground truth depth maps.
329
+
330
+ <frame_id> is either:
331
+ an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
332
+ or
333
+ "s" for the opposite image in the stereo pair.
334
+
335
+ <scale> is an integer representing the scale of the image relative to the fullsize image:
336
+ -1 images at native resolution as loaded from disk
337
+ 0 images resized to (self.width, self.height )
338
+ 1 images resized to (self.width // 2, self.height // 2)
339
+ 2 images resized to (self.width // 4, self.height // 4)
340
+ 3 images resized to (self.width // 8, self.height // 8)
341
+ """
342
+ inputs = {}
343
+
344
+ do_color_aug = self.is_train and random.random() > 0.5
345
+ do_flip = self.is_train and random.random() > 0.5
346
+
347
+ line = self.filenames[index].split()
348
+ folder = line[0]
349
+
350
+ if len(line) == 3:
351
+ frame_index = int(line[1])
352
+ else:
353
+ frame_index = 0
354
+
355
+ if len(line) == 3:
356
+ side = line[2]
357
+ else:
358
+ side = None
359
+
360
+ for i in self.frame_idxs:
361
+ if i == "s":
362
+ other_side = {"r": "l", "l": "r"}[side]
363
+ inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
364
+ else:
365
+ inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
366
+
367
+ # adjusting intrinsics to match each scale in the pyramid
368
+ for scale in range(self.num_scales):
369
+ K = self.K.copy()
370
+
371
+ K[0, :] *= self.width // (2 ** scale)
372
+ K[1, :] *= self.height // (2 ** scale)
373
+
374
+ inv_K = np.linalg.pinv(K)
375
+
376
+ inputs[("K", scale)] = torch.from_numpy(K)
377
+ inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
378
+
379
+ if do_color_aug:
380
+ color_aug = transforms.ColorJitter.get_params(
381
+ self.brightness, self.contrast, self.saturation, self.hue)
382
+ else:
383
+ color_aug = (lambda x: x)
384
+
385
+ self.preprocess(inputs, color_aug)
386
+
387
+ for i in self.frame_idxs:
388
+ del inputs[("color", i, -1)]
389
+ del inputs[("color_aug", i, -1)]
390
+
391
+ if self.load_depth:
392
+ depth_gt = self.get_depth(folder, frame_index, side, do_flip)
393
+ inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
394
+ inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
395
+
396
+ if "s" in self.frame_idxs:
397
+ stereo_T = np.eye(4, dtype=np.float32)
398
+ baseline_sign = -1 if do_flip else 1
399
+ side_sign = -1 if side == "l" else 1
400
+ stereo_T[0, 3] = side_sign * baseline_sign * 0.1
401
+
402
+ inputs["stereo_T"] = torch.from_numpy(stereo_T)
403
+
404
+ return inputs
405
+
406
+ def get_color(self, folder, frame_index, side, do_flip):
407
+ raise NotImplementedError
408
+
409
+ def check_depth(self):
410
+ raise NotImplementedError
411
+
412
+ def get_depth(self, folder, frame_index, side, do_flip):
413
+ raise NotImplementedError
414
+
415
+ class KITTIDataset(MonoDataset):
416
+ """Superclass for different types of KITTI dataset loaders
417
+ """
418
+ def __init__(self, *args, **kwargs):
419
+ super(KITTIDataset, self).__init__(*args, **kwargs)
420
+
421
+ # NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
422
+ # To normalize you need to scale the first row by 1 / image_width and the second row
423
+ # by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
424
+ # If your principal point is far from the center you might need to disable the horizontal
425
+ # flip augmentation.
426
+ self.K = np.array([[0.58, 0, 0.5, 0],
427
+ [0, 1.92, 0.5, 0],
428
+ [0, 0, 1, 0],
429
+ [0, 0, 0, 1]], dtype=np.float32)
430
+
431
+ self.full_res_shape = (1242, 375)
432
+ self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
433
+
434
+ def check_depth(self):
435
+ line = self.filenames[0].split()
436
+ scene_name = line[0]
437
+ frame_index = int(line[1])
438
+
439
+ velo_filename = os.path.join(
440
+ self.data_path,
441
+ scene_name,
442
+ "velodyne_points/data/{:010d}.bin".format(int(frame_index)))
443
+
444
+ return os.path.isfile(velo_filename)
445
+
446
+ def get_color(self, folder, frame_index, side, do_flip):
447
+ color = self.loader(self.get_image_path(folder, frame_index, side))
448
+
449
+ if do_flip:
450
+ color = color.transpose(Image.FLIP_LEFT_RIGHT)
451
+
452
+ return color
453
+
454
+
455
+ class KITTIDepthDataset(KITTIDataset):
456
+ """KITTI dataset which uses the updated ground truth depth maps
457
+ """
458
+ def __init__(self, *args, **kwargs):
459
+ super(KITTIDepthDataset, self).__init__(*args, **kwargs)
460
+
461
+ def get_image_path(self, folder, frame_index, side):
462
+ f_str = "{:010d}{}".format(frame_index, self.img_ext)
463
+ image_path = os.path.join(
464
+ self.data_path,
465
+ folder,
466
+ "image_0{}/data".format(self.side_map[side]),
467
+ f_str)
468
+ return image_path
469
+
470
+ def get_depth(self, folder, frame_index, side, do_flip):
471
+ f_str = "{:010d}.png".format(frame_index)
472
+ depth_path = os.path.join(
473
+ self.data_path,
474
+ folder,
475
+ "proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
476
+ f_str)
477
+
478
+ depth_gt = Image.open(depth_path)
479
+ depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
480
+ depth_gt = np.array(depth_gt).astype(np.float32) / 256
481
+
482
+ if do_flip:
483
+ depth_gt = np.fliplr(depth_gt)
484
+
485
+ return depth_gt
losses.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from torch import Tensor
6
+
7
+ class ContentLoss(nn.Module):
8
+ """Constructs a content loss function based on the VGG19 network.
9
+ Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
10
+
11
+ Paper reference list:
12
+ -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
13
+ -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
14
+ -`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
15
+
16
+ """
17
+
18
+ def __init__(self) -> None:
19
+ super(ContentLoss, self).__init__()
20
+ # Load the VGG19 model trained on the ImageNet dataset.
21
+ vgg19 = models.vgg19(pretrained=True).eval()
22
+ # Extract the thirty-sixth layer output in the VGG19 model as the content loss.
23
+ self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
24
+ # Freeze model parameters.
25
+ for parameters in self.feature_extractor.parameters():
26
+ parameters.requires_grad = False
27
+
28
+ # The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
29
+ self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
30
+ self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
31
+
32
+ def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
33
+ # Standardized operations
34
+ sr = sr.sub(self.mean).div(self.std)
35
+ hr = hr.sub(self.mean).div(self.std)
36
+
37
+ # Find the feature map difference between the two images
38
+ loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
39
+
40
+ return loss
41
+
42
+
43
+ class GenGaussLoss(nn.Module):
44
+ def __init__(
45
+ self, reduction='mean',
46
+ alpha_eps = 1e-4, beta_eps=1e-4,
47
+ resi_min = 1e-4, resi_max=1e3
48
+ ) -> None:
49
+ super(GenGaussLoss, self).__init__()
50
+ self.reduction = reduction
51
+ self.alpha_eps = alpha_eps
52
+ self.beta_eps = beta_eps
53
+ self.resi_min = resi_min
54
+ self.resi_max = resi_max
55
+
56
+ def forward(
57
+ self,
58
+ mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
59
+ ):
60
+ one_over_alpha1 = one_over_alpha + self.alpha_eps
61
+ beta1 = beta + self.beta_eps
62
+
63
+ resi = torch.abs(mean - target)
64
+ # resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
65
+ resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
66
+ ## check if resi has nans
67
+ if torch.sum(resi != resi) > 0:
68
+ print('resi has nans!!')
69
+ return None
70
+
71
+ log_one_over_alpha = torch.log(one_over_alpha1)
72
+ log_beta = torch.log(beta1)
73
+ lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
74
+
75
+ if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
76
+ print('log_one_over_alpha has nan')
77
+ if torch.sum(lgamma_beta != lgamma_beta) > 0:
78
+ print('lgamma_beta has nan')
79
+ if torch.sum(log_beta != log_beta) > 0:
80
+ print('log_beta has nan')
81
+
82
+ l = resi - log_one_over_alpha + lgamma_beta - log_beta
83
+
84
+ if self.reduction == 'mean':
85
+ return l.mean()
86
+ elif self.reduction == 'sum':
87
+ return l.sum()
88
+ else:
89
+ print('Reduction not supported')
90
+ return None
91
+
92
+ class TempCombLoss(nn.Module):
93
+ def __init__(
94
+ self, reduction='mean',
95
+ alpha_eps = 1e-4, beta_eps=1e-4,
96
+ resi_min = 1e-4, resi_max=1e3
97
+ ) -> None:
98
+ super(TempCombLoss, self).__init__()
99
+ self.reduction = reduction
100
+ self.alpha_eps = alpha_eps
101
+ self.beta_eps = beta_eps
102
+ self.resi_min = resi_min
103
+ self.resi_max = resi_max
104
+
105
+ self.L_GenGauss = GenGaussLoss(
106
+ reduction=self.reduction,
107
+ alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
108
+ resi_min=self.resi_min, resi_max=self.resi_max
109
+ )
110
+ self.L_l1 = nn.L1Loss(reduction=self.reduction)
111
+
112
+ def forward(
113
+ self,
114
+ mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
115
+ T1: float, T2: float
116
+ ):
117
+ l1 = self.L_l1(mean, target)
118
+ l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
119
+ l = T1*l1 + T2*l2
120
+
121
+ return l
122
+
123
+
124
+ # x1 = torch.randn(4,3,32,32)
125
+ # x2 = torch.rand(4,3,32,32)
126
+ # x3 = torch.rand(4,3,32,32)
127
+ # x4 = torch.randn(4,3,32,32)
128
+
129
+ # L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
130
+ # L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
131
+ # print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
networks_SRGAN.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from torch import Tensor
6
+
7
+ # __all__ = [
8
+ # "ResidualConvBlock",
9
+ # "Discriminator", "Generator",
10
+ # ]
11
+
12
+
13
+ class ResidualConvBlock(nn.Module):
14
+ """Implements residual conv function.
15
+
16
+ Args:
17
+ channels (int): Number of channels in the input image.
18
+ """
19
+
20
+ def __init__(self, channels: int) -> None:
21
+ super(ResidualConvBlock, self).__init__()
22
+ self.rcb = nn.Sequential(
23
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
24
+ nn.BatchNorm2d(channels),
25
+ nn.PReLU(),
26
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
27
+ nn.BatchNorm2d(channels),
28
+ )
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ identity = x
32
+
33
+ out = self.rcb(x)
34
+ out = torch.add(out, identity)
35
+
36
+ return out
37
+
38
+
39
+ class Discriminator(nn.Module):
40
+ def __init__(self) -> None:
41
+ super(Discriminator, self).__init__()
42
+ self.features = nn.Sequential(
43
+ # input size. (3) x 96 x 96
44
+ nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
45
+ nn.LeakyReLU(0.2, True),
46
+ # state size. (64) x 48 x 48
47
+ nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
48
+ nn.BatchNorm2d(64),
49
+ nn.LeakyReLU(0.2, True),
50
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
51
+ nn.BatchNorm2d(128),
52
+ nn.LeakyReLU(0.2, True),
53
+ # state size. (128) x 24 x 24
54
+ nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
55
+ nn.BatchNorm2d(128),
56
+ nn.LeakyReLU(0.2, True),
57
+ nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
58
+ nn.BatchNorm2d(256),
59
+ nn.LeakyReLU(0.2, True),
60
+ # state size. (256) x 12 x 12
61
+ nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
62
+ nn.BatchNorm2d(256),
63
+ nn.LeakyReLU(0.2, True),
64
+ nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
65
+ nn.BatchNorm2d(512),
66
+ nn.LeakyReLU(0.2, True),
67
+ # state size. (512) x 6 x 6
68
+ nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
69
+ nn.BatchNorm2d(512),
70
+ nn.LeakyReLU(0.2, True),
71
+ )
72
+
73
+ self.classifier = nn.Sequential(
74
+ nn.Linear(512 * 6 * 6, 1024),
75
+ nn.LeakyReLU(0.2, True),
76
+ nn.Linear(1024, 1),
77
+ )
78
+
79
+ def forward(self, x: Tensor) -> Tensor:
80
+ out = self.features(x)
81
+ out = torch.flatten(out, 1)
82
+ out = self.classifier(out)
83
+
84
+ return out
85
+
86
+
87
+ class Generator(nn.Module):
88
+ def __init__(self) -> None:
89
+ super(Generator, self).__init__()
90
+ # First conv layer.
91
+ self.conv_block1 = nn.Sequential(
92
+ nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
93
+ nn.PReLU(),
94
+ )
95
+
96
+ # Features trunk blocks.
97
+ trunk = []
98
+ for _ in range(16):
99
+ trunk.append(ResidualConvBlock(64))
100
+ self.trunk = nn.Sequential(*trunk)
101
+
102
+ # Second conv layer.
103
+ self.conv_block2 = nn.Sequential(
104
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
105
+ nn.BatchNorm2d(64),
106
+ )
107
+
108
+ # Upscale conv block.
109
+ self.upsampling = nn.Sequential(
110
+ nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
111
+ nn.PixelShuffle(2),
112
+ nn.PReLU(),
113
+ nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
114
+ nn.PixelShuffle(2),
115
+ nn.PReLU(),
116
+ )
117
+
118
+ # Output layer.
119
+ self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
120
+
121
+ # Initialize neural network weights.
122
+ self._initialize_weights()
123
+
124
+ def forward(self, x: Tensor, dop=None) -> Tensor:
125
+ if not dop:
126
+ return self._forward_impl(x)
127
+ else:
128
+ return self._forward_w_dop_impl(x, dop)
129
+
130
+ # Support torch.script function.
131
+ def _forward_impl(self, x: Tensor) -> Tensor:
132
+ out1 = self.conv_block1(x)
133
+ out = self.trunk(out1)
134
+ out2 = self.conv_block2(out)
135
+ out = torch.add(out1, out2)
136
+ out = self.upsampling(out)
137
+ out = self.conv_block3(out)
138
+
139
+ return out
140
+
141
+ def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
142
+ out1 = self.conv_block1(x)
143
+ out = self.trunk(out1)
144
+ out2 = F.dropout2d(self.conv_block2(out), p=dop)
145
+ out = torch.add(out1, out2)
146
+ out = self.upsampling(out)
147
+ out = self.conv_block3(out)
148
+
149
+ return out
150
+
151
+ def _initialize_weights(self) -> None:
152
+ for module in self.modules():
153
+ if isinstance(module, nn.Conv2d):
154
+ nn.init.kaiming_normal_(module.weight)
155
+ if module.bias is not None:
156
+ nn.init.constant_(module.bias, 0)
157
+ elif isinstance(module, nn.BatchNorm2d):
158
+ nn.init.constant_(module.weight, 1)
159
+
160
+
161
+ #### BayesCap
162
+ class BayesCap(nn.Module):
163
+ def __init__(self, in_channels=3, out_channels=3) -> None:
164
+ super(BayesCap, self).__init__()
165
+ # First conv layer.
166
+ self.conv_block1 = nn.Sequential(
167
+ nn.Conv2d(
168
+ in_channels, 64,
169
+ kernel_size=9, stride=1, padding=4
170
+ ),
171
+ nn.PReLU(),
172
+ )
173
+
174
+ # Features trunk blocks.
175
+ trunk = []
176
+ for _ in range(16):
177
+ trunk.append(ResidualConvBlock(64))
178
+ self.trunk = nn.Sequential(*trunk)
179
+
180
+ # Second conv layer.
181
+ self.conv_block2 = nn.Sequential(
182
+ nn.Conv2d(
183
+ 64, 64,
184
+ kernel_size=3, stride=1, padding=1, bias=False
185
+ ),
186
+ nn.BatchNorm2d(64),
187
+ )
188
+
189
+ # Output layer.
190
+ self.conv_block3_mu = nn.Conv2d(
191
+ 64, out_channels=out_channels,
192
+ kernel_size=9, stride=1, padding=4
193
+ )
194
+ self.conv_block3_alpha = nn.Sequential(
195
+ nn.Conv2d(
196
+ 64, 64,
197
+ kernel_size=9, stride=1, padding=4
198
+ ),
199
+ nn.PReLU(),
200
+ nn.Conv2d(
201
+ 64, 64,
202
+ kernel_size=9, stride=1, padding=4
203
+ ),
204
+ nn.PReLU(),
205
+ nn.Conv2d(
206
+ 64, 1,
207
+ kernel_size=9, stride=1, padding=4
208
+ ),
209
+ nn.ReLU(),
210
+ )
211
+ self.conv_block3_beta = nn.Sequential(
212
+ nn.Conv2d(
213
+ 64, 64,
214
+ kernel_size=9, stride=1, padding=4
215
+ ),
216
+ nn.PReLU(),
217
+ nn.Conv2d(
218
+ 64, 64,
219
+ kernel_size=9, stride=1, padding=4
220
+ ),
221
+ nn.PReLU(),
222
+ nn.Conv2d(
223
+ 64, 1,
224
+ kernel_size=9, stride=1, padding=4
225
+ ),
226
+ nn.ReLU(),
227
+ )
228
+
229
+ # Initialize neural network weights.
230
+ self._initialize_weights()
231
+
232
+ def forward(self, x: Tensor) -> Tensor:
233
+ return self._forward_impl(x)
234
+
235
+ # Support torch.script function.
236
+ def _forward_impl(self, x: Tensor) -> Tensor:
237
+ out1 = self.conv_block1(x)
238
+ out = self.trunk(out1)
239
+ out2 = self.conv_block2(out)
240
+ out = out1 + out2
241
+ out_mu = self.conv_block3_mu(out)
242
+ out_alpha = self.conv_block3_alpha(out)
243
+ out_beta = self.conv_block3_beta(out)
244
+ return out_mu, out_alpha, out_beta
245
+
246
+ def _initialize_weights(self) -> None:
247
+ for module in self.modules():
248
+ if isinstance(module, nn.Conv2d):
249
+ nn.init.kaiming_normal_(module.weight)
250
+ if module.bias is not None:
251
+ nn.init.constant_(module.bias, 0)
252
+ elif isinstance(module, nn.BatchNorm2d):
253
+ nn.init.constant_(module.weight, 1)
254
+
255
+
256
+ class BayesCap_noID(nn.Module):
257
+ def __init__(self, in_channels=3, out_channels=3) -> None:
258
+ super(BayesCap_noID, self).__init__()
259
+ # First conv layer.
260
+ self.conv_block1 = nn.Sequential(
261
+ nn.Conv2d(
262
+ in_channels, 64,
263
+ kernel_size=9, stride=1, padding=4
264
+ ),
265
+ nn.PReLU(),
266
+ )
267
+
268
+ # Features trunk blocks.
269
+ trunk = []
270
+ for _ in range(16):
271
+ trunk.append(ResidualConvBlock(64))
272
+ self.trunk = nn.Sequential(*trunk)
273
+
274
+ # Second conv layer.
275
+ self.conv_block2 = nn.Sequential(
276
+ nn.Conv2d(
277
+ 64, 64,
278
+ kernel_size=3, stride=1, padding=1, bias=False
279
+ ),
280
+ nn.BatchNorm2d(64),
281
+ )
282
+
283
+ # Output layer.
284
+ # self.conv_block3_mu = nn.Conv2d(
285
+ # 64, out_channels=out_channels,
286
+ # kernel_size=9, stride=1, padding=4
287
+ # )
288
+ self.conv_block3_alpha = nn.Sequential(
289
+ nn.Conv2d(
290
+ 64, 64,
291
+ kernel_size=9, stride=1, padding=4
292
+ ),
293
+ nn.PReLU(),
294
+ nn.Conv2d(
295
+ 64, 64,
296
+ kernel_size=9, stride=1, padding=4
297
+ ),
298
+ nn.PReLU(),
299
+ nn.Conv2d(
300
+ 64, 1,
301
+ kernel_size=9, stride=1, padding=4
302
+ ),
303
+ nn.ReLU(),
304
+ )
305
+ self.conv_block3_beta = nn.Sequential(
306
+ nn.Conv2d(
307
+ 64, 64,
308
+ kernel_size=9, stride=1, padding=4
309
+ ),
310
+ nn.PReLU(),
311
+ nn.Conv2d(
312
+ 64, 64,
313
+ kernel_size=9, stride=1, padding=4
314
+ ),
315
+ nn.PReLU(),
316
+ nn.Conv2d(
317
+ 64, 1,
318
+ kernel_size=9, stride=1, padding=4
319
+ ),
320
+ nn.ReLU(),
321
+ )
322
+
323
+ # Initialize neural network weights.
324
+ self._initialize_weights()
325
+
326
+ def forward(self, x: Tensor) -> Tensor:
327
+ return self._forward_impl(x)
328
+
329
+ # Support torch.script function.
330
+ def _forward_impl(self, x: Tensor) -> Tensor:
331
+ out1 = self.conv_block1(x)
332
+ out = self.trunk(out1)
333
+ out2 = self.conv_block2(out)
334
+ out = out1 + out2
335
+ # out_mu = self.conv_block3_mu(out)
336
+ out_alpha = self.conv_block3_alpha(out)
337
+ out_beta = self.conv_block3_beta(out)
338
+ return out_alpha, out_beta
339
+
340
+ def _initialize_weights(self) -> None:
341
+ for module in self.modules():
342
+ if isinstance(module, nn.Conv2d):
343
+ nn.init.kaiming_normal_(module.weight)
344
+ if module.bias is not None:
345
+ nn.init.constant_(module.bias, 0)
346
+ elif isinstance(module, nn.BatchNorm2d):
347
+ nn.init.constant_(module.weight, 1)
networks_T1toT2.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+
6
+ ### components
7
+ class ResConv(nn.Module):
8
+ """
9
+ Residual convolutional block, where
10
+ convolutional block consists: (convolution => [BN] => ReLU) * 3
11
+ residual connection adds the input to the output
12
+ """
13
+ def __init__(self, in_channels, out_channels, mid_channels=None):
14
+ super().__init__()
15
+ if not mid_channels:
16
+ mid_channels = out_channels
17
+ self.double_conv = nn.Sequential(
18
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
19
+ nn.BatchNorm2d(mid_channels),
20
+ nn.ReLU(inplace=True),
21
+ nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
22
+ nn.BatchNorm2d(mid_channels),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
25
+ nn.BatchNorm2d(out_channels),
26
+ nn.ReLU(inplace=True)
27
+ )
28
+ self.double_conv1 = nn.Sequential(
29
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.ReLU(inplace=True),
32
+ )
33
+ def forward(self, x):
34
+ x_in = self.double_conv1(x)
35
+ x1 = self.double_conv(x)
36
+ return self.double_conv(x) + x_in
37
+
38
+ class Down(nn.Module):
39
+ """Downscaling with maxpool then Resconv"""
40
+ def __init__(self, in_channels, out_channels):
41
+ super().__init__()
42
+ self.maxpool_conv = nn.Sequential(
43
+ nn.MaxPool2d(2),
44
+ ResConv(in_channels, out_channels)
45
+ )
46
+ def forward(self, x):
47
+ return self.maxpool_conv(x)
48
+
49
+ class Up(nn.Module):
50
+ """Upscaling then double conv"""
51
+ def __init__(self, in_channels, out_channels, bilinear=True):
52
+ super().__init__()
53
+ # if bilinear, use the normal convolutions to reduce the number of channels
54
+ if bilinear:
55
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
56
+ self.conv = ResConv(in_channels, out_channels, in_channels // 2)
57
+ else:
58
+ self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
59
+ self.conv = ResConv(in_channels, out_channels)
60
+ def forward(self, x1, x2):
61
+ x1 = self.up(x1)
62
+ # input is CHW
63
+ diffY = x2.size()[2] - x1.size()[2]
64
+ diffX = x2.size()[3] - x1.size()[3]
65
+ x1 = F.pad(
66
+ x1,
67
+ [
68
+ diffX // 2, diffX - diffX // 2,
69
+ diffY // 2, diffY - diffY // 2
70
+ ]
71
+ )
72
+ # if you have padding issues, see
73
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
74
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
75
+ x = torch.cat([x2, x1], dim=1)
76
+ return self.conv(x)
77
+
78
+ class OutConv(nn.Module):
79
+ def __init__(self, in_channels, out_channels):
80
+ super(OutConv, self).__init__()
81
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
82
+ def forward(self, x):
83
+ # return F.relu(self.conv(x))
84
+ return self.conv(x)
85
+
86
+ ##### The composite networks
87
+ class UNet(nn.Module):
88
+ def __init__(self, n_channels, out_channels, bilinear=True):
89
+ super(UNet, self).__init__()
90
+ self.n_channels = n_channels
91
+ self.out_channels = out_channels
92
+ self.bilinear = bilinear
93
+ ####
94
+ self.inc = ResConv(n_channels, 64)
95
+ self.down1 = Down(64, 128)
96
+ self.down2 = Down(128, 256)
97
+ self.down3 = Down(256, 512)
98
+ factor = 2 if bilinear else 1
99
+ self.down4 = Down(512, 1024 // factor)
100
+ self.up1 = Up(1024, 512 // factor, bilinear)
101
+ self.up2 = Up(512, 256 // factor, bilinear)
102
+ self.up3 = Up(256, 128 // factor, bilinear)
103
+ self.up4 = Up(128, 64, bilinear)
104
+ self.outc = OutConv(64, out_channels)
105
+ def forward(self, x):
106
+ x1 = self.inc(x)
107
+ x2 = self.down1(x1)
108
+ x3 = self.down2(x2)
109
+ x4 = self.down3(x3)
110
+ x5 = self.down4(x4)
111
+ x = self.up1(x5, x4)
112
+ x = self.up2(x, x3)
113
+ x = self.up3(x, x2)
114
+ x = self.up4(x, x1)
115
+ y = self.outc(x)
116
+ return y
117
+
118
+ class CasUNet(nn.Module):
119
+ def __init__(self, n_unet, io_channels, bilinear=True):
120
+ super(CasUNet, self).__init__()
121
+ self.n_unet = n_unet
122
+ self.io_channels = io_channels
123
+ self.bilinear = bilinear
124
+ ####
125
+ self.unet_list = nn.ModuleList()
126
+ for i in range(self.n_unet):
127
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
128
+ def forward(self, x, dop=None):
129
+ y = x
130
+ for i in range(self.n_unet):
131
+ if i==0:
132
+ if dop is not None:
133
+ y = F.dropout2d(self.unet_list[i](y), p=dop)
134
+ else:
135
+ y = self.unet_list[i](y)
136
+ else:
137
+ y = self.unet_list[i](y+x)
138
+ return y
139
+
140
+ class CasUNet_2head(nn.Module):
141
+ def __init__(self, n_unet, io_channels, bilinear=True):
142
+ super(CasUNet_2head, self).__init__()
143
+ self.n_unet = n_unet
144
+ self.io_channels = io_channels
145
+ self.bilinear = bilinear
146
+ ####
147
+ self.unet_list = nn.ModuleList()
148
+ for i in range(self.n_unet):
149
+ if i != self.n_unet-1:
150
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
151
+ else:
152
+ self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
153
+ def forward(self, x):
154
+ y = x
155
+ for i in range(self.n_unet):
156
+ if i==0:
157
+ y = self.unet_list[i](y)
158
+ else:
159
+ y = self.unet_list[i](y+x)
160
+ y_mean, y_sigma = y[0], y[1]
161
+ return y_mean, y_sigma
162
+
163
+ class CasUNet_3head(nn.Module):
164
+ def __init__(self, n_unet, io_channels, bilinear=True):
165
+ super(CasUNet_3head, self).__init__()
166
+ self.n_unet = n_unet
167
+ self.io_channels = io_channels
168
+ self.bilinear = bilinear
169
+ ####
170
+ self.unet_list = nn.ModuleList()
171
+ for i in range(self.n_unet):
172
+ if i != self.n_unet-1:
173
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
174
+ else:
175
+ self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
176
+ def forward(self, x):
177
+ y = x
178
+ for i in range(self.n_unet):
179
+ if i==0:
180
+ y = self.unet_list[i](y)
181
+ else:
182
+ y = self.unet_list[i](y+x)
183
+ y_mean, y_alpha, y_beta = y[0], y[1], y[2]
184
+ return y_mean, y_alpha, y_beta
185
+
186
+ class UNet_2head(nn.Module):
187
+ def __init__(self, n_channels, out_channels, bilinear=True):
188
+ super(UNet_2head, self).__init__()
189
+ self.n_channels = n_channels
190
+ self.out_channels = out_channels
191
+ self.bilinear = bilinear
192
+ ####
193
+ self.inc = ResConv(n_channels, 64)
194
+ self.down1 = Down(64, 128)
195
+ self.down2 = Down(128, 256)
196
+ self.down3 = Down(256, 512)
197
+ factor = 2 if bilinear else 1
198
+ self.down4 = Down(512, 1024 // factor)
199
+ self.up1 = Up(1024, 512 // factor, bilinear)
200
+ self.up2 = Up(512, 256 // factor, bilinear)
201
+ self.up3 = Up(256, 128 // factor, bilinear)
202
+ self.up4 = Up(128, 64, bilinear)
203
+ #per pixel multiple channels may exist
204
+ self.out_mean = OutConv(64, out_channels)
205
+ #variance will always be a single number for a pixel
206
+ self.out_var = nn.Sequential(
207
+ OutConv(64, 128),
208
+ OutConv(128, 1),
209
+ )
210
+ def forward(self, x):
211
+ x1 = self.inc(x)
212
+ x2 = self.down1(x1)
213
+ x3 = self.down2(x2)
214
+ x4 = self.down3(x3)
215
+ x5 = self.down4(x4)
216
+ x = self.up1(x5, x4)
217
+ x = self.up2(x, x3)
218
+ x = self.up3(x, x2)
219
+ x = self.up4(x, x1)
220
+ y_mean, y_var = self.out_mean(x), self.out_var(x)
221
+ return y_mean, y_var
222
+
223
+ class UNet_3head(nn.Module):
224
+ def __init__(self, n_channels, out_channels, bilinear=True):
225
+ super(UNet_3head, self).__init__()
226
+ self.n_channels = n_channels
227
+ self.out_channels = out_channels
228
+ self.bilinear = bilinear
229
+ ####
230
+ self.inc = ResConv(n_channels, 64)
231
+ self.down1 = Down(64, 128)
232
+ self.down2 = Down(128, 256)
233
+ self.down3 = Down(256, 512)
234
+ factor = 2 if bilinear else 1
235
+ self.down4 = Down(512, 1024 // factor)
236
+ self.up1 = Up(1024, 512 // factor, bilinear)
237
+ self.up2 = Up(512, 256 // factor, bilinear)
238
+ self.up3 = Up(256, 128 // factor, bilinear)
239
+ self.up4 = Up(128, 64, bilinear)
240
+ #per pixel multiple channels may exist
241
+ self.out_mean = OutConv(64, out_channels)
242
+ #variance will always be a single number for a pixel
243
+ self.out_alpha = nn.Sequential(
244
+ OutConv(64, 128),
245
+ OutConv(128, 1),
246
+ nn.ReLU()
247
+ )
248
+ self.out_beta = nn.Sequential(
249
+ OutConv(64, 128),
250
+ OutConv(128, 1),
251
+ nn.ReLU()
252
+ )
253
+ def forward(self, x):
254
+ x1 = self.inc(x)
255
+ x2 = self.down1(x1)
256
+ x3 = self.down2(x2)
257
+ x4 = self.down3(x3)
258
+ x5 = self.down4(x4)
259
+ x = self.up1(x5, x4)
260
+ x = self.up2(x, x3)
261
+ x = self.up3(x, x2)
262
+ x = self.up4(x, x1)
263
+ y_mean, y_alpha, y_beta = self.out_mean(x), \
264
+ self.out_alpha(x), self.out_beta(x)
265
+ return y_mean, y_alpha, y_beta
266
+
267
+ class ResidualBlock(nn.Module):
268
+ def __init__(self, in_features):
269
+ super(ResidualBlock, self).__init__()
270
+ conv_block = [
271
+ nn.ReflectionPad2d(1),
272
+ nn.Conv2d(in_features, in_features, 3),
273
+ nn.InstanceNorm2d(in_features),
274
+ nn.ReLU(inplace=True),
275
+ nn.ReflectionPad2d(1),
276
+ nn.Conv2d(in_features, in_features, 3),
277
+ nn.InstanceNorm2d(in_features)
278
+ ]
279
+ self.conv_block = nn.Sequential(*conv_block)
280
+ def forward(self, x):
281
+ return x + self.conv_block(x)
282
+
283
+ class Generator(nn.Module):
284
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9):
285
+ super(Generator, self).__init__()
286
+ # Initial convolution block
287
+ model = [
288
+ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
289
+ nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
290
+ ]
291
+ # Downsampling
292
+ in_features = 64
293
+ out_features = in_features*2
294
+ for _ in range(2):
295
+ model += [
296
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
297
+ nn.InstanceNorm2d(out_features),
298
+ nn.ReLU(inplace=True)
299
+ ]
300
+ in_features = out_features
301
+ out_features = in_features*2
302
+ # Residual blocks
303
+ for _ in range(n_residual_blocks):
304
+ model += [ResidualBlock(in_features)]
305
+ # Upsampling
306
+ out_features = in_features//2
307
+ for _ in range(2):
308
+ model += [
309
+ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
310
+ nn.InstanceNorm2d(out_features),
311
+ nn.ReLU(inplace=True)
312
+ ]
313
+ in_features = out_features
314
+ out_features = in_features//2
315
+ # Output layer
316
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
317
+ self.model = nn.Sequential(*model)
318
+ def forward(self, x):
319
+ return self.model(x)
320
+
321
+
322
+ class ResnetGenerator(nn.Module):
323
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
324
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
325
+ """
326
+
327
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
328
+ """Construct a Resnet-based generator
329
+ Parameters:
330
+ input_nc (int) -- the number of channels in input images
331
+ output_nc (int) -- the number of channels in output images
332
+ ngf (int) -- the number of filters in the last conv layer
333
+ norm_layer -- normalization layer
334
+ use_dropout (bool) -- if use dropout layers
335
+ n_blocks (int) -- the number of ResNet blocks
336
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
337
+ """
338
+ assert(n_blocks >= 0)
339
+ super(ResnetGenerator, self).__init__()
340
+ if type(norm_layer) == functools.partial:
341
+ use_bias = norm_layer.func == nn.InstanceNorm2d
342
+ else:
343
+ use_bias = norm_layer == nn.InstanceNorm2d
344
+
345
+ model = [nn.ReflectionPad2d(3),
346
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
347
+ norm_layer(ngf),
348
+ nn.ReLU(True)]
349
+
350
+ n_downsampling = 2
351
+ for i in range(n_downsampling): # add downsampling layers
352
+ mult = 2 ** i
353
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
354
+ norm_layer(ngf * mult * 2),
355
+ nn.ReLU(True)]
356
+
357
+ mult = 2 ** n_downsampling
358
+ for i in range(n_blocks): # add ResNet blocks
359
+
360
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
361
+
362
+ for i in range(n_downsampling): # add upsampling layers
363
+ mult = 2 ** (n_downsampling - i)
364
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
365
+ kernel_size=3, stride=2,
366
+ padding=1, output_padding=1,
367
+ bias=use_bias),
368
+ norm_layer(int(ngf * mult / 2)),
369
+ nn.ReLU(True)]
370
+ model += [nn.ReflectionPad2d(3)]
371
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
372
+ model += [nn.Tanh()]
373
+
374
+ self.model = nn.Sequential(*model)
375
+
376
+ def forward(self, input):
377
+ """Standard forward"""
378
+ return self.model(input)
379
+
380
+
381
+ class ResnetBlock(nn.Module):
382
+ """Define a Resnet block"""
383
+
384
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
385
+ """Initialize the Resnet block
386
+ A resnet block is a conv block with skip connections
387
+ We construct a conv block with build_conv_block function,
388
+ and implement skip connections in <forward> function.
389
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
390
+ """
391
+ super(ResnetBlock, self).__init__()
392
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
393
+
394
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
395
+ """Construct a convolutional block.
396
+ Parameters:
397
+ dim (int) -- the number of channels in the conv layer.
398
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
399
+ norm_layer -- normalization layer
400
+ use_dropout (bool) -- if use dropout layers.
401
+ use_bias (bool) -- if the conv layer uses bias or not
402
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
403
+ """
404
+ conv_block = []
405
+ p = 0
406
+ if padding_type == 'reflect':
407
+ conv_block += [nn.ReflectionPad2d(1)]
408
+ elif padding_type == 'replicate':
409
+ conv_block += [nn.ReplicationPad2d(1)]
410
+ elif padding_type == 'zero':
411
+ p = 1
412
+ else:
413
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
414
+
415
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
416
+ if use_dropout:
417
+ conv_block += [nn.Dropout(0.5)]
418
+
419
+ p = 0
420
+ if padding_type == 'reflect':
421
+ conv_block += [nn.ReflectionPad2d(1)]
422
+ elif padding_type == 'replicate':
423
+ conv_block += [nn.ReplicationPad2d(1)]
424
+ elif padding_type == 'zero':
425
+ p = 1
426
+ else:
427
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
428
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
429
+
430
+ return nn.Sequential(*conv_block)
431
+
432
+ def forward(self, x):
433
+ """Forward function (with skip connections)"""
434
+ out = x + self.conv_block(x) # add skip connections
435
+ return out
436
+
437
+ ### discriminator
438
+ class NLayerDiscriminator(nn.Module):
439
+ """Defines a PatchGAN discriminator"""
440
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
441
+ """Construct a PatchGAN discriminator
442
+ Parameters:
443
+ input_nc (int) -- the number of channels in input images
444
+ ndf (int) -- the number of filters in the last conv layer
445
+ n_layers (int) -- the number of conv layers in the discriminator
446
+ norm_layer -- normalization layer
447
+ """
448
+ super(NLayerDiscriminator, self).__init__()
449
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
450
+ use_bias = norm_layer.func == nn.InstanceNorm2d
451
+ else:
452
+ use_bias = norm_layer == nn.InstanceNorm2d
453
+ kw = 4
454
+ padw = 1
455
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
456
+ nf_mult = 1
457
+ nf_mult_prev = 1
458
+ for n in range(1, n_layers): # gradually increase the number of filters
459
+ nf_mult_prev = nf_mult
460
+ nf_mult = min(2 ** n, 8)
461
+ sequence += [
462
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
463
+ norm_layer(ndf * nf_mult),
464
+ nn.LeakyReLU(0.2, True)
465
+ ]
466
+ nf_mult_prev = nf_mult
467
+ nf_mult = min(2 ** n_layers, 8)
468
+ sequence += [
469
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
470
+ norm_layer(ndf * nf_mult),
471
+ nn.LeakyReLU(0.2, True)
472
+ ]
473
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
474
+ self.model = nn.Sequential(*sequence)
475
+ def forward(self, input):
476
+ """Standard forward."""
477
+ return self.model(input)
requirements.txt ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.8.1
2
+ aiosignal==1.2.0
3
+ albumentations
4
+ analytics-python==1.4.0
5
+ anyio==3.6.1
6
+ argon2-cffi==21.3.0
7
+ argon2-cffi-bindings==21.2.0
8
+ asttokens==2.0.5
9
+ async-timeout==4.0.2
10
+ attrs==21.4.0
11
+ Babel==2.10.1
12
+ backcall==0.2.0
13
+ backoff==1.10.0
14
+ bcrypt==3.2.2
15
+ beautifulsoup4==4.11.1
16
+ bleach==5.0.0
17
+ brotlipy==0.7.0
18
+ certifi
19
+ cffi
20
+ charset-normalizer
21
+ click==8.1.3
22
+ cloudpickle
23
+ cryptography
24
+ cycler==0.11.0
25
+ cytoolz==0.11.2
26
+ dask
27
+ debugpy==1.6.0
28
+ decorator==5.1.1
29
+ defusedxml==0.7.1
30
+ entrypoints==0.4
31
+ executing==0.8.3
32
+ fastapi==0.78.0
33
+ fastjsonschema==2.15.3
34
+ ffmpy==0.3.0
35
+ filelock==3.7.1
36
+ fire==0.4.0
37
+ fonttools==4.33.3
38
+ frozenlist==1.3.0
39
+ fsspec
40
+ ftfy==6.1.1
41
+ gdown==4.5.1
42
+ gradio==3.0.24
43
+ h11==0.12.0
44
+ httpcore==0.15.0
45
+ httpx==0.23.0
46
+ idna
47
+ imagecodecs
48
+ imageio
49
+ ipykernel==6.13.0
50
+ ipython==8.4.0
51
+ ipython-genutils==0.2.0
52
+ jedi==0.18.1
53
+ Jinja2==3.1.2
54
+ joblib
55
+ json5==0.9.8
56
+ jsonschema==4.6.0
57
+ jupyter-client==7.3.1
58
+ jupyter-core==4.10.0
59
+ jupyter-server==1.17.0
60
+ jupyterlab==3.4.2
61
+ jupyterlab-pygments==0.2.2
62
+ jupyterlab-server==2.14.0
63
+ kiwisolver==1.4.2
64
+ kornia==0.6.5
65
+ linkify-it-py==1.0.3
66
+ locket
67
+ markdown-it-py==2.1.0
68
+ MarkupSafe==2.1.1
69
+ matplotlib==3.5.2
70
+ matplotlib-inline==0.1.3
71
+ mdit-py-plugins==0.3.0
72
+ mdurl==0.1.1
73
+ mistune==0.8.4
74
+ mkl-fft==1.3.1
75
+ mkl-random
76
+ mkl-service==2.4.0
77
+ mltk==0.0.5
78
+ monotonic==1.6
79
+ multidict==6.0.2
80
+ munch==2.5.0
81
+ nbclassic==0.3.7
82
+ nbclient==0.6.4
83
+ nbconvert==6.5.0
84
+ nbformat==5.4.0
85
+ nest-asyncio==1.5.5
86
+ networkx
87
+ nltk==3.7
88
+ notebook==6.4.11
89
+ notebook-shim==0.1.0
90
+ ntk==1.1.3
91
+ numpy
92
+ opencv-python==4.6.0.66
93
+ orjson==3.7.7
94
+ packaging
95
+ pandas==1.4.2
96
+ pandocfilters==1.5.0
97
+ paramiko==2.11.0
98
+ parso==0.8.3
99
+ partd
100
+ pexpect==4.8.0
101
+ pickleshare==0.7.5
102
+ Pillow==9.0.1
103
+ prometheus-client==0.14.1
104
+ prompt-toolkit==3.0.29
105
+ psutil==5.9.1
106
+ ptyprocess==0.7.0
107
+ pure-eval==0.2.2
108
+ pycocotools==2.0.4
109
+ pycparser
110
+ pycryptodome==3.15.0
111
+ pydantic==1.9.1
112
+ pydub==0.25.1
113
+ Pygments==2.12.0
114
+ PyNaCl==1.5.0
115
+ pyOpenSSL
116
+ pyparsing
117
+ pyrsistent==0.18.1
118
+ PySocks
119
+ python-dateutil==2.8.2
120
+ python-multipart==0.0.5
121
+ pytz==2022.1
122
+ PyWavelets
123
+ PyYAML
124
+ pyzmq==23.1.0
125
+ qudida
126
+ regex==2022.6.2
127
+ requests
128
+ rfc3986==1.5.0
129
+ scikit-image
130
+ scikit-learn
131
+ scipy
132
+ seaborn==0.11.2
133
+ Send2Trash==1.8.0
134
+ six
135
+ sniffio==1.2.0
136
+ soupsieve==2.3.2.post1
137
+ stack-data==0.2.0
138
+ starlette==0.19.1
139
+ termcolor==1.1.0
140
+ terminado==0.15.0
141
+ threadpoolctl
142
+ tifffile
143
+ tinycss2==1.1.1
144
+ toolz
145
+ torch==1.11.0
146
+ torchaudio==0.11.0
147
+ torchvision==0.12.0
148
+ tornado==6.1
149
+ tqdm==4.64.0
150
+ traitlets==5.2.2.post1
151
+ typing_extensions
152
+ uc-micro-py==1.0.1
153
+ urllib3
154
+ uvicorn==0.18.2
155
+ wcwidth==0.2.5
156
+ webencodings==0.5.1
157
+ websocket-client==1.3.2
158
+ yarl==1.7.2
src/.gitkeep ADDED
File without changes
src/README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: BayesCap
3
+ emoji: 🔥
4
+ colorFrom: indigo
5
+ colorTo: purple
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+ # Configuration
11
+ `title`: _string_
12
+ Display title for the Space
13
+ `emoji`: _string_
14
+ Space emoji (emoji-only character allowed)
15
+ `colorFrom`: _string_
16
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
17
+ `colorTo`: _string_
18
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
19
+ `sdk`: _string_
20
+ Can be either `gradio` or `streamlit`
21
+ `app_file`: _string_
22
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
23
+ Path is relative to the root of the repository.
24
+
25
+ `pinned`: _boolean_
26
+ Whether the Space stays on top of your list.
src/__pycache__/ds.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
src/__pycache__/losses.cpython-310.pyc ADDED
Binary file (4.17 kB). View file
 
src/__pycache__/networks_SRGAN.cpython-310.pyc ADDED
Binary file (6.99 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (34 kB). View file
 
src/app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib import cm
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.models as models
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode as IMode
13
+
14
+ from PIL import Image
15
+
16
+ from ds import *
17
+ from losses import *
18
+ from networks_SRGAN import *
19
+ from utils import *
20
+
21
+
22
+ NetG = Generator()
23
+ model_parameters = filter(lambda p: True, NetG.parameters())
24
+ params = sum([np.prod(p.size()) for p in model_parameters])
25
+ print("Number of Parameters:",params)
26
+ NetC = BayesCap(in_channels=3, out_channels=3)
27
+
28
+
29
+ NetG = Generator()
30
+ NetG.load_state_dict(torch.load('../ckpt/srgan-ImageNet-bc347d67.pth', map_location='cuda:0'))
31
+ NetG.to('cuda')
32
+ NetG.eval()
33
+
34
+ NetC = BayesCap(in_channels=3, out_channels=3)
35
+ NetC.load_state_dict(torch.load('../ckpt/BayesCap_SRGAN_best.pth', map_location='cuda:0'))
36
+ NetC.to('cuda')
37
+ NetC.eval()
38
+
39
+ def tensor01_to_pil(xt):
40
+ r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
41
+ return r
42
+
43
+
44
+ def predict(img):
45
+ """
46
+ img: image
47
+ """
48
+ image_size = (256,256)
49
+ upscale_factor = 4
50
+ lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
51
+ # lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
52
+
53
+ img = Image.fromarray(np.array(img))
54
+ img = lr_transforms(img)
55
+ lr_tensor = utils.image2tensor(img, range_norm=False, half=False)
56
+
57
+ device = 'cuda'
58
+ dtype = torch.cuda.FloatTensor
59
+ xLR = lr_tensor.to(device).unsqueeze(0)
60
+ xLR = xLR.type(dtype)
61
+ # pass them through the network
62
+ with torch.no_grad():
63
+ xSR = NetG(xLR)
64
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
65
+
66
+ a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
67
+ b_map = xSRC_beta[0].to('cpu').data
68
+ u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
69
+
70
+
71
+ x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
72
+
73
+ x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
74
+
75
+ #im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
76
+
77
+ a_map = torch.clamp(a_map, min=0, max=0.1)
78
+ a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
79
+ x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
80
+
81
+ b_map = torch.clamp(b_map, min=0.45, max=0.75)
82
+ b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
83
+ x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
84
+
85
+ u_map = torch.clamp(u_map, min=0, max=0.15)
86
+ u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
87
+ x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
88
+
89
+ return x_LR, x_mean, x_alpha, x_beta, x_uncer
90
+
91
+ import gradio as gr
92
+
93
+ title = "BayesCap"
94
+ description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022)"
95
+ article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
96
+
97
+
98
+ gr.Interface(
99
+ fn=predict,
100
+ inputs=gr.inputs.Image(type='pil', label="Orignal"),
101
+ outputs=[
102
+ gr.outputs.Image(type='pil', label="Low-res"),
103
+ gr.outputs.Image(type='pil', label="Super-res"),
104
+ gr.outputs.Image(type='pil', label="Alpha"),
105
+ gr.outputs.Image(type='pil', label="Beta"),
106
+ gr.outputs.Image(type='pil', label="Uncertainty")
107
+ ],
108
+ title=title,
109
+ description=description,
110
+ article=article,
111
+ examples=[
112
+ ["../demo_examples/baby.png"],
113
+ ["../demo_examples/bird.png"]
114
+ ]
115
+ ).launch(share=True)
src/ds.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import random
4
+ import copy
5
+ import io
6
+ import os
7
+ import numpy as np
8
+ from PIL import Image
9
+ import skimage.transform
10
+ from collections import Counter
11
+
12
+
13
+ import torch
14
+ import torch.utils.data as data
15
+ from torch import Tensor
16
+ from torch.utils.data import Dataset
17
+ from torchvision import transforms
18
+ from torchvision.transforms.functional import InterpolationMode as IMode
19
+
20
+ import utils
21
+
22
+ class ImgDset(Dataset):
23
+ """Customize the data set loading function and prepare low/high resolution image data in advance.
24
+
25
+ Args:
26
+ dataroot (str): Training data set address
27
+ image_size (int): High resolution image size
28
+ upscale_factor (int): Image magnification
29
+ mode (str): Data set loading method, the training data set is for data enhancement,
30
+ and the verification data set is not for data enhancement
31
+
32
+ """
33
+
34
+ def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
35
+ super(ImgDset, self).__init__()
36
+ self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
37
+
38
+ if mode == "train":
39
+ self.hr_transforms = transforms.Compose([
40
+ transforms.RandomCrop(image_size),
41
+ transforms.RandomRotation(90),
42
+ transforms.RandomHorizontalFlip(0.5),
43
+ ])
44
+ else:
45
+ self.hr_transforms = transforms.Resize(image_size)
46
+
47
+ self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
48
+
49
+ def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
50
+ # Read a batch of image data
51
+ image = Image.open(self.filenames[batch_index])
52
+
53
+ # Transform image
54
+ hr_image = self.hr_transforms(image)
55
+ lr_image = self.lr_transforms(hr_image)
56
+
57
+ # Convert image data into Tensor stream format (PyTorch).
58
+ # Note: The range of input and output is between [0, 1]
59
+ lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
60
+ hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
61
+
62
+ return lr_tensor, hr_tensor
63
+
64
+ def __len__(self) -> int:
65
+ return len(self.filenames)
66
+
67
+
68
+ class PairedImages_w_nameList(Dataset):
69
+ '''
70
+ can act as supervised or un-supervised based on flists
71
+ '''
72
+ def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
73
+ self.flist1 = flist1
74
+ self.flist2 = flist2
75
+ self.transform1 = transform1
76
+ self.transform2 = transform2
77
+ self.do_aug = do_aug
78
+ def __getitem__(self, index):
79
+ impath1 = self.flist1[index]
80
+ img1 = Image.open(impath1).convert('RGB')
81
+ impath2 = self.flist2[index]
82
+ img2 = Image.open(impath2).convert('RGB')
83
+
84
+ img1 = utils.image2tensor(img1, range_norm=False, half=False)
85
+ img2 = utils.image2tensor(img2, range_norm=False, half=False)
86
+
87
+ if self.transform1 is not None:
88
+ img1 = self.transform1(img1)
89
+ if self.transform2 is not None:
90
+ img2 = self.transform2(img2)
91
+
92
+ return img1, img2
93
+ def __len__(self):
94
+ return len(self.flist1)
95
+
96
+ class PairedImages_w_nameList_npy(Dataset):
97
+ '''
98
+ can act as supervised or un-supervised based on flists
99
+ '''
100
+ def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
101
+ self.flist1 = flist1
102
+ self.flist2 = flist2
103
+ self.transform1 = transform1
104
+ self.transform2 = transform2
105
+ self.do_aug = do_aug
106
+ def __getitem__(self, index):
107
+ impath1 = self.flist1[index]
108
+ img1 = np.load(impath1)
109
+ impath2 = self.flist2[index]
110
+ img2 = np.load(impath2)
111
+
112
+ if self.transform1 is not None:
113
+ img1 = self.transform1(img1)
114
+ if self.transform2 is not None:
115
+ img2 = self.transform2(img2)
116
+
117
+ return img1, img2
118
+ def __len__(self):
119
+ return len(self.flist1)
120
+
121
+ # def call_paired():
122
+ # root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
123
+ # root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
124
+
125
+ # flist1=glob.glob(root1+'/*/*.png')
126
+ # flist2=glob.glob(root2+'/*/*.png')
127
+
128
+ # dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
129
+
130
+ #### KITTI depth
131
+
132
+ def load_velodyne_points(filename):
133
+ """Load 3D point cloud from KITTI file format
134
+ (adapted from https://github.com/hunse/kitti)
135
+ """
136
+ points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
137
+ points[:, 3] = 1.0 # homogeneous
138
+ return points
139
+
140
+
141
+ def read_calib_file(path):
142
+ """Read KITTI calibration file
143
+ (from https://github.com/hunse/kitti)
144
+ """
145
+ float_chars = set("0123456789.e+- ")
146
+ data = {}
147
+ with open(path, 'r') as f:
148
+ for line in f.readlines():
149
+ key, value = line.split(':', 1)
150
+ value = value.strip()
151
+ data[key] = value
152
+ if float_chars.issuperset(value):
153
+ # try to cast to float array
154
+ try:
155
+ data[key] = np.array(list(map(float, value.split(' '))))
156
+ except ValueError:
157
+ # casting error: data[key] already eq. value, so pass
158
+ pass
159
+
160
+ return data
161
+
162
+
163
+ def sub2ind(matrixSize, rowSub, colSub):
164
+ """Convert row, col matrix subscripts to linear indices
165
+ """
166
+ m, n = matrixSize
167
+ return rowSub * (n-1) + colSub - 1
168
+
169
+
170
+ def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
171
+ """Generate a depth map from velodyne data
172
+ """
173
+ # load calibration files
174
+ cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
175
+ velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
176
+ velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
177
+ velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
178
+
179
+ # get image shape
180
+ im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
181
+
182
+ # compute projection matrix velodyne->image plane
183
+ R_cam2rect = np.eye(4)
184
+ R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
185
+ P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
186
+ P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
187
+
188
+ # load velodyne points and remove all behind image plane (approximation)
189
+ # each row of the velodyne data is forward, left, up, reflectance
190
+ velo = load_velodyne_points(velo_filename)
191
+ velo = velo[velo[:, 0] >= 0, :]
192
+
193
+ # project the points to the camera
194
+ velo_pts_im = np.dot(P_velo2im, velo.T).T
195
+ velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
196
+
197
+ if vel_depth:
198
+ velo_pts_im[:, 2] = velo[:, 0]
199
+
200
+ # check if in bounds
201
+ # use minus 1 to get the exact same value as KITTI matlab code
202
+ velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
203
+ velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
204
+ val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
205
+ val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
206
+ velo_pts_im = velo_pts_im[val_inds, :]
207
+
208
+ # project to image
209
+ depth = np.zeros((im_shape[:2]))
210
+ depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
211
+
212
+ # find the duplicate points and choose the closest depth
213
+ inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
214
+ dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
215
+ for dd in dupe_inds:
216
+ pts = np.where(inds == dd)[0]
217
+ x_loc = int(velo_pts_im[pts[0], 0])
218
+ y_loc = int(velo_pts_im[pts[0], 1])
219
+ depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
220
+ depth[depth < 0] = 0
221
+
222
+ return depth
223
+
224
+ def pil_loader(path):
225
+ # open path as file to avoid ResourceWarning
226
+ # (https://github.com/python-pillow/Pillow/issues/835)
227
+ with open(path, 'rb') as f:
228
+ with Image.open(f) as img:
229
+ return img.convert('RGB')
230
+
231
+
232
+ class MonoDataset(data.Dataset):
233
+ """Superclass for monocular dataloaders
234
+
235
+ Args:
236
+ data_path
237
+ filenames
238
+ height
239
+ width
240
+ frame_idxs
241
+ num_scales
242
+ is_train
243
+ img_ext
244
+ """
245
+ def __init__(self,
246
+ data_path,
247
+ filenames,
248
+ height,
249
+ width,
250
+ frame_idxs,
251
+ num_scales,
252
+ is_train=False,
253
+ img_ext='.jpg'):
254
+ super(MonoDataset, self).__init__()
255
+
256
+ self.data_path = data_path
257
+ self.filenames = filenames
258
+ self.height = height
259
+ self.width = width
260
+ self.num_scales = num_scales
261
+ self.interp = Image.ANTIALIAS
262
+
263
+ self.frame_idxs = frame_idxs
264
+
265
+ self.is_train = is_train
266
+ self.img_ext = img_ext
267
+
268
+ self.loader = pil_loader
269
+ self.to_tensor = transforms.ToTensor()
270
+
271
+ # We need to specify augmentations differently in newer versions of torchvision.
272
+ # We first try the newer tuple version; if this fails we fall back to scalars
273
+ try:
274
+ self.brightness = (0.8, 1.2)
275
+ self.contrast = (0.8, 1.2)
276
+ self.saturation = (0.8, 1.2)
277
+ self.hue = (-0.1, 0.1)
278
+ transforms.ColorJitter.get_params(
279
+ self.brightness, self.contrast, self.saturation, self.hue)
280
+ except TypeError:
281
+ self.brightness = 0.2
282
+ self.contrast = 0.2
283
+ self.saturation = 0.2
284
+ self.hue = 0.1
285
+
286
+ self.resize = {}
287
+ for i in range(self.num_scales):
288
+ s = 2 ** i
289
+ self.resize[i] = transforms.Resize((self.height // s, self.width // s),
290
+ interpolation=self.interp)
291
+
292
+ self.load_depth = self.check_depth()
293
+
294
+ def preprocess(self, inputs, color_aug):
295
+ """Resize colour images to the required scales and augment if required
296
+
297
+ We create the color_aug object in advance and apply the same augmentation to all
298
+ images in this item. This ensures that all images input to the pose network receive the
299
+ same augmentation.
300
+ """
301
+ for k in list(inputs):
302
+ frame = inputs[k]
303
+ if "color" in k:
304
+ n, im, i = k
305
+ for i in range(self.num_scales):
306
+ inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
307
+
308
+ for k in list(inputs):
309
+ f = inputs[k]
310
+ if "color" in k:
311
+ n, im, i = k
312
+ inputs[(n, im, i)] = self.to_tensor(f)
313
+ inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
314
+
315
+ def __len__(self):
316
+ return len(self.filenames)
317
+
318
+ def __getitem__(self, index):
319
+ """Returns a single training item from the dataset as a dictionary.
320
+
321
+ Values correspond to torch tensors.
322
+ Keys in the dictionary are either strings or tuples:
323
+
324
+ ("color", <frame_id>, <scale>) for raw colour images,
325
+ ("color_aug", <frame_id>, <scale>) for augmented colour images,
326
+ ("K", scale) or ("inv_K", scale) for camera intrinsics,
327
+ "stereo_T" for camera extrinsics, and
328
+ "depth_gt" for ground truth depth maps.
329
+
330
+ <frame_id> is either:
331
+ an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
332
+ or
333
+ "s" for the opposite image in the stereo pair.
334
+
335
+ <scale> is an integer representing the scale of the image relative to the fullsize image:
336
+ -1 images at native resolution as loaded from disk
337
+ 0 images resized to (self.width, self.height )
338
+ 1 images resized to (self.width // 2, self.height // 2)
339
+ 2 images resized to (self.width // 4, self.height // 4)
340
+ 3 images resized to (self.width // 8, self.height // 8)
341
+ """
342
+ inputs = {}
343
+
344
+ do_color_aug = self.is_train and random.random() > 0.5
345
+ do_flip = self.is_train and random.random() > 0.5
346
+
347
+ line = self.filenames[index].split()
348
+ folder = line[0]
349
+
350
+ if len(line) == 3:
351
+ frame_index = int(line[1])
352
+ else:
353
+ frame_index = 0
354
+
355
+ if len(line) == 3:
356
+ side = line[2]
357
+ else:
358
+ side = None
359
+
360
+ for i in self.frame_idxs:
361
+ if i == "s":
362
+ other_side = {"r": "l", "l": "r"}[side]
363
+ inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
364
+ else:
365
+ inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
366
+
367
+ # adjusting intrinsics to match each scale in the pyramid
368
+ for scale in range(self.num_scales):
369
+ K = self.K.copy()
370
+
371
+ K[0, :] *= self.width // (2 ** scale)
372
+ K[1, :] *= self.height // (2 ** scale)
373
+
374
+ inv_K = np.linalg.pinv(K)
375
+
376
+ inputs[("K", scale)] = torch.from_numpy(K)
377
+ inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
378
+
379
+ if do_color_aug:
380
+ color_aug = transforms.ColorJitter.get_params(
381
+ self.brightness, self.contrast, self.saturation, self.hue)
382
+ else:
383
+ color_aug = (lambda x: x)
384
+
385
+ self.preprocess(inputs, color_aug)
386
+
387
+ for i in self.frame_idxs:
388
+ del inputs[("color", i, -1)]
389
+ del inputs[("color_aug", i, -1)]
390
+
391
+ if self.load_depth:
392
+ depth_gt = self.get_depth(folder, frame_index, side, do_flip)
393
+ inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
394
+ inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
395
+
396
+ if "s" in self.frame_idxs:
397
+ stereo_T = np.eye(4, dtype=np.float32)
398
+ baseline_sign = -1 if do_flip else 1
399
+ side_sign = -1 if side == "l" else 1
400
+ stereo_T[0, 3] = side_sign * baseline_sign * 0.1
401
+
402
+ inputs["stereo_T"] = torch.from_numpy(stereo_T)
403
+
404
+ return inputs
405
+
406
+ def get_color(self, folder, frame_index, side, do_flip):
407
+ raise NotImplementedError
408
+
409
+ def check_depth(self):
410
+ raise NotImplementedError
411
+
412
+ def get_depth(self, folder, frame_index, side, do_flip):
413
+ raise NotImplementedError
414
+
415
+ class KITTIDataset(MonoDataset):
416
+ """Superclass for different types of KITTI dataset loaders
417
+ """
418
+ def __init__(self, *args, **kwargs):
419
+ super(KITTIDataset, self).__init__(*args, **kwargs)
420
+
421
+ # NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
422
+ # To normalize you need to scale the first row by 1 / image_width and the second row
423
+ # by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
424
+ # If your principal point is far from the center you might need to disable the horizontal
425
+ # flip augmentation.
426
+ self.K = np.array([[0.58, 0, 0.5, 0],
427
+ [0, 1.92, 0.5, 0],
428
+ [0, 0, 1, 0],
429
+ [0, 0, 0, 1]], dtype=np.float32)
430
+
431
+ self.full_res_shape = (1242, 375)
432
+ self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
433
+
434
+ def check_depth(self):
435
+ line = self.filenames[0].split()
436
+ scene_name = line[0]
437
+ frame_index = int(line[1])
438
+
439
+ velo_filename = os.path.join(
440
+ self.data_path,
441
+ scene_name,
442
+ "velodyne_points/data/{:010d}.bin".format(int(frame_index)))
443
+
444
+ return os.path.isfile(velo_filename)
445
+
446
+ def get_color(self, folder, frame_index, side, do_flip):
447
+ color = self.loader(self.get_image_path(folder, frame_index, side))
448
+
449
+ if do_flip:
450
+ color = color.transpose(Image.FLIP_LEFT_RIGHT)
451
+
452
+ return color
453
+
454
+
455
+ class KITTIDepthDataset(KITTIDataset):
456
+ """KITTI dataset which uses the updated ground truth depth maps
457
+ """
458
+ def __init__(self, *args, **kwargs):
459
+ super(KITTIDepthDataset, self).__init__(*args, **kwargs)
460
+
461
+ def get_image_path(self, folder, frame_index, side):
462
+ f_str = "{:010d}{}".format(frame_index, self.img_ext)
463
+ image_path = os.path.join(
464
+ self.data_path,
465
+ folder,
466
+ "image_0{}/data".format(self.side_map[side]),
467
+ f_str)
468
+ return image_path
469
+
470
+ def get_depth(self, folder, frame_index, side, do_flip):
471
+ f_str = "{:010d}.png".format(frame_index)
472
+ depth_path = os.path.join(
473
+ self.data_path,
474
+ folder,
475
+ "proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
476
+ f_str)
477
+
478
+ depth_gt = Image.open(depth_path)
479
+ depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
480
+ depth_gt = np.array(depth_gt).astype(np.float32) / 256
481
+
482
+ if do_flip:
483
+ depth_gt = np.fliplr(depth_gt)
484
+
485
+ return depth_gt
src/flagged/Alpha/0.png ADDED
src/flagged/Beta/0.png ADDED
src/flagged/Low-res/0.png ADDED
src/flagged/Orignal/0.png ADDED
src/flagged/Super-res/0.png ADDED
src/flagged/Uncertainty/0.png ADDED
src/flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 'Orignal','Low-res','Super-res','Alpha','Beta','Uncertainty','flag','username','timestamp'
2
+ 'Orignal/0.png','Low-res/0.png','Super-res/0.png','Alpha/0.png','Beta/0.png','Uncertainty/0.png','','','2022-07-09 14:01:12.964411'
src/losses.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from torch import Tensor
6
+
7
+ class ContentLoss(nn.Module):
8
+ """Constructs a content loss function based on the VGG19 network.
9
+ Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
10
+
11
+ Paper reference list:
12
+ -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
13
+ -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
14
+ -`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
15
+
16
+ """
17
+
18
+ def __init__(self) -> None:
19
+ super(ContentLoss, self).__init__()
20
+ # Load the VGG19 model trained on the ImageNet dataset.
21
+ vgg19 = models.vgg19(pretrained=True).eval()
22
+ # Extract the thirty-sixth layer output in the VGG19 model as the content loss.
23
+ self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
24
+ # Freeze model parameters.
25
+ for parameters in self.feature_extractor.parameters():
26
+ parameters.requires_grad = False
27
+
28
+ # The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
29
+ self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
30
+ self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
31
+
32
+ def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
33
+ # Standardized operations
34
+ sr = sr.sub(self.mean).div(self.std)
35
+ hr = hr.sub(self.mean).div(self.std)
36
+
37
+ # Find the feature map difference between the two images
38
+ loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
39
+
40
+ return loss
41
+
42
+
43
+ class GenGaussLoss(nn.Module):
44
+ def __init__(
45
+ self, reduction='mean',
46
+ alpha_eps = 1e-4, beta_eps=1e-4,
47
+ resi_min = 1e-4, resi_max=1e3
48
+ ) -> None:
49
+ super(GenGaussLoss, self).__init__()
50
+ self.reduction = reduction
51
+ self.alpha_eps = alpha_eps
52
+ self.beta_eps = beta_eps
53
+ self.resi_min = resi_min
54
+ self.resi_max = resi_max
55
+
56
+ def forward(
57
+ self,
58
+ mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
59
+ ):
60
+ one_over_alpha1 = one_over_alpha + self.alpha_eps
61
+ beta1 = beta + self.beta_eps
62
+
63
+ resi = torch.abs(mean - target)
64
+ # resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
65
+ resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
66
+ ## check if resi has nans
67
+ if torch.sum(resi != resi) > 0:
68
+ print('resi has nans!!')
69
+ return None
70
+
71
+ log_one_over_alpha = torch.log(one_over_alpha1)
72
+ log_beta = torch.log(beta1)
73
+ lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
74
+
75
+ if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
76
+ print('log_one_over_alpha has nan')
77
+ if torch.sum(lgamma_beta != lgamma_beta) > 0:
78
+ print('lgamma_beta has nan')
79
+ if torch.sum(log_beta != log_beta) > 0:
80
+ print('log_beta has nan')
81
+
82
+ l = resi - log_one_over_alpha + lgamma_beta - log_beta
83
+
84
+ if self.reduction == 'mean':
85
+ return l.mean()
86
+ elif self.reduction == 'sum':
87
+ return l.sum()
88
+ else:
89
+ print('Reduction not supported')
90
+ return None
91
+
92
+ class TempCombLoss(nn.Module):
93
+ def __init__(
94
+ self, reduction='mean',
95
+ alpha_eps = 1e-4, beta_eps=1e-4,
96
+ resi_min = 1e-4, resi_max=1e3
97
+ ) -> None:
98
+ super(TempCombLoss, self).__init__()
99
+ self.reduction = reduction
100
+ self.alpha_eps = alpha_eps
101
+ self.beta_eps = beta_eps
102
+ self.resi_min = resi_min
103
+ self.resi_max = resi_max
104
+
105
+ self.L_GenGauss = GenGaussLoss(
106
+ reduction=self.reduction,
107
+ alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
108
+ resi_min=self.resi_min, resi_max=self.resi_max
109
+ )
110
+ self.L_l1 = nn.L1Loss(reduction=self.reduction)
111
+
112
+ def forward(
113
+ self,
114
+ mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
115
+ T1: float, T2: float
116
+ ):
117
+ l1 = self.L_l1(mean, target)
118
+ l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
119
+ l = T1*l1 + T2*l2
120
+
121
+ return l
122
+
123
+
124
+ # x1 = torch.randn(4,3,32,32)
125
+ # x2 = torch.rand(4,3,32,32)
126
+ # x3 = torch.rand(4,3,32,32)
127
+ # x4 = torch.randn(4,3,32,32)
128
+
129
+ # L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
130
+ # L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
131
+ # print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
src/networks_SRGAN.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from torch import Tensor
6
+
7
+ # __all__ = [
8
+ # "ResidualConvBlock",
9
+ # "Discriminator", "Generator",
10
+ # ]
11
+
12
+
13
+ class ResidualConvBlock(nn.Module):
14
+ """Implements residual conv function.
15
+
16
+ Args:
17
+ channels (int): Number of channels in the input image.
18
+ """
19
+
20
+ def __init__(self, channels: int) -> None:
21
+ super(ResidualConvBlock, self).__init__()
22
+ self.rcb = nn.Sequential(
23
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
24
+ nn.BatchNorm2d(channels),
25
+ nn.PReLU(),
26
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
27
+ nn.BatchNorm2d(channels),
28
+ )
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ identity = x
32
+
33
+ out = self.rcb(x)
34
+ out = torch.add(out, identity)
35
+
36
+ return out
37
+
38
+
39
+ class Discriminator(nn.Module):
40
+ def __init__(self) -> None:
41
+ super(Discriminator, self).__init__()
42
+ self.features = nn.Sequential(
43
+ # input size. (3) x 96 x 96
44
+ nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
45
+ nn.LeakyReLU(0.2, True),
46
+ # state size. (64) x 48 x 48
47
+ nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
48
+ nn.BatchNorm2d(64),
49
+ nn.LeakyReLU(0.2, True),
50
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
51
+ nn.BatchNorm2d(128),
52
+ nn.LeakyReLU(0.2, True),
53
+ # state size. (128) x 24 x 24
54
+ nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
55
+ nn.BatchNorm2d(128),
56
+ nn.LeakyReLU(0.2, True),
57
+ nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
58
+ nn.BatchNorm2d(256),
59
+ nn.LeakyReLU(0.2, True),
60
+ # state size. (256) x 12 x 12
61
+ nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
62
+ nn.BatchNorm2d(256),
63
+ nn.LeakyReLU(0.2, True),
64
+ nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
65
+ nn.BatchNorm2d(512),
66
+ nn.LeakyReLU(0.2, True),
67
+ # state size. (512) x 6 x 6
68
+ nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
69
+ nn.BatchNorm2d(512),
70
+ nn.LeakyReLU(0.2, True),
71
+ )
72
+
73
+ self.classifier = nn.Sequential(
74
+ nn.Linear(512 * 6 * 6, 1024),
75
+ nn.LeakyReLU(0.2, True),
76
+ nn.Linear(1024, 1),
77
+ )
78
+
79
+ def forward(self, x: Tensor) -> Tensor:
80
+ out = self.features(x)
81
+ out = torch.flatten(out, 1)
82
+ out = self.classifier(out)
83
+
84
+ return out
85
+
86
+
87
+ class Generator(nn.Module):
88
+ def __init__(self) -> None:
89
+ super(Generator, self).__init__()
90
+ # First conv layer.
91
+ self.conv_block1 = nn.Sequential(
92
+ nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
93
+ nn.PReLU(),
94
+ )
95
+
96
+ # Features trunk blocks.
97
+ trunk = []
98
+ for _ in range(16):
99
+ trunk.append(ResidualConvBlock(64))
100
+ self.trunk = nn.Sequential(*trunk)
101
+
102
+ # Second conv layer.
103
+ self.conv_block2 = nn.Sequential(
104
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
105
+ nn.BatchNorm2d(64),
106
+ )
107
+
108
+ # Upscale conv block.
109
+ self.upsampling = nn.Sequential(
110
+ nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
111
+ nn.PixelShuffle(2),
112
+ nn.PReLU(),
113
+ nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
114
+ nn.PixelShuffle(2),
115
+ nn.PReLU(),
116
+ )
117
+
118
+ # Output layer.
119
+ self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
120
+
121
+ # Initialize neural network weights.
122
+ self._initialize_weights()
123
+
124
+ def forward(self, x: Tensor, dop=None) -> Tensor:
125
+ if not dop:
126
+ return self._forward_impl(x)
127
+ else:
128
+ return self._forward_w_dop_impl(x, dop)
129
+
130
+ # Support torch.script function.
131
+ def _forward_impl(self, x: Tensor) -> Tensor:
132
+ out1 = self.conv_block1(x)
133
+ out = self.trunk(out1)
134
+ out2 = self.conv_block2(out)
135
+ out = torch.add(out1, out2)
136
+ out = self.upsampling(out)
137
+ out = self.conv_block3(out)
138
+
139
+ return out
140
+
141
+ def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
142
+ out1 = self.conv_block1(x)
143
+ out = self.trunk(out1)
144
+ out2 = F.dropout2d(self.conv_block2(out), p=dop)
145
+ out = torch.add(out1, out2)
146
+ out = self.upsampling(out)
147
+ out = self.conv_block3(out)
148
+
149
+ return out
150
+
151
+ def _initialize_weights(self) -> None:
152
+ for module in self.modules():
153
+ if isinstance(module, nn.Conv2d):
154
+ nn.init.kaiming_normal_(module.weight)
155
+ if module.bias is not None:
156
+ nn.init.constant_(module.bias, 0)
157
+ elif isinstance(module, nn.BatchNorm2d):
158
+ nn.init.constant_(module.weight, 1)
159
+
160
+
161
+ #### BayesCap
162
+ class BayesCap(nn.Module):
163
+ def __init__(self, in_channels=3, out_channels=3) -> None:
164
+ super(BayesCap, self).__init__()
165
+ # First conv layer.
166
+ self.conv_block1 = nn.Sequential(
167
+ nn.Conv2d(
168
+ in_channels, 64,
169
+ kernel_size=9, stride=1, padding=4
170
+ ),
171
+ nn.PReLU(),
172
+ )
173
+
174
+ # Features trunk blocks.
175
+ trunk = []
176
+ for _ in range(16):
177
+ trunk.append(ResidualConvBlock(64))
178
+ self.trunk = nn.Sequential(*trunk)
179
+
180
+ # Second conv layer.
181
+ self.conv_block2 = nn.Sequential(
182
+ nn.Conv2d(
183
+ 64, 64,
184
+ kernel_size=3, stride=1, padding=1, bias=False
185
+ ),
186
+ nn.BatchNorm2d(64),
187
+ )
188
+
189
+ # Output layer.
190
+ self.conv_block3_mu = nn.Conv2d(
191
+ 64, out_channels=out_channels,
192
+ kernel_size=9, stride=1, padding=4
193
+ )
194
+ self.conv_block3_alpha = nn.Sequential(
195
+ nn.Conv2d(
196
+ 64, 64,
197
+ kernel_size=9, stride=1, padding=4
198
+ ),
199
+ nn.PReLU(),
200
+ nn.Conv2d(
201
+ 64, 64,
202
+ kernel_size=9, stride=1, padding=4
203
+ ),
204
+ nn.PReLU(),
205
+ nn.Conv2d(
206
+ 64, 1,
207
+ kernel_size=9, stride=1, padding=4
208
+ ),
209
+ nn.ReLU(),
210
+ )
211
+ self.conv_block3_beta = nn.Sequential(
212
+ nn.Conv2d(
213
+ 64, 64,
214
+ kernel_size=9, stride=1, padding=4
215
+ ),
216
+ nn.PReLU(),
217
+ nn.Conv2d(
218
+ 64, 64,
219
+ kernel_size=9, stride=1, padding=4
220
+ ),
221
+ nn.PReLU(),
222
+ nn.Conv2d(
223
+ 64, 1,
224
+ kernel_size=9, stride=1, padding=4
225
+ ),
226
+ nn.ReLU(),
227
+ )
228
+
229
+ # Initialize neural network weights.
230
+ self._initialize_weights()
231
+
232
+ def forward(self, x: Tensor) -> Tensor:
233
+ return self._forward_impl(x)
234
+
235
+ # Support torch.script function.
236
+ def _forward_impl(self, x: Tensor) -> Tensor:
237
+ out1 = self.conv_block1(x)
238
+ out = self.trunk(out1)
239
+ out2 = self.conv_block2(out)
240
+ out = out1 + out2
241
+ out_mu = self.conv_block3_mu(out)
242
+ out_alpha = self.conv_block3_alpha(out)
243
+ out_beta = self.conv_block3_beta(out)
244
+ return out_mu, out_alpha, out_beta
245
+
246
+ def _initialize_weights(self) -> None:
247
+ for module in self.modules():
248
+ if isinstance(module, nn.Conv2d):
249
+ nn.init.kaiming_normal_(module.weight)
250
+ if module.bias is not None:
251
+ nn.init.constant_(module.bias, 0)
252
+ elif isinstance(module, nn.BatchNorm2d):
253
+ nn.init.constant_(module.weight, 1)
254
+
255
+
256
+ class BayesCap_noID(nn.Module):
257
+ def __init__(self, in_channels=3, out_channels=3) -> None:
258
+ super(BayesCap_noID, self).__init__()
259
+ # First conv layer.
260
+ self.conv_block1 = nn.Sequential(
261
+ nn.Conv2d(
262
+ in_channels, 64,
263
+ kernel_size=9, stride=1, padding=4
264
+ ),
265
+ nn.PReLU(),
266
+ )
267
+
268
+ # Features trunk blocks.
269
+ trunk = []
270
+ for _ in range(16):
271
+ trunk.append(ResidualConvBlock(64))
272
+ self.trunk = nn.Sequential(*trunk)
273
+
274
+ # Second conv layer.
275
+ self.conv_block2 = nn.Sequential(
276
+ nn.Conv2d(
277
+ 64, 64,
278
+ kernel_size=3, stride=1, padding=1, bias=False
279
+ ),
280
+ nn.BatchNorm2d(64),
281
+ )
282
+
283
+ # Output layer.
284
+ # self.conv_block3_mu = nn.Conv2d(
285
+ # 64, out_channels=out_channels,
286
+ # kernel_size=9, stride=1, padding=4
287
+ # )
288
+ self.conv_block3_alpha = nn.Sequential(
289
+ nn.Conv2d(
290
+ 64, 64,
291
+ kernel_size=9, stride=1, padding=4
292
+ ),
293
+ nn.PReLU(),
294
+ nn.Conv2d(
295
+ 64, 64,
296
+ kernel_size=9, stride=1, padding=4
297
+ ),
298
+ nn.PReLU(),
299
+ nn.Conv2d(
300
+ 64, 1,
301
+ kernel_size=9, stride=1, padding=4
302
+ ),
303
+ nn.ReLU(),
304
+ )
305
+ self.conv_block3_beta = nn.Sequential(
306
+ nn.Conv2d(
307
+ 64, 64,
308
+ kernel_size=9, stride=1, padding=4
309
+ ),
310
+ nn.PReLU(),
311
+ nn.Conv2d(
312
+ 64, 64,
313
+ kernel_size=9, stride=1, padding=4
314
+ ),
315
+ nn.PReLU(),
316
+ nn.Conv2d(
317
+ 64, 1,
318
+ kernel_size=9, stride=1, padding=4
319
+ ),
320
+ nn.ReLU(),
321
+ )
322
+
323
+ # Initialize neural network weights.
324
+ self._initialize_weights()
325
+
326
+ def forward(self, x: Tensor) -> Tensor:
327
+ return self._forward_impl(x)
328
+
329
+ # Support torch.script function.
330
+ def _forward_impl(self, x: Tensor) -> Tensor:
331
+ out1 = self.conv_block1(x)
332
+ out = self.trunk(out1)
333
+ out2 = self.conv_block2(out)
334
+ out = out1 + out2
335
+ # out_mu = self.conv_block3_mu(out)
336
+ out_alpha = self.conv_block3_alpha(out)
337
+ out_beta = self.conv_block3_beta(out)
338
+ return out_alpha, out_beta
339
+
340
+ def _initialize_weights(self) -> None:
341
+ for module in self.modules():
342
+ if isinstance(module, nn.Conv2d):
343
+ nn.init.kaiming_normal_(module.weight)
344
+ if module.bias is not None:
345
+ nn.init.constant_(module.bias, 0)
346
+ elif isinstance(module, nn.BatchNorm2d):
347
+ nn.init.constant_(module.weight, 1)
src/networks_T1toT2.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+
6
+ ### components
7
+ class ResConv(nn.Module):
8
+ """
9
+ Residual convolutional block, where
10
+ convolutional block consists: (convolution => [BN] => ReLU) * 3
11
+ residual connection adds the input to the output
12
+ """
13
+ def __init__(self, in_channels, out_channels, mid_channels=None):
14
+ super().__init__()
15
+ if not mid_channels:
16
+ mid_channels = out_channels
17
+ self.double_conv = nn.Sequential(
18
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
19
+ nn.BatchNorm2d(mid_channels),
20
+ nn.ReLU(inplace=True),
21
+ nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
22
+ nn.BatchNorm2d(mid_channels),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
25
+ nn.BatchNorm2d(out_channels),
26
+ nn.ReLU(inplace=True)
27
+ )
28
+ self.double_conv1 = nn.Sequential(
29
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.ReLU(inplace=True),
32
+ )
33
+ def forward(self, x):
34
+ x_in = self.double_conv1(x)
35
+ x1 = self.double_conv(x)
36
+ return self.double_conv(x) + x_in
37
+
38
+ class Down(nn.Module):
39
+ """Downscaling with maxpool then Resconv"""
40
+ def __init__(self, in_channels, out_channels):
41
+ super().__init__()
42
+ self.maxpool_conv = nn.Sequential(
43
+ nn.MaxPool2d(2),
44
+ ResConv(in_channels, out_channels)
45
+ )
46
+ def forward(self, x):
47
+ return self.maxpool_conv(x)
48
+
49
+ class Up(nn.Module):
50
+ """Upscaling then double conv"""
51
+ def __init__(self, in_channels, out_channels, bilinear=True):
52
+ super().__init__()
53
+ # if bilinear, use the normal convolutions to reduce the number of channels
54
+ if bilinear:
55
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
56
+ self.conv = ResConv(in_channels, out_channels, in_channels // 2)
57
+ else:
58
+ self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
59
+ self.conv = ResConv(in_channels, out_channels)
60
+ def forward(self, x1, x2):
61
+ x1 = self.up(x1)
62
+ # input is CHW
63
+ diffY = x2.size()[2] - x1.size()[2]
64
+ diffX = x2.size()[3] - x1.size()[3]
65
+ x1 = F.pad(
66
+ x1,
67
+ [
68
+ diffX // 2, diffX - diffX // 2,
69
+ diffY // 2, diffY - diffY // 2
70
+ ]
71
+ )
72
+ # if you have padding issues, see
73
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
74
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
75
+ x = torch.cat([x2, x1], dim=1)
76
+ return self.conv(x)
77
+
78
+ class OutConv(nn.Module):
79
+ def __init__(self, in_channels, out_channels):
80
+ super(OutConv, self).__init__()
81
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
82
+ def forward(self, x):
83
+ # return F.relu(self.conv(x))
84
+ return self.conv(x)
85
+
86
+ ##### The composite networks
87
+ class UNet(nn.Module):
88
+ def __init__(self, n_channels, out_channels, bilinear=True):
89
+ super(UNet, self).__init__()
90
+ self.n_channels = n_channels
91
+ self.out_channels = out_channels
92
+ self.bilinear = bilinear
93
+ ####
94
+ self.inc = ResConv(n_channels, 64)
95
+ self.down1 = Down(64, 128)
96
+ self.down2 = Down(128, 256)
97
+ self.down3 = Down(256, 512)
98
+ factor = 2 if bilinear else 1
99
+ self.down4 = Down(512, 1024 // factor)
100
+ self.up1 = Up(1024, 512 // factor, bilinear)
101
+ self.up2 = Up(512, 256 // factor, bilinear)
102
+ self.up3 = Up(256, 128 // factor, bilinear)
103
+ self.up4 = Up(128, 64, bilinear)
104
+ self.outc = OutConv(64, out_channels)
105
+ def forward(self, x):
106
+ x1 = self.inc(x)
107
+ x2 = self.down1(x1)
108
+ x3 = self.down2(x2)
109
+ x4 = self.down3(x3)
110
+ x5 = self.down4(x4)
111
+ x = self.up1(x5, x4)
112
+ x = self.up2(x, x3)
113
+ x = self.up3(x, x2)
114
+ x = self.up4(x, x1)
115
+ y = self.outc(x)
116
+ return y
117
+
118
+ class CasUNet(nn.Module):
119
+ def __init__(self, n_unet, io_channels, bilinear=True):
120
+ super(CasUNet, self).__init__()
121
+ self.n_unet = n_unet
122
+ self.io_channels = io_channels
123
+ self.bilinear = bilinear
124
+ ####
125
+ self.unet_list = nn.ModuleList()
126
+ for i in range(self.n_unet):
127
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
128
+ def forward(self, x, dop=None):
129
+ y = x
130
+ for i in range(self.n_unet):
131
+ if i==0:
132
+ if dop is not None:
133
+ y = F.dropout2d(self.unet_list[i](y), p=dop)
134
+ else:
135
+ y = self.unet_list[i](y)
136
+ else:
137
+ y = self.unet_list[i](y+x)
138
+ return y
139
+
140
+ class CasUNet_2head(nn.Module):
141
+ def __init__(self, n_unet, io_channels, bilinear=True):
142
+ super(CasUNet_2head, self).__init__()
143
+ self.n_unet = n_unet
144
+ self.io_channels = io_channels
145
+ self.bilinear = bilinear
146
+ ####
147
+ self.unet_list = nn.ModuleList()
148
+ for i in range(self.n_unet):
149
+ if i != self.n_unet-1:
150
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
151
+ else:
152
+ self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
153
+ def forward(self, x):
154
+ y = x
155
+ for i in range(self.n_unet):
156
+ if i==0:
157
+ y = self.unet_list[i](y)
158
+ else:
159
+ y = self.unet_list[i](y+x)
160
+ y_mean, y_sigma = y[0], y[1]
161
+ return y_mean, y_sigma
162
+
163
+ class CasUNet_3head(nn.Module):
164
+ def __init__(self, n_unet, io_channels, bilinear=True):
165
+ super(CasUNet_3head, self).__init__()
166
+ self.n_unet = n_unet
167
+ self.io_channels = io_channels
168
+ self.bilinear = bilinear
169
+ ####
170
+ self.unet_list = nn.ModuleList()
171
+ for i in range(self.n_unet):
172
+ if i != self.n_unet-1:
173
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
174
+ else:
175
+ self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
176
+ def forward(self, x):
177
+ y = x
178
+ for i in range(self.n_unet):
179
+ if i==0:
180
+ y = self.unet_list[i](y)
181
+ else:
182
+ y = self.unet_list[i](y+x)
183
+ y_mean, y_alpha, y_beta = y[0], y[1], y[2]
184
+ return y_mean, y_alpha, y_beta
185
+
186
+ class UNet_2head(nn.Module):
187
+ def __init__(self, n_channels, out_channels, bilinear=True):
188
+ super(UNet_2head, self).__init__()
189
+ self.n_channels = n_channels
190
+ self.out_channels = out_channels
191
+ self.bilinear = bilinear
192
+ ####
193
+ self.inc = ResConv(n_channels, 64)
194
+ self.down1 = Down(64, 128)
195
+ self.down2 = Down(128, 256)
196
+ self.down3 = Down(256, 512)
197
+ factor = 2 if bilinear else 1
198
+ self.down4 = Down(512, 1024 // factor)
199
+ self.up1 = Up(1024, 512 // factor, bilinear)
200
+ self.up2 = Up(512, 256 // factor, bilinear)
201
+ self.up3 = Up(256, 128 // factor, bilinear)
202
+ self.up4 = Up(128, 64, bilinear)
203
+ #per pixel multiple channels may exist
204
+ self.out_mean = OutConv(64, out_channels)
205
+ #variance will always be a single number for a pixel
206
+ self.out_var = nn.Sequential(
207
+ OutConv(64, 128),
208
+ OutConv(128, 1),
209
+ )
210
+ def forward(self, x):
211
+ x1 = self.inc(x)
212
+ x2 = self.down1(x1)
213
+ x3 = self.down2(x2)
214
+ x4 = self.down3(x3)
215
+ x5 = self.down4(x4)
216
+ x = self.up1(x5, x4)
217
+ x = self.up2(x, x3)
218
+ x = self.up3(x, x2)
219
+ x = self.up4(x, x1)
220
+ y_mean, y_var = self.out_mean(x), self.out_var(x)
221
+ return y_mean, y_var
222
+
223
+ class UNet_3head(nn.Module):
224
+ def __init__(self, n_channels, out_channels, bilinear=True):
225
+ super(UNet_3head, self).__init__()
226
+ self.n_channels = n_channels
227
+ self.out_channels = out_channels
228
+ self.bilinear = bilinear
229
+ ####
230
+ self.inc = ResConv(n_channels, 64)
231
+ self.down1 = Down(64, 128)
232
+ self.down2 = Down(128, 256)
233
+ self.down3 = Down(256, 512)
234
+ factor = 2 if bilinear else 1
235
+ self.down4 = Down(512, 1024 // factor)
236
+ self.up1 = Up(1024, 512 // factor, bilinear)
237
+ self.up2 = Up(512, 256 // factor, bilinear)
238
+ self.up3 = Up(256, 128 // factor, bilinear)
239
+ self.up4 = Up(128, 64, bilinear)
240
+ #per pixel multiple channels may exist
241
+ self.out_mean = OutConv(64, out_channels)
242
+ #variance will always be a single number for a pixel
243
+ self.out_alpha = nn.Sequential(
244
+ OutConv(64, 128),
245
+ OutConv(128, 1),
246
+ nn.ReLU()
247
+ )
248
+ self.out_beta = nn.Sequential(
249
+ OutConv(64, 128),
250
+ OutConv(128, 1),
251
+ nn.ReLU()
252
+ )
253
+ def forward(self, x):
254
+ x1 = self.inc(x)
255
+ x2 = self.down1(x1)
256
+ x3 = self.down2(x2)
257
+ x4 = self.down3(x3)
258
+ x5 = self.down4(x4)
259
+ x = self.up1(x5, x4)
260
+ x = self.up2(x, x3)
261
+ x = self.up3(x, x2)
262
+ x = self.up4(x, x1)
263
+ y_mean, y_alpha, y_beta = self.out_mean(x), \
264
+ self.out_alpha(x), self.out_beta(x)
265
+ return y_mean, y_alpha, y_beta
266
+
267
+ class ResidualBlock(nn.Module):
268
+ def __init__(self, in_features):
269
+ super(ResidualBlock, self).__init__()
270
+ conv_block = [
271
+ nn.ReflectionPad2d(1),
272
+ nn.Conv2d(in_features, in_features, 3),
273
+ nn.InstanceNorm2d(in_features),
274
+ nn.ReLU(inplace=True),
275
+ nn.ReflectionPad2d(1),
276
+ nn.Conv2d(in_features, in_features, 3),
277
+ nn.InstanceNorm2d(in_features)
278
+ ]
279
+ self.conv_block = nn.Sequential(*conv_block)
280
+ def forward(self, x):
281
+ return x + self.conv_block(x)
282
+
283
+ class Generator(nn.Module):
284
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9):
285
+ super(Generator, self).__init__()
286
+ # Initial convolution block
287
+ model = [
288
+ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
289
+ nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
290
+ ]
291
+ # Downsampling
292
+ in_features = 64
293
+ out_features = in_features*2
294
+ for _ in range(2):
295
+ model += [
296
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
297
+ nn.InstanceNorm2d(out_features),
298
+ nn.ReLU(inplace=True)
299
+ ]
300
+ in_features = out_features
301
+ out_features = in_features*2
302
+ # Residual blocks
303
+ for _ in range(n_residual_blocks):
304
+ model += [ResidualBlock(in_features)]
305
+ # Upsampling
306
+ out_features = in_features//2
307
+ for _ in range(2):
308
+ model += [
309
+ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
310
+ nn.InstanceNorm2d(out_features),
311
+ nn.ReLU(inplace=True)
312
+ ]
313
+ in_features = out_features
314
+ out_features = in_features//2
315
+ # Output layer
316
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
317
+ self.model = nn.Sequential(*model)
318
+ def forward(self, x):
319
+ return self.model(x)
320
+
321
+
322
+ class ResnetGenerator(nn.Module):
323
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
324
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
325
+ """
326
+
327
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
328
+ """Construct a Resnet-based generator
329
+ Parameters:
330
+ input_nc (int) -- the number of channels in input images
331
+ output_nc (int) -- the number of channels in output images
332
+ ngf (int) -- the number of filters in the last conv layer
333
+ norm_layer -- normalization layer
334
+ use_dropout (bool) -- if use dropout layers
335
+ n_blocks (int) -- the number of ResNet blocks
336
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
337
+ """
338
+ assert(n_blocks >= 0)
339
+ super(ResnetGenerator, self).__init__()
340
+ if type(norm_layer) == functools.partial:
341
+ use_bias = norm_layer.func == nn.InstanceNorm2d
342
+ else:
343
+ use_bias = norm_layer == nn.InstanceNorm2d
344
+
345
+ model = [nn.ReflectionPad2d(3),
346
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
347
+ norm_layer(ngf),
348
+ nn.ReLU(True)]
349
+
350
+ n_downsampling = 2
351
+ for i in range(n_downsampling): # add downsampling layers
352
+ mult = 2 ** i
353
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
354
+ norm_layer(ngf * mult * 2),
355
+ nn.ReLU(True)]
356
+
357
+ mult = 2 ** n_downsampling
358
+ for i in range(n_blocks): # add ResNet blocks
359
+
360
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
361
+
362
+ for i in range(n_downsampling): # add upsampling layers
363
+ mult = 2 ** (n_downsampling - i)
364
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
365
+ kernel_size=3, stride=2,
366
+ padding=1, output_padding=1,
367
+ bias=use_bias),
368
+ norm_layer(int(ngf * mult / 2)),
369
+ nn.ReLU(True)]
370
+ model += [nn.ReflectionPad2d(3)]
371
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
372
+ model += [nn.Tanh()]
373
+
374
+ self.model = nn.Sequential(*model)
375
+
376
+ def forward(self, input):
377
+ """Standard forward"""
378
+ return self.model(input)
379
+
380
+
381
+ class ResnetBlock(nn.Module):
382
+ """Define a Resnet block"""
383
+
384
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
385
+ """Initialize the Resnet block
386
+ A resnet block is a conv block with skip connections
387
+ We construct a conv block with build_conv_block function,
388
+ and implement skip connections in <forward> function.
389
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
390
+ """
391
+ super(ResnetBlock, self).__init__()
392
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
393
+
394
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
395
+ """Construct a convolutional block.
396
+ Parameters:
397
+ dim (int) -- the number of channels in the conv layer.
398
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
399
+ norm_layer -- normalization layer
400
+ use_dropout (bool) -- if use dropout layers.
401
+ use_bias (bool) -- if the conv layer uses bias or not
402
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
403
+ """
404
+ conv_block = []
405
+ p = 0
406
+ if padding_type == 'reflect':
407
+ conv_block += [nn.ReflectionPad2d(1)]
408
+ elif padding_type == 'replicate':
409
+ conv_block += [nn.ReplicationPad2d(1)]
410
+ elif padding_type == 'zero':
411
+ p = 1
412
+ else:
413
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
414
+
415
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
416
+ if use_dropout:
417
+ conv_block += [nn.Dropout(0.5)]
418
+
419
+ p = 0
420
+ if padding_type == 'reflect':
421
+ conv_block += [nn.ReflectionPad2d(1)]
422
+ elif padding_type == 'replicate':
423
+ conv_block += [nn.ReplicationPad2d(1)]
424
+ elif padding_type == 'zero':
425
+ p = 1
426
+ else:
427
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
428
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
429
+
430
+ return nn.Sequential(*conv_block)
431
+
432
+ def forward(self, x):
433
+ """Forward function (with skip connections)"""
434
+ out = x + self.conv_block(x) # add skip connections
435
+ return out
436
+
437
+ ### discriminator
438
+ class NLayerDiscriminator(nn.Module):
439
+ """Defines a PatchGAN discriminator"""
440
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
441
+ """Construct a PatchGAN discriminator
442
+ Parameters:
443
+ input_nc (int) -- the number of channels in input images
444
+ ndf (int) -- the number of filters in the last conv layer
445
+ n_layers (int) -- the number of conv layers in the discriminator
446
+ norm_layer -- normalization layer
447
+ """
448
+ super(NLayerDiscriminator, self).__init__()
449
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
450
+ use_bias = norm_layer.func == nn.InstanceNorm2d
451
+ else:
452
+ use_bias = norm_layer == nn.InstanceNorm2d
453
+ kw = 4
454
+ padw = 1
455
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
456
+ nf_mult = 1
457
+ nf_mult_prev = 1
458
+ for n in range(1, n_layers): # gradually increase the number of filters
459
+ nf_mult_prev = nf_mult
460
+ nf_mult = min(2 ** n, 8)
461
+ sequence += [
462
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
463
+ norm_layer(ndf * nf_mult),
464
+ nn.LeakyReLU(0.2, True)
465
+ ]
466
+ nf_mult_prev = nf_mult
467
+ nf_mult = min(2 ** n_layers, 8)
468
+ sequence += [
469
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
470
+ norm_layer(ndf * nf_mult),
471
+ nn.LeakyReLU(0.2, True)
472
+ ]
473
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
474
+ self.model = nn.Sequential(*sequence)
475
+ def forward(self, input):
476
+ """Standard forward."""
477
+ return self.model(input)
src/utils.py ADDED
@@ -0,0 +1,1273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Optional
3
+ import numpy as np
4
+ import os
5
+ import cv2
6
+ from glob import glob
7
+ from PIL import Image, ImageDraw
8
+ from tqdm import tqdm
9
+ import kornia
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ import albumentations as albu
13
+ import functools
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+ import torchvision as tv
20
+ import torchvision.models as models
21
+ from torchvision import transforms
22
+ from torchvision.transforms import functional as F
23
+ from losses import TempCombLoss
24
+
25
+ ########### DeblurGAN function
26
+ def get_norm_layer(norm_type='instance'):
27
+ if norm_type == 'batch':
28
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
29
+ elif norm_type == 'instance':
30
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
31
+ else:
32
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
33
+ return norm_layer
34
+
35
+ def _array_to_batch(x):
36
+ x = np.transpose(x, (2, 0, 1))
37
+ x = np.expand_dims(x, 0)
38
+ return torch.from_numpy(x)
39
+
40
+ def get_normalize():
41
+ normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
42
+ normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
43
+
44
+ def process(a, b):
45
+ r = normalize(image=a, target=b)
46
+ return r['image'], r['target']
47
+
48
+ return process
49
+
50
+ def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
51
+ x, _ = get_normalize()(x, x)
52
+ if mask is None:
53
+ mask = np.ones_like(x, dtype=np.float32)
54
+ else:
55
+ mask = np.round(mask.astype('float32') / 255)
56
+
57
+ h, w, _ = x.shape
58
+ block_size = 32
59
+ min_height = (h // block_size + 1) * block_size
60
+ min_width = (w // block_size + 1) * block_size
61
+
62
+ pad_params = {'mode': 'constant',
63
+ 'constant_values': 0,
64
+ 'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
65
+ }
66
+ x = np.pad(x, **pad_params)
67
+ mask = np.pad(mask, **pad_params)
68
+
69
+ return map(_array_to_batch, (x, mask)), h, w
70
+
71
+ def postprocess(x: torch.Tensor) -> np.ndarray:
72
+ x, = x
73
+ x = x.detach().cpu().float().numpy()
74
+ x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
75
+ return x.astype('uint8')
76
+
77
+ def sorted_glob(pattern):
78
+ return sorted(glob(pattern))
79
+ ###########
80
+
81
+ def normalize(image: np.ndarray) -> np.ndarray:
82
+ """Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
83
+ Args:
84
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
85
+ Returns:
86
+ Normalized image data. Data range [0, 1].
87
+ """
88
+ return image.astype(np.float64) / 255.0
89
+
90
+
91
+ def unnormalize(image: np.ndarray) -> np.ndarray:
92
+ """Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
93
+ Args:
94
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
95
+ Returns:
96
+ Denormalized image data. Data range [0, 255].
97
+ """
98
+ return image.astype(np.float64) * 255.0
99
+
100
+
101
+ def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
102
+ """Convert ``PIL.Image`` to Tensor.
103
+ Args:
104
+ image (np.ndarray): The image data read by ``PIL.Image``
105
+ range_norm (bool): Scale [0, 1] data to between [-1, 1]
106
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
107
+ Returns:
108
+ Normalized image data
109
+ Examples:
110
+ >>> image = Image.open("image.bmp")
111
+ >>> tensor_image = image2tensor(image, range_norm=False, half=False)
112
+ """
113
+ tensor = F.to_tensor(image)
114
+
115
+ if range_norm:
116
+ tensor = tensor.mul_(2.0).sub_(1.0)
117
+ if half:
118
+ tensor = tensor.half()
119
+
120
+ return tensor
121
+
122
+
123
+ def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
124
+ """Converts ``torch.Tensor`` to ``PIL.Image``.
125
+ Args:
126
+ tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
127
+ range_norm (bool): Scale [-1, 1] data to between [0, 1]
128
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
129
+ Returns:
130
+ Convert image data to support PIL library
131
+ Examples:
132
+ >>> tensor = torch.randn([1, 3, 128, 128])
133
+ >>> image = tensor2image(tensor, range_norm=False, half=False)
134
+ """
135
+ if range_norm:
136
+ tensor = tensor.add_(1.0).div_(2.0)
137
+ if half:
138
+ tensor = tensor.half()
139
+
140
+ image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
141
+
142
+ return image
143
+
144
+
145
+ def convert_rgb_to_y(image: Any) -> Any:
146
+ """Convert RGB image or tensor image data to YCbCr(Y) format.
147
+ Args:
148
+ image: RGB image data read by ``PIL.Image''.
149
+ Returns:
150
+ Y image array data.
151
+ """
152
+ if type(image) == np.ndarray:
153
+ return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
154
+ elif type(image) == torch.Tensor:
155
+ if len(image.shape) == 4:
156
+ image = image.squeeze_(0)
157
+ return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
158
+ else:
159
+ raise Exception("Unknown Type", type(image))
160
+
161
+
162
+ def convert_rgb_to_ycbcr(image: Any) -> Any:
163
+ """Convert RGB image or tensor image data to YCbCr format.
164
+ Args:
165
+ image: RGB image data read by ``PIL.Image''.
166
+ Returns:
167
+ YCbCr image array data.
168
+ """
169
+ if type(image) == np.ndarray:
170
+ y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
171
+ cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
172
+ cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
173
+ return np.array([y, cb, cr]).transpose([1, 2, 0])
174
+ elif type(image) == torch.Tensor:
175
+ if len(image.shape) == 4:
176
+ image = image.squeeze(0)
177
+ y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
178
+ cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
179
+ cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
180
+ return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
181
+ else:
182
+ raise Exception("Unknown Type", type(image))
183
+
184
+
185
+ def convert_ycbcr_to_rgb(image: Any) -> Any:
186
+ """Convert YCbCr format image to RGB format.
187
+ Args:
188
+ image: YCbCr image data read by ``PIL.Image''.
189
+ Returns:
190
+ RGB image array data.
191
+ """
192
+ if type(image) == np.ndarray:
193
+ r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
194
+ g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
195
+ b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
196
+ return np.array([r, g, b]).transpose([1, 2, 0])
197
+ elif type(image) == torch.Tensor:
198
+ if len(image.shape) == 4:
199
+ image = image.squeeze(0)
200
+ r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
201
+ g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
202
+ b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
203
+ return torch.cat([r, g, b], 0).permute(1, 2, 0)
204
+ else:
205
+ raise Exception("Unknown Type", type(image))
206
+
207
+
208
+ def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
209
+ """Cut ``PIL.Image`` in the center area of the image.
210
+ Args:
211
+ lr: Low-resolution image data read by ``PIL.Image``.
212
+ hr: High-resolution image data read by ``PIL.Image``.
213
+ image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
214
+ upscale_factor (int): magnification factor.
215
+ Returns:
216
+ Randomly cropped low-resolution images and high-resolution images.
217
+ """
218
+ w, h = hr.size
219
+
220
+ left = (w - image_size) // 2
221
+ top = (h - image_size) // 2
222
+ right = left + image_size
223
+ bottom = top + image_size
224
+
225
+ lr = lr.crop((left // upscale_factor,
226
+ top // upscale_factor,
227
+ right // upscale_factor,
228
+ bottom // upscale_factor))
229
+ hr = hr.crop((left, top, right, bottom))
230
+
231
+ return lr, hr
232
+
233
+
234
+ def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
235
+ """Will ``PIL.Image`` randomly capture the specified area of the image.
236
+ Args:
237
+ lr: Low-resolution image data read by ``PIL.Image``.
238
+ hr: High-resolution image data read by ``PIL.Image``.
239
+ image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
240
+ upscale_factor (int): magnification factor.
241
+ Returns:
242
+ Randomly cropped low-resolution images and high-resolution images.
243
+ """
244
+ w, h = hr.size
245
+ left = torch.randint(0, w - image_size + 1, size=(1,)).item()
246
+ top = torch.randint(0, h - image_size + 1, size=(1,)).item()
247
+ right = left + image_size
248
+ bottom = top + image_size
249
+
250
+ lr = lr.crop((left // upscale_factor,
251
+ top // upscale_factor,
252
+ right // upscale_factor,
253
+ bottom // upscale_factor))
254
+ hr = hr.crop((left, top, right, bottom))
255
+
256
+ return lr, hr
257
+
258
+
259
+ def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
260
+ """Will ``PIL.Image`` randomly rotate the image.
261
+ Args:
262
+ lr: Low-resolution image data read by ``PIL.Image``.
263
+ hr: High-resolution image data read by ``PIL.Image``.
264
+ angle (int): rotation angle, clockwise and counterclockwise rotation.
265
+ Returns:
266
+ Randomly rotated low-resolution images and high-resolution images.
267
+ """
268
+ angle = random.choice((+angle, -angle))
269
+ lr = F.rotate(lr, angle)
270
+ hr = F.rotate(hr, angle)
271
+
272
+ return lr, hr
273
+
274
+
275
+ def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
276
+ """Flip the ``PIL.Image`` image horizontally randomly.
277
+ Args:
278
+ lr: Low-resolution image data read by ``PIL.Image``.
279
+ hr: High-resolution image data read by ``PIL.Image``.
280
+ p (optional, float): rollover probability. (Default: 0.5)
281
+ Returns:
282
+ Low-resolution image and high-resolution image after random horizontal flip.
283
+ """
284
+ if torch.rand(1).item() > p:
285
+ lr = F.hflip(lr)
286
+ hr = F.hflip(hr)
287
+
288
+ return lr, hr
289
+
290
+
291
+ def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
292
+ """Turn the ``PIL.Image`` image upside down randomly.
293
+ Args:
294
+ lr: Low-resolution image data read by ``PIL.Image``.
295
+ hr: High-resolution image data read by ``PIL.Image``.
296
+ p (optional, float): rollover probability. (Default: 0.5)
297
+ Returns:
298
+ Randomly rotated up and down low-resolution images and high-resolution images.
299
+ """
300
+ if torch.rand(1).item() > p:
301
+ lr = F.vflip(lr)
302
+ hr = F.vflip(hr)
303
+
304
+ return lr, hr
305
+
306
+
307
+ def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
308
+ """Set ``PIL.Image`` to randomly adjust the image brightness.
309
+ Args:
310
+ lr: Low-resolution image data read by ``PIL.Image``.
311
+ hr: High-resolution image data read by ``PIL.Image``.
312
+ Returns:
313
+ Low-resolution image and high-resolution image with randomly adjusted brightness.
314
+ """
315
+ # Randomly adjust the brightness gain range.
316
+ factor = random.uniform(0.5, 2)
317
+ lr = F.adjust_brightness(lr, factor)
318
+ hr = F.adjust_brightness(hr, factor)
319
+
320
+ return lr, hr
321
+
322
+
323
+ def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
324
+ """Set ``PIL.Image`` to randomly adjust the image contrast.
325
+ Args:
326
+ lr: Low-resolution image data read by ``PIL.Image``.
327
+ hr: High-resolution image data read by ``PIL.Image``.
328
+ Returns:
329
+ Low-resolution image and high-resolution image with randomly adjusted contrast.
330
+ """
331
+ # Randomly adjust the contrast gain range.
332
+ factor = random.uniform(0.5, 2)
333
+ lr = F.adjust_contrast(lr, factor)
334
+ hr = F.adjust_contrast(hr, factor)
335
+
336
+ return lr, hr
337
+
338
+ #### metrics to compute -- assumes single images, i.e., tensor of 3 dims
339
+ def img_mae(x1, x2):
340
+ m = torch.abs(x1-x2).mean()
341
+ return m
342
+
343
+ def img_mse(x1, x2):
344
+ m = torch.pow(torch.abs(x1-x2),2).mean()
345
+ return m
346
+
347
+ def img_psnr(x1, x2):
348
+ m = kornia.metrics.psnr(x1, x2, 1)
349
+ return m
350
+
351
+ def img_ssim(x1, x2):
352
+ m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
353
+ m = m.mean()
354
+ return m
355
+
356
+ def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
357
+ '''
358
+ xLR/SR/HR: 3xHxW
359
+ xSRvar: 1xHxW
360
+ '''
361
+ plt.figure(figsize=(30,10))
362
+
363
+ plt.subplot(1,5,1)
364
+ plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
365
+ plt.axis('off')
366
+
367
+ plt.subplot(1,5,2)
368
+ plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
369
+ plt.axis('off')
370
+
371
+ plt.subplot(1,5,3)
372
+ plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
373
+ plt.axis('off')
374
+
375
+ plt.subplot(1,5,4)
376
+ error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
377
+ print('error', error_map.min(), error_map.max())
378
+ plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
379
+ plt.clim(elim[0], elim[1])
380
+ plt.axis('off')
381
+
382
+ plt.subplot(1,5,5)
383
+ print('uncer', xSRvar.min(), xSRvar.max())
384
+ plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
385
+ plt.clim(ulim[0], ulim[1])
386
+ plt.axis('off')
387
+
388
+ plt.subplots_adjust(wspace=0, hspace=0)
389
+ plt.show()
390
+
391
+ def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
392
+ '''
393
+ xLR/SR/HR: 3xHxW
394
+ '''
395
+ plt.figure(figsize=(30,10))
396
+
397
+ if task != 'm':
398
+ plt.subplot(1,4,1)
399
+ plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
400
+ plt.axis('off')
401
+
402
+ plt.subplot(1,4,2)
403
+ plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
404
+ plt.axis('off')
405
+
406
+ plt.subplot(1,4,3)
407
+ plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
408
+ plt.axis('off')
409
+ else:
410
+ plt.subplot(1,4,1)
411
+ plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
412
+ plt.clim(0,0.9)
413
+ plt.axis('off')
414
+
415
+ plt.subplot(1,4,2)
416
+ plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
417
+ plt.clim(0,0.9)
418
+ plt.axis('off')
419
+
420
+ plt.subplot(1,4,3)
421
+ plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
422
+ plt.clim(0,0.9)
423
+ plt.axis('off')
424
+
425
+ plt.subplot(1,4,4)
426
+ if task == 'inpainting':
427
+ error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
428
+ else:
429
+ error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
430
+ print('error', error_map.min(), error_map.max())
431
+ plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
432
+ plt.clim(elim[0], elim[1])
433
+ plt.axis('off')
434
+
435
+ plt.subplots_adjust(wspace=0, hspace=0)
436
+ plt.show()
437
+
438
+ def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
439
+ '''
440
+ xSRvar: 1xHxW
441
+ '''
442
+ plt.figure(figsize=(30,10))
443
+
444
+ plt.subplot(1,4,1)
445
+ print('uncer', xSRvar1.min(), xSRvar1.max())
446
+ plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
447
+ plt.clim(ulim[0], ulim[1])
448
+ plt.axis('off')
449
+
450
+ plt.subplot(1,4,2)
451
+ print('uncer', xSRvar2.min(), xSRvar2.max())
452
+ plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
453
+ plt.clim(ulim[0], ulim[1])
454
+ plt.axis('off')
455
+
456
+ plt.subplot(1,4,3)
457
+ print('uncer', xSRvar3.min(), xSRvar3.max())
458
+ plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
459
+ plt.clim(ulim[0], ulim[1])
460
+ plt.axis('off')
461
+
462
+ plt.subplot(1,4,4)
463
+ print('uncer', xSRvar4.min(), xSRvar4.max())
464
+ plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
465
+ plt.clim(ulim[0], ulim[1])
466
+ plt.axis('off')
467
+
468
+ plt.subplots_adjust(wspace=0, hspace=0)
469
+ plt.show()
470
+
471
+ def get_UCE(list_err, list_yout_var, num_bins=100):
472
+ err_min = np.min(list_err)
473
+ err_max = np.max(list_err)
474
+ err_len = (err_max-err_min)/num_bins
475
+ num_points = len(list_err)
476
+
477
+ bin_stats = {}
478
+ for i in range(num_bins):
479
+ bin_stats[i] = {
480
+ 'start_idx': err_min + i*err_len,
481
+ 'end_idx': err_min + (i+1)*err_len,
482
+ 'num_points': 0,
483
+ 'mean_err': 0,
484
+ 'mean_var': 0,
485
+ }
486
+
487
+ for e,v in zip(list_err, list_yout_var):
488
+ for i in range(num_bins):
489
+ if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
490
+ bin_stats[i]['num_points'] += 1
491
+ bin_stats[i]['mean_err'] += e
492
+ bin_stats[i]['mean_var'] += v
493
+
494
+ uce = 0
495
+ eps = 1e-8
496
+ for i in range(num_bins):
497
+ bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
498
+ bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
499
+ bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
500
+ *(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
501
+ uce += bin_stats[i]['uce_bin']
502
+
503
+ list_x, list_y = [], []
504
+ for i in range(num_bins):
505
+ if bin_stats[i]['num_points']>0:
506
+ list_x.append(bin_stats[i]['mean_err'])
507
+ list_y.append(bin_stats[i]['mean_var'])
508
+
509
+ # sns.set_style('darkgrid')
510
+ # sns.scatterplot(x=list_x, y=list_y)
511
+ # sns.regplot(x=list_x, y=list_y, order=1)
512
+ # plt.xlabel('MSE', fontsize=34)
513
+ # plt.ylabel('Uncertainty', fontsize=34)
514
+ # plt.plot(list_x, list_x, color='r')
515
+ # plt.xlim(np.min(list_x), np.max(list_x))
516
+ # plt.ylim(np.min(list_err), np.max(list_x))
517
+ # plt.show()
518
+
519
+ return bin_stats, uce
520
+
521
+ ##################### training BayesCap
522
+ def train_BayesCap(
523
+ NetC,
524
+ NetG,
525
+ train_loader,
526
+ eval_loader,
527
+ Cri = TempCombLoss(),
528
+ device='cuda',
529
+ dtype=torch.cuda.FloatTensor(),
530
+ init_lr=1e-4,
531
+ num_epochs=100,
532
+ eval_every=1,
533
+ ckpt_path='../ckpt/BayesCap',
534
+ T1=1e0,
535
+ T2=5e-2,
536
+ task=None,
537
+ ):
538
+ NetC.to(device)
539
+ NetC.train()
540
+ NetG.to(device)
541
+ NetG.eval()
542
+ optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
543
+ optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
544
+
545
+ score = -1e8
546
+ all_loss = []
547
+ for eph in range(num_epochs):
548
+ eph_loss = 0
549
+ with tqdm(train_loader, unit='batch') as tepoch:
550
+ for (idx, batch) in enumerate(tepoch):
551
+ if idx>2000:
552
+ break
553
+ tepoch.set_description('Epoch {}'.format(eph))
554
+ ##
555
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
556
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
557
+ if task == 'inpainting':
558
+ xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
559
+ xMask = xMask.to(device).type(dtype)
560
+ # pass them through the network
561
+ with torch.no_grad():
562
+ if task == 'inpainting':
563
+ _, xSR1 = NetG(xLR, xMask)
564
+ elif task == 'depth':
565
+ xSR1 = NetG(xLR)[("disp", 0)]
566
+ else:
567
+ xSR1 = NetG(xLR)
568
+ # with torch.autograd.set_detect_anomaly(True):
569
+ xSR = xSR1.clone()
570
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
571
+ # print(xSRC_alpha)
572
+ optimizer.zero_grad()
573
+ if task == 'depth':
574
+ loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
575
+ else:
576
+ loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
577
+ # print(loss)
578
+ loss.backward()
579
+ optimizer.step()
580
+ ##
581
+ eph_loss += loss.item()
582
+ tepoch.set_postfix(loss=loss.item())
583
+ eph_loss /= len(train_loader)
584
+ all_loss.append(eph_loss)
585
+ print('Avg. loss: {}'.format(eph_loss))
586
+ # evaluate and save the models
587
+ torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
588
+ if eph%eval_every == 0:
589
+ curr_score = eval_BayesCap(
590
+ NetC,
591
+ NetG,
592
+ eval_loader,
593
+ device=device,
594
+ dtype=dtype,
595
+ task=task,
596
+ )
597
+ print('current score: {} | Last best score: {}'.format(curr_score, score))
598
+ if curr_score >= score:
599
+ score = curr_score
600
+ torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
601
+ optim_scheduler.step()
602
+
603
+ #### get different uncertainty maps
604
+ def get_uncer_BayesCap(
605
+ NetC,
606
+ NetG,
607
+ xin,
608
+ task=None,
609
+ xMask=None,
610
+ ):
611
+ with torch.no_grad():
612
+ if task == 'inpainting':
613
+ _, xSR = NetG(xin, xMask)
614
+ else:
615
+ xSR = NetG(xin)
616
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
617
+ a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
618
+ b_map = xSRC_beta.to('cpu').data
619
+ xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
620
+
621
+ return xSRvar
622
+
623
+ def get_uncer_TTDAp(
624
+ NetG,
625
+ xin,
626
+ p_mag=0.05,
627
+ num_runs=50,
628
+ task=None,
629
+ xMask=None,
630
+ ):
631
+ list_xSR = []
632
+ with torch.no_grad():
633
+ for z in range(num_runs):
634
+ if task == 'inpainting':
635
+ _, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
636
+ else:
637
+ xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
638
+ list_xSR.append(xSRz)
639
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
640
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
641
+ return xSRvar
642
+
643
+ def get_uncer_DO(
644
+ NetG,
645
+ xin,
646
+ dop=0.2,
647
+ num_runs=50,
648
+ task=None,
649
+ xMask=None,
650
+ ):
651
+ list_xSR = []
652
+ with torch.no_grad():
653
+ for z in range(num_runs):
654
+ if task == 'inpainting':
655
+ _, xSRz = NetG(xin, xMask, dop=dop)
656
+ else:
657
+ xSRz = NetG(xin, dop=dop)
658
+ list_xSR.append(xSRz)
659
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
660
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
661
+ return xSRvar
662
+
663
+ ################### Different eval functions
664
+
665
+ def eval_BayesCap(
666
+ NetC,
667
+ NetG,
668
+ eval_loader,
669
+ device='cuda',
670
+ dtype=torch.cuda.FloatTensor,
671
+ task=None,
672
+ xMask=None,
673
+ ):
674
+ NetC.to(device)
675
+ NetC.eval()
676
+ NetG.to(device)
677
+ NetG.eval()
678
+
679
+ mean_ssim = 0
680
+ mean_psnr = 0
681
+ mean_mse = 0
682
+ mean_mae = 0
683
+ num_imgs = 0
684
+ list_error = []
685
+ list_var = []
686
+ with tqdm(eval_loader, unit='batch') as tepoch:
687
+ for (idx, batch) in enumerate(tepoch):
688
+ tepoch.set_description('Validating ...')
689
+ ##
690
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
691
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
692
+ if task == 'inpainting':
693
+ if xMask==None:
694
+ xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
695
+ xMask = xMask.to(device).type(dtype)
696
+ else:
697
+ xMask = xMask.to(device).type(dtype)
698
+ # pass them through the network
699
+ with torch.no_grad():
700
+ if task == 'inpainting':
701
+ _, xSR = NetG(xLR, xMask)
702
+ elif task == 'depth':
703
+ xSR = NetG(xLR)[("disp", 0)]
704
+ else:
705
+ xSR = NetG(xLR)
706
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
707
+ a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
708
+ b_map = xSRC_beta.to('cpu').data
709
+ xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
710
+ n_batch = xSRC_mu.shape[0]
711
+ if task == 'depth':
712
+ xHR = xSR
713
+ for j in range(n_batch):
714
+ num_imgs += 1
715
+ mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
716
+ mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
717
+ mean_mse += img_mse(xSRC_mu[j], xHR[j])
718
+ mean_mae += img_mae(xSRC_mu[j], xHR[j])
719
+
720
+ show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
721
+
722
+ error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
723
+ var_map = xSRvar[j].to('cpu').data.reshape(-1)
724
+ list_error.extend(list(error_map.numpy()))
725
+ list_var.extend(list(var_map.numpy()))
726
+ ##
727
+ mean_ssim /= num_imgs
728
+ mean_psnr /= num_imgs
729
+ mean_mse /= num_imgs
730
+ mean_mae /= num_imgs
731
+ print(
732
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
733
+ (
734
+ mean_ssim, mean_psnr, mean_mse, mean_mae
735
+ )
736
+ )
737
+ # print(len(list_error), len(list_var))
738
+ # print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
739
+ # print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
740
+ return mean_ssim
741
+
742
+ def eval_TTDA_p(
743
+ NetG,
744
+ eval_loader,
745
+ device='cuda',
746
+ dtype=torch.cuda.FloatTensor,
747
+ p_mag=0.05,
748
+ num_runs=50,
749
+ task = None,
750
+ xMask = None,
751
+ ):
752
+ NetG.to(device)
753
+ NetG.eval()
754
+
755
+ mean_ssim = 0
756
+ mean_psnr = 0
757
+ mean_mse = 0
758
+ mean_mae = 0
759
+ num_imgs = 0
760
+ with tqdm(eval_loader, unit='batch') as tepoch:
761
+ for (idx, batch) in enumerate(tepoch):
762
+ tepoch.set_description('Validating ...')
763
+ ##
764
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
765
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
766
+ # pass them through the network
767
+ list_xSR = []
768
+ with torch.no_grad():
769
+ if task=='inpainting':
770
+ _, xSR = NetG(xLR, xMask)
771
+ else:
772
+ xSR = NetG(xLR)
773
+ for z in range(num_runs):
774
+ xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
775
+ list_xSR.append(xSRz)
776
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
777
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
778
+ n_batch = xSR.shape[0]
779
+ for j in range(n_batch):
780
+ num_imgs += 1
781
+ mean_ssim += img_ssim(xSR[j], xHR[j])
782
+ mean_psnr += img_psnr(xSR[j], xHR[j])
783
+ mean_mse += img_mse(xSR[j], xHR[j])
784
+ mean_mae += img_mae(xSR[j], xHR[j])
785
+
786
+ show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
787
+
788
+ mean_ssim /= num_imgs
789
+ mean_psnr /= num_imgs
790
+ mean_mse /= num_imgs
791
+ mean_mae /= num_imgs
792
+ print(
793
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
794
+ (
795
+ mean_ssim, mean_psnr, mean_mse, mean_mae
796
+ )
797
+ )
798
+
799
+ return mean_ssim
800
+
801
+ def eval_DO(
802
+ NetG,
803
+ eval_loader,
804
+ device='cuda',
805
+ dtype=torch.cuda.FloatTensor,
806
+ dop=0.2,
807
+ num_runs=50,
808
+ task=None,
809
+ xMask=None,
810
+ ):
811
+ NetG.to(device)
812
+ NetG.eval()
813
+
814
+ mean_ssim = 0
815
+ mean_psnr = 0
816
+ mean_mse = 0
817
+ mean_mae = 0
818
+ num_imgs = 0
819
+ with tqdm(eval_loader, unit='batch') as tepoch:
820
+ for (idx, batch) in enumerate(tepoch):
821
+ tepoch.set_description('Validating ...')
822
+ ##
823
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
824
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
825
+ # pass them through the network
826
+ list_xSR = []
827
+ with torch.no_grad():
828
+ if task == 'inpainting':
829
+ _, xSR = NetG(xLR, xMask)
830
+ else:
831
+ xSR = NetG(xLR)
832
+ for z in range(num_runs):
833
+ xSRz = NetG(xLR, dop=dop)
834
+ list_xSR.append(xSRz)
835
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
836
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
837
+ n_batch = xSR.shape[0]
838
+ for j in range(n_batch):
839
+ num_imgs += 1
840
+ mean_ssim += img_ssim(xSR[j], xHR[j])
841
+ mean_psnr += img_psnr(xSR[j], xHR[j])
842
+ mean_mse += img_mse(xSR[j], xHR[j])
843
+ mean_mae += img_mae(xSR[j], xHR[j])
844
+
845
+ show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
846
+ ##
847
+ mean_ssim /= num_imgs
848
+ mean_psnr /= num_imgs
849
+ mean_mse /= num_imgs
850
+ mean_mae /= num_imgs
851
+ print(
852
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
853
+ (
854
+ mean_ssim, mean_psnr, mean_mse, mean_mae
855
+ )
856
+ )
857
+
858
+ return mean_ssim
859
+
860
+
861
+ ############### compare all function
862
+ def compare_all(
863
+ NetC,
864
+ NetG,
865
+ eval_loader,
866
+ p_mag = 0.05,
867
+ dop = 0.2,
868
+ num_runs = 100,
869
+ device='cuda',
870
+ dtype=torch.cuda.FloatTensor,
871
+ task=None,
872
+ ):
873
+ NetC.to(device)
874
+ NetC.eval()
875
+ NetG.to(device)
876
+ NetG.eval()
877
+
878
+ with tqdm(eval_loader, unit='batch') as tepoch:
879
+ for (idx, batch) in enumerate(tepoch):
880
+ tepoch.set_description('Comparing ...')
881
+ ##
882
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
883
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
884
+ if task == 'inpainting':
885
+ xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
886
+ xMask = xMask.to(device).type(dtype)
887
+ # pass them through the network
888
+ with torch.no_grad():
889
+ if task == 'inpainting':
890
+ _, xSR = NetG(xLR, xMask)
891
+ else:
892
+ xSR = NetG(xLR)
893
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
894
+
895
+ if task == 'inpainting':
896
+ xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
897
+ xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
898
+ xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
899
+ else:
900
+ xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
901
+ xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
902
+ xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
903
+
904
+ print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
905
+
906
+ n_batch = xSR.shape[0]
907
+ for j in range(n_batch):
908
+ if task=='s':
909
+ show_SR_w_err(xLR[j], xHR[j], xSR[j])
910
+ show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
911
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
912
+ if task=='d':
913
+ show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
914
+ show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
915
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
916
+ if task=='inpainting':
917
+ show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
918
+ show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
919
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
920
+ if task=='m':
921
+ show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
922
+ show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
923
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
924
+
925
+
926
+ ################# Degrading Identity
927
+ def degrage_BayesCap_p(
928
+ NetC,
929
+ NetG,
930
+ eval_loader,
931
+ device='cuda',
932
+ dtype=torch.cuda.FloatTensor,
933
+ num_runs=50,
934
+ ):
935
+ NetC.to(device)
936
+ NetC.eval()
937
+ NetG.to(device)
938
+ NetG.eval()
939
+
940
+ p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
941
+ list_s = []
942
+ list_p = []
943
+ list_u1 = []
944
+ list_u2 = []
945
+ list_c = []
946
+ for p_mag in p_mag_list:
947
+ mean_ssim = 0
948
+ mean_psnr = 0
949
+ mean_mse = 0
950
+ mean_mae = 0
951
+ num_imgs = 0
952
+ list_error = []
953
+ list_error2 = []
954
+ list_var = []
955
+
956
+ with tqdm(eval_loader, unit='batch') as tepoch:
957
+ for (idx, batch) in enumerate(tepoch):
958
+ tepoch.set_description('Validating ...')
959
+ ##
960
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
961
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
962
+ # pass them through the network
963
+ with torch.no_grad():
964
+ xSR = NetG(xLR)
965
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
966
+ a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
967
+ b_map = xSRC_beta.to('cpu').data
968
+ xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
969
+ n_batch = xSRC_mu.shape[0]
970
+ for j in range(n_batch):
971
+ num_imgs += 1
972
+ mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
973
+ mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
974
+ mean_mse += img_mse(xSRC_mu[j], xSR[j])
975
+ mean_mae += img_mae(xSRC_mu[j], xSR[j])
976
+
977
+ error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
978
+ error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
979
+ var_map = xSRvar[j].to('cpu').data.reshape(-1)
980
+ list_error.extend(list(error_map.numpy()))
981
+ list_error2.extend(list(error_map2.numpy()))
982
+ list_var.extend(list(var_map.numpy()))
983
+ ##
984
+ mean_ssim /= num_imgs
985
+ mean_psnr /= num_imgs
986
+ mean_mse /= num_imgs
987
+ mean_mae /= num_imgs
988
+ print(
989
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
990
+ (
991
+ mean_ssim, mean_psnr, mean_mse, mean_mae
992
+ )
993
+ )
994
+ uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
995
+ uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
996
+ print('UCE1: ', uce1)
997
+ print('UCE2: ', uce2)
998
+ list_s.append(mean_ssim.item())
999
+ list_p.append(mean_psnr.item())
1000
+ list_u1.append(uce1)
1001
+ list_u2.append(uce2)
1002
+
1003
+ plt.plot(list_s)
1004
+ plt.show()
1005
+ plt.plot(list_p)
1006
+ plt.show()
1007
+
1008
+ plt.plot(list_u1, label='wrt SR output')
1009
+ plt.plot(list_u2, label='wrt BayesCap output')
1010
+ plt.legend()
1011
+ plt.show()
1012
+
1013
+ sns.set_style('darkgrid')
1014
+ fig,ax = plt.subplots()
1015
+ # make a plot
1016
+ ax.plot(p_mag_list, list_s, color="red", marker="o")
1017
+ # set x-axis label
1018
+ ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
1019
+ # set y-axis label
1020
+ ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
1021
+
1022
+ # twin object for two different y-axis on the sample plot
1023
+ ax2=ax.twinx()
1024
+ # make a plot with different y-axis using second axis object
1025
+ ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
1026
+ ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
1027
+ ax2.set_ylabel("UCE", color="green", fontsize=10)
1028
+ plt.legend(fontsize=10)
1029
+ plt.tight_layout()
1030
+ plt.show()
1031
+
1032
+ ################# DeepFill_v2
1033
+
1034
+ # ----------------------------------------
1035
+ # PATH processing
1036
+ # ----------------------------------------
1037
+ def text_readlines(filename):
1038
+ # Try to read a txt file and return a list.Return [] if there was a mistake.
1039
+ try:
1040
+ file = open(filename, 'r')
1041
+ except IOError:
1042
+ error = []
1043
+ return error
1044
+ content = file.readlines()
1045
+ # This for loop deletes the EOF (like \n)
1046
+ for i in range(len(content)):
1047
+ content[i] = content[i][:len(content[i])-1]
1048
+ file.close()
1049
+ return content
1050
+
1051
+ def savetxt(name, loss_log):
1052
+ np_loss_log = np.array(loss_log)
1053
+ np.savetxt(name, np_loss_log)
1054
+
1055
+ def get_files(path):
1056
+ # read a folder, return the complete path
1057
+ ret = []
1058
+ for root, dirs, files in os.walk(path):
1059
+ for filespath in files:
1060
+ ret.append(os.path.join(root, filespath))
1061
+ return ret
1062
+
1063
+ def get_names(path):
1064
+ # read a folder, return the image name
1065
+ ret = []
1066
+ for root, dirs, files in os.walk(path):
1067
+ for filespath in files:
1068
+ ret.append(filespath)
1069
+ return ret
1070
+
1071
+ def text_save(content, filename, mode = 'a'):
1072
+ # save a list to a txt
1073
+ # Try to save a list variable in txt file.
1074
+ file = open(filename, mode)
1075
+ for i in range(len(content)):
1076
+ file.write(str(content[i]) + '\n')
1077
+ file.close()
1078
+
1079
+ def check_path(path):
1080
+ if not os.path.exists(path):
1081
+ os.makedirs(path)
1082
+
1083
+ # ----------------------------------------
1084
+ # Validation and Sample at training
1085
+ # ----------------------------------------
1086
+ def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
1087
+ # Save image one-by-one
1088
+ for i in range(len(img_list)):
1089
+ img = img_list[i]
1090
+ # Recover normalization: * 255 because last layer is sigmoid activated
1091
+ img = img * 255
1092
+ # Process img_copy and do not destroy the data of img
1093
+ img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
1094
+ img_copy = np.clip(img_copy, 0, pixel_max_cnt)
1095
+ img_copy = img_copy.astype(np.uint8)
1096
+ img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
1097
+ # Save to certain path
1098
+ save_img_name = sample_name + '_' + name_list[i] + '.jpg'
1099
+ save_img_path = os.path.join(sample_folder, save_img_name)
1100
+ cv2.imwrite(save_img_path, img_copy)
1101
+
1102
+ def psnr(pred, target, pixel_max_cnt = 255):
1103
+ mse = torch.mul(target - pred, target - pred)
1104
+ rmse_avg = (torch.mean(mse).item()) ** 0.5
1105
+ p = 20 * np.log10(pixel_max_cnt / rmse_avg)
1106
+ return p
1107
+
1108
+ def grey_psnr(pred, target, pixel_max_cnt = 255):
1109
+ pred = torch.sum(pred, dim = 0)
1110
+ target = torch.sum(target, dim = 0)
1111
+ mse = torch.mul(target - pred, target - pred)
1112
+ rmse_avg = (torch.mean(mse).item()) ** 0.5
1113
+ p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
1114
+ return p
1115
+
1116
+ def ssim(pred, target):
1117
+ pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
1118
+ target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
1119
+ target = target[0]
1120
+ pred = pred[0]
1121
+ ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
1122
+ return ssim
1123
+
1124
+ ## for contextual attention
1125
+
1126
+ def extract_image_patches(images, ksizes, strides, rates, padding='same'):
1127
+ """
1128
+ Extract patches from images and put them in the C output dimension.
1129
+ :param padding:
1130
+ :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
1131
+ :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
1132
+ each dimension of images
1133
+ :param strides: [stride_rows, stride_cols]
1134
+ :param rates: [dilation_rows, dilation_cols]
1135
+ :return: A Tensor
1136
+ """
1137
+ assert len(images.size()) == 4
1138
+ assert padding in ['same', 'valid']
1139
+ batch_size, channel, height, width = images.size()
1140
+
1141
+ if padding == 'same':
1142
+ images = same_padding(images, ksizes, strides, rates)
1143
+ elif padding == 'valid':
1144
+ pass
1145
+ else:
1146
+ raise NotImplementedError('Unsupported padding type: {}.\
1147
+ Only "same" or "valid" are supported.'.format(padding))
1148
+
1149
+ unfold = torch.nn.Unfold(kernel_size=ksizes,
1150
+ dilation=rates,
1151
+ padding=0,
1152
+ stride=strides)
1153
+ patches = unfold(images)
1154
+ return patches # [N, C*k*k, L], L is the total number of such blocks
1155
+
1156
+ def same_padding(images, ksizes, strides, rates):
1157
+ assert len(images.size()) == 4
1158
+ batch_size, channel, rows, cols = images.size()
1159
+ out_rows = (rows + strides[0] - 1) // strides[0]
1160
+ out_cols = (cols + strides[1] - 1) // strides[1]
1161
+ effective_k_row = (ksizes[0] - 1) * rates[0] + 1
1162
+ effective_k_col = (ksizes[1] - 1) * rates[1] + 1
1163
+ padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
1164
+ padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
1165
+ # Pad the input
1166
+ padding_top = int(padding_rows / 2.)
1167
+ padding_left = int(padding_cols / 2.)
1168
+ padding_bottom = padding_rows - padding_top
1169
+ padding_right = padding_cols - padding_left
1170
+ paddings = (padding_left, padding_right, padding_top, padding_bottom)
1171
+ images = torch.nn.ZeroPad2d(paddings)(images)
1172
+ return images
1173
+
1174
+ def reduce_mean(x, axis=None, keepdim=False):
1175
+ if not axis:
1176
+ axis = range(len(x.shape))
1177
+ for i in sorted(axis, reverse=True):
1178
+ x = torch.mean(x, dim=i, keepdim=keepdim)
1179
+ return x
1180
+
1181
+
1182
+ def reduce_std(x, axis=None, keepdim=False):
1183
+ if not axis:
1184
+ axis = range(len(x.shape))
1185
+ for i in sorted(axis, reverse=True):
1186
+ x = torch.std(x, dim=i, keepdim=keepdim)
1187
+ return x
1188
+
1189
+
1190
+ def reduce_sum(x, axis=None, keepdim=False):
1191
+ if not axis:
1192
+ axis = range(len(x.shape))
1193
+ for i in sorted(axis, reverse=True):
1194
+ x = torch.sum(x, dim=i, keepdim=keepdim)
1195
+ return x
1196
+
1197
+ def random_mask(num_batch=1, mask_shape=(256,256)):
1198
+ list_mask = []
1199
+ for _ in range(num_batch):
1200
+ # rectangle mask
1201
+ image_height = mask_shape[0]
1202
+ image_width = mask_shape[1]
1203
+ max_delta_height = image_height//8
1204
+ max_delta_width = image_width//8
1205
+ height = image_height//4
1206
+ width = image_width//4
1207
+ max_t = image_height - height
1208
+ max_l = image_width - width
1209
+ t = random.randint(0, max_t)
1210
+ l = random.randint(0, max_l)
1211
+ # bbox = (t, l, height, width)
1212
+ h = random.randint(0, max_delta_height//2)
1213
+ w = random.randint(0, max_delta_width//2)
1214
+ mask = torch.zeros((1, 1, image_height, image_width))
1215
+ mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
1216
+ rect_mask = mask
1217
+
1218
+ # brush mask
1219
+ min_num_vertex = 4
1220
+ max_num_vertex = 12
1221
+ mean_angle = 2 * math.pi / 5
1222
+ angle_range = 2 * math.pi / 15
1223
+ min_width = 12
1224
+ max_width = 40
1225
+ H, W = image_height, image_width
1226
+ average_radius = math.sqrt(H*H+W*W) / 8
1227
+ mask = Image.new('L', (W, H), 0)
1228
+
1229
+ for _ in range(np.random.randint(1, 4)):
1230
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
1231
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
1232
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
1233
+ angles = []
1234
+ vertex = []
1235
+ for i in range(num_vertex):
1236
+ if i % 2 == 0:
1237
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
1238
+ else:
1239
+ angles.append(np.random.uniform(angle_min, angle_max))
1240
+
1241
+ h, w = mask.size
1242
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
1243
+ for i in range(num_vertex):
1244
+ r = np.clip(
1245
+ np.random.normal(loc=average_radius, scale=average_radius//2),
1246
+ 0, 2*average_radius)
1247
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
1248
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
1249
+ vertex.append((int(new_x), int(new_y)))
1250
+
1251
+ draw = ImageDraw.Draw(mask)
1252
+ width = int(np.random.uniform(min_width, max_width))
1253
+ draw.line(vertex, fill=255, width=width)
1254
+ for v in vertex:
1255
+ draw.ellipse((v[0] - width//2,
1256
+ v[1] - width//2,
1257
+ v[0] + width//2,
1258
+ v[1] + width//2),
1259
+ fill=255)
1260
+
1261
+ if np.random.normal() > 0:
1262
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
1263
+ if np.random.normal() > 0:
1264
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
1265
+
1266
+ mask = transforms.ToTensor()(mask)
1267
+ mask = mask.reshape((1, 1, H, W))
1268
+ brush_mask = mask
1269
+
1270
+ mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
1271
+ list_mask.append(mask)
1272
+ mask = torch.cat(list_mask, dim=0)
1273
+ return mask
utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Optional
3
+ import numpy as np
4
+ import os
5
+ import cv2
6
+ from glob import glob
7
+ from PIL import Image, ImageDraw
8
+ from tqdm import tqdm
9
+ import kornia
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ import albumentations as albu
13
+ import functools
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+ import torchvision as tv
20
+ import torchvision.models as models
21
+ from torchvision import transforms
22
+ from torchvision.transforms import functional as F
23
+ from losses import TempCombLoss
24
+
25
+
26
+ ######## for loading checkpoint from googledrive
27
+ google_drive_paths = {
28
+ "BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL",
29
+ "BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9",
30
+ }
31
+
32
+ def ensure_checkpoint_exists(model_weights_filename):
33
+ if not os.path.isfile(model_weights_filename) and (
34
+ model_weights_filename in google_drive_paths
35
+ ):
36
+ gdrive_url = google_drive_paths[model_weights_filename]
37
+ try:
38
+ from gdown import download as drive_download
39
+
40
+ drive_download(gdrive_url, model_weights_filename, quiet=False)
41
+ except ModuleNotFoundError:
42
+ print(
43
+ "gdown module not found.",
44
+ "pip3 install gdown or, manually download the checkpoint file:",
45
+ gdrive_url
46
+ )
47
+
48
+ if not os.path.isfile(model_weights_filename) and (
49
+ model_weights_filename not in google_drive_paths
50
+ ):
51
+ print(
52
+ model_weights_filename,
53
+ " not found, you may need to manually download the model weights."
54
+ )
55
+
56
+ def normalize(image: np.ndarray) -> np.ndarray:
57
+ """Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
58
+ Args:
59
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
60
+ Returns:
61
+ Normalized image data. Data range [0, 1].
62
+ """
63
+ return image.astype(np.float64) / 255.0
64
+
65
+
66
+ def unnormalize(image: np.ndarray) -> np.ndarray:
67
+ """Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
68
+ Args:
69
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
70
+ Returns:
71
+ Denormalized image data. Data range [0, 255].
72
+ """
73
+ return image.astype(np.float64) * 255.0
74
+
75
+
76
+ def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
77
+ """Convert ``PIL.Image`` to Tensor.
78
+ Args:
79
+ image (np.ndarray): The image data read by ``PIL.Image``
80
+ range_norm (bool): Scale [0, 1] data to between [-1, 1]
81
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
82
+ Returns:
83
+ Normalized image data
84
+ Examples:
85
+ >>> image = Image.open("image.bmp")
86
+ >>> tensor_image = image2tensor(image, range_norm=False, half=False)
87
+ """
88
+ tensor = F.to_tensor(image)
89
+
90
+ if range_norm:
91
+ tensor = tensor.mul_(2.0).sub_(1.0)
92
+ if half:
93
+ tensor = tensor.half()
94
+
95
+ return tensor
96
+
97
+
98
+ def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
99
+ """Converts ``torch.Tensor`` to ``PIL.Image``.
100
+ Args:
101
+ tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
102
+ range_norm (bool): Scale [-1, 1] data to between [0, 1]
103
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
104
+ Returns:
105
+ Convert image data to support PIL library
106
+ Examples:
107
+ >>> tensor = torch.randn([1, 3, 128, 128])
108
+ >>> image = tensor2image(tensor, range_norm=False, half=False)
109
+ """
110
+ if range_norm:
111
+ tensor = tensor.add_(1.0).div_(2.0)
112
+ if half:
113
+ tensor = tensor.half()
114
+
115
+ image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
116
+
117
+ return image