KenjieDec commited on
Commit
5911164
1 Parent(s): 6ab882a
__init_paths.py CHANGED
@@ -14,5 +14,8 @@ this_dir = osp.dirname(__file__)
14
  path = osp.join(this_dir, 'retinaface')
15
  add_path(path)
16
 
 
 
 
17
  path = osp.join(this_dir, 'face_model')
18
  add_path(path)
 
14
  path = osp.join(this_dir, 'retinaface')
15
  add_path(path)
16
 
17
+ path = osp.join(this_dir, 'sr_model')
18
+ add_path(path)
19
+
20
  path = osp.join(this_dir, 'face_model')
21
  add_path(path)
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  os.system('wget "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116085&Signature=GlUNW6%2B8FxvxWmE9jKIZYOOciKQ%3D" -O weights/RetinaFace-R50.pth')
4
  os.system('wget "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116208&Signature=hBgvVvKVSNGeXqT8glG%2Bd2t2OKc%3D" -O weights/GPEN-512.pth')
5
  os.system('wget "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116315&Signature=9tPavW2h%2F1LhIKiXj73sTQoWqcc%3D" -O weights/GPEN-1024-Color.pth ')
 
6
 
7
  import gradio as gr
8
 
@@ -21,6 +22,7 @@ def inference(file, mode):
21
  if mode == "enhance":
22
  model = {'name':'GPEN-512', 'size':512}
23
  im = cv2.imread(file, cv2.IMREAD_COLOR)
 
24
  faceenhancer = FaceEnhancement(size=model['size'], model=model['name'], channel_multiplier=2, device='cpu')
25
  img, orig_faces, enhanced_faces = faceenhancer.process(im)
26
  cv2.imwrite(os.path.join("output.png"), img)
@@ -50,7 +52,8 @@ gr.Interface(
50
  description=description,
51
  article=article,
52
  examples=[
53
- ['sample.png']
 
54
  ],
55
  enable_queue=True
56
  ).launch()
 
3
  os.system('wget "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/RetinaFace-R50.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116085&Signature=GlUNW6%2B8FxvxWmE9jKIZYOOciKQ%3D" -O weights/RetinaFace-R50.pth')
4
  os.system('wget "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-BFR-512.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116208&Signature=hBgvVvKVSNGeXqT8glG%2Bd2t2OKc%3D" -O weights/GPEN-512.pth')
5
  os.system('wget "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/GPEN-Colorization-1024.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1961116315&Signature=9tPavW2h%2F1LhIKiXj73sTQoWqcc%3D" -O weights/GPEN-1024-Color.pth ')
6
+ os.system('wget "https://public-vigen-video.oss-cn-shanghai.aliyuncs.com/robin/models/realesrnet_x2.pth?OSSAccessKeyId=LTAI4G6bfnyW4TA4wFUXTYBe&Expires=1962694780&Signature=lI%2FolhA%2FyigiTRvoDIVbtMIyhjI%3D" -O weights/realesrnet_x2.pth ')
7
 
8
  import gradio as gr
9
 
 
22
  if mode == "enhance":
23
  model = {'name':'GPEN-512', 'size':512}
24
  im = cv2.imread(file, cv2.IMREAD_COLOR)
25
+ im = cv2.resize(im, (0,0), fx=2, fy=2)
26
  faceenhancer = FaceEnhancement(size=model['size'], model=model['name'], channel_multiplier=2, device='cpu')
27
  img, orig_faces, enhanced_faces = faceenhancer.process(im)
28
  cv2.imwrite(os.path.join("output.png"), img)
 
52
  description=description,
53
  article=article,
54
  examples=[
55
+ ['enhance.png'],
56
+ ['color.png']
57
  ],
58
  enable_queue=True
59
  ).launch()
face_enhancement.py CHANGED
@@ -11,12 +11,14 @@ from PIL import Image
11
  import __init_paths
12
  from retinaface.retinaface_detection import RetinaFaceDetection
13
  from face_model.face_gan import FaceGAN
 
14
  from align_faces import warp_and_crop_face, get_reference_facial_points
15
 
16
  class FaceEnhancement(object):
