import json from collections import OrderedDict from math import exp from .Common import * # +++++++++++++++++++++++++++++++++++++ # FP16 Training # ------------------------------------- # Modified from Nvidia/Apex # https://github.com/NVIDIA/apex/blob/master/apex/fp16_utils/fp16util.py class tofp16(nn.Module): def __init__(self): super(tofp16, self).__init__() def forward(self, input): if input.is_cuda: return input.half() else: # PyTorch 1.0 doesn't support fp16 in CPU return input.float() def BN_convert_float(module): if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): module.float() for child in module.children(): BN_convert_float(child) return module def network_to_half(network): return nn.Sequential(tofp16(), BN_convert_float(network.half())) # warnings.simplefilter('ignore') # +++++++++++++++++++++++++++++++++++++ # DCSCN # ------------------------------------- class DCSCN(BaseModule): # https://github.com/jiny2001/dcscn-super-resolution def __init__( self, color_channel=3, up_scale=2, feature_layers=12, first_feature_filters=196, last_feature_filters=48, reconstruction_filters=128, up_sampler_filters=32, ): super(DCSCN, self).__init__() self.total_feature_channels = 0 self.total_reconstruct_filters = 0 self.upscale = up_scale self.act_fn = nn.SELU(inplace=False) self.feature_block = self.make_feature_extraction_block( color_channel, feature_layers, first_feature_filters, last_feature_filters ) self.reconstruction_block = self.make_reconstruction_block( reconstruction_filters ) self.up_sampler = self.make_upsampler(up_sampler_filters, color_channel) self.selu_init_params() def selu_init_params(self): for i in self.modules(): if isinstance(i, nn.Conv2d): i.weight.data.normal_(0.0, 1.0 / sqrt(i.weight.numel())) if i.bias is not None: i.bias.data.fill_(0) def conv_block(self, in_channel, out_channel, kernel_size): m = OrderedDict( [ # ("Padding", nn.ReplicationPad2d((kernel_size - 1) // 2)), ( "Conv2d", nn.Conv2d( in_channel, out_channel, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, ), ), ("Activation", self.act_fn), ] ) return nn.Sequential(m) def make_feature_extraction_block( self, color_channel, num_layers, first_filters, last_filters ): # input layer feature_block = [ ("Feature 1", self.conv_block(color_channel, first_filters, 3)) ] # exponential decay # rest layers alpha_rate = log(first_filters / last_filters) / (num_layers - 1) filter_nums = [ round(first_filters * exp(-alpha_rate * i)) for i in range(num_layers) ] self.total_feature_channels = sum(filter_nums) layer_filters = [ [filter_nums[i], filter_nums[i + 1], 3] for i in range(num_layers - 1) ] feature_block.extend( [ ("Feature {}".format(index + 2), self.conv_block(*x)) for index, x in enumerate(layer_filters) ] ) return nn.Sequential(OrderedDict(feature_block)) def make_reconstruction_block(self, num_filters): B1 = self.conv_block(self.total_feature_channels, num_filters // 2, 1) B2 = self.conv_block(num_filters // 2, num_filters, 3) m = OrderedDict( [ ("A", self.conv_block(self.total_feature_channels, num_filters, 1)), ("B", nn.Sequential(*[B1, B2])), ] ) self.total_reconstruct_filters = num_filters * 2 return nn.Sequential(m) def make_upsampler(self, out_channel, color_channel): out = out_channel * self.upscale**2 m = OrderedDict( [ ( "Conv2d_block", self.conv_block(self.total_reconstruct_filters, out, kernel_size=3), ), ("PixelShuffle", nn.PixelShuffle(self.upscale)), ( "Conv2d", nn.Conv2d( out_channel, color_channel, kernel_size=3, padding=1, bias=False ), ), ] ) return nn.Sequential(m) def forward(self, x): # residual learning lr, lr_up = x feature = [] for layer in self.feature_block.children(): lr = layer(lr) feature.append(lr) feature = torch.cat(feature, dim=1) reconstruction = [ layer(feature) for layer in self.reconstruction_block.children() ] reconstruction = torch.cat(reconstruction, dim=1) lr = self.up_sampler(reconstruction) return lr + lr_up # +++++++++++++++++++++++++++++++++++++ # CARN # ------------------------------------- class CARN_Block(BaseModule): def __init__( self, channels, kernel_size=3, padding=1, dilation=1, groups=1, activation=nn.SELU(), repeat=3, SEBlock=False, conv=nn.Conv2d, single_conv_size=1, single_conv_group=1, ): super(CARN_Block, self).__init__() m = [] for i in range(repeat): m.append( ResidualFixBlock( channels, channels, kernel_size=kernel_size, padding=padding, dilation=dilation, groups=groups, activation=activation, conv=conv, ) ) if SEBlock: m.append(SpatialChannelSqueezeExcitation(channels, reduction=channels)) self.blocks = nn.Sequential(*m) self.singles = nn.Sequential( *[ ConvBlock( channels * (i + 2), channels, kernel_size=single_conv_size, padding=(single_conv_size - 1) // 2, groups=single_conv_group, activation=activation, conv=conv, ) for i in range(repeat) ] ) def forward(self, x): c0 = x for block, single in zip(self.blocks, self.singles): b = block(x) c0 = c = torch.cat([c0, b], dim=1) x = single(c) return x class CARN(BaseModule): # Fast, Accurate, and Lightweight Super-Resolution with Cascading Residual Network # https://github.com/nmhkahn/CARN-pytorch def __init__( self, color_channels=3, mid_channels=64, scale=2, activation=nn.SELU(), num_blocks=3, conv=nn.Conv2d, ): super(CARN, self).__init__() self.color_channels = color_channels self.mid_channels = mid_channels self.scale = scale self.entry_block = ConvBlock( color_channels, mid_channels, kernel_size=3, padding=1, activation=activation, conv=conv, ) self.blocks = nn.Sequential( *[ CARN_Block( mid_channels, kernel_size=3, padding=1, activation=activation, conv=conv, single_conv_size=1, single_conv_group=1, ) for _ in range(num_blocks) ] ) self.singles = nn.Sequential( *[ ConvBlock( mid_channels * (i + 2), mid_channels, kernel_size=1, padding=0, activation=activation, conv=conv, ) for i in range(num_blocks) ] ) self.upsampler = UpSampleBlock( mid_channels, scale=scale, activation=activation, conv=conv ) self.exit_conv = conv(mid_channels, color_channels, kernel_size=3, padding=1) def forward(self, x): x = self.entry_block(x) c0 = x for block, single in zip(self.blocks, self.singles): b = block(x) c0 = c = torch.cat([c0, b], dim=1) x = single(c) x = self.upsampler(x) out = self.exit_conv(x) return out class CARN_V2(CARN): def __init__( self, color_channels=3, mid_channels=64, scale=2, activation=nn.LeakyReLU(0.1), SEBlock=True, conv=nn.Conv2d, atrous=(1, 1, 1), repeat_blocks=3, single_conv_size=3, single_conv_group=1, ): super(CARN_V2, self).__init__( color_channels=color_channels, mid_channels=mid_channels, scale=scale, activation=activation, conv=conv, ) num_blocks = len(atrous) m = [] for i in range(num_blocks): m.append( CARN_Block( mid_channels, kernel_size=3, padding=1, dilation=1, activation=activation, SEBlock=SEBlock, conv=conv, repeat=repeat_blocks, single_conv_size=single_conv_size, single_conv_group=single_conv_group, ) ) self.blocks = nn.Sequential(*m) self.singles = nn.Sequential( *[ ConvBlock( mid_channels * (i + 2), mid_channels, kernel_size=single_conv_size, padding=(single_conv_size - 1) // 2, groups=single_conv_group, activation=activation, conv=conv, ) for i in range(num_blocks) ] ) def forward(self, x): x = self.entry_block(x) c0 = x res = x for block, single in zip(self.blocks, self.singles): b = block(x) c0 = c = torch.cat([c0, b], dim=1) x = single(c) x = x + res x = self.upsampler(x) out = self.exit_conv(x) return out # +++++++++++++++++++++++++++++++++++++ # original Waifu2x model # ------------------------------------- class UpConv_7(BaseModule): # https://github.com/nagadomi/waifu2x/blob/3c46906cb78895dbd5a25c3705994a1b2e873199/lib/srcnn.lua#L311 def __init__(self): super(UpConv_7, self).__init__() self.act_fn = nn.LeakyReLU(0.1, inplace=False) self.offset = 7 # because of 0 padding from torch.nn import ZeroPad2d self.pad = ZeroPad2d(self.offset) m = [ nn.Conv2d(3, 16, 3, 1, 0), self.act_fn, nn.Conv2d(16, 32, 3, 1, 0), self.act_fn, nn.Conv2d(32, 64, 3, 1, 0), self.act_fn, nn.Conv2d(64, 128, 3, 1, 0), self.act_fn, nn.Conv2d(128, 128, 3, 1, 0), self.act_fn, nn.Conv2d(128, 256, 3, 1, 0), self.act_fn, # in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding= nn.ConvTranspose2d(256, 3, kernel_size=4, stride=2, padding=3, bias=False), ] self.Sequential = nn.Sequential(*m) def load_pre_train_weights(self, json_file): with open(json_file) as f: weights = json.load(f) box = [] for i in weights: box.append(i["weight"]) box.append(i["bias"]) own_state = self.state_dict() for index, (name, param) in enumerate(own_state.items()): own_state[name].copy_(torch.FloatTensor(box[index])) def forward(self, x): x = self.pad(x) return self.Sequential.forward(x) class Vgg_7(UpConv_7): def __init__(self): super(Vgg_7, self).__init__() self.act_fn = nn.LeakyReLU(0.1, inplace=False) self.offset = 7 m = [ nn.Conv2d(3, 32, 3, 1, 0), self.act_fn, nn.Conv2d(32, 32, 3, 1, 0), self.act_fn, nn.Conv2d(32, 64, 3, 1, 0), self.act_fn, nn.Conv2d(64, 64, 3, 1, 0), self.act_fn, nn.Conv2d(64, 128, 3, 1, 0), self.act_fn, nn.Conv2d(128, 128, 3, 1, 0), self.act_fn, nn.Conv2d(128, 3, 3, 1, 0), ] self.Sequential = nn.Sequential(*m)