My-AI-Projects commited on
Commit
aaa2047
1 Parent(s): 9b88016

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. basicsr/__init__.py +12 -0
  2. basicsr/__pycache__/__init__.cpython-39.pyc +0 -0
  3. basicsr/__pycache__/train.cpython-39.pyc +0 -0
  4. basicsr/__pycache__/version.cpython-39.pyc +0 -0
  5. basicsr/archs/__init__.py +25 -0
  6. basicsr/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  7. basicsr/archs/__pycache__/ddcolor_arch.cpython-39.pyc +0 -0
  8. basicsr/archs/__pycache__/discriminator_arch.cpython-39.pyc +0 -0
  9. basicsr/archs/__pycache__/vgg_arch.cpython-39.pyc +0 -0
  10. basicsr/archs/ddcolor_arch.py +385 -0
  11. basicsr/archs/ddcolor_arch_utils/__int__.py +0 -0
  12. basicsr/archs/ddcolor_arch_utils/__pycache__/convnext.cpython-38.pyc +0 -0
  13. basicsr/archs/ddcolor_arch_utils/__pycache__/convnext.cpython-39.pyc +0 -0
  14. basicsr/archs/ddcolor_arch_utils/__pycache__/position_encoding.cpython-38.pyc +0 -0
  15. basicsr/archs/ddcolor_arch_utils/__pycache__/position_encoding.cpython-39.pyc +0 -0
  16. basicsr/archs/ddcolor_arch_utils/__pycache__/transformer.cpython-38.pyc +0 -0
  17. basicsr/archs/ddcolor_arch_utils/__pycache__/transformer.cpython-39.pyc +0 -0
  18. basicsr/archs/ddcolor_arch_utils/__pycache__/transformer_utils.cpython-38.pyc +0 -0
  19. basicsr/archs/ddcolor_arch_utils/__pycache__/transformer_utils.cpython-39.pyc +0 -0
  20. basicsr/archs/ddcolor_arch_utils/__pycache__/unet.cpython-38.pyc +0 -0
  21. basicsr/archs/ddcolor_arch_utils/__pycache__/unet.cpython-39.pyc +0 -0
  22. basicsr/archs/ddcolor_arch_utils/convnext.py +155 -0
  23. basicsr/archs/ddcolor_arch_utils/position_encoding.py +52 -0
  24. basicsr/archs/ddcolor_arch_utils/transformer.py +368 -0
  25. basicsr/archs/ddcolor_arch_utils/transformer_utils.py +192 -0
  26. basicsr/archs/ddcolor_arch_utils/unet.py +208 -0
  27. basicsr/archs/ddcolor_arch_utils/util.py +63 -0
  28. basicsr/archs/discriminator_arch.py +28 -0
  29. basicsr/archs/vgg_arch.py +165 -0
  30. basicsr/data/__init__.py +101 -0
  31. basicsr/data/__pycache__/__init__.cpython-39.pyc +0 -0
  32. basicsr/data/__pycache__/data_sampler.cpython-39.pyc +0 -0
  33. basicsr/data/__pycache__/fmix.cpython-39.pyc +0 -0
  34. basicsr/data/__pycache__/lab_dataset.cpython-39.pyc +0 -0
  35. basicsr/data/__pycache__/prefetch_dataloader.cpython-39.pyc +0 -0
  36. basicsr/data/__pycache__/transforms.cpython-39.pyc +0 -0
  37. basicsr/data/data_sampler.py +48 -0
  38. basicsr/data/data_util.py +313 -0
  39. basicsr/data/fmix.py +206 -0
  40. basicsr/data/lab_dataset.py +159 -0
  41. basicsr/data/prefetch_dataloader.py +125 -0
  42. basicsr/data/transforms.py +192 -0
  43. basicsr/losses/__init__.py +26 -0
  44. basicsr/losses/__pycache__/__init__.cpython-39.pyc +0 -0
  45. basicsr/losses/__pycache__/loss_util.cpython-39.pyc +0 -0
  46. basicsr/losses/__pycache__/losses.cpython-39.pyc +0 -0
  47. basicsr/losses/loss_util.py +95 -0
  48. basicsr/losses/losses.py +551 -0
  49. basicsr/metrics/__init__.py +20 -0
  50. basicsr/metrics/__pycache__/__init__.cpython-39.pyc +0 -0
basicsr/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/xinntao/BasicSR
2
+ # flake8: noqa
3
+ from .archs import *
4
+ from .data import *
5
+ from .losses import *
6
+ from .metrics import *
7
+ from .models import *
8
+ # from .ops import *
9
+ # from .test import *
10
+ from .train import *
11
+ from .utils import *
12
+ from .version import __gitsha__, __version__
basicsr/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (342 Bytes). View file
 
basicsr/__pycache__/train.cpython-39.pyc ADDED
Binary file (6.57 kB). View file
 
basicsr/__pycache__/version.cpython-39.pyc ADDED
Binary file (242 Bytes). View file
 
basicsr/archs/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ from copy import deepcopy
3
+ from os import path as osp
4
+
5
+ from basicsr.utils import get_root_logger, scandir
6
+ from basicsr.utils.registry import ARCH_REGISTRY
7
+
8
+ __all__ = ['build_network']
9
+
10
+ # automatically scan and import arch modules for registry
11
+ # scan all the files under the 'archs' folder and collect files ending with
12
+ # '_arch.py'
13
+ arch_folder = osp.dirname(osp.abspath(__file__))
14
+ arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15
+ # import all the arch modules
16
+ _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
17
+
18
+
19
+ def build_network(opt):
20
+ opt = deepcopy(opt)
21
+ network_type = opt.pop('type')
22
+ net = ARCH_REGISTRY.get(network_type)(**opt)
23
+ logger = get_root_logger()
24
+ logger.info(f'Network [{net.__class__.__name__}] is created.')
25
+ return net
basicsr/archs/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.14 kB). View file
 
basicsr/archs/__pycache__/ddcolor_arch.cpython-39.pyc ADDED
Binary file (10.4 kB). View file
 
basicsr/archs/__pycache__/discriminator_arch.cpython-39.pyc ADDED
Binary file (1.37 kB). View file
 
basicsr/archs/__pycache__/vgg_arch.cpython-39.pyc ADDED
Binary file (4.88 kB). View file
 
basicsr/archs/ddcolor_arch.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from basicsr.archs.ddcolor_arch_utils.unet import Hook, CustomPixelShuffle_ICNR, UnetBlockWide, NormType, custom_conv_layer
5
+ from basicsr.archs.ddcolor_arch_utils.convnext import ConvNeXt
6
+ from basicsr.archs.ddcolor_arch_utils.transformer_utils import SelfAttentionLayer, CrossAttentionLayer, FFNLayer, MLP
7
+ from basicsr.archs.ddcolor_arch_utils.position_encoding import PositionEmbeddingSine
8
+ from basicsr.archs.ddcolor_arch_utils.transformer import Transformer
9
+ from basicsr.utils.registry import ARCH_REGISTRY
10
+
11
+
12
+ @ARCH_REGISTRY.register()
13
+ class DDColor(nn.Module):
14
+
15
+ def __init__(self,
16
+ encoder_name='convnext-l',
17
+ decoder_name='MultiScaleColorDecoder',
18
+ num_input_channels=3,
19
+ input_size=(256, 256),
20
+ nf=512,
21
+ num_output_channels=3,
22
+ last_norm='Weight',
23
+ do_normalize=False,
24
+ num_queries=256,
25
+ num_scales=3,
26
+ dec_layers=9,
27
+ encoder_from_pretrain=False):
28
+ super().__init__()
29
+
30
+ self.encoder = Encoder(encoder_name, ['norm0', 'norm1', 'norm2', 'norm3'], from_pretrain=encoder_from_pretrain)
31
+ self.encoder.eval()
32
+ test_input = torch.randn(1, num_input_channels, *input_size)
33
+ self.encoder(test_input)
34
+
35
+ self.decoder = Decoder(
36
+ self.encoder.hooks,
37
+ nf=nf,
38
+ last_norm=last_norm,
39
+ num_queries=num_queries,
40
+ num_scales=num_scales,
41
+ dec_layers=dec_layers,
42
+ decoder_name=decoder_name
43
+ )
44
+ self.refine_net = nn.Sequential(custom_conv_layer(num_queries + 3, num_output_channels, ks=1, use_activ=False, norm_type=NormType.Spectral))
45
+
46
+ self.do_normalize = do_normalize
47
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
48
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
49
+
50
+ def normalize(self, img):
51
+ return (img - self.mean) / self.std
52
+
53
+ def denormalize(self, img):
54
+ return img * self.std + self.mean
55
+
56
+ def forward(self, x):
57
+ if x.shape[1] == 3:
58
+ x = self.normalize(x)
59
+
60
+ self.encoder(x)
61
+ out_feat = self.decoder()
62
+ coarse_input = torch.cat([out_feat, x], dim=1)
63
+ out = self.refine_net(coarse_input)
64
+
65
+ if self.do_normalize:
66
+ out = self.denormalize(out)
67
+ return out
68
+
69
+
70
+ class Decoder(nn.Module):
71
+
72
+ def __init__(self,
73
+ hooks,
74
+ nf=512,
75
+ blur=True,
76
+ last_norm='Weight',
77
+ num_queries=256,
78
+ num_scales=3,
79
+ dec_layers=9,
80
+ decoder_name='MultiScaleColorDecoder'):
81
+ super().__init__()
82
+ self.hooks = hooks
83
+ self.nf = nf
84
+ self.blur = blur
85
+ self.last_norm = getattr(NormType, last_norm)
86
+ self.decoder_name = decoder_name
87
+
88
+ self.layers = self.make_layers()
89
+ embed_dim = nf // 2
90
+
91
+ self.last_shuf = CustomPixelShuffle_ICNR(embed_dim, embed_dim, blur=self.blur, norm_type=self.last_norm, scale=4)
92
+
93
+ if self.decoder_name == 'MultiScaleColorDecoder':
94
+ self.color_decoder = MultiScaleColorDecoder(
95
+ in_channels=[512, 512, 256],
96
+ num_queries=num_queries,
97
+ num_scales=num_scales,
98
+ dec_layers=dec_layers,
99
+ )
100
+ else:
101
+ self.color_decoder = SingleColorDecoder(
102
+ in_channels=hooks[-1].feature.shape[1],
103
+ num_queries=num_queries,
104
+ )
105
+
106
+
107
+ def forward(self):
108
+ encode_feat = self.hooks[-1].feature
109
+ out0 = self.layers[0](encode_feat)
110
+ out1 = self.layers[1](out0)
111
+ out2 = self.layers[2](out1)
112
+ out3 = self.last_shuf(out2)
113
+
114
+ if self.decoder_name == 'MultiScaleColorDecoder':
115
+ out = self.color_decoder([out0, out1, out2], out3)
116
+ else:
117
+ out = self.color_decoder(out3, encode_feat)
118
+
119
+ return out
120
+
121
+ def make_layers(self):
122
+ decoder_layers = []
123
+
124
+ e_in_c = self.hooks[-1].feature.shape[1]
125
+ in_c = e_in_c
126
+
127
+ out_c = self.nf
128
+ setup_hooks = self.hooks[-2::-1]
129
+ for layer_index, hook in enumerate(setup_hooks):
130
+ feature_c = hook.feature.shape[1]
131
+ if layer_index == len(setup_hooks) - 1:
132
+ out_c = out_c // 2
133
+ decoder_layers.append(
134
+ UnetBlockWide(
135
+ in_c, feature_c, out_c, hook, blur=self.blur, self_attention=False, norm_type=NormType.Spectral))
136
+ in_c = out_c
137
+ return nn.Sequential(*decoder_layers)
138
+
139
+
140
+ class Encoder(nn.Module):
141
+
142
+ def __init__(self, encoder_name, hook_names, from_pretrain, **kwargs):
143
+ super().__init__()
144
+
145
+ if encoder_name == 'convnext-t' or encoder_name == 'convnext':
146
+ self.arch = ConvNeXt()
147
+ elif encoder_name == 'convnext-s':
148
+ self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768])
149
+ elif encoder_name == 'convnext-b':
150
+ self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
151
+ elif encoder_name == 'convnext-l':
152
+ self.arch = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
153
+ else:
154
+ raise NotImplementedError
155
+
156
+ self.encoder_name = encoder_name
157
+ self.hook_names = hook_names
158
+ self.hooks = self.setup_hooks()
159
+
160
+ if from_pretrain:
161
+ self.load_pretrain_model()
162
+
163
+ def setup_hooks(self):
164
+ hooks = [Hook(self.arch._modules[name]) for name in self.hook_names]
165
+ return hooks
166
+
167
+ def forward(self, x):
168
+ return self.arch(x)
169
+
170
+ def load_pretrain_model(self):
171
+ if self.encoder_name == 'convnext-t' or self.encoder_name == 'convnext':
172
+ self.load('pretrain/convnext_tiny_22k_224.pth')
173
+ elif self.encoder_name == 'convnext-s':
174
+ self.load('pretrain/convnext_small_22k_224.pth')
175
+ elif self.encoder_name == 'convnext-b':
176
+ self.load('pretrain/convnext_base_22k_224.pth')
177
+ elif self.encoder_name == 'convnext-l':
178
+ self.load('pretrain/convnext_large_22k_224.pth')
179
+ else:
180
+ raise NotImplementedError
181
+ print('Loaded pretrained convnext model.')
182
+
183
+ def load(self, path):
184
+ from basicsr.utils import get_root_logger
185
+ logger = get_root_logger()
186
+ if not path:
187
+ logger.info("No checkpoint found. Initializing model from scratch")
188
+ return
189
+ logger.info("[Encoder] Loading from {} ...".format(path))
190
+ checkpoint = torch.load(path, map_location=torch.device("cpu"))
191
+ checkpoint_state_dict = checkpoint['model'] if 'model' in checkpoint.keys() else checkpoint
192
+ incompatible = self.arch.load_state_dict(checkpoint_state_dict, strict=False)
193
+
194
+ if incompatible.missing_keys:
195
+ msg = "Some model parameters or buffers are not found in the checkpoint:\n"
196
+ msg += str(incompatible.missing_keys)
197
+ logger.warning(msg)
198
+ if incompatible.unexpected_keys:
199
+ msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
200
+ msg += str(incompatible.unexpected_keys)
201
+ logger.warning(msg)
202
+
203
+
204
+ class MultiScaleColorDecoder(nn.Module):
205
+
206
+ def __init__(
207
+ self,
208
+ in_channels,
209
+ hidden_dim=256,
210
+ num_queries=100,
211
+ nheads=8,
212
+ dim_feedforward=2048,
213
+ dec_layers=9,
214
+ pre_norm=False,
215
+ color_embed_dim=256,
216
+ enforce_input_project=True,
217
+ num_scales=3
218
+ ):
219
+ super().__init__()
220
+
221
+ # positional encoding
222
+ N_steps = hidden_dim // 2
223
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
224
+
225
+ # define Transformer decoder here
226
+ self.num_heads = nheads
227
+ self.num_layers = dec_layers
228
+ self.transformer_self_attention_layers = nn.ModuleList()
229
+ self.transformer_cross_attention_layers = nn.ModuleList()
230
+ self.transformer_ffn_layers = nn.ModuleList()
231
+
232
+ for _ in range(self.num_layers):
233
+ self.transformer_self_attention_layers.append(
234
+ SelfAttentionLayer(
235
+ d_model=hidden_dim,
236
+ nhead=nheads,
237
+ dropout=0.0,
238
+ normalize_before=pre_norm,
239
+ )
240
+ )
241
+ self.transformer_cross_attention_layers.append(
242
+ CrossAttentionLayer(
243
+ d_model=hidden_dim,
244
+ nhead=nheads,
245
+ dropout=0.0,
246
+ normalize_before=pre_norm,
247
+ )
248
+ )
249
+ self.transformer_ffn_layers.append(
250
+ FFNLayer(
251
+ d_model=hidden_dim,
252
+ dim_feedforward=dim_feedforward,
253
+ dropout=0.0,
254
+ normalize_before=pre_norm,
255
+ )
256
+ )
257
+
258
+ self.decoder_norm = nn.LayerNorm(hidden_dim)
259
+
260
+ self.num_queries = num_queries
261
+ # learnable color query features
262
+ self.query_feat = nn.Embedding(num_queries, hidden_dim)
263
+ # learnable color query p.e.
264
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
265
+
266
+ # level embedding
267
+ self.num_feature_levels = num_scales
268
+ self.level_embed = nn.Embedding(self.num_feature_levels, hidden_dim)
269
+
270
+ # input projections
271
+ self.input_proj = nn.ModuleList()
272
+ for i in range(self.num_feature_levels):
273
+ if in_channels[i] != hidden_dim or enforce_input_project:
274
+ self.input_proj.append(nn.Conv2d(in_channels[i], hidden_dim, kernel_size=1))
275
+ nn.init.kaiming_uniform_(self.input_proj[-1].weight, a=1)
276
+ if self.input_proj[-1].bias is not None:
277
+ nn.init.constant_(self.input_proj[-1].bias, 0)
278
+ else:
279
+ self.input_proj.append(nn.Sequential())
280
+
281
+ # output FFNs
282
+ self.color_embed = MLP(hidden_dim, hidden_dim, color_embed_dim, 3)
283
+
284
+ def forward(self, x, img_features):
285
+ # x is a list of multi-scale feature
286
+ assert len(x) == self.num_feature_levels
287
+ src = []
288
+ pos = []
289
+
290
+ for i in range(self.num_feature_levels):
291
+ pos.append(self.pe_layer(x[i], None).flatten(2))
292
+ src.append(self.input_proj[i](x[i]).flatten(2) + self.level_embed.weight[i][None, :, None])
293
+
294
+ # flatten NxCxHxW to HWxNxC
295
+ pos[-1] = pos[-1].permute(2, 0, 1)
296
+ src[-1] = src[-1].permute(2, 0, 1)
297
+
298
+ _, bs, _ = src[0].shape
299
+
300
+ # QxNxC
301
+ query_embed = self.query_embed.weight.unsqueeze(1).repeat(1, bs, 1)
302
+ output = self.query_feat.weight.unsqueeze(1).repeat(1, bs, 1)
303
+
304
+ for i in range(self.num_layers):
305
+ level_index = i % self.num_feature_levels
306
+ # attention: cross-attention first
307
+ output = self.transformer_cross_attention_layers[i](
308
+ output, src[level_index],
309
+ memory_mask=None,
310
+ memory_key_padding_mask=None,
311
+ pos=pos[level_index], query_pos=query_embed
312
+ )
313
+ output = self.transformer_self_attention_layers[i](
314
+ output, tgt_mask=None,
315
+ tgt_key_padding_mask=None,
316
+ query_pos=query_embed
317
+ )
318
+ # FFN
319
+ output = self.transformer_ffn_layers[i](
320
+ output
321
+ )
322
+
323
+ decoder_output = self.decoder_norm(output)
324
+ decoder_output = decoder_output.transpose(0, 1) # [N, bs, C] -> [bs, N, C]
325
+ color_embed = self.color_embed(decoder_output)
326
+ out = torch.einsum("bqc,bchw->bqhw", color_embed, img_features)
327
+
328
+ return out
329
+
330
+
331
+ class SingleColorDecoder(nn.Module):
332
+
333
+ def __init__(
334
+ self,
335
+ in_channels=768,
336
+ hidden_dim=256,
337
+ num_queries=256, # 100
338
+ nheads=8,
339
+ dropout=0.1,
340
+ dim_feedforward=2048,
341
+ enc_layers=0,
342
+ dec_layers=6,
343
+ pre_norm=False,
344
+ deep_supervision=True,
345
+ enforce_input_project=True,
346
+ ):
347
+
348
+ super().__init__()
349
+
350
+ N_steps = hidden_dim // 2
351
+ self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True)
352
+
353
+ transformer = Transformer(
354
+ d_model=hidden_dim,
355
+ dropout=dropout,
356
+ nhead=nheads,
357
+ dim_feedforward=dim_feedforward,
358
+ num_encoder_layers=enc_layers,
359
+ num_decoder_layers=dec_layers,
360
+ normalize_before=pre_norm,
361
+ return_intermediate_dec=deep_supervision,
362
+ )
363
+ self.num_queries = num_queries
364
+ self.transformer = transformer
365
+
366
+ self.query_embed = nn.Embedding(num_queries, hidden_dim)
367
+
368
+ if in_channels != hidden_dim or enforce_input_project:
369
+ self.input_proj = nn.Conv2d(in_channels, hidden_dim, kernel_size=1)
370
+ nn.init.kaiming_uniform_(self.input_proj.weight, a=1)
371
+ if self.input_proj.bias is not None:
372
+ nn.init.constant_(self.input_proj.bias, 0)
373
+ else:
374
+ self.input_proj = nn.Sequential()
375
+
376
+
377
+ def forward(self, img_features, encode_feat):
378
+ pos = self.pe_layer(encode_feat)
379
+ src = encode_feat
380
+ mask = None
381
+ hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos)
382
+ color_embed = hs[-1]
383
+ color_preds = torch.einsum('bqc,bchw->bqhw', color_embed, img_features)
384
+ return color_preds
385
+
basicsr/archs/ddcolor_arch_utils/__int__.py ADDED
File without changes
basicsr/archs/ddcolor_arch_utils/__pycache__/convnext.cpython-38.pyc ADDED
Binary file (6.2 kB). View file
 