17
- def __init__(self, base_dir='./', size=512, out_size=None, model=None, channel_multiplier=2, narrow=1, key=None, device='cuda'):
18
  self.facedetector = RetinaFaceDetection(base_dir, device)
19
  self.facegan = FaceGAN(base_dir, size, out_size, model, channel_multiplier, narrow, key, device=device)
 
20
  self.size = size
21
  self.out_size = size if out_size==None else out_size
22
  self.threshold = 0.9
@@ -53,6 +55,16 @@ class FaceEnhancement(object):
53
  orig_faces.append(img)
54
  enhanced_faces.append(ef)
55
 
 
 
 
 
 
 
 
 
 
 
56
  facebs, landms = self.facedetector.detect(img)
57
 
58
  height, width = img.shape[:2]
@@ -89,7 +101,10 @@ class FaceEnhancement(object):
89
  full_img[np.where(mask>0)] = tmp_img[np.where(mask>0)]
90
 
91
  full_mask = full_mask[:, :, np.newaxis]
92
- img = cv2.convertScaleAbs(img*(1-full_mask) + full_img*full_mask)
 
 
 
93
 
94
  return img, orig_faces, enhanced_faces
95
 
 
11
  import __init_paths
12
  from retinaface.retinaface_detection import RetinaFaceDetection
13
  from face_model.face_gan import FaceGAN
14
+ from sr_model.real_esrnet import RealESRNet
15
  from align_faces import warp_and_crop_face, get_reference_facial_points
16
 
17
  class FaceEnhancement(object):
18
+ def __init__(self, base_dir='./', size=512, out_size=None, model=None, channel_multiplier=2, narrow=1, key=None, device='cpu'):
19
  self.facedetector = RetinaFaceDetection(base_dir, device)
20
  self.facegan = FaceGAN(base_dir, size, out_size, model, channel_multiplier, narrow, key, device=device)
21
+ self.srmodel = RealESRNet(base_dir, args.sr_model, args.sr_scale, args.tile_size, device=device)
22
  self.size = size
23
  self.out_size = size if out_size==None else out_size
24
  self.threshold = 0.9
 
55
  orig_faces.append(img)
56
  enhanced_faces.append(ef)
57
 
58
+ if self.use_sr:
59
+ ef = self.srmodel.process(ef)
60
+
61
+ return ef, orig_faces, enhanced_faces
62
+
63
+ if self.use_sr:
64
+ img_sr = self.srmodel.process(img)
65
+ if img_sr is not None:
66
+ img = cv2.resize(img, img_sr.shape[:2][::-1])
67
+
68
  facebs, landms = self.facedetector.detect(img)
69
 
70
  height, width = img.shape[:2]
 
101
  full_img[np.where(mask>0)] = tmp_img[np.where(mask>0)]
102
 
103
  full_mask = full_mask[:, :, np.newaxis]
104
+ if self.use_sr and img_sr is not None:
105
+ img = cv2.convertScaleAbs(img_sr*(1-full_mask) + full_img*full_mask)
106
+ else:
107
+ img = cv2.convertScaleAbs(img*(1-full_mask) + full_img*full_mask)
108
 
109
  return img, orig_faces, enhanced_faces
110
 
sr_model/arch_util.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn as nn
4
+ from torch.nn import functional as F
5
+ from torch.nn import init as init
6
+ from torch.nn.modules.batchnorm import _BatchNorm
7
+
8
+ @torch.no_grad()
9
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
10
+ """Initialize network weights.
11
+
12
+ Args:
13
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
14
+ scale (float): Scale initialized weights, especially for residual
15
+ blocks. Default: 1.
16
+ bias_fill (float): The value to fill bias. Default: 0
17
+ kwargs (dict): Other arguments for initialization function.
18
+ """
19
+ if not isinstance(module_list, list):
20
+ module_list = [module_list]
21
+ for module in module_list:
22
+ for m in module.modules():
23
+ if isinstance(m, nn.Conv2d):
24
+ init.kaiming_normal_(m.weight, **kwargs)
25
+ m.weight.data *= scale
26
+ if m.bias is not None:
27
+ m.bias.data.fill_(bias_fill)
28
+ elif isinstance(m, nn.Linear):
29
+ init.kaiming_normal_(m.weight, **kwargs)
30
+ m.weight.data *= scale
31
+ if m.bias is not None:
32
+ m.bias.data.fill_(bias_fill)
33
+ elif isinstance(m, _BatchNorm):
34
+ init.constant_(m.weight, 1)
35
+ if m.bias is not None:
36
+ m.bias.data.fill_(bias_fill)
37
+
38
+
39
+ def make_layer(basic_block, num_basic_block, **kwarg):
40
+ """Make layers by stacking the same blocks.
41
+
42
+ Args:
43
+ basic_block (nn.module): nn.module class for basic block.
44
+ num_basic_block (int): number of blocks.
45
+
46
+ Returns:
47
+ nn.Sequential: Stacked blocks in nn.Sequential.
48
+ """
49
+ layers = []
50
+ for _ in range(num_basic_block):
51
+ layers.append(basic_block(**kwarg))
52
+ return nn.Sequential(*layers)
53
+
54
+
55
+ class ResidualBlockNoBN(nn.Module):
56
+ """Residual block without BN.
57
+
58
+ It has a style of:
59
+ ---Conv-ReLU-Conv-+-
60
+ |________________|
61
+
62
+ Args:
63
+ num_feat (int): Channel number of intermediate features.
64
+ Default: 64.
65
+ res_scale (float): Residual scale. Default: 1.
66
+ pytorch_init (bool): If set to True, use pytorch default init,
67
+ otherwise, use default_init_weights. Default: False.
68
+ """
69
+
70
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
71
+ super(ResidualBlockNoBN, self).__init__()
72
+ self.res_scale = res_scale
73
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
74
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
75
+ self.relu = nn.ReLU(inplace=True)
76
+
77
+ if not pytorch_init:
78
+ default_init_weights([self.conv1, self.conv2], 0.1)
79
+
80
+ def forward(self, x):
81
+ identity = x
82
+ out = self.conv2(self.relu(self.conv1(x)))
83
+ return identity + out * self.res_scale
84
+
85
+
86
+ class Upsample(nn.Sequential):
87
+ """Upsample module.
88
+
89
+ Args:
90
+ scale (int): Scale factor. Supported scales: 2^n and 3.
91
+ num_feat (int): Channel number of intermediate features.
92
+ """
93
+
94
+ def __init__(self, scale, num_feat):
95
+ m = []
96
+ if (scale & (scale - 1)) == 0: # scale = 2^n
97
+ for _ in range(int(math.log(scale, 2))):
98
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
99
+ m.append(nn.PixelShuffle(2))
100
+ elif scale == 3:
101
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
102
+ m.append(nn.PixelShuffle(3))
103
+ else:
104
+ raise ValueError(f'scale {scale} is not supported. '
105
+ 'Supported scales: 2^n and 3.')
106
+ super(Upsample, self).__init__(*m)
107
+
108
+ # TODO: may write a cpp file
109
+ def pixel_unshuffle(x, scale):
110
+ """ Pixel unshuffle.
111
+
112
+ Args:
113
+ x (Tensor): Input feature with shape (b, c, hh, hw).
114
+ scale (int): Downsample ratio.
115
+
116
+ Returns:
117
+ Tensor: the pixel unshuffled feature.
118
+ """
119
+ b, c, hh, hw = x.size()
120
+ out_channel = c * (scale**2)
121
+ assert hh % scale == 0 and hw % scale == 0
122
+ h = hh // scale
123
+ w = hw // scale
124
+ x_view = x.view(b, c, h, scale, w, scale)
125
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
sr_model/real_esrnet.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ from rrdbnet_arch import RRDBNet
6
+ from torch.nn import functional as F
7
+
8
+ class RealESRNet(object):
9
+ def __init__(self, base_dir='./', model=None, scale=2, tile_size=0, tile_pad=10, device='cuda'):
10
+ self.base_dir = base_dir
11
+ self.scale = scale
12
+ self.tile_size = tile_size
13
+ self.tile_pad = tile_pad
14
+ self.device = device
15
+ self.load_srmodel(base_dir, model)
16
+
17
+ def load_srmodel(self, base_dir, model):
18
+ self.srmodel = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=32, num_block=23, num_grow_ch=32, scale=self.scale)
19
+ if model is None:
20
+ loadnet = torch.load(os.path.join(self.base_dir, 'weights', 'realesrnet_x%d.pth'%self.scale))
21
+ else:
22
+ loadnet = torch.load(os.path.join(self.base_dir, 'weights', model+'_x%d.pth'%self.scale))
23
+ #print(loadnet['params_ema'].keys)
24
+ self.srmodel.load_state_dict(loadnet['params_ema'], strict=True)
25
+ self.srmodel.eval()
26
+ self.srmodel = self.srmodel.to(self.device)
27
+
28
+ def tile_process(self, img):
29
+ """It will first crop input images to tiles, and then process each tile.
30
+ Finally, all the processed tiles are merged into one images.
31
+
32
+ Modified from: https://github.com/ata4/esrgan-launcher
33
+ """
34
+ batch, channel, height, width = img.shape
35
+ output_height = height * self.scale
36
+ output_width = width * self.scale
37
+ output_shape = (batch, channel, output_height, output_width)
38
+
39
+ # start with black image
40
+ output = img.new_zeros(output_shape)
41
+ tiles_x = math.ceil(width / self.tile_size)
42
+ tiles_y = math.ceil(height / self.tile_size)
43
+
44
+ # loop over all tiles
45
+ for y in range(tiles_y):
46
+ for x in range(tiles_x):
47
+ # extract tile from input image
48
+ ofs_x = x * self.tile_size
49
+ ofs_y = y * self.tile_size
50
+ # input tile area on total image
51
+ input_start_x = ofs_x
52
+ input_end_x = min(ofs_x + self.tile_size, width)
53
+ input_start_y = ofs_y
54
+ input_end_y = min(ofs_y + self.tile_size, height)
55
+
56
+ # input tile area on total image with padding
57
+ input_start_x_pad = max(input_start_x - self.tile_pad, 0)
58
+ input_end_x_pad = min(input_end_x + self.tile_pad, width)
59
+ input_start_y_pad = max(input_start_y - self.tile_pad, 0)
60
+ input_end_y_pad = min(input_end_y + self.tile_pad, height)
61
+
62
+ # input tile dimensions
63
+ input_tile_width = input_end_x - input_start_x
64
+ input_tile_height = input_end_y - input_start_y
65
+ tile_idx = y * tiles_x + x + 1
66
+ input_tile = img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad]
67
+
68
+ # upscale tile
69
+ try:
70
+ with torch.no_grad():
71
+ output_tile = self.srmodel(input_tile)
72
+ except RuntimeError as error:
73
+ print('Error', error)
74
+ return None
75
+ if tile_idx%10==0: print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
76
+
77
+ # output tile area on total image
78
+ output_start_x = input_start_x * self.scale
79
+ output_end_x = input_end_x * self.scale
80
+ output_start_y = input_start_y * self.scale
81
+ output_end_y = input_end_y * self.scale
82
+
83
+ # output tile area without padding
84
+ output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
85
+ output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
86
+ output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
87
+ output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
88
+
89
+ # put tile into output image
90
+ output[:, :, output_start_y:output_end_y,
91
+ output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile,
92
+ output_start_x_tile:output_end_x_tile]
93
+ return output
94
+
95
+ def process(self, img):
96
+ img = img.astype(np.float32) / 255.
97
+ img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
98
+ img = img.unsqueeze(0).to(self.device)
99
+
100
+ if self.scale == 2:
101
+ mod_scale = 2
102
+ elif self.scale == 1:
103
+ mod_scale = 4
104
+ else:
105
+ mod_scale = None
106
+ if mod_scale is not None:
107
+ h_pad, w_pad = 0, 0
108
+ _, _, h, w = img.size()
109
+ if (h % mod_scale != 0):
110
+ h_pad = (mod_scale - h % mod_scale)
111
+ if (w % mod_scale != 0):
112
+ w_pad = (mod_scale - w % mod_scale)
113
+ img = F.pad(img, (0, w_pad, 0, h_pad), 'reflect')
114
+
115
+ try:
116
+ with torch.no_grad():
117
+ if self.tile_size > 0:
118
+ output = self.tile_process(img)
119
+ else:
120
+ output = self.srmodel(img)
121
+ del img
122
+ # remove extra pad
123
+ if mod_scale is not None:
124
+ _, _, h, w = output.size()
125
+ output = output[:, :, 0:h - h_pad, 0:w - w_pad]
126
+ output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
127
+ output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
128
+ output = (output * 255.0).round().astype(np.uint8)
129
+
130
+ return output
131
+ except Exception as e:
132
+ print('sr failed:', e)
133
+ return None
sr_model/rrdbnet_arch.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+ from arch_util import default_init_weights, make_layer, pixel_unshuffle
6
+
7
+
8
+ class ResidualDenseBlock(nn.Module):
9
+ """Residual Dense Block.
10
+
11
+ Used in RRDB block in ESRGAN.
12
+
13
+ Args:
14
+ num_feat (int): Channel number of intermediate features.
15
+ num_grow_ch (int): Channels for each growth.
16
+ """
17
+
18
+ def __init__(self, num_feat=64, num_grow_ch=32):
19
+ super(ResidualDenseBlock, self).__init__()
20
+ self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1)
21
+ self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1)
22
+ self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1)
23
+ self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1)
24
+ self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1)
25
+
26
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
27
+
28
+ # initialization
29
+ default_init_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
30
+
31
+ def forward(self, x):
32
+ x1 = self.lrelu(self.conv1(x))
33
+ x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
34
+ x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
35
+ x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
36
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
37
+ # Emperically, we use 0.2 to scale the residual for better performance
38
+ return x5 * 0.2 + x
39
+
40
+
41
+ class RRDB(nn.Module):
42
+ """Residual in Residual Dense Block.
43
+
44
+ Used in RRDB-Net in ESRGAN.
45
+
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ num_grow_ch (int): Channels for each growth.
49
+ """
50
+
51
+ def __init__(self, num_feat, num_grow_ch=32):
52
+ super(RRDB, self).__init__()
53
+ self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch)
54
+ self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch)
55
+ self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch)
56
+
57
+ def forward(self, x):
58
+ out = self.rdb1(x)
59
+ out = self.rdb2(out)
60
+ out = self.rdb3(out)
61
+ # Emperically, we use 0.2 to scale the residual for better performance
62
+ return out * 0.2 + x
63
+
64
+ class RRDBNet(nn.Module):
65
+ """Networks consisting of Residual in Residual Dense Block, which is used
66
+ in ESRGAN.
67
+
68
+ ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
69
+
70
+ We extend ESRGAN for scale x2 and scale x1.
71
+ Note: This is one option for scale 1, scale 2 in RRDBNet.
72
+ We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size
73
+ and enlarge the channel size before feeding inputs into the main ESRGAN architecture.
74
+
75
+ Args:
76
+ num_in_ch (int): Channel number of inputs.
77
+ num_out_ch (int): Channel number of outputs.
78
+ num_feat (int): Channel number of intermediate features.
79
+ Default: 64
80
+ num_block (int): Block number in the trunk network. Defaults: 23
81
+ num_grow_ch (int): Channels for each growth. Default: 32.
82
+ """
83
+
84
+ def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, num_grow_ch=32):
85
+ super(RRDBNet, self).__init__()
86
+ self.scale = scale
87
+ if scale == 2:
88
+ num_in_ch = num_in_ch * 4
89
+ elif scale == 1:
90
+ num_in_ch = num_in_ch * 16
91
+ self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
92
+ self.body = make_layer(RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch)
93
+ self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
94
+ # upsample
95
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
96
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
97
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
98
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
99
+
100
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
101
+
102
+ def forward(self, x):
103
+ if self.scale == 2:
104
+ feat = pixel_unshuffle(x, scale=2)
105
+ elif self.scale == 1:
106
+ feat = pixel_unshuffle(x, scale=4)
107
+ else:
108
+ feat = x
109
+ feat = self.conv_first(feat)
110
+ body_feat = self.conv_body(self.body(feat))
111
+ feat = feat + body_feat
112
+ # upsample
113
+ feat = self.lrelu(self.conv_up1(F.interpolate(feat, scale_factor=2, mode='nearest')))
114
+ feat = self.lrelu(self.conv_up2(F.interpolate(feat, scale_factor=2, mode='nearest')))
115
+ out = self.conv_last(self.lrelu(self.conv_hr(feat)))
116
+ return out