basicsr/archs/ddcolor_arch_utils/__pycache__/convnext.cpython-39.pyc ADDED
Binary file (6.12 kB). View file
 
basicsr/archs/ddcolor_arch_utils/__pycache__/position_encoding.cpython-38.pyc ADDED
Binary file (2.03 kB). View file
 
basicsr/archs/ddcolor_arch_utils/__pycache__/position_encoding.cpython-39.pyc ADDED
Binary file (2.05 kB). View file
 
basicsr/archs/ddcolor_arch_utils/__pycache__/transformer.cpython-38.pyc ADDED
Binary file (8.81 kB). View file
 
basicsr/archs/ddcolor_arch_utils/__pycache__/transformer.cpython-39.pyc ADDED
Binary file (8.77 kB). View file
 
basicsr/archs/ddcolor_arch_utils/__pycache__/transformer_utils.cpython-38.pyc ADDED
Binary file (6.57 kB). View file
 
basicsr/archs/ddcolor_arch_utils/__pycache__/transformer_utils.cpython-39.pyc ADDED
Binary file (6.57 kB). View file
 
basicsr/archs/ddcolor_arch_utils/__pycache__/unet.cpython-38.pyc ADDED
Binary file (7.37 kB). View file
 
basicsr/archs/ddcolor_arch_utils/__pycache__/unet.cpython-39.pyc ADDED
Binary file (7.39 kB). View file
 
basicsr/archs/ddcolor_arch_utils/convnext.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+
3
+ # All rights reserved.
4
+
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from timm.models.layers import trunc_normal_, DropPath
13
+
14
+ class Block(nn.Module):
15
+ r""" ConvNeXt Block. There are two equivalent implementations:
16
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
17
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
18
+ We use (2) as we find it slightly faster in PyTorch
19
+
20
+ Args:
21
+ dim (int): Number of input channels.
22
+ drop_path (float): Stochastic depth rate. Default: 0.0
23
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
24
+ """
25
+ def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
26
+ super().__init__()
27
+ self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
28
+ self.norm = LayerNorm(dim, eps=1e-6)
29
+ self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
30
+ self.act = nn.GELU()
31
+ self.pwconv2 = nn.Linear(4 * dim, dim)
32
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
33
+ requires_grad=True) if layer_scale_init_value > 0 else None
34
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
35
+
36
+ def forward(self, x):
37
+ input = x
38
+ x = self.dwconv(x)
39
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
40
+ x = self.norm(x)
41
+ x = self.pwconv1(x)
42
+ x = self.act(x)
43
+ x = self.pwconv2(x)
44
+ if self.gamma is not None:
45
+ x = self.gamma * x
46
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
47
+
48
+ x = input + self.drop_path(x)
49
+ return x
50
+
51
+ class ConvNeXt(nn.Module):
52
+ r""" ConvNeXt
53
+ A PyTorch impl of : `A ConvNet for the 2020s` -
54
+ https://arxiv.org/pdf/2201.03545.pdf
55
+ Args:
56
+ in_chans (int): Number of input image channels. Default: 3
57
+ num_classes (int): Number of classes for classification head. Default: 1000
58
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
59
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
60
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
61
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
62
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
63
+ """
64
+ def __init__(self, in_chans=3, num_classes=1000,
65
+ depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0.,
66
+ layer_scale_init_value=1e-6, head_init_scale=1.,
67
+ ):
68
+ super().__init__()
69
+
70
+ self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
71
+ stem = nn.Sequential(
72
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
73
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
74
+ )
75
+ self.downsample_layers.append(stem)
76
+ for i in range(3):
77
+ downsample_layer = nn.Sequential(
78
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
79
+ nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
80
+ )
81
+ self.downsample_layers.append(downsample_layer)
82
+
83
+ self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
84
+ dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
85
+ cur = 0
86
+ for i in range(4):
87
+ stage = nn.Sequential(
88
+ *[Block(dim=dims[i], drop_path=dp_rates[cur + j],
89
+ layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])]
90
+ )
91
+ self.stages.append(stage)
92
+ cur += depths[i]
93
+
94
+ # add norm layers for each output
95
+ out_indices = (0, 1, 2, 3)
96
+ for i in out_indices:
97
+ layer = LayerNorm(dims[i], eps=1e-6, data_format="channels_first")
98
+ # layer = nn.Identity()
99
+ layer_name = f'norm{i}'
100
+ self.add_module(layer_name, layer)
101
+
102
+ self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
103
+ # self.head_cls = nn.Linear(dims[-1], 4)
104
+
105
+ self.apply(self._init_weights)
106
+ # self.head_cls.weight.data.mul_(head_init_scale)
107
+ # self.head_cls.bias.data.mul_(head_init_scale)
108
+
109
+ def _init_weights(self, m):
110
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
111
+ trunc_normal_(m.weight, std=.02)
112
+ nn.init.constant_(m.bias, 0)
113
+
114
+ def forward_features(self, x):
115
+ for i in range(4):
116
+ x = self.downsample_layers[i](x)
117
+ x = self.stages[i](x)
118
+
119
+ # add extra norm
120
+ norm_layer = getattr(self, f'norm{i}')
121
+ # x = norm_layer(x)
122
+ norm_layer(x)
123
+
124
+ return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
125
+
126
+ def forward(self, x):
127
+ x = self.forward_features(x)
128
+ # x = self.head_cls(x)
129
+ return x
130
+
131
+ class LayerNorm(nn.Module):
132
+ r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
133
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
134
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
135
+ with shape (batch_size, channels, height, width).
136
+ """
137
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
138
+ super().__init__()
139
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
140
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
141
+ self.eps = eps
142
+ self.data_format = data_format
143
+ if self.data_format not in ["channels_last", "channels_first"]:
144
+ raise NotImplementedError
145
+ self.normalized_shape = (normalized_shape, )
146
+
147
+ def forward(self, x):
148
+ if self.data_format == "channels_last": # B H W C
149
+ return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
150
+ elif self.data_format == "channels_first": # B C H W
151
+ u = x.mean(1, keepdim=True)
152
+ s = (x - u).pow(2).mean(1, keepdim=True)
153
+ x = (x - u) / torch.sqrt(s + self.eps)
154
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
155
+ return x
basicsr/archs/ddcolor_arch_utils/position_encoding.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py
3
+ """
4
+ Various positional encodings for the transformer.
5
+ """
6
+ import math
7
+
8
+ import torch
9
+ from torch import nn
10
+
11
+
12
+ class PositionEmbeddingSine(nn.Module):
13
+ """
14
+ This is a more standard version of the position embedding, very similar to the one
15
+ used by the Attention is all you need paper, generalized to work on images.
16
+ """
17
+
18
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
19
+ super().__init__()
20
+ self.num_pos_feats = num_pos_feats
21
+ self.temperature = temperature
22
+ self.normalize = normalize
23
+ if scale is not None and normalize is False:
24
+ raise ValueError("normalize should be True if scale is passed")
25
+ if scale is None:
26
+ scale = 2 * math.pi
27
+ self.scale = scale
28
+
29
+ def forward(self, x, mask=None):
30
+ if mask is None:
31
+ mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
32
+ not_mask = ~mask
33
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
34
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
35
+ if self.normalize:
36
+ eps = 1e-6
37
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39
+
40
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42
+
43
+ pos_x = x_embed[:, :, :, None] / dim_t
44
+ pos_y = y_embed[:, :, :, None] / dim_t
45
+ pos_x = torch.stack(
46
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
47
+ ).flatten(3)
48
+ pos_y = torch.stack(
49
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
50
+ ).flatten(3)
51
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
52
+ return pos
basicsr/archs/ddcolor_arch_utils/transformer.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # Modified from: https://github.com/facebookresearch/detr/blob/master/models/transformer.py
3
+ """
4
+ Transformer class.
5
+ Copy-paste from torch.nn.Transformer with modifications:
6
+ * positional encodings are passed in MHattention
7
+ * extra LN at the end of encoder is removed
8
+ * decoder returns a stack of activations from all decoding layers
9
+ """
10
+ import copy
11
+ from typing import List, Optional
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import Tensor, nn
16
+
17
+
18
+ class Transformer(nn.Module):
19
+ def __init__(
20
+ self,
21
+ d_model=512,
22
+ nhead=8,
23
+ num_encoder_layers=6,
24
+ num_decoder_layers=6,
25
+ dim_feedforward=2048,
26
+ dropout=0.1,
27
+ activation="relu",
28
+ normalize_before=False,
29
+ return_intermediate_dec=False,
30
+ ):
31
+ super().__init__()
32
+
33
+ encoder_layer = TransformerEncoderLayer(
34
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
35
+ )
36
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
37
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
38
+
39
+ decoder_layer = TransformerDecoderLayer(
40
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
41
+ )
42
+ decoder_norm = nn.LayerNorm(d_model)
43
+ self.decoder = TransformerDecoder(
44
+ decoder_layer,
45
+ num_decoder_layers,
46
+ decoder_norm,
47
+ return_intermediate=return_intermediate_dec,
48
+ )
49
+
50
+ self._reset_parameters()
51
+
52
+ self.d_model = d_model
53
+ self.nhead = nhead
54
+
55
+ def _reset_parameters(self):
56
+ for p in self.parameters():
57
+ if p.dim() > 1:
58
+ nn.init.xavier_uniform_(p)
59
+
60
+ def forward(self, src, mask, query_embed, pos_embed):
61
+ # flatten NxCxHxW to HWxNxC
62
+ bs, c, h, w = src.shape
63
+ src = src.flatten(2).permute(2, 0, 1)
64
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
65
+ query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
66
+ if mask is not None:
67
+ mask = mask.flatten(1)
68
+
69
+ tgt = torch.zeros_like(query_embed)
70
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
71
+ hs = self.decoder(
72
+ tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed
73
+ )
74
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
75
+
76
+
77
+ class TransformerEncoder(nn.Module):
78
+ def __init__(self, encoder_layer, num_layers, norm=None):
79
+ super().__init__()
80
+ self.layers = _get_clones(encoder_layer, num_layers)
81
+ self.num_layers = num_layers
82
+ self.norm = norm
83
+
84
+ def forward(
85
+ self,
86
+ src,
87
+ mask: Optional[Tensor] = None,
88
+ src_key_padding_mask: Optional[Tensor] = None,
89
+ pos: Optional[Tensor] = None,
90
+ ):
91
+ output = src
92
+
93
+ for layer in self.layers:
94
+ output = layer(
95
+ output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos
96
+ )
97
+
98
+ if self.norm is not None:
99
+ output = self.norm(output)
100
+
101
+ return output
102
+
103
+
104
+ class TransformerDecoder(nn.Module):
105
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
106
+ super().__init__()
107
+ self.layers = _get_clones(decoder_layer, num_layers)
108
+ self.num_layers = num_layers
109
+ self.norm = norm
110
+ self.return_intermediate = return_intermediate
111
+
112
+ def forward(
113
+ self,
114
+ tgt,
115
+ memory,
116
+ tgt_mask: Optional[Tensor] = None,
117
+ memory_mask: Optional[Tensor] = None,
118
+ tgt_key_padding_mask: Optional[Tensor] = None,
119
+ memory_key_padding_mask: Optional[Tensor] = None,
120
+ pos: Optional[Tensor] = None,
121
+ query_pos: Optional[Tensor] = None,
122
+ ):
123
+ output = tgt
124
+
125
+ intermediate = []
126
+
127
+ for layer in self.layers:
128
+ output = layer(
129
+ output,
130
+ memory,
131
+ tgt_mask=tgt_mask,
132
+ memory_mask=memory_mask,
133
+ tgt_key_padding_mask=tgt_key_padding_mask,
134
+ memory_key_padding_mask=memory_key_padding_mask,
135
+ pos=pos,
136
+ query_pos=query_pos,
137
+ )
138
+ if self.return_intermediate:
139
+ intermediate.append(self.norm(output))
140
+
141
+ if self.norm is not None:
142
+ output = self.norm(output)
143
+ if self.return_intermediate:
144
+ intermediate.pop()
145
+ intermediate.append(output)
146
+
147
+ if self.return_intermediate:
148
+ return torch.stack(intermediate)
149
+
150
+ return output.unsqueeze(0)
151
+
152
+
153
+ class TransformerEncoderLayer(nn.Module):
154
+ def __init__(
155
+ self,
156
+ d_model,
157
+ nhead,
158
+ dim_feedforward=2048,
159
+ dropout=0.1,
160
+ activation="relu",
161
+ normalize_before=False,
162
+ ):
163
+ super().__init__()
164
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
165
+ # Implementation of Feedforward model
166
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
167
+ self.dropout = nn.Dropout(dropout)
168
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
169
+
170
+ self.norm1 = nn.LayerNorm(d_model)
171
+ self.norm2 = nn.LayerNorm(d_model)
172
+ self.dropout1 = nn.Dropout(dropout)
173
+ self.dropout2 = nn.Dropout(dropout)
174
+
175
+ self.activation = _get_activation_fn(activation)
176
+ self.normalize_before = normalize_before
177
+
178
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
179
+ return tensor if pos is None else tensor + pos
180
+
181
+ def forward_post(
182
+ self,
183
+ src,
184
+ src_mask: Optional[Tensor] = None,
185
+ src_key_padding_mask: Optional[Tensor] = None,
186
+ pos: Optional[Tensor] = None,
187
+ ):
188
+ q = k = self.with_pos_embed(src, pos)
189
+ src2 = self.self_attn(
190
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
191
+ )[0]
192
+ src = src + self.dropout1(src2)
193
+ src = self.norm1(src)
194
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
195
+ src = src + self.dropout2(src2)
196
+ src = self.norm2(src)
197
+ return src
198
+
199
+ def forward_pre(
200
+ self,
201
+ src,
202
+ src_mask: Optional[Tensor] = None,
203
+ src_key_padding_mask: Optional[Tensor] = None,
204
+ pos: Optional[Tensor] = None,
205
+ ):
206
+ src2 = self.norm1(src)
207
+ q = k = self.with_pos_embed(src2, pos)
208
+ src2 = self.self_attn(
209
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
210
+ )[0]
211
+ src = src + self.dropout1(src2)
212
+ src2 = self.norm2(src)
213
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
214
+ src = src + self.dropout2(src2)
215
+ return src
216
+
217
+ def forward(
218
+ self,
219
+ src,
220
+ src_mask: Optional[Tensor] = None,
221
+ src_key_padding_mask: Optional[Tensor] = None,
222
+ pos: Optional[Tensor] = None,
223
+ ):
224
+ if self.normalize_before:
225
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
226
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
227
+
228
+
229
+ class TransformerDecoderLayer(nn.Module):
230
+ def __init__(
231
+ self,
232
+ d_model,
233
+ nhead,
234
+ dim_feedforward=2048,
235
+ dropout=0.1,
236
+ activation="relu",
237
+ normalize_before=False,
238
+ ):
239
+ super().__init__()
240
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
241
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
242
+ # Implementation of Feedforward model
243
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
244
+ self.dropout = nn.Dropout(dropout)
245
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
246
+
247
+ self.norm1 = nn.LayerNorm(d_model)
248
+ self.norm2 = nn.LayerNorm(d_model)
249
+ self.norm3 = nn.LayerNorm(d_model)
250
+ self.dropout1 = nn.Dropout(dropout)
251
+ self.dropout2 = nn.Dropout(dropout)
252
+ self.dropout3 = nn.Dropout(dropout)
253
+
254
+ self.activation = _get_activation_fn(activation)
255
+ self.normalize_before = normalize_before
256
+
257
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
258
+ return tensor if pos is None else tensor + pos
259
+
260
+ def forward_post(
261
+ self,
262
+ tgt,
263
+ memory,
264
+ tgt_mask: Optional[Tensor] = None,
265
+ memory_mask: Optional[Tensor] = None,
266
+ tgt_key_padding_mask: Optional[Tensor] = None,
267
+ memory_key_padding_mask: Optional[Tensor] = None,
268
+ pos: Optional[Tensor] = None,
269
+ query_pos: Optional[Tensor] = None,
270
+ ):
271
+ q = k = self.with_pos_embed(tgt, query_pos)
272
+ tgt2 = self.self_attn(
273
+ q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
274
+ )[0]
275
+ tgt = tgt + self.dropout1(tgt2)
276
+ tgt = self.norm1(tgt)
277
+ tgt2 = self.multihead_attn(
278
+ query=self.with_pos_embed(tgt, query_pos),
279
+ key=self.with_pos_embed(memory, pos),
280
+ value=memory,
281
+ attn_mask=memory_mask,
282
+ key_padding_mask=memory_key_padding_mask,
283
+ )[0]
284
+ tgt = tgt + self.dropout2(tgt2)
285
+ tgt = self.norm2(tgt)
286
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
287
+ tgt = tgt + self.dropout3(tgt2)
288
+ tgt = self.norm3(tgt)
289
+ return tgt
290
+
291
+ def forward_pre(
292
+ self,
293
+ tgt,
294
+ memory,
295
+ tgt_mask: Optional[Tensor] = None,
296
+ memory_mask: Optional[Tensor] = None,
297
+ tgt_key_padding_mask: Optional[Tensor] = None,
298
+ memory_key_padding_mask: Optional[Tensor] = None,
299
+ pos: Optional[Tensor] = None,
300
+ query_pos: Optional[Tensor] = None,
301
+ ):
302
+ tgt2 = self.norm1(tgt)
303
+ q = k = self.with_pos_embed(tgt2, query_pos)
304
+ tgt2 = self.self_attn(
305
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
306
+ )[0]
307
+ tgt = tgt + self.dropout1(tgt2)
308
+ tgt2 = self.norm2(tgt)
309
+ tgt2 = self.multihead_attn(
310
+ query=self.with_pos_embed(tgt2, query_pos),
311
+ key=self.with_pos_embed(memory, pos),
312
+ value=memory,
313
+ attn_mask=memory_mask,
314
+ key_padding_mask=memory_key_padding_mask,
315
+ )[0]
316
+ tgt = tgt + self.dropout2(tgt2)
317
+ tgt2 = self.norm3(tgt)
318
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
319
+ tgt = tgt + self.dropout3(tgt2)
320
+ return tgt
321
+
322
+ def forward(
323
+ self,
324
+ tgt,
325
+ memory,
326
+ tgt_mask: Optional[Tensor] = None,
327
+ memory_mask: Optional[Tensor] = None,
328
+ tgt_key_padding_mask: Optional[Tensor] = None,
329
+ memory_key_padding_mask: Optional[Tensor] = None,
330
+ pos: Optional[Tensor] = None,
331
+ query_pos: Optional[Tensor] = None,
332
+ ):
333
+ if self.normalize_before:
334
+ return self.forward_pre(
335
+ tgt,
336
+ memory,
337
+ tgt_mask,
338
+ memory_mask,
339
+ tgt_key_padding_mask,
340
+ memory_key_padding_mask,
341
+ pos,
342
+ query_pos,
343
+ )
344
+ return self.forward_post(
345
+ tgt,
346
+ memory,
347
+ tgt_mask,
348
+ memory_mask,
349
+ tgt_key_padding_mask,
350
+ memory_key_padding_mask,
351
+ pos,
352
+ query_pos,
353
+ )
354
+
355
+
356
+ def _get_clones(module, N):
357
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
358
+
359
+
360
+ def _get_activation_fn(activation):
361
+ """Return an activation function given a string"""
362
+ if activation == "relu":
363
+ return F.relu
364
+ if activation == "gelu":
365
+ return F.gelu
366
+ if activation == "glu":
367
+ return F.glu
368
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
basicsr/archs/ddcolor_arch_utils/transformer_utils.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ from torch import nn, Tensor
3
+ from torch.nn import functional as F
4
+
5
+ class SelfAttentionLayer(nn.Module):
6
+
7
+ def __init__(self, d_model, nhead, dropout=0.0,
8
+ activation="relu", normalize_before=False):
9
+ super().__init__()
10
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
11
+
12
+ self.norm = nn.LayerNorm(d_model)
13
+ self.dropout = nn.Dropout(dropout)
14
+
15
+ self.activation = _get_activation_fn(activation)
16
+ self.normalize_before = normalize_before
17
+
18
+ self._reset_parameters()
19
+
20
+ def _reset_parameters(self):
21
+ for p in self.parameters():
22
+ if p.dim() > 1:
23
+ nn.init.xavier_uniform_(p)
24
+
25
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
26
+ return tensor if pos is None else tensor + pos
27
+
28
+ def forward_post(self, tgt,
29
+ tgt_mask: Optional[Tensor] = None,
30
+ tgt_key_padding_mask: Optional[Tensor] = None,
31
+ query_pos: Optional[Tensor] = None):
32
+ q = k = self.with_pos_embed(tgt, query_pos)
33
+ tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
34
+ key_padding_mask=tgt_key_padding_mask)[0]
35
+ tgt = tgt + self.dropout(tgt2)
36
+ tgt = self.norm(tgt)
37
+
38
+ return tgt
39
+
40
+ def forward_pre(self, tgt,
41
+ tgt_mask: Optional[Tensor] = None,
42
+ tgt_key_padding_mask: Optional[Tensor] = None,
43
+ query_pos: Optional[Tensor] = None):
44
+ tgt2 = self.norm(tgt)
45
+ q = k = self.with_pos_embed(tgt2, query_pos)
46
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
47
+ key_padding_mask=tgt_key_padding_mask)[0]
48
+ tgt = tgt + self.dropout(tgt2)
49
+
50
+ return tgt
51
+
52
+ def forward(self, tgt,
53
+ tgt_mask: Optional[Tensor] = None,
54
+ tgt_key_padding_mask: Optional[Tensor] = None,
55
+ query_pos: Optional[Tensor] = None):
56
+ if self.normalize_before:
57
+ return self.forward_pre(tgt, tgt_mask,
58
+ tgt_key_padding_mask, query_pos)
59
+ return self.forward_post(tgt, tgt_mask,
60
+ tgt_key_padding_mask, query_pos)
61
+
62
+
63
+ class CrossAttentionLayer(nn.Module):
64
+
65
+ def __init__(self, d_model, nhead, dropout=0.0,
66
+ activation="relu", normalize_before=False):
67
+ super().__init__()
68
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
69
+
70
+ self.norm = nn.LayerNorm(d_model)
71
+ self.dropout = nn.Dropout(dropout)
72
+
73
+ self.activation = _get_activation_fn(activation)
74
+ self.normalize_before = normalize_before
75
+
76
+ self._reset_parameters()
77
+
78
+ def _reset_parameters(self):
79
+ for p in self.parameters():
80
+ if p.dim() > 1:
81
+ nn.init.xavier_uniform_(p)
82
+
83
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
84
+ return tensor if pos is None else tensor + pos
85
+
86
+ def forward_post(self, tgt, memory,
87
+ memory_mask: Optional[Tensor] = None,
88
+ memory_key_padding_mask: Optional[Tensor] = None,
89
+ pos: Optional[Tensor] = None,
90
+ query_pos: Optional[Tensor] = None):
91
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
92
+ key=self.with_pos_embed(memory, pos),
93
+ value=memory, attn_mask=memory_mask,
94
+ key_padding_mask=memory_key_padding_mask)[0]
95
+ tgt = tgt + self.dropout(tgt2)
96
+ tgt = self.norm(tgt)
97
+
98
+ return tgt
99
+
100
+ def forward_pre(self, tgt, memory,
101
+ memory_mask: Optional[Tensor] = None,
102
+ memory_key_padding_mask: Optional[Tensor] = None,
103
+ pos: Optional[Tensor] = None,
104
+ query_pos: Optional[Tensor] = None):
105
+ tgt2 = self.norm(tgt)
106
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
107
+ key=self.with_pos_embed(memory, pos),
108
+ value=memory, attn_mask=memory_mask,
109
+ key_padding_mask=memory_key_padding_mask)[0]
110
+ tgt = tgt + self.dropout(tgt2)
111
+
112
+ return tgt
113
+
114
+ def forward(self, tgt, memory,
115
+ memory_mask: Optional[Tensor] = None,
116
+ memory_key_padding_mask: Optional[Tensor] = None,
117
+ pos: Optional[Tensor] = None,
118
+ query_pos: Optional[Tensor] = None):
119
+ if self.normalize_before:
120
+ return self.forward_pre(tgt, memory, memory_mask,
121
+ memory_key_padding_mask, pos, query_pos)
122
+ return self.forward_post(tgt, memory, memory_mask,
123
+ memory_key_padding_mask, pos, query_pos)
124
+
125
+
126
+ class FFNLayer(nn.Module):
127
+
128
+ def __init__(self, d_model, dim_feedforward=2048, dropout=0.0,
129
+ activation="relu", normalize_before=False):
130
+ super().__init__()
131
+ # Implementation of Feedforward model
132
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
133
+ self.dropout = nn.Dropout(dropout)
134
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
135
+
136
+ self.norm = nn.LayerNorm(d_model)
137
+
138
+ self.activation = _get_activation_fn(activation)
139
+ self.normalize_before = normalize_before
140
+
141
+ self._reset_parameters()
142
+
143
+ def _reset_parameters(self):
144
+ for p in self.parameters():
145
+ if p.dim() > 1:
146
+ nn.init.xavier_uniform_(p)
147
+
148
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
149
+ return tensor if pos is None else tensor + pos
150
+
151
+ def forward_post(self, tgt):
152
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
153
+ tgt = tgt + self.dropout(tgt2)
154
+ tgt = self.norm(tgt)
155
+ return tgt
156
+
157
+ def forward_pre(self, tgt):
158
+ tgt2 = self.norm(tgt)
159
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
160
+ tgt = tgt + self.dropout(tgt2)
161
+ return tgt
162
+
163
+ def forward(self, tgt):
164
+ if self.normalize_before:
165
+ return self.forward_pre(tgt)
166
+ return self.forward_post(tgt)
167
+
168
+
169
+ def _get_activation_fn(activation):
170
+ """Return an activation function given a string"""
171
+ if activation == "relu":
172
+ return F.relu
173
+ if activation == "gelu":
174
+ return F.gelu
175
+ if activation == "glu":
176
+ return F.glu
177
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
178
+
179
+
180
+ class MLP(nn.Module):
181
+ """ Very simple multi-layer perceptron (also called FFN)"""
182
+
183
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
184
+ super().__init__()
185
+ self.num_layers = num_layers
186
+ h = [hidden_dim] * (num_layers - 1)
187
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
188
+
189
+ def forward(self, x):
190
+ for i, layer in enumerate(self.layers):
191
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
192
+ return x
basicsr/archs/ddcolor_arch_utils/unet.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ import collections
6
+
7
+
8
+ NormType = Enum('NormType', 'Batch BatchZero Weight Spectral')
9
+
10
+
11
+ class Hook:
12
+ feature = None
13
+
14
+ def __init__(self, module):
15
+ self.hook = module.register_forward_hook(self.hook_fn)
16
+
17
+ def hook_fn(self, module, input, output):
18
+ if isinstance(output, torch.Tensor):
19
+ self.feature = output
20
+ elif isinstance(output, collections.OrderedDict):
21
+ self.feature = output['out']
22
+
23
+ def remove(self):
24
+ self.hook.remove()
25
+
26
+
27
+ class SelfAttention(nn.Module):
28
+ "Self attention layer for nd."
29
+
30
+ def __init__(self, n_channels: int):
31
+ super().__init__()
32
+ self.query = conv1d(n_channels, n_channels // 8)
33
+ self.key = conv1d(n_channels, n_channels // 8)
34
+ self.value = conv1d(n_channels, n_channels)
35
+ self.gamma = nn.Parameter(torch.tensor([0.]))
36
+
37
+ def forward(self, x):
38
+ #Notation from https://arxiv.org/pdf/1805.08318.pdf
39
+ size = x.size()
40
+ x = x.view(*size[:2], -1)
41
+ f, g, h = self.query(x), self.key(x), self.value(x)
42
+ beta = F.softmax(torch.bmm(f.permute(0, 2, 1).contiguous(), g), dim=1)
43
+ o = self.gamma * torch.bmm(h, beta) + x
44
+ return o.view(*size).contiguous()
45
+
46
+
47
+ def batchnorm_2d(nf: int, norm_type: NormType = NormType.Batch):
48
+ "A batchnorm2d layer with `nf` features initialized depending on `norm_type`."
49
+ bn = nn.BatchNorm2d(nf)
50
+ with torch.no_grad():
51
+ bn.bias.fill_(1e-3)
52
+ bn.weight.fill_(0. if norm_type == NormType.BatchZero else 1.)
53
+ return bn
54
+
55
+
56
+ def init_default(m: nn.Module, func=nn.init.kaiming_normal_) -> None:
57
+ "Initialize `m` weights with `func` and set `bias` to 0."
58
+ if func:
59
+ if hasattr(m, 'weight'): func(m.weight)
60
+ if hasattr(m, 'bias') and hasattr(m.bias, 'data'): m.bias.data.fill_(0.)
61
+ return m
62
+
63
+
64
+ def icnr(x, scale=2, init=nn.init.kaiming_normal_):
65
+ "ICNR init of `x`, with `scale` and `init` function."
66
+ ni, nf, h, w = x.shape
67
+ ni2 = int(ni / (scale**2))
68
+ k = init(torch.zeros([ni2, nf, h, w])).transpose(0, 1)
69
+ k = k.contiguous().view(ni2, nf, -1)
70
+ k = k.repeat(1, 1, scale**2)
71
+ k = k.contiguous().view([nf, ni, h, w]).transpose(0, 1)
72
+ x.data.copy_(k)
73
+
74
+
75
+ def conv1d(ni: int, no: int, ks: int = 1, stride: int = 1, padding: int = 0, bias: bool = False):
76
+ "Create and initialize a `nn.Conv1d` layer with spectral normalization."
77
+ conv = nn.Conv1d(ni, no, ks, stride=stride, padding=padding, bias=bias)
78
+ nn.init.kaiming_normal_(conv.weight)
79
+ if bias: conv.bias.data.zero_()
80
+ return nn.utils.spectral_norm(conv)
81
+
82
+
83
+ def custom_conv_layer(
84
+ ni: int,
85
+ nf: int,
86
+ ks: int = 3,
87
+ stride: int = 1,
88
+ padding: int = None,
89
+ bias: bool = None,
90
+ is_1d: bool = False,
91
+ norm_type=NormType.Batch,
92
+ use_activ: bool = True,
93
+ transpose: bool = False,
94
+ init=nn.init.kaiming_normal_,
95
+ self_attention: bool = False,
96
+ extra_bn: bool = False,
97
+ ):
98
+ "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
99
+ if padding is None:
100
+ padding = (ks - 1) // 2 if not transpose else 0
101
+ bn = norm_type in (NormType.Batch, NormType.BatchZero) or extra_bn == True
102
+ if bias is None:
103
+ bias = not bn
104
+ conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
105
+ conv = init_default(
106
+ conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding),
107
+ init,
108
+ )
109
+
110
+ if norm_type == NormType.Weight:
111
+ conv = nn.utils.weight_norm(conv)
112
+ elif norm_type == NormType.Spectral:
113
+ conv = nn.utils.spectral_norm(conv)
114
+ layers = [conv]
115
+ if use_activ:
116
+ layers.append(nn.ReLU(True))
117
+ if bn:
118
+ layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
119
+ if self_attention:
120
+ layers.append(SelfAttention(nf))
121
+ return nn.Sequential(*layers)
122
+
123
+
124
+ def conv_layer(ni: int,
125
+ nf: int,
126
+ ks: int = 3,
127
+ stride: int = 1,
128
+ padding: int = None,
129
+ bias: bool = None,
130
+ is_1d: bool = False,
131
+ norm_type=NormType.Batch,
132
+ use_activ: bool = True,
133
+ transpose: bool = False,
134
+ init=nn.init.kaiming_normal_,
135
+ self_attention: bool = False):
136
+ "Create a sequence of convolutional (`ni` to `nf`), ReLU (if `use_activ`) and batchnorm (if `bn`) layers."
137
+ if padding is None: padding = (ks - 1) // 2 if not transpose else 0
138
+ bn = norm_type in (NormType.Batch, NormType.BatchZero)
139
+ if bias is None: bias = not bn
140
+ conv_func = nn.ConvTranspose2d if transpose else nn.Conv1d if is_1d else nn.Conv2d
141
+ conv = init_default(conv_func(ni, nf, kernel_size=ks, bias=bias, stride=stride, padding=padding), init)
142
+ if norm_type == NormType.Weight: conv = nn.utils.weight_norm(conv)
143
+ elif norm_type == NormType.Spectral: conv = nn.utils.spectral_norm(conv)
144
+ layers = [conv]
145
+ if use_activ: layers.append(nn.ReLU(True))
146
+ if bn: layers.append((nn.BatchNorm1d if is_1d else nn.BatchNorm2d)(nf))
147
+ if self_attention: layers.append(SelfAttention(nf))
148
+ return nn.Sequential(*layers)
149
+
150
+
151
+ def _conv(ni: int, nf: int, ks: int = 3, stride: int = 1, **kwargs):
152
+ return conv_layer(ni, nf, ks=ks, stride=stride, norm_type=NormType.Spectral, **kwargs)
153
+
154
+
155
+ class CustomPixelShuffle_ICNR(nn.Module):
156
+ "Upsample by `scale` from `ni` filters to `nf` (default `ni`), using `nn.PixelShuffle`, `icnr` init, and `weight_norm`."
157
+
158
+ def __init__(self,
159
+ ni: int,
160
+ nf: int = None,
161
+ scale: int = 2,
162
+ blur: bool = True,
163
+ norm_type=NormType.Spectral,
164
+ extra_bn=False):
165
+ super().__init__()
166
+ self.conv = custom_conv_layer(
167
+ ni, nf * (scale**2), ks=1, use_activ=False, norm_type=norm_type, extra_bn=extra_bn)
168
+ icnr(self.conv[0].weight)
169
+ self.shuf = nn.PixelShuffle(scale)
170
+ self.do_blur = blur
171
+ # Blurring over (h*w) kernel
172
+ # "Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts"
173
+ # - https://arxiv.org/abs/1806.02658
174
+ self.pad = nn.ReplicationPad2d((1, 0, 1, 0))
175
+ self.blur = nn.AvgPool2d(2, stride=1)
176
+ self.relu = nn.ReLU(True)
177
+
178
+ def forward(self, x):
179
+ x = self.shuf(self.relu(self.conv(x)))
180
+ return self.blur(self.pad(x)) if self.do_blur else x
181
+
182
+
183
+ class UnetBlockWide(nn.Module):
184
+ "A quasi-UNet block, using `PixelShuffle_ICNR upsampling`."
185
+
186
+ def __init__(self,
187
+ up_in_c: int,
188
+ x_in_c: int,
189
+ n_out: int,
190
+ hook,
191
+ blur: bool = False,
192
+ self_attention: bool = False,
193
+ norm_type=NormType.Spectral):
194
+ super().__init__()
195
+
196
+ self.hook = hook
197
+ up_out = n_out
198
+ self.shuf = CustomPixelShuffle_ICNR(up_in_c, up_out, blur=blur, norm_type=norm_type, extra_bn=True)
199
+ self.bn = batchnorm_2d(x_in_c)
200
+ ni = up_out + x_in_c
201
+ self.conv = custom_conv_layer(ni, n_out, norm_type=norm_type, self_attention=self_attention, extra_bn=True)
202
+ self.relu = nn.ReLU()
203
+
204
+ def forward(self, up_in):
205
+ s = self.hook.feature
206
+ up_out = self.shuf(up_in)
207
+ cat_x = self.relu(torch.cat([up_out, self.bn(s)], dim=1))
208
+ return self.conv(cat_x)
basicsr/archs/ddcolor_arch_utils/util.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from skimage import color
4
+
5
+
6
+ def rgb2lab(img_rgb):
7
+ img_lab = color.rgb2lab(img_rgb)
8
+ return img_lab[:, :, :1], img_lab[:, :, 1:]
9
+
10
+
11
+ def tensor_lab2rgb(labs, illuminant="D65", observer="2"):
12
+ """
13
+ Args:
14
+ lab : (B, C, H, W)
15
+ Returns:
16
+ tuple : (B, C, H, W)
17
+ """
18
+ illuminants = \
19
+ {"A": {'2': (1.098466069456375, 1, 0.3558228003436005),
20
+ '10': (1.111420406956693, 1, 0.3519978321919493)},
21
+ "D50": {'2': (0.9642119944211994, 1, 0.8251882845188288),
22
+ '10': (0.9672062750333777, 1, 0.8142801513128616)},
23
+ "D55": {'2': (0.956797052643698, 1, 0.9214805860173273),
24
+ '10': (0.9579665682254781, 1, 0.9092525159847462)},
25
+ "D65": {'2': (0.95047, 1., 1.08883), # This was: `lab_ref_white`
26
+ '10': (0.94809667673716, 1, 1.0730513595166162)},
27
+ "D75": {'2': (0.9497220898840717, 1, 1.226393520724154),
28
+ '10': (0.9441713925645873, 1, 1.2064272211720228)},
29
+ "E": {'2': (1.0, 1.0, 1.0),
30
+ '10': (1.0, 1.0, 1.0)}}
31
+ xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169],
32
+ [0.019334, 0.119193, 0.950227]])
33
+
34
+ rgb_from_xyz = np.array([[3.240481340, -0.96925495, 0.055646640], [-1.53715152, 1.875990000, -0.20404134],
35
+ [-0.49853633, 0.041555930, 1.057311070]])
36
+ B, C, H, W = labs.shape
37
+ arrs = labs.permute((0, 2, 3, 1)).contiguous() # (B, 3, H, W) -> (B, H, W, 3)
38
+ L, a, b = arrs[:, :, :, 0:1], arrs[:, :, :, 1:2], arrs[:, :, :, 2:]
39
+ y = (L + 16.) / 116.
40
+ x = (a / 500.) + y
41
+ z = y - (b / 200.)
42
+ invalid = z.data < 0
43
+ z[invalid] = 0
44
+ xyz = torch.cat([x, y, z], dim=3)
45
+ mask = xyz.data > 0.2068966
46
+ mask_xyz = xyz.clone()
47
+ mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
48
+ mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.) / 7.787
49
+ xyz_ref_white = illuminants[illuminant][observer]
50
+ for i in range(C):
51
+ mask_xyz[:, :, :, i] = mask_xyz[:, :, :, i] * xyz_ref_white[i]
52
+
53
+ rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(B, H, W, C)
54
+ rgb = rgb_trans.permute((0, 3, 1, 2)).contiguous()
55
+ mask = rgb.data > 0.0031308
56
+ mask_rgb = rgb.clone()
57
+ mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
58
+ mask_rgb[~mask] = rgb[~mask] * 12.92
59
+ neg_mask = mask_rgb.data < 0
60
+ large_mask = mask_rgb.data > 1
61
+ mask_rgb[neg_mask] = 0
62
+ mask_rgb[large_mask] = 1
63
+ return mask_rgb
basicsr/archs/discriminator_arch.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import numpy as np
5
+
6
+ from basicsr.archs.ddcolor_arch_utils.unet import _conv
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+
10
+ @ARCH_REGISTRY.register()
11
+ class DynamicUNetDiscriminator(nn.Module):
12
+
13
+ def __init__(self, n_channels: int = 3, nf: int = 256, n_blocks: int = 3):
14
+ super().__init__()
15
+ layers = [_conv(n_channels, nf, ks=4, stride=2)]
16
+ for i in range(n_blocks):
17
+ layers += [
18
+ _conv(nf, nf, ks=3, stride=1),
19
+ _conv(nf, nf * 2, ks=4, stride=2, self_attention=(i == 0)),
20
+ ]
21
+ nf *= 2
22
+ layers += [_conv(nf, nf, ks=3, stride=1), _conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False)]
23
+ self.layers = nn.Sequential(*layers)
24
+
25
+ def forward(self, x):
26
+ out = self.layers(x)
27
+ out = out.view(out.size(0), -1)
28
+ return out
basicsr/archs/vgg_arch.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from collections import OrderedDict
4
+ from torch import nn as nn
5
+ from torchvision.models import vgg as vgg
6
+
7
+ from basicsr.utils.registry import ARCH_REGISTRY
8
+
9
+ VGG_PRETRAIN_PATH = {
10
+ 'vgg19': './pretrain/vgg19-dcbb9e9d.pth',
11
+ 'vgg16_bn': './pretrain/vgg16_bn-6c64b313.pth'
12
+ }
13
+
14
+ NAMES = {
15
+ 'vgg11': [
16
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
17
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
18
+ 'pool5'
19
+ ],
20
+ 'vgg13': [
21
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
22
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
23
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
24
+ ],
25
+ 'vgg16': [
26
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
27
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
28
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
29
+ 'pool5'
30
+ ],
31
+ 'vgg19': [
32
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
33
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
34
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
35
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
36
+ ]
37
+ }
38
+
39
+
40
+ def insert_bn(names):
41
+ """Insert bn layer after each conv.
42
+
43
+ Args:
44
+ names (list): The list of layer names.
45
+
46
+ Returns:
47
+ list: The list of layer names with bn layers.
48
+ """
49
+ names_bn = []
50
+ for name in names:
51
+ names_bn.append(name)
52
+ if 'conv' in name:
53
+ position = name.replace('conv', '')
54
+ names_bn.append('bn' + position)
55
+ return names_bn
56
+
57
+
58
+ @ARCH_REGISTRY.register()
59
+ class VGGFeatureExtractor(nn.Module):
60
+ """VGG network for feature extraction.
61
+
62
+ In this implementation, we allow users to choose whether use normalization
63
+ in the input feature and the type of vgg network. Note that the pretrained
64
+ path must fit the vgg type.
65
+
66
+ Args:
67
+ layer_name_list (list[str]): Forward function returns the corresponding
68
+ features according to the layer_name_list.
69
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
70
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
71
+ use_input_norm (bool): If True, normalize the input image. Importantly,
72
+ the input feature must in the range [0, 1]. Default: True.
73
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
74
+ Default: False.
75
+ requires_grad (bool): If true, the parameters of VGG network will be
76
+ optimized. Default: False.
77
+ remove_pooling (bool): If true, the max pooling operations in VGG net
78
+ will be removed. Default: False.
79
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
80
+ """
81
+
82
+ def __init__(self,
83
+ layer_name_list,
84
+ vgg_type='vgg19',
85
+ use_input_norm=True,
86
+ range_norm=False,
87
+ requires_grad=False,
88
+ remove_pooling=False,
89
+ pooling_stride=2):
90
+ super(VGGFeatureExtractor, self).__init__()
91
+
92
+ self.layer_name_list = layer_name_list
93
+ self.use_input_norm = use_input_norm
94
+ self.range_norm = range_norm
95
+
96
+ self.names = NAMES[vgg_type.replace('_bn', '')]
97
+ if 'bn' in vgg_type:
98
+ self.names = insert_bn(self.names)
99
+
100
+ # only borrow layers that will be used to avoid unused params
101
+ max_idx = 0
102
+ for v in layer_name_list:
103
+ idx = self.names.index(v)
104
+ if idx > max_idx:
105
+ max_idx = idx
106
+
107
+ if os.path.exists(VGG_PRETRAIN_PATH[vgg_type]):
108
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
109
+ state_dict = torch.load(VGG_PRETRAIN_PATH[vgg_type], map_location=lambda storage, loc: storage)
110
+ vgg_net.load_state_dict(state_dict)
111
+ else:
112
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
113
+
114
+ features = vgg_net.features[:max_idx + 1]
115
+
116
+ modified_net = OrderedDict()
117
+ for k, v in zip(self.names, features):
118
+ if 'pool' in k:
119
+ # if remove_pooling is true, pooling operation will be removed
120
+ if remove_pooling:
121
+ continue
122
+ else:
123
+ # in some cases, we may want to change the default stride
124
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
125
+ else:
126
+ modified_net[k] = v
127
+
128
+ self.vgg_net = nn.Sequential(modified_net)
129
+
130
+ if not requires_grad:
131
+ self.vgg_net.eval()
132
+ for param in self.parameters():
133
+ param.requires_grad = False
134
+ else:
135
+ self.vgg_net.train()
136
+ for param in self.parameters():
137
+ param.requires_grad = True
138
+
139
+ if self.use_input_norm:
140
+ # the mean is for image with range [0, 1]
141
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
142
+ # the std is for image with range [0, 1]
143
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
144
+
145
+ def forward(self, x):
146
+ """Forward function.
147
+
148
+ Args:
149
+ x (Tensor): Input tensor with shape (n, c, h, w).
150
+
151
+ Returns:
152
+ Tensor: Forward results.
153
+ """
154
+ if self.range_norm:
155
+ x = (x + 1) / 2
156
+ if self.use_input_norm:
157
+ x = (x - self.mean) / self.std
158
+
159
+ output = {}
160
+ for key, layer in self.vgg_net._modules.items():
161
+ x = layer(x)
162
+ if key in self.layer_name_list:
163
+ output[key] = x.clone()
164
+
165
+ return output
basicsr/data/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ from copy import deepcopy
7
+ from functools import partial
8
+ from os import path as osp
9
+
10
+ from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11
+ from basicsr.utils import get_root_logger, scandir
12
+ from basicsr.utils.dist_util import get_dist_info
13
+ from basicsr.utils.registry import DATASET_REGISTRY
14
+
15
+ __all__ = ['build_dataset', 'build_dataloader']
16
+
17
+ # automatically scan and import dataset modules for registry
18
+ # scan all the files under the data folder with '_dataset' in file names
19
+ data_folder = osp.dirname(osp.abspath(__file__))
20
+ dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21
+ # import all the dataset modules
22
+ _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23
+
24
+
25
+ def build_dataset(dataset_opt):
26
+ """Build dataset from options.
27
+
28
+ Args:
29
+ dataset_opt (dict): Configuration for dataset. It must contain:
30
+ name (str): Dataset name.
31
+ type (str): Dataset type.
32
+ """
33
+ dataset_opt = deepcopy(dataset_opt)
34
+ dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35
+ logger = get_root_logger()
36
+ logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
37
+ return dataset
38
+
39
+
40
+ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41
+ """Build dataloader.
42
+
43
+ Args:
44
+ dataset (torch.utils.data.Dataset): Dataset.
45
+ dataset_opt (dict): Dataset options. It contains the following keys:
46
+ phase (str): 'train' or 'val'.
47
+ num_worker_per_gpu (int): Number of workers for each GPU.
48
+ batch_size_per_gpu (int): Training batch size for each GPU.
49
+ num_gpu (int): Number of GPUs. Used only in the train phase.
50
+ Default: 1.
51
+ dist (bool): Whether in distributed training. Used only in the train
52
+ phase. Default: False.
53
+ sampler (torch.utils.data.sampler): Data sampler. Default: None.
54
+ seed (int | None): Seed. Default: None
55
+ """
56
+ phase = dataset_opt['phase']
57
+ rank, _ = get_dist_info()
58
+ if phase == 'train':
59
+ if dist: # distributed training
60
+ batch_size = dataset_opt['batch_size_per_gpu']
61
+ num_workers = dataset_opt['num_worker_per_gpu']
62
+ else: # non-distributed training
63
+ multiplier = 1 if num_gpu == 0 else num_gpu
64
+ batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65
+ num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66
+ dataloader_args = dict(
67
+ dataset=dataset,
68
+ batch_size=batch_size,
69
+ shuffle=False,
70
+ num_workers=num_workers,
71
+ sampler=sampler,
72
+ drop_last=True)
73
+ if sampler is None:
74
+ dataloader_args['shuffle'] = True
75
+ dataloader_args['worker_init_fn'] = partial(
76
+ worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77
+ elif phase in ['val', 'test']: # validation
78
+ dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79
+ else:
80
+ raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
81
+
82
+ dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83
+ dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84
+
85
+ prefetch_mode = dataset_opt.get('prefetch_mode')
86
+ if prefetch_mode == 'cpu': # CPUPrefetcher
87
+ num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88
+ logger = get_root_logger()
89
+ logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90
+ return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91
+ else:
92
+ # prefetch_mode=None: Normal dataloader
93
+ # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94
+ return torch.utils.data.DataLoader(**dataloader_args)
95
+
96
+
97
+ def worker_init_fn(worker_id, num_workers, rank, seed):
98
+ # Set the worker seed to num_workers * rank + worker_id + seed
99
+ worker_seed = num_workers * rank + worker_id + seed
100
+ np.random.seed(worker_seed)
101
+ random.seed(worker_seed)
basicsr/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (3.58 kB). View file
 
basicsr/data/__pycache__/data_sampler.cpython-39.pyc ADDED
Binary file (2.14 kB). View file
 
basicsr/data/__pycache__/fmix.cpython-39.pyc ADDED
Binary file (7.01 kB). View file
 
basicsr/data/__pycache__/lab_dataset.cpython-39.pyc ADDED
Binary file (4.77 kB). View file
 
basicsr/data/__pycache__/prefetch_dataloader.cpython-39.pyc ADDED
Binary file (4.38 kB). View file
 
basicsr/data/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (6.4 kB). View file
 
basicsr/data/data_sampler.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EnlargedSampler(Sampler):
7
+ """Sampler that restricts data loading to a subset of the dataset.
8
+
9
+ Modified from torch.utils.data.distributed.DistributedSampler
10
+ Support enlarging the dataset for iteration-based training, for saving
11
+ time when restart the dataloader after each epoch
12
+
13
+ Args:
14
+ dataset (torch.utils.data.Dataset): Dataset used for sampling.
15
+ num_replicas (int | None): Number of processes participating in
16
+ the training. It is usually the world_size.
17
+ rank (int | None): Rank of the current process within num_replicas.
18
+ ratio (int): Enlarging ratio. Default: 1.
19
+ """
20
+
21
+ def __init__(self, dataset, num_replicas, rank, ratio=1):
22
+ self.dataset = dataset
23
+ self.num_replicas = num_replicas
24
+ self.rank = rank
25
+ self.epoch = 0
26
+ self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27
+ self.total_size = self.num_samples * self.num_replicas
28
+
29
+ def __iter__(self):
30
+ # deterministically shuffle based on epoch
31
+ g = torch.Generator()
32
+ g.manual_seed(self.epoch)
33
+ indices = torch.randperm(self.total_size, generator=g).tolist()
34
+
35
+ dataset_size = len(self.dataset)
36
+ indices = [v % dataset_size for v in indices]
37
+
38
+ # subsample
39
+ indices = indices[self.rank:self.total_size:self.num_replicas]
40
+ assert len(indices) == self.num_samples
41
+
42
+ return iter(indices)
43
+
44
+ def __len__(self):
45
+ return self.num_samples
46
+
47
+ def set_epoch(self, epoch):
48
+ self.epoch = epoch
basicsr/data/data_util.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from os import path as osp
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.data.transforms import mod_crop
8
+ from basicsr.utils import img2tensor, scandir
9
+
10
+
11
+ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
12
+ """Read a sequence of images from a given folder path.
13
+
14
+ Args:
15
+ path (list[str] | str): List of image paths or image folder path.
16
+ require_mod_crop (bool): Require mod crop for each image.
17
+ Default: False.
18
+ scale (int): Scale factor for mod_crop. Default: 1.
19
+ return_imgname(bool): Whether return image names. Default False.
20
+
21
+ Returns:
22
+ Tensor: size (t, c, h, w), RGB, [0, 1].
23
+ list[str]: Returned image name list.
24
+ """
25
+ if isinstance(path, list):
26
+ img_paths = path
27
+ else:
28
+ img_paths = sorted(list(scandir(path, full_path=True)))
29
+ imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
30
+
31
+ if require_mod_crop:
32
+ imgs = [mod_crop(img, scale) for img in imgs]
33
+ imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
34
+ imgs = torch.stack(imgs, dim=0)
35
+
36
+ if return_imgname:
37
+ imgnames = [osp.splitext(osp.basename(path))[0] for path in img_paths]
38
+ return imgs, imgnames
39
+ else:
40
+ return imgs
41
+
42
+
43
+ def generate_frame_indices(crt_idx, max_frame_num, num_frames, padding='reflection'):
44
+ """Generate an index list for reading `num_frames` frames from a sequence
45
+ of images.
46
+
47
+ Args:
48
+ crt_idx (int): Current center index.
49
+ max_frame_num (int): Max number of the sequence of images (from 1).
50
+ num_frames (int): Reading num_frames frames.
51
+ padding (str): Padding mode, one of
52
+ 'replicate' | 'reflection' | 'reflection_circle' | 'circle'
53
+ Examples: current_idx = 0, num_frames = 5
54
+ The generated frame indices under different padding mode:
55
+ replicate: [0, 0, 0, 1, 2]
56
+ reflection: [2, 1, 0, 1, 2]
57
+ reflection_circle: [4, 3, 0, 1, 2]
58
+ circle: [3, 4, 0, 1, 2]
59
+
60
+ Returns:
61
+ list[int]: A list of indices.
62
+ """
63
+ assert num_frames % 2 == 1, 'num_frames should be an odd number.'
64
+ assert padding in ('replicate', 'reflection', 'reflection_circle', 'circle'), f'Wrong padding mode: {padding}.'
65
+
66
+ max_frame_num = max_frame_num - 1 # start from 0
67
+ num_pad = num_frames // 2
68
+
69
+ indices = []
70
+ for i in range(crt_idx - num_pad, crt_idx + num_pad + 1):
71
+ if i < 0:
72
+ if padding == 'replicate':
73
+ pad_idx = 0
74
+ elif padding == 'reflection':
75
+ pad_idx = -i
76
+ elif padding == 'reflection_circle':
77
+ pad_idx = crt_idx + num_pad - i
78
+ else:
79
+ pad_idx = num_frames + i
80
+ elif i > max_frame_num:
81
+ if padding == 'replicate':
82
+ pad_idx = max_frame_num
83
+ elif padding == 'reflection':
84
+ pad_idx = max_frame_num * 2 - i
85
+ elif padding == 'reflection_circle':
86
+ pad_idx = (crt_idx - num_pad) - (i - max_frame_num)
87
+ else:
88
+ pad_idx = i - num_frames
89
+ else:
90
+ pad_idx = i
91
+ indices.append(pad_idx)
92
+ return indices
93
+
94
+
95
+ def paired_paths_from_lmdb(folders, keys):
96
+ """Generate paired paths from lmdb files.
97
+
98
+ Contents of lmdb. Taking the `lq.lmdb` for example, the file structure is:
99
+
100
+ lq.lmdb
101
+ ├── data.mdb
102
+ ├── lock.mdb
103
+ ├── meta_info.txt
104
+
105
+ The data.mdb and lock.mdb are standard lmdb files and you can refer to
106
+ https://lmdb.readthedocs.io/en/release/ for more details.
107
+
108
+ The meta_info.txt is a specified txt file to record the meta information
109
+ of our datasets. It will be automatically created when preparing
110
+ datasets by our provided dataset tools.
111
+ Each line in the txt file records
112
+ 1)image name (with extension),
113
+ 2)image shape,
114
+ 3)compression level, separated by a white space.
115
+ Example: `baboon.png (120,125,3) 1`
116
+
117
+ We use the image name without extension as the lmdb key.
118
+ Note that we use the same key for the corresponding lq and gt images.
119
+
120
+ Args:
121
+ folders (list[str]): A list of folder path. The order of list should
122
+ be [input_folder, gt_folder].
123
+ keys (list[str]): A list of keys identifying folders. The order should
124
+ be in consistent with folders, e.g., ['lq', 'gt'].
125
+ Note that this key is different from lmdb keys.
126
+
127
+ Returns:
128
+ list[str]: Returned path list.
129
+ """
130
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
131
+ f'But got {len(folders)}')
132
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
133
+ input_folder, gt_folder = folders
134
+ input_key, gt_key = keys
135
+
136
+ if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
137
+ raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
138
+ f'formats. But received {input_key}: {input_folder}; '
139
+ f'{gt_key}: {gt_folder}')
140
+ # ensure that the two meta_info files are the same
141
+ with open(osp.join(input_folder, 'meta_info.txt')) as fin:
142
+ input_lmdb_keys = [line.split('.')[0] for line in fin]
143
+ with open(osp.join(gt_folder, 'meta_info.txt')) as fin:
144
+ gt_lmdb_keys = [line.split('.')[0] for line in fin]
145
+ if set(input_lmdb_keys) != set(gt_lmdb_keys):
146
+ raise ValueError(f'Keys in {input_key}_folder and {gt_key}_folder are different.')
147
+ else:
148
+ paths = []
149
+ for lmdb_key in sorted(input_lmdb_keys):
150
+ paths.append(dict([(f'{input_key}_path', lmdb_key), (f'{gt_key}_path', lmdb_key)]))
151
+ return paths
152
+
153
+
154
+ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmpl):
155
+ """Generate paired paths from an meta information file.
156
+
157
+ Each line in the meta information file contains the image names and
158
+ image shape (usually for gt), separated by a white space.
159
+
160
+ Example of an meta information file:
161
+ ```
162
+ 0001_s001.png (480,480,3)
163
+ 0001_s002.png (480,480,3)
164
+ ```
165
+
166
+ Args:
167
+ folders (list[str]): A list of folder path. The order of list should
168
+ be [input_folder, gt_folder].
169
+ keys (list[str]): A list of keys identifying folders. The order should
170
+ be in consistent with folders, e.g., ['lq', 'gt'].
171
+ meta_info_file (str): Path to the meta information file.
172
+ filename_tmpl (str): Template for each filename. Note that the
173
+ template excludes the file extension. Usually the filename_tmpl is
174
+ for files in the input folder.
175
+
176
+ Returns:
177
+ list[str]: Returned path list.
178
+ """
179
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
180
+ f'But got {len(folders)}')
181
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
182
+ input_folder, gt_folder = folders
183
+ input_key, gt_key = keys
184
+
185
+ with open(meta_info_file, 'r') as fin:
186
+ gt_names = [line.split(' ')[0] for line in fin]
187
+
188
+ paths = []
189
+ for gt_name in gt_names:
190
+ basename, ext = osp.splitext(osp.basename(gt_name))
191
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
192
+ input_path = osp.join(input_folder, input_name)
193
+ gt_path = osp.join(gt_folder, gt_name)
194
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
195
+ return paths
196
+
197
+
198
+ def paired_paths_from_folder(folders, keys, filename_tmpl):
199
+ """Generate paired paths from folders.
200
+
201
+ Args:
202
+ folders (list[str]): A list of folder path. The order of list should
203
+ be [input_folder, gt_folder].
204
+ keys (list[str]): A list of keys identifying folders. The order should
205
+ be in consistent with folders, e.g., ['lq', 'gt'].
206
+ filename_tmpl (str): Template for each filename. Note that the
207
+ template excludes the file extension. Usually the filename_tmpl is
208
+ for files in the input folder.
209
+
210
+ Returns:
211
+ list[str]: Returned path list.
212
+ """
213
+ assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
214
+ f'But got {len(folders)}')
215
+ assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
216
+ input_folder, gt_folder = folders
217
+ input_key, gt_key = keys
218
+
219
+ input_paths = list(scandir(input_folder))
220
+ gt_paths = list(scandir(gt_folder))
221
+ assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
222
+ f'{len(input_paths)}, {len(gt_paths)}.')
223
+ paths = []
224
+ for gt_path in gt_paths:
225
+ basename, ext = osp.splitext(osp.basename(gt_path))
226
+ input_name = f'{filename_tmpl.format(basename)}{ext}'
227
+ input_path = osp.join(input_folder, input_name)
228
+ assert input_name in input_paths, f'{input_name} is not in {input_key}_paths.'
229
+ gt_path = osp.join(gt_folder, gt_path)
230
+ paths.append(dict([(f'{input_key}_path', input_path), (f'{gt_key}_path', gt_path)]))
231
+ return paths
232
+
233
+
234
+ def paths_from_folder(folder):
235
+ """Generate paths from folder.
236
+
237
+ Args:
238
+ folder (str): Folder path.
239
+
240
+ Returns:
241
+ list[str]: Returned path list.
242
+ """
243
+
244
+ paths = list(scandir(folder))
245
+ paths = [osp.join(folder, path) for path in paths]
246
+ return paths
247
+
248
+
249
+ def paths_from_lmdb(folder):
250
+ """Generate paths from lmdb.
251
+
252
+ Args:
253
+ folder (str): Folder path.
254
+
255
+ Returns:
256
+ list[str]: Returned path list.
257
+ """
258
+ if not folder.endswith('.lmdb'):
259
+ raise ValueError(f'Folder {folder}folder should in lmdb format.')
260
+ with open(osp.join(folder, 'meta_info.txt')) as fin:
261
+ paths = [line.split('.')[0] for line in fin]
262
+ return paths
263
+
264
+
265
+ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
266
+ """Generate Gaussian kernel used in `duf_downsample`.
267
+
268
+ Args:
269
+ kernel_size (int): Kernel size. Default: 13.
270
+ sigma (float): Sigma of the Gaussian kernel. Default: 1.6.
271
+
272
+ Returns:
273
+ np.array: The Gaussian kernel.
274
+ """
275
+ from scipy.ndimage import filters as filters
276
+ kernel = np.zeros((kernel_size, kernel_size))
277
+ # set element at the middle to one, a dirac delta
278
+ kernel[kernel_size // 2, kernel_size // 2] = 1
279
+ # gaussian-smooth the dirac, resulting in a gaussian filter
280
+ return filters.gaussian_filter(kernel, sigma)
281
+
282
+
283
+ def duf_downsample(x, kernel_size=13, scale=4):
284
+ """Downsamping with Gaussian kernel used in the DUF official code.
285
+
286
+ Args:
287
+ x (Tensor): Frames to be downsampled, with shape (b, t, c, h, w).
288
+ kernel_size (int): Kernel size. Default: 13.
289
+ scale (int): Downsampling factor. Supported scale: (2, 3, 4).
290
+ Default: 4.
291
+
292
+ Returns:
293
+ Tensor: DUF downsampled frames.
294
+ """
295
+ assert scale in (2, 3, 4), f'Only support scale (2, 3, 4), but got {scale}.'
296
+
297
+ squeeze_flag = False
298
+ if x.ndim == 4:
299
+ squeeze_flag = True
300
+ x = x.unsqueeze(0)
301
+ b, t, c, h, w = x.size()
302
+ x = x.view(-1, 1, h, w)
303
+ pad_w, pad_h = kernel_size // 2 + scale * 2, kernel_size // 2 + scale * 2
304
+ x = F.pad(x, (pad_w, pad_w, pad_h, pad_h), 'reflect')
305
+
306
+ gaussian_filter = generate_gaussian_kernel(kernel_size, 0.4 * scale)
307
+ gaussian_filter = torch.from_numpy(gaussian_filter).type_as(x).unsqueeze(0).unsqueeze(0)
308
+ x = F.conv2d(x, gaussian_filter, stride=scale)
309
+ x = x[:, :, 2:-2, 2:-2]
310
+ x = x.view(b, t, c, x.size(2), x.size(3))
311
+ if squeeze_flag:
312
+ x = x.squeeze(0)
313
+ return x
basicsr/data/fmix.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Fmix paper from arxiv: https://arxiv.org/abs/2002.12047
3
+ Fmix code from github : https://github.com/ecs-vlc/FMix
4
+ '''
5
+ import math
6
+ import random
7
+ import numpy as np
8
+ from scipy.stats import beta
9
+
10
+
11
+ def fftfreqnd(h, w=None, z=None):
12
+ """ Get bin values for discrete fourier transform of size (h, w, z)
13
+ :param h: Required, first dimension size
14
+ :param w: Optional, second dimension size
15
+ :param z: Optional, third dimension size
16
+ """
17
+ fz = fx = 0
18
+ fy = np.fft.fftfreq(h)
19
+
20
+ if w is not None:
21
+ fy = np.expand_dims(fy, -1)
22
+
23
+ if w % 2 == 1:
24
+ fx = np.fft.fftfreq(w)[: w // 2 + 2]
25
+ else:
26
+ fx = np.fft.fftfreq(w)[: w // 2 + 1]
27
+
28
+ if z is not None:
29
+ fy = np.expand_dims(fy, -1)
30
+ if z % 2 == 1:
31
+ fz = np.fft.fftfreq(z)[:, None]
32
+ else:
33
+ fz = np.fft.fftfreq(z)[:, None]
34
+
35
+ return np.sqrt(fx * fx + fy * fy + fz * fz)
36
+
37
+
38
+ def get_spectrum(freqs, decay_power, ch, h, w=0, z=0):
39
+ """ Samples a fourier image with given size and frequencies decayed by decay power
40
+ :param freqs: Bin values for the discrete fourier transform
41
+ :param decay_power: Decay power for frequency decay prop 1/f**d
42
+ :param ch: Number of channels for the resulting mask
43
+ :param h: Required, first dimension size
44
+ :param w: Optional, second dimension size
45
+ :param z: Optional, third dimension size
46
+ """
47
+ scale = np.ones(1) / (np.maximum(freqs, np.array([1. / max(w, h, z)])) ** decay_power)
48
+
49
+ param_size = [ch] + list(freqs.shape) + [2]
50
+ param = np.random.randn(*param_size)
51
+
52
+ scale = np.expand_dims(scale, -1)[None, :]
53
+
54
+ return scale * param
55
+
56
+
57
+ def make_low_freq_image(decay, shape, ch=1):
58
+ """ Sample a low frequency image from fourier space
59
+ :param decay_power: Decay power for frequency decay prop 1/f**d
60
+ :param shape: Shape of desired mask, list up to 3 dims
61
+ :param ch: Number of channels for desired mask
62
+ """
63
+ freqs = fftfreqnd(*shape)
64
+ spectrum = get_spectrum(freqs, decay, ch, *shape)#.reshape((1, *shape[:-1], -1))
65
+ spectrum = spectrum[:, 0] + 1j * spectrum[:, 1]
66
+ mask = np.real(np.fft.irfftn(spectrum, shape))
67
+
68
+ if len(shape) == 1:
69
+ mask = mask[:1, :shape[0]]
70
+ if len(shape) == 2:
71
+ mask = mask[:1, :shape[0], :shape[1]]
72
+ if len(shape) == 3:
73
+ mask = mask[:1, :shape[0], :shape[1], :shape[2]]
74
+
75
+ mask = mask
76
+ mask = (mask - mask.min())
77
+ mask = mask / mask.max()
78
+ return mask
79
+
80
+
81
+ def sample_lam(alpha, reformulate=False):
82
+ """ Sample a lambda from symmetric beta distribution with given alpha
83
+ :param alpha: Alpha value for beta distribution
84
+ :param reformulate: If True, uses the reformulation of [1].
85
+ """
86
+ if reformulate:
87
+ lam = beta.rvs(alpha+1, alpha) # rvs(arg1,arg2,loc=期望, scale=标准差, size=生成随机数的个数) 从分布中生成指定个数的随机数
88
+ else:
89
+ lam = beta.rvs(alpha, alpha) # rvs(arg1,arg2,loc=期望, scale=标准差, size=生成随机数的个数) 从分布中生成指定个数的随机数
90
+
91
+ return lam
92
+
93
+
94
+ def binarise_mask(mask, lam, in_shape, max_soft=0.0):
95
+ """ Binarises a given low frequency image such that it has mean lambda.
96
+ :param mask: Low frequency image, usually the result of `make_low_freq_image`
97
+ :param lam: Mean value of final mask
98
+ :param in_shape: Shape of inputs
99
+ :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
100
+ :return:
101
+ """
102
+ idx = mask.reshape(-1).argsort()[::-1]
103
+ mask = mask.reshape(-1)
104
+ num = math.ceil(lam * mask.size) if random.random() > 0.5 else math.floor(lam * mask.size)
105
+
106
+ eff_soft = max_soft
107
+ if max_soft > lam or max_soft > (1-lam):
108
+ eff_soft = min(lam, 1-lam)
109
+
110
+ soft = int(mask.size * eff_soft)
111
+ num_low = num - soft
112
+ num_high = num + soft
113
+
114
+ mask[idx[:num_high]] = 1
115
+ mask[idx[num_low:]] = 0
116
+ mask[idx[num_low:num_high]] = np.linspace(1, 0, (num_high - num_low))
117
+
118
+ mask = mask.reshape((1, *in_shape))
119
+ return mask
120
+
121
+
122
+ def sample_mask(alpha, decay_power, shape, max_soft=0.0, reformulate=False):
123
+ """ Samples a mean lambda from beta distribution parametrised by alpha, creates a low frequency image and binarises
124
+ it based on this lambda
125
+ :param alpha: Alpha value for beta distribution from which to sample mean of mask
126
+ :param decay_power: Decay power for frequency decay prop 1/f**d
127
+ :param shape: Shape of desired mask, list up to 3 dims
128
+ :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
129
+ :param reformulate: If True, uses the reformulation of [1].
130
+ """
131
+ if isinstance(shape, int):
132
+ shape = (shape,)
133
+
134
+ # Choose lambda
135
+ lam = sample_lam(alpha, reformulate)
136
+
137
+ # Make mask, get mean / std
138
+ mask = make_low_freq_image(decay_power, shape)
139
+ mask = binarise_mask(mask, lam, shape, max_soft)
140
+
141
+ return lam, mask
142
+
143
+
144
+ def sample_and_apply(x, alpha, decay_power, shape, max_soft=0.0, reformulate=False):
145
+ """
146
+ :param x: Image batch on which to apply fmix of shape [b, c, shape*]
147
+ :param alpha: Alpha value for beta distribution from which to sample mean of mask
148
+ :param decay_power: Decay power for frequency decay prop 1/f**d
149
+ :param shape: Shape of desired mask, list up to 3 dims
150
+ :param max_soft: Softening value between 0 and 0.5 which smooths hard edges in the mask.
151
+ :param reformulate: If True, uses the reformulation of [1].
152
+ :return: mixed input, permutation indices, lambda value of mix,
153
+ """
154
+ lam, mask = sample_mask(alpha, decay_power, shape, max_soft, reformulate)
155
+ index = np.random.permutation(x.shape[0])
156
+
157
+ x1, x2 = x * mask, x[index] * (1-mask)
158
+ return x1+x2, index, lam
159
+
160
+
161
+ class FMixBase:
162
+ """ FMix augmentation
163
+ Args:
164
+ decay_power (float): Decay power for frequency decay prop 1/f**d
165
+ alpha (float): Alpha value for beta distribution from which to sample mean of mask
166
+ size ([int] | [int, int] | [int, int, int]): Shape of desired mask, list up to 3 dims
167
+ max_soft (float): Softening value between 0 and 0.5 which smooths hard edges in the mask.
168
+ reformulate (bool): If True, uses the reformulation of [1].
169
+ """
170
+
171
+ def __init__(self, decay_power=3, alpha=1, size=(32, 32), max_soft=0.0, reformulate=False):
172
+ super().__init__()
173
+ self.decay_power = decay_power
174
+ self.reformulate = reformulate
175
+ self.size = size
176
+ self.alpha = alpha
177
+ self.max_soft = max_soft
178
+ self.index = None
179
+ self.lam = None
180
+
181
+ def __call__(self, x):
182
+ raise NotImplementedError
183
+
184
+ def loss(self, *args, **kwargs):
185
+ raise NotImplementedError
186
+
187
+
188
+ if __name__ == '__main__':
189
+ # para = {'alpha':1.,'decay_power':3.,'shape':(10,10),'max_soft':0.0,'reformulate':False}
190
+ # lam, mask = sample_mask(**para)
191
+ # mask = mask.transpose(1, 2, 0)
192
+ # img1 = np.zeros((10, 10, 3))
193
+ # img2 = np.ones((10, 10, 3))
194
+ # img_gt = mask * img1 + (1. - mask) * img2
195
+ # import ipdb; ipdb.set_trace()
196
+
197
+ # test
198
+ import cv2
199
+ i1 = cv2.imread('output/ILSVRC2012_val_00000001.JPEG')
200
+ i2 = cv2.imread('output/ILSVRC2012_val_00000002.JPEG')
201
+ para = {'alpha':1.,'decay_power':3.,'shape':(256, 256),'max_soft':0.0,'reformulate':False}
202
+ lam, mask = sample_mask(**para)
203
+ mask = mask.transpose(1, 2, 0)
204
+ i = mask * i1 + (1. - mask) * i2
205
+ #i = i.astype(np.uint8)
206
+ cv2.imwrite('fmix.jpg', i)
basicsr/data/lab_dataset.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import random
3
+ import time
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils import data as data
7
+
8
+ from basicsr.data.transforms import rgb2lab
9
+ from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
10
+ from basicsr.utils.registry import DATASET_REGISTRY
11
+ from basicsr.data.fmix import sample_mask
12
+
13
+
14
+ @DATASET_REGISTRY.register()
15
+ class LabDataset(data.Dataset):
16
+ """
17
+ Dataset used for Lab colorizaion
18
+ """
19
+
20
+ def __init__(self, opt):
21
+ super(LabDataset, self).__init__()
22
+ self.opt = opt
23
+ # file client (io backend)
24
+ self.file_client = None
25
+ self.io_backend_opt = opt['io_backend']
26
+ self.gt_folder = opt['dataroot_gt']
27
+
28
+ meta_info_file = self.opt['meta_info_file']
29
+ assert meta_info_file is not None
30
+ if not isinstance(meta_info_file, list):
31
+ meta_info_file = [meta_info_file]
32
+ self.paths = []
33
+ for meta_info in meta_info_file:
34
+ with open(meta_info, 'r') as fin:
35
+ self.paths.extend([line.strip() for line in fin])
36
+
37
+ self.min_ab, self.max_ab = -128, 128
38
+ self.interval_ab = 4
39
+ self.ab_palette = [i for i in range(self.min_ab, self.max_ab + self.interval_ab, self.interval_ab)]
40
+ # print(self.ab_palette)
41
+
42
+ self.do_fmix = opt['do_fmix']
43
+ self.fmix_params = {'alpha':1.,'decay_power':3.,'shape':(256,256),'max_soft':0.0,'reformulate':False}
44
+ self.fmix_p = opt['fmix_p']
45
+ self.do_cutmix = opt['do_cutmix']
46
+ self.cutmix_params = {'alpha':1.}
47
+ self.cutmix_p = opt['cutmix_p']
48
+
49
+
50
+ def __getitem__(self, index):
51
+ if self.file_client is None:
52
+ self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
53
+
54
+ # -------------------------------- Load gt images -------------------------------- #
55
+ # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
56
+ gt_path = self.paths[index]
57
+ gt_size = self.opt['gt_size']
58
+ # avoid errors caused by high latency in reading files
59
+ retry = 3
60
+ while retry > 0:
61
+ try:
62
+ img_bytes = self.file_client.get(gt_path, 'gt')
63
+ except Exception as e:
64
+ logger = get_root_logger()
65
+ logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
66
+ # change another file to read
67
+ index = random.randint(0, self.__len__())
68
+ gt_path = self.paths[index]
69
+ time.sleep(1) # sleep 1s for occasional server congestion
70
+ else:
71
+ break
72
+ finally:
73
+ retry -= 1
74
+ img_gt = imfrombytes(img_bytes, float32=True)
75
+ img_gt = cv2.resize(img_gt, (gt_size, gt_size)) # TODO: 直接resize是否是最佳方案?
76
+
77
+ # -------------------------------- (Optional) CutMix & FMix -------------------------------- #
78
+ if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > self.fmix_p:
79
+ with torch.no_grad():
80
+ lam, mask = sample_mask(**self.fmix_params)
81
+
82
+ fmix_index = random.randint(0, self.__len__())
83
+ fmix_img_path = self.paths[fmix_index]
84
+ fmix_img_bytes = self.file_client.get(fmix_img_path, 'gt')
85
+ fmix_img = imfrombytes(fmix_img_bytes, float32=True)
86
+ fmix_img = cv2.resize(fmix_img, (gt_size, gt_size))
87
+
88
+ mask = mask.transpose(1, 2, 0) # (1, 256, 256) -> # (256, 256, 1)
89
+ img_gt = mask * img_gt + (1. - mask) * fmix_img
90
+ img_gt = img_gt.astype(np.float32)
91
+
92
+ if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > self.cutmix_p:
93
+ with torch.no_grad():
94
+ cmix_index = random.randint(0, self.__len__())
95
+ cmix_img_path = self.paths[cmix_index]
96
+ cmix_img_bytes = self.file_client.get(cmix_img_path, 'gt')
97
+ cmix_img = imfrombytes(cmix_img_bytes, float32=True)
98
+ cmix_img = cv2.resize(cmix_img, (gt_size, gt_size))
99
+
100
+ lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']), 0.3, 0.4)
101
+ bbx1, bby1, bbx2, bby2 = rand_bbox(cmix_img.shape[:2], lam)
102
+
103
+ img_gt[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]
104
+
105
+
106
+ # ----------------------------- Get gray lq, to tentor ----------------------------- #
107
+ # convert to gray
108
+ img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
109
+ img_l, img_ab = rgb2lab(img_gt)
110
+
111
+ target_a, target_b = self.ab2int(img_ab)
112
+
113
+ # numpy to tensor
114
+ img_l, img_ab = img2tensor([img_l, img_ab], bgr2rgb=False, float32=True)
115
+ target_a, target_b = torch.LongTensor(target_a), torch.LongTensor(target_b)
116
+ return_d = {
117
+ 'lq': img_l,
118
+ 'gt': img_ab,
119
+ 'target_a': target_a,
120
+ 'target_b': target_b,
121
+ 'lq_path': gt_path,
122
+ 'gt_path': gt_path
123
+ }
124
+ return return_d
125
+
126
+ def ab2int(self, img_ab):
127
+ img_a, img_b = img_ab[:, :, 0], img_ab[:, :, 1]
128
+ int_a = (img_a - self.min_ab) / self.interval_ab
129
+ int_b = (img_b - self.min_ab) / self.interval_ab
130
+
131
+ return np.round(int_a), np.round(int_b)
132
+
133
+ def __len__(self):
134
+ return len(self.paths)
135
+
136
+
137
+ def rand_bbox(size, lam):
138
+ '''cutmix 的 bbox 截取函数
139
+ Args:
140
+ size : tuple 图片尺寸 e.g (256,256)
141
+ lam : float 截取比例
142
+ Returns:
143
+ bbox 的左上角和右下角坐标
144
+ int,int,int,int
145
+ '''
146
+ W = size[0] # 截取图片的宽度
147
+ H = size[1] # 截取图片的高度
148
+ cut_rat = np.sqrt(1. - lam) # 需要截取的 bbox 比例
149
+ cut_w = np.int(W * cut_rat) # 需要截取的 bbox 宽度
150
+ cut_h = np.int(H * cut_rat) # 需要截取的 bbox 高度
151
+
152
+ cx = np.random.randint(W) # 均匀分布采样,随机选择截取的 bbox 的中心点 x 坐标
153
+ cy = np.random.randint(H) # 均匀分布采样,随机选择截取的 bbox 的中心点 y 坐标
154
+
155
+ bbx1 = np.clip(cx - cut_w // 2, 0, W) # 左上角 x 坐标
156
+ bby1 = np.clip(cy - cut_h // 2, 0, H) # 左上角 y 坐标
157
+ bbx2 = np.clip(cx + cut_w // 2, 0, W) # 右下角 x 坐标
158
+ bby2 = np.clip(cy + cut_h // 2, 0, H) # 右下角 y 坐标
159
+ return bbx1, bby1, bbx2, bby2
basicsr/data/prefetch_dataloader.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue as Queue
2
+ import threading
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ class PrefetchGenerator(threading.Thread):
8
+ """A general prefetch generator.
9
+
10
+ Ref:
11
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12
+
13
+ Args:
14
+ generator: Python generator.
15
+ num_prefetch_queue (int): Number of prefetch queue.
16
+ """
17
+
18
+ def __init__(self, generator, num_prefetch_queue):
19
+ threading.Thread.__init__(self)
20
+ self.queue = Queue.Queue(num_prefetch_queue)
21
+ self.generator = generator
22
+ self.daemon = True
23
+ self.start()
24
+
25
+ def run(self):
26
+ for item in self.generator:
27
+ self.queue.put(item)
28
+ self.queue.put(None)
29
+
30
+ def __next__(self):
31
+ next_item = self.queue.get()
32
+ if next_item is None:
33
+ raise StopIteration
34
+ return next_item
35
+
36
+ def __iter__(self):
37
+ return self
38
+
39
+
40
+ class PrefetchDataLoader(DataLoader):
41
+ """Prefetch version of dataloader.
42
+
43
+ Ref:
44
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45
+
46
+ TODO:
47
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48
+ ddp.
49
+
50
+ Args:
51
+ num_prefetch_queue (int): Number of prefetch queue.
52
+ kwargs (dict): Other arguments for dataloader.
53
+ """
54
+
55
+ def __init__(self, num_prefetch_queue, **kwargs):
56
+ self.num_prefetch_queue = num_prefetch_queue
57
+ super(PrefetchDataLoader, self).__init__(**kwargs)
58
+
59
+ def __iter__(self):
60
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61
+
62
+
63
+ class CPUPrefetcher():
64
+ """CPU prefetcher.
65
+
66
+ Args:
67
+ loader: Dataloader.
68
+ """
69
+
70
+ def __init__(self, loader):
71
+ self.ori_loader = loader
72
+ self.loader = iter(loader)
73
+
74
+ def next(self):
75
+ try:
76
+ return next(self.loader)
77
+ except StopIteration:
78
+ return None
79
+
80
+ def reset(self):
81
+ self.loader = iter(self.ori_loader)
82
+
83
+
84
+ class CUDAPrefetcher():
85
+ """CUDA prefetcher.
86
+
87
+ Ref:
88
+ https://github.com/NVIDIA/apex/issues/304#
89
+
90
+ It may consums more GPU memory.
91
+
92
+ Args:
93
+ loader: Dataloader.
94
+ opt (dict): Options.
95
+ """
96
+
97
+ def __init__(self, loader, opt):
98
+ self.ori_loader = loader
99
+ self.loader = iter(loader)
100
+ self.opt = opt
101
+ self.stream = torch.cuda.Stream()
102
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103
+ self.preload()
104
+
105
+ def preload(self):
106
+ try:
107
+ self.batch = next(self.loader) # self.batch is a dict
108
+ except StopIteration:
109
+ self.batch = None
110
+ return None
111
+ # put tensors to gpu
112
+ with torch.cuda.stream(self.stream):
113
+ for k, v in self.batch.items():
114
+ if torch.is_tensor(v):
115
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116
+
117
+ def next(self):
118
+ torch.cuda.current_stream().wait_stream(self.stream)
119
+ batch = self.batch
120
+ self.preload()
121
+ return batch
122
+
123
+ def reset(self):
124
+ self.loader = iter(self.ori_loader)
125
+ self.preload()
basicsr/data/transforms.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import random
5
+ from scipy import special
6
+ from skimage import color
7
+
8
+
9
+ def mod_crop(img, scale):
10
+ """Mod crop images, used during testing.
11
+
12
+ Args:
13
+ img (ndarray): Input image.
14
+ scale (int): Scale factor.
15
+
16
+ Returns:
17
+ ndarray: Result image.
18
+ """
19
+ img = img.copy()
20
+ if img.ndim in (2, 3):
21
+ h, w = img.shape[0], img.shape[1]
22
+ h_remainder, w_remainder = h % scale, w % scale
23
+ img = img[:h - h_remainder, :w - w_remainder, ...]
24
+ else:
25
+ raise ValueError(f'Wrong img ndim: {img.ndim}.')
26
+ return img
27
+
28
+
29
+ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
30
+ """Paired random crop. Support Numpy array and Tensor inputs.
31
+
32
+ It crops lists of lq and gt images with corresponding locations.
33
+
34
+ Args:
35
+ img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
36
+ should have the same shape. If the input is an ndarray, it will
37
+ be transformed to a list containing itself.
38
+ img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
39
+ should have the same shape. If the input is an ndarray, it will
40
+ be transformed to a list containing itself.
41
+ gt_patch_size (int): GT patch size.
42
+ scale (int): Scale factor.
43
+ gt_path (str): Path to ground-truth. Default: None.
44
+
45
+ Returns:
46
+ list[ndarray] | ndarray: GT images and LQ images. If returned results
47
+ only have one element, just return ndarray.
48
+ """
49
+
50
+ if not isinstance(img_gts, list):
51
+ img_gts = [img_gts]
52
+ if not isinstance(img_lqs, list):
53
+ img_lqs = [img_lqs]
54
+
55
+ # determine input type: Numpy array or Tensor
56
+ input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
57
+
58
+ if input_type == 'Tensor':
59
+ h_lq, w_lq = img_lqs[0].size()[-2:]
60
+ h_gt, w_gt = img_gts[0].size()[-2:]
61
+ else:
62
+ h_lq, w_lq = img_lqs[0].shape[0:2]
63
+ h_gt, w_gt = img_gts[0].shape[0:2]
64
+ lq_patch_size = gt_patch_size // scale
65
+
66
+ if h_gt != h_lq * scale or w_gt != w_lq * scale:
67
+ raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
68
+ f'multiplication of LQ ({h_lq}, {w_lq}).')
69
+ if h_lq < lq_patch_size or w_lq < lq_patch_size:
70
+ raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
71
+ f'({lq_patch_size}, {lq_patch_size}). '
72
+ f'Please remove {gt_path}.')
73
+
74
+ # randomly choose top and left coordinates for lq patch
75
+ top = random.randint(0, h_lq - lq_patch_size)
76
+ left = random.randint(0, w_lq - lq_patch_size)
77
+
78
+ # crop lq patch
79
+ if input_type == 'Tensor':
80
+ img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
81
+ else:
82
+ img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
83
+
84
+ # crop corresponding gt patch
85
+ top_gt, left_gt = int(top * scale), int(left * scale)
86
+ if input_type == 'Tensor':
87
+ img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
88
+ else:
89
+ img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
90
+ if len(img_gts) == 1:
91
+ img_gts = img_gts[0]
92
+ if len(img_lqs) == 1:
93
+ img_lqs = img_lqs[0]
94
+ return img_gts, img_lqs
95
+
96
+
97
+ def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
98
+ """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
99
+
100
+ We use vertical flip and transpose for rotation implementation.
101
+ All the images in the list use the same augmentation.
102
+
103
+ Args:
104
+ imgs (list[ndarray] | ndarray): Images to be augmented. If the input
105
+ is an ndarray, it will be transformed to a list.
106
+ hflip (bool): Horizontal flip. Default: True.
107
+ rotation (bool): Ratotation. Default: True.
108
+ flows (list[ndarray]: Flows to be augmented. If the input is an
109
+ ndarray, it will be transformed to a list.
110
+ Dimension is (h, w, 2). Default: None.
111
+ return_status (bool): Return the status of flip and rotation.
112
+ Default: False.
113
+
114
+ Returns:
115
+ list[ndarray] | ndarray: Augmented images and flows. If returned
116
+ results only have one element, just return ndarray.
117
+
118
+ """
119
+ hflip = hflip and random.random() < 0.5
120
+ vflip = rotation and random.random() < 0.5
121
+ rot90 = rotation and random.random() < 0.5
122
+
123
+ def _augment(img):
124
+ if hflip: # horizontal
125
+ cv2.flip(img, 1, img)
126
+ if vflip: # vertical
127
+ cv2.flip(img, 0, img)
128
+ if rot90:
129
+ img = img.transpose(1, 0, 2)
130
+ return img
131
+
132
+ def _augment_flow(flow):
133
+ if hflip: # horizontal
134
+ cv2.flip(flow, 1, flow)
135
+ flow[:, :, 0] *= -1
136
+ if vflip: # vertical
137
+ cv2.flip(flow, 0, flow)
138
+ flow[:, :, 1] *= -1
139
+ if rot90:
140
+ flow = flow.transpose(1, 0, 2)
141
+ flow = flow[:, :, [1, 0]]
142
+ return flow
143
+
144
+ if not isinstance(imgs, list):
145
+ imgs = [imgs]
146
+ imgs = [_augment(img) for img in imgs]
147
+ if len(imgs) == 1:
148
+ imgs = imgs[0]
149
+
150
+ if flows is not None:
151
+ if not isinstance(flows, list):
152
+ flows = [flows]
153
+ flows = [_augment_flow(flow) for flow in flows]
154
+ if len(flows) == 1:
155
+ flows = flows[0]
156
+ return imgs, flows
157
+ else:
158
+ if return_status:
159
+ return imgs, (hflip, vflip, rot90)
160
+ else:
161
+ return imgs
162
+
163
+
164
+ def img_rotate(img, angle, center=None, scale=1.0, borderMode=cv2.BORDER_CONSTANT, borderValue=0.):
165
+ """Rotate image.
166
+
167
+ Args:
168
+ img (ndarray): Image to be rotated.
169
+ angle (float): Rotation angle in degrees. Positive values mean
170
+ counter-clockwise rotation.
171
+ center (tuple[int]): Rotation center. If the center is None,
172
+ initialize it as the center of the image. Default: None.
173
+ scale (float): Isotropic scale factor. Default: 1.0.
174
+ """
175
+ (h, w) = img.shape[:2]
176
+
177
+ if center is None:
178
+ center = (w // 2, h // 2)
179
+
180
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
181
+ rotated_img = cv2.warpAffine(img, matrix, (w, h), borderMode=borderMode, borderValue=borderValue)
182
+ return rotated_img
183
+
184
+
185
+ def rgb2lab(img_rgb):
186
+ img_lab = color.rgb2lab(img_rgb)
187
+ img_l = img_lab[:, :, :1]
188
+ img_ab = img_lab[:, :, 1:]
189
+ return img_l, img_ab
190
+
191
+
192
+
basicsr/losses/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils import get_root_logger
4
+ from basicsr.utils.registry import LOSS_REGISTRY
5
+ from .losses import (CharbonnierLoss, GANLoss, L1Loss, MSELoss, PerceptualLoss, WeightedTVLoss, g_path_regularize,
6
+ gradient_penalty_loss, r1_penalty)
7
+
8
+ __all__ = [
9
+ 'L1Loss', 'MSELoss', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
10
+ 'r1_penalty', 'g_path_regularize'
11
+ ]
12
+
13
+
14
+ def build_loss(opt):
15
+ """Build loss from options.
16
+
17
+ Args:
18
+ opt (dict): Configuration. It must contain:
19
+ type (str): Model type.
20
+ """
21
+ opt = deepcopy(opt)
22
+ loss_type = opt.pop('type')
23
+ loss = LOSS_REGISTRY.get(loss_type)(**opt)
24
+ logger = get_root_logger()
25
+ logger.info(f'Loss [{loss.__class__.__name__}] is created.')
26
+ return loss
basicsr/losses/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (1.02 kB). View file
 
basicsr/losses/__pycache__/loss_util.cpython-39.pyc ADDED
Binary file (2.7 kB). View file
 
basicsr/losses/__pycache__/losses.cpython-39.pyc ADDED
Binary file (17.9 kB). View file
 
basicsr/losses/loss_util.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from torch.nn import functional as F
3
+
4
+
5
+ def reduce_loss(loss, reduction):
6
+ """Reduce loss as specified.
7
+
8
+ Args:
9
+ loss (Tensor): Elementwise loss tensor.
10
+ reduction (str): Options are 'none', 'mean' and 'sum'.
11
+
12
+ Returns:
13
+ Tensor: Reduced loss tensor.
14
+ """
15
+ reduction_enum = F._Reduction.get_enum(reduction)
16
+ # none: 0, elementwise_mean:1, sum: 2
17
+ if reduction_enum == 0:
18
+ return loss
19
+ elif reduction_enum == 1:
20
+ return loss.mean()
21
+ else:
22
+ return loss.sum()
23
+
24
+
25
+ def weight_reduce_loss(loss, weight=None, reduction='mean'):
26
+ """Apply element-wise weight and reduce loss.
27
+
28
+ Args:
29
+ loss (Tensor): Element-wise loss.
30
+ weight (Tensor): Element-wise weights. Default: None.
31
+ reduction (str): Same as built-in losses of PyTorch. Options are
32
+ 'none', 'mean' and 'sum'. Default: 'mean'.
33
+
34
+ Returns:
35
+ Tensor: Loss values.
36
+ """
37
+ # if weight is specified, apply element-wise weight
38
+ if weight is not None:
39
+ assert weight.dim() == loss.dim()
40
+ assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41
+ loss = loss * weight
42
+
43
+ # if weight is not specified or reduction is sum, just reduce the loss
44
+ if weight is None or reduction == 'sum':
45
+ loss = reduce_loss(loss, reduction)
46
+ # if reduction is mean, then compute mean over weight region
47
+ elif reduction == 'mean':
48
+ if weight.size(1) > 1:
49
+ weight = weight.sum()
50
+ else:
51
+ weight = weight.sum() * loss.size(1)
52
+ loss = loss.sum() / weight
53
+
54
+ return loss
55
+
56
+
57
+ def weighted_loss(loss_func):
58
+ """Create a weighted version of a given loss function.
59
+
60
+ To use this decorator, the loss function must have the signature like
61
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
62
+ element-wise loss without any reduction. This decorator will add weight
63
+ and reduction arguments to the function. The decorated function will have
64
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
65
+ **kwargs)`.
66
+
67
+ :Example:
68
+
69
+ >>> import torch
70
+ >>> @weighted_loss
71
+ >>> def l1_loss(pred, target):
72
+ >>> return (pred - target).abs()
73
+
74
+ >>> pred = torch.Tensor([0, 2, 3])
75
+ >>> target = torch.Tensor([1, 1, 1])
76
+ >>> weight = torch.Tensor([1, 0, 1])
77
+
78
+ >>> l1_loss(pred, target)
79
+ tensor(1.3333)
80
+ >>> l1_loss(pred, target, weight)
81
+ tensor(1.5000)
82
+ >>> l1_loss(pred, target, reduction='none')
83
+ tensor([1., 1., 2.])
84
+ >>> l1_loss(pred, target, weight, reduction='sum')
85
+ tensor(3.)
86
+ """
87
+
88
+ @functools.wraps(loss_func)
89
+ def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90
+ # get element-wise loss
91
+ loss = loss_func(pred, target, **kwargs)
92
+ loss = weight_reduce_loss(loss, weight, reduction)
93
+ return loss
94
+
95
+ return wrapper
basicsr/losses/losses.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import autograd as autograd
4
+ from torch import nn as nn
5
+ from torch.nn import functional as F
6
+
7
+ from basicsr.archs.vgg_arch import VGGFeatureExtractor
8
+ from basicsr.utils.registry import LOSS_REGISTRY
9
+ from .loss_util import weighted_loss
10
+
11
+ _reduction_modes = ['none', 'mean', 'sum']
12
+
13
+
14
+ @weighted_loss
15
+ def l1_loss(pred, target):
16
+ return F.l1_loss(pred, target, reduction='none')
17
+
18
+
19
+ @weighted_loss
20
+ def mse_loss(pred, target):
21
+ return F.mse_loss(pred, target, reduction='none')
22
+
23
+
24
+ @weighted_loss
25
+ def charbonnier_loss(pred, target, eps=1e-12):
26
+ return torch.sqrt((pred - target)**2 + eps)
27
+
28
+
29
+ @LOSS_REGISTRY.register()
30
+ class L1Loss(nn.Module):
31
+ """L1 (mean absolute error, MAE) loss.
32
+
33
+ Args:
34
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
35
+ reduction (str): Specifies the reduction to apply to the output.
36
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
37
+ """
38
+
39
+ def __init__(self, loss_weight=1.0, reduction='mean'):
40
+ super(L1Loss, self).__init__()
41
+ if reduction not in ['none', 'mean', 'sum']:
42
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
43
+
44
+ self.loss_weight = loss_weight
45
+ self.reduction = reduction
46
+
47
+ def forward(self, pred, target, weight=None, **kwargs):
48
+ """
49
+ Args:
50
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
51
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
52
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
53
+ weights. Default: None.
54
+ """
55
+ return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction)
56
+
57
+
58
+ @LOSS_REGISTRY.register()
59
+ class MSELoss(nn.Module):
60
+ """MSE (L2) loss.
61
+
62
+ Args:
63
+ loss_weight (float): Loss weight for MSE loss. Default: 1.0.
64
+ reduction (str): Specifies the reduction to apply to the output.
65
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
66
+ """
67
+
68
+ def __init__(self, loss_weight=1.0, reduction='mean'):
69
+ super(MSELoss, self).__init__()
70
+ if reduction not in ['none', 'mean', 'sum']:
71
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
72
+
73
+ self.loss_weight = loss_weight
74
+ self.reduction = reduction
75
+
76
+ def forward(self, pred, target, weight=None, **kwargs):
77
+ """
78
+ Args:
79
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
80
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
81
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
82
+ weights. Default: None.
83
+ """
84
+ return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction)
85
+
86
+
87
+ @LOSS_REGISTRY.register()
88
+ class CharbonnierLoss(nn.Module):
89
+ """Charbonnier loss (one variant of Robust L1Loss, a differentiable
90
+ variant of L1Loss).
91
+
92
+ Described in "Deep Laplacian Pyramid Networks for Fast and Accurate
93
+ Super-Resolution".
94
+
95
+ Args:
96
+ loss_weight (float): Loss weight for L1 loss. Default: 1.0.
97
+ reduction (str): Specifies the reduction to apply to the output.
98
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
99
+ eps (float): A value used to control the curvature near zero.
100
+ Default: 1e-12.
101
+ """
102
+
103
+ def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12):
104
+ super(CharbonnierLoss, self).__init__()
105
+ if reduction not in ['none', 'mean', 'sum']:
106
+ raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}')
107
+
108
+ self.loss_weight = loss_weight
109
+ self.reduction = reduction
110
+ self.eps = eps
111
+
112
+ def forward(self, pred, target, weight=None, **kwargs):
113
+ """
114
+ Args:
115
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
116
+ target (Tensor): of shape (N, C, H, W). Ground truth tensor.
117
+ weight (Tensor, optional): of shape (N, C, H, W). Element-wise
118
+ weights. Default: None.
119
+ """
120
+ return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction)
121
+
122
+
123
+ @LOSS_REGISTRY.register()
124
+ class WeightedTVLoss(L1Loss):
125
+ """Weighted TV loss.
126
+
127
+ Args:
128
+ loss_weight (float): Loss weight. Default: 1.0.
129
+ """
130
+
131
+ def __init__(self, loss_weight=1.0):
132
+ super(WeightedTVLoss, self).__init__(loss_weight=loss_weight)
133
+
134
+ def forward(self, pred, weight=None):
135
+ if weight is None:
136
+ y_weight = None
137
+ x_weight = None
138
+ else:
139
+ y_weight = weight[:, :, :-1, :]
140
+ x_weight = weight[:, :, :, :-1]
141
+
142
+ y_diff = super(WeightedTVLoss, self).forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight)
143
+ x_diff = super(WeightedTVLoss, self).forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight)
144
+
145
+ loss = x_diff + y_diff
146
+
147
+ return loss
148
+
149
+
150
+ @LOSS_REGISTRY.register()
151
+ class PerceptualLoss(nn.Module):
152
+ """Perceptual loss with commonly used style loss.
153
+
154
+ Args:
155
+ layer_weights (dict): The weight for each layer of vgg feature.
156
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
157
+ feature layer (before relu5_4) will be extracted with weight
158
+ 1.0 in calculating losses.
159
+ vgg_type (str): The type of vgg network used as feature extractor.
160
+ Default: 'vgg19'.
161
+ use_input_norm (bool): If True, normalize the input image in vgg.
162
+ Default: True.
163
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
164
+ Default: False.
165
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
166
+ loss will be calculated and the loss will multiplied by the
167
+ weight. Default: 1.0.
168
+ style_weight (float): If `style_weight > 0`, the style loss will be
169
+ calculated and the loss will multiplied by the weight.
170
+ Default: 0.
171
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
172
+ """
173
+
174
+ def __init__(self,
175
+ layer_weights,
176
+ vgg_type='vgg19',
177
+ use_input_norm=True,
178
+ range_norm=False,
179
+ perceptual_weight=1.0,
180
+ style_weight=0.,
181
+ criterion='l1'):
182
+ super(PerceptualLoss, self).__init__()
183
+ self.perceptual_weight = perceptual_weight
184
+ self.style_weight = style_weight
185
+ self.layer_weights = layer_weights
186
+ self.vgg = VGGFeatureExtractor(
187
+ layer_name_list=list(layer_weights.keys()),
188
+ vgg_type=vgg_type,
189
+ use_input_norm=use_input_norm,
190
+ range_norm=range_norm)
191
+
192
+ self.criterion_type = criterion
193
+ if self.criterion_type == 'l1':
194
+ self.criterion = torch.nn.L1Loss()
195
+ elif self.criterion_type == 'l2':
196
+ self.criterion = torch.nn.L2loss()
197
+ elif self.criterion_type == 'fro':
198
+ self.criterion = None
199
+ else:
200
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
201
+
202
+ def forward(self, x, gt):
203
+ """Forward function.
204
+
205
+ Args:
206
+ x (Tensor): Input tensor with shape (n, c, h, w).
207
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
208
+
209
+ Returns:
210
+ Tensor: Forward results.
211
+ """
212
+ # extract vgg features
213
+ x_features = self.vgg(x)
214
+ gt_features = self.vgg(gt.detach())
215
+
216
+ # calculate perceptual loss
217
+ if self.perceptual_weight > 0:
218
+ percep_loss = 0
219
+ for k in x_features.keys():
220
+ if self.criterion_type == 'fro':
221
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
222
+ else:
223
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
224
+ percep_loss *= self.perceptual_weight
225
+ else:
226
+ percep_loss = None
227
+
228
+ # calculate style loss
229
+ if self.style_weight > 0:
230
+ style_loss = 0
231
+ for k in x_features.keys():
232
+ if self.criterion_type == 'fro':
233
+ style_loss += torch.norm(
234
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
235
+ else:
236
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
237
+ gt_features[k])) * self.layer_weights[k]
238
+ style_loss *= self.style_weight
239
+ else:
240
+ style_loss = None
241
+
242
+ return percep_loss, style_loss
243
+
244
+ def _gram_mat(self, x):
245
+ """Calculate Gram matrix.
246
+
247
+ Args:
248
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
249
+
250
+ Returns:
251
+ torch.Tensor: Gram matrix.
252
+ """
253
+ n, c, h, w = x.size()
254
+ features = x.view(n, c, w * h)
255
+ features_t = features.transpose(1, 2)
256
+ gram = features.bmm(features_t) / (c * h * w)
257
+ return gram
258
+
259
+
260
+ @LOSS_REGISTRY.register()
261
+ class GANLoss(nn.Module):
262
+ """Define GAN loss.
263
+
264
+ Args:
265
+ gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
266
+ real_label_val (float): The value for real label. Default: 1.0.
267
+ fake_label_val (float): The value for fake label. Default: 0.0.
268
+ loss_weight (float): Loss weight. Default: 1.0.
269
+ Note that loss_weight is only for generators; and it is always 1.0
270
+ for discriminators.
271
+ """
272
+
273
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
274
+ super(GANLoss, self).__init__()
275
+ self.gan_type = gan_type
276
+ self.loss_weight = loss_weight
277
+ self.real_label_val = real_label_val
278
+ self.fake_label_val = fake_label_val
279
+
280
+ if self.gan_type == 'vanilla':
281
+ self.loss = nn.BCEWithLogitsLoss()
282
+ elif self.gan_type == 'lsgan':
283
+ self.loss = nn.MSELoss()
284
+ elif self.gan_type == 'wgan':
285
+ self.loss = self._wgan_loss
286
+ elif self.gan_type == 'wgan_softplus':
287
+ self.loss = self._wgan_softplus_loss
288
+ elif self.gan_type == 'hinge':
289
+ self.loss = nn.ReLU()
290
+ else:
291
+ raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
292
+
293
+ def _wgan_loss(self, input, target):
294
+ """wgan loss.
295
+
296
+ Args:
297
+ input (Tensor): Input tensor.
298
+ target (bool): Target label.
299
+
300
+ Returns:
301
+ Tensor: wgan loss.
302
+ """
303
+ return -input.mean() if target else input.mean()
304
+
305
+ def _wgan_softplus_loss(self, input, target):
306
+ """wgan loss with soft plus. softplus is a smooth approximation to the
307
+ ReLU function.
308
+
309
+ In StyleGAN2, it is called:
310
+ Logistic loss for discriminator;
311
+ Non-saturating loss for generator.
312
+
313
+ Args:
314
+ input (Tensor): Input tensor.
315
+ target (bool): Target label.
316
+
317
+ Returns:
318
+ Tensor: wgan loss.
319
+ """
320
+ return F.softplus(-input).mean() if target else F.softplus(input).mean()
321
+
322
+ def get_target_label(self, input, target_is_real):
323
+ """Get target label.
324
+
325
+ Args:
326
+ input (Tensor): Input tensor.
327
+ target_is_real (bool): Whether the target is real or fake.
328
+
329
+ Returns:
330
+ (bool | Tensor): Target tensor. Return bool for wgan, otherwise,
331
+ return Tensor.
332
+ """
333
+
334
+ if self.gan_type in ['wgan', 'wgan_softplus']:
335
+ return target_is_real
336
+ target_val = (self.real_label_val if target_is_real else self.fake_label_val)
337
+ return input.new_ones(input.size()) * target_val
338
+
339
+ def forward(self, input, target_is_real, is_disc=False):
340
+ """
341
+ Args:
342
+ input (Tensor): The input for the loss module, i.e., the network
343
+ prediction.
344
+ target_is_real (bool): Whether the targe is real or fake.
345
+ is_disc (bool): Whether the loss for discriminators or not.
346
+ Default: False.
347
+
348
+ Returns:
349
+ Tensor: GAN loss value.
350
+ """
351
+ target_label = self.get_target_label(input, target_is_real)
352
+ if self.gan_type == 'hinge':
353
+ if is_disc: # for discriminators in hinge-gan
354
+ input = -input if target_is_real else input
355
+ loss = self.loss(1 + input).mean()
356
+ else: # for generators in hinge-gan
357
+ loss = -input.mean()
358
+ else: # other gan types
359
+ loss = self.loss(input, target_label)
360
+
361
+ # loss_weight is always 1.0 for discriminators
362
+ return loss if is_disc else loss * self.loss_weight
363
+
364
+
365
+ @LOSS_REGISTRY.register()
366
+ class MultiScaleGANLoss(GANLoss):
367
+ """
368
+ MultiScaleGANLoss accepts a list of predictions
369
+ """
370
+
371
+ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
372
+ super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight)
373
+
374
+ def forward(self, input, target_is_real, is_disc=False):
375
+ """
376
+ The input is a list of tensors, or a list of (a list of tensors)
377
+ """
378
+ if isinstance(input, list):
379
+ loss = 0
380
+ for pred_i in input:
381
+ if isinstance(pred_i, list):
382
+ # Only compute GAN loss for the last layer
383
+ # in case of multiscale feature matching
384
+ pred_i = pred_i[-1]
385
+ # Safe operation: 0-dim tensor calling self.mean() does nothing
386
+ loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean()
387
+ loss += loss_tensor
388
+ return loss / len(input)
389
+ else:
390
+ return super().forward(input, target_is_real, is_disc)
391
+
392
+
393
+ def r1_penalty(real_pred, real_img):
394
+ """R1 regularization for discriminator. The core idea is to
395
+ penalize the gradient on real data alone: when the
396
+ generator distribution produces the true data distribution
397
+ and the discriminator is equal to 0 on the data manifold, the
398
+ gradient penalty ensures that the discriminator cannot create
399
+ a non-zero gradient orthogonal to the data manifold without
400
+ suffering a loss in the GAN game.
401
+
402
+ Ref:
403
+ Eq. 9 in Which training methods for GANs do actually converge.
404
+ """
405
+ grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
406
+ grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
407
+ return grad_penalty
408
+
409
+
410
+ def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01):
411
+ noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
412
+ grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0]
413
+ path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1))
414
+
415
+ path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length)
416
+
417
+ path_penalty = (path_lengths - path_mean).pow(2).mean()
418
+
419
+ return path_penalty, path_lengths.detach().mean(), path_mean.detach()
420
+
421
+
422
+ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
423
+ """Calculate gradient penalty for wgan-gp.
424
+
425
+ Args:
426
+ discriminator (nn.Module): Network for the discriminator.
427
+ real_data (Tensor): Real input data.
428
+ fake_data (Tensor): Fake input data.
429
+ weight (Tensor): Weight tensor. Default: None.
430
+
431
+ Returns:
432
+ Tensor: A tensor for gradient penalty.
433
+ """
434
+
435
+ batch_size = real_data.size(0)
436
+ alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
437
+
438
+ # interpolate between real_data and fake_data
439
+ interpolates = alpha * real_data + (1. - alpha) * fake_data
440
+ interpolates = autograd.Variable(interpolates, requires_grad=True)
441
+
442
+ disc_interpolates = discriminator(interpolates)
443
+ gradients = autograd.grad(
444
+ outputs=disc_interpolates,
445
+ inputs=interpolates,
446
+ grad_outputs=torch.ones_like(disc_interpolates),
447
+ create_graph=True,
448
+ retain_graph=True,
449
+ only_inputs=True)[0]
450
+
451
+ if weight is not None:
452
+ gradients = gradients * weight
453
+
454
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
455
+ if weight is not None:
456
+ gradients_penalty /= torch.mean(weight)
457
+
458
+ return gradients_penalty
459
+
460
+
461
+ @LOSS_REGISTRY.register()
462
+ class GANFeatLoss(nn.Module):
463
+ """Define feature matching loss for gans
464
+
465
+ Args:
466
+ criterion (str): Support 'l1', 'l2', 'charbonnier'.
467
+ loss_weight (float): Loss weight. Default: 1.0.
468
+ reduction (str): Specifies the reduction to apply to the output.
469
+ Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
470
+ """
471
+
472
+ def __init__(self, criterion='l1', loss_weight=1.0, reduction='mean'):
473
+ super(GANFeatLoss, self).__init__()
474
+ if criterion == 'l1':
475
+ self.loss_op = L1Loss(loss_weight, reduction)
476
+ elif criterion == 'l2':
477
+ self.loss_op = MSELoss(loss_weight, reduction)
478
+ elif criterion == 'charbonnier':
479
+ self.loss_op = CharbonnierLoss(loss_weight, reduction)
480
+ else:
481
+ raise ValueError(f'Unsupported loss mode: {criterion}. Supported ones are: l1|l2|charbonnier')
482
+
483
+ self.loss_weight = loss_weight
484
+
485
+ def forward(self, pred_fake, pred_real):
486
+ num_d = len(pred_fake)
487
+ loss = 0
488
+ for i in range(num_d): # for each discriminator
489
+ # last output is the final prediction, exclude it
490
+ num_intermediate_outputs = len(pred_fake[i]) - 1
491
+ for j in range(num_intermediate_outputs): # for each layer output
492
+ unweighted_loss = self.loss_op(pred_fake[i][j], pred_real[i][j].detach())
493
+ loss += unweighted_loss / num_d
494
+ return loss * self.loss_weight
495
+
496
+
497
+ class sobel_loss(nn.Module):
498
+ def __init__(self, weight=1.0):
499
+ super().__init__()
500
+ kernel_x = torch.Tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
501
+ kernel_y = torch.Tensor([[-1.0, -2.0, -1.0], [0.0, 0.0, 0.0], [1.0, 2.0, 1.0]])
502
+ kernel = torch.stack([kernel_x, kernel_y])
503
+ kernel.requires_grad = False
504
+ kernel = kernel.unsqueeze(1)
505
+ self.register_buffer('sobel_kernel', kernel)
506
+ self.weight = weight
507
+
508
+ def forward(self, input_tensor, target_tensor):
509
+ b, c, h, w = input_tensor.size()
510
+ input_tensor = input_tensor.view(b * c, 1, h, w)
511
+ input_edge = F.conv2d(input_tensor, self.sobel_kernel, padding=1)
512
+ input_edge = input_edge.view(b, 2*c, h, w)
513
+
514
+ target_tensor = target_tensor.view(-1, 1, h, w)
515
+ target_edge = F.conv2d(target_tensor, self.sobel_kernel, padding=1)
516
+ target_edge = target_edge.view(b, 2*c, h, w)
517
+
518
+ return self.weight * F.l1_loss(input_edge, target_edge)
519
+
520
+
521
+ @LOSS_REGISTRY.register()
522
+ class ColorfulnessLoss(nn.Module):
523
+ """Colorfulness loss.
524
+
525
+ Args:
526
+ loss_weight (float): Loss weight for Colorfulness loss. Default: 1.0.
527
+
528
+ """
529
+
530
+ def __init__(self, loss_weight=1.0):
531
+ super(ColorfulnessLoss, self).__init__()
532
+
533
+ self.loss_weight = loss_weight
534
+
535
+ def forward(self, pred, **kwargs):
536
+ """
537
+ Args:
538
+ pred (Tensor): of shape (N, C, H, W). Predicted tensor.
539
+ """
540
+ colorfulness_loss = 0
541
+ for i in range(pred.shape[0]):
542
+ (R, G, B) = pred[i][0], pred[i][1], pred[i][2]
543
+ rg = torch.abs(R - G)
544
+ yb = torch.abs(0.5 * (R+G) - B)
545
+ (rbMean, rbStd) = (torch.mean(rg), torch.std(rg))
546
+ (ybMean, ybStd) = (torch.mean(yb), torch.std(yb))
547
+ stdRoot = torch.sqrt((rbStd ** 2) + (ybStd ** 2))
548
+ meanRoot = torch.sqrt((rbMean ** 2) + (ybMean ** 2))
549
+ colorfulness = stdRoot + (0.3 * meanRoot)
550
+ colorfulness_loss += (1 - colorfulness)
551
+ return self.loss_weight * colorfulness_loss
basicsr/metrics/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ from basicsr.utils.registry import METRIC_REGISTRY
4
+ from .psnr_ssim import calculate_psnr, calculate_ssim
5
+ from .colorfulness import calculate_cf
6
+
7
+ __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_cf']
8
+
9
+
10
+ def calculate_metric(data, opt):
11
+ """Calculate metric from data and options.
12
+
13
+ Args:
14
+ opt (dict): Configuration. It must contain:
15
+ type (str): Model type.
16
+ """
17
+ opt = deepcopy(opt)
18
+ metric_type = opt.pop('type')
19
+ metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
20
+ return metric
basicsr/metrics/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (776 Bytes). View file