sczhou commited on
Commit
a4138cd
2 Parent(s): d26b29f a580cd2
CodeFormer/.gitignore → .gitignore RENAMED
@@ -5,9 +5,9 @@ version.py
5
 
6
  # ignored files with suffix
7
  *.html
8
- # *.png
9
- # *.jpeg
10
- # *.jpg
11
  *.pt
12
  *.gif
13
  *.pth
@@ -122,7 +122,8 @@ venv.bak/
122
  .mypy_cache/
123
 
124
  # project
125
- results/
 
126
  dlib/
127
  *.pth
128
  *_old*
 
5
 
6
  # ignored files with suffix
7
  *.html
8
+ *.png
9
+ *.jpeg
10
+ *.jpg
11
  *.pt
12
  *.gif
13
  *.pth
 
122
  .mypy_cache/
123
 
124
  # project
125
+ CodeFormer/results/
126
+ output/
127
  dlib/
128
  *.pth
129
  *_old*
CodeFormer/basicsr/utils/misc.py CHANGED
@@ -1,13 +1,36 @@
1
- import numpy as np
2
  import os
 
3
  import random
4
  import time
5
  import torch
 
6
  from os import path as osp
7
 
8
  from .dist_util import master_only
9
  from .logger import get_root_logger
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  def set_random_seed(seed):
13
  """Set random seeds."""
@@ -131,4 +154,4 @@ def sizeof_fmt(size, suffix='B'):
131
  if abs(size) < 1024.0:
132
  return f'{size:3.1f} {unit}{suffix}'
133
  size /= 1024.0
134
- return f'{size:3.1f} Y{suffix}'
 
 
1
  import os
2
+ import re
3
  import random
4
  import time
5
  import torch
6
+ import numpy as np
7
  from os import path as osp
8
 
9
  from .dist_util import master_only
10
  from .logger import get_root_logger
11
 
12
+ IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
13
+ torch.__version__)[0][:3])] >= [1, 12, 0]
14
+
15
+ def gpu_is_available():
16
+ if IS_HIGH_VERSION:
17
+ if torch.backends.mps.is_available():
18
+ return True
19
+ return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
20
+
21
+ def get_device(gpu_id=None):
22
+ if gpu_id is None:
23
+ gpu_str = ''
24
+ elif isinstance(gpu_id, int):
25
+ gpu_str = f':{gpu_id}'
26
+ else:
27
+ raise TypeError('Input should be int value.')
28
+
29
+ if IS_HIGH_VERSION:
30
+ if torch.backends.mps.is_available():
31
+ return torch.device('mps'+gpu_str)
32
+ return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
33
+
34
 
35
  def set_random_seed(seed):
36
  """Set random seeds."""
 
154
  if abs(size) < 1024.0:
155
  return f'{size:3.1f} {unit}{suffix}'
156
  size /= 1024.0
157
+ return f'{size:3.1f} Y{suffix}'
CodeFormer/basicsr/version.py CHANGED
@@ -1,5 +1,5 @@
1
  # GENERATED VERSION FILE
2
- # TIME: Sun Aug 7 15:14:26 2022
3
  __version__ = '1.3.2'
4
- __gitsha__ = '6f94023'
5
  version_info = (1, 3, 2)
 
1
  # GENERATED VERSION FILE
2
+ # TIME: Sat Sep 21 15:31:46 2024
3
  __version__ = '1.3.2'
4
+ __gitsha__ = '1.3.2'
5
  version_info = (1, 3, 2)
CodeFormer/facelib/utils/face_restoration_helper.py CHANGED
@@ -6,8 +6,14 @@ from torchvision.transforms.functional import normalize
6
 
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
- from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
 
 
10
 
 
 
 
 
11
 
12
  def get_largest_face(det_faces, h, w):
13
 
@@ -64,8 +70,15 @@ class FaceRestoreHelper(object):
64
  self.crop_ratio = crop_ratio # (h, w)
65
  assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
66
  self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
67
-
68
- if self.template_3points:
 
 
 
 
 
 
 
69
  self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
70
  else:
71
  # standard 5 landmarks for FFHQ faces with 512 x 512
@@ -77,7 +90,6 @@ class FaceRestoreHelper(object):
77
  # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
78
  # [198.22603, 372.82502], [313.91018, 372.75659]])
79
 
80
-
81
  self.face_template = self.face_template * (face_size / 512.0)
82
  if self.crop_ratio[0] > 1:
83
  self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
@@ -97,12 +109,16 @@ class FaceRestoreHelper(object):
97
  self.pad_input_imgs = []
98
 
99
  if device is None:
100
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
101
  else:
102
  self.device = device
103
 
104
  # init face detection model
105
- self.face_det = init_detection_model(det_model, half=False, device=self.device)
 
 
 
106
 
107
  # init face parsing model
108
  self.use_parse = use_parse
@@ -125,7 +141,7 @@ class FaceRestoreHelper(object):
125
  img = img[:, :, 0:3]
126
 
127
  self.input_img = img
128
- self.is_gray = is_gray(img, threshold=5)
129
  if self.is_gray:
130
  print('Grayscale input: True')
131
 
@@ -133,25 +149,72 @@ class FaceRestoreHelper(object):
133
  f = 512.0/min(self.input_img.shape[:2])
134
  self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  def get_face_landmarks_5(self,
137
  only_keep_largest=False,
138
  only_center_face=False,
139
  resize=None,
140
  blur_ratio=0.01,
141
  eye_dist_threshold=None):
 
 
 
142
  if resize is None:
143
  scale = 1
144
  input_img = self.input_img
145
  else:
146
  h, w = self.input_img.shape[0:2]
147
  scale = resize / min(h, w)
148
- scale = max(1, scale) # always scale up
149
  h, w = int(h * scale), int(w * scale)
150
  interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
151
  input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
152
 
153
  with torch.no_grad():
154
- bboxes = self.face_det.detect_faces(input_img)
155
 
156
  if bboxes is None or bboxes.shape[0] == 0:
157
  return 0
@@ -298,10 +361,12 @@ class FaceRestoreHelper(object):
298
  torch.save(inverse_affine, save_path)
299
 
300
 
301
- def add_restored_face(self, face):
302
  if self.is_gray:
303
- face = bgr2gray(face) # convert img into grayscale
304
- self.restored_faces.append(face)
 
 
305
 
306
 
307
  def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
 
6
 
7
  from facelib.detection import init_detection_model
8
  from facelib.parsing import init_parsing_model
9
+ from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
10
+ from basicsr.utils.download_util import load_file_from_url
11
+ from basicsr.utils.misc import get_device
12
 
13
+ dlib_model_url = {
14
+ 'face_detector': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
15
+ 'shape_predictor_5': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
16
+ }
17
 
18
  def get_largest_face(det_faces, h, w):
19
 
 
70
  self.crop_ratio = crop_ratio # (h, w)
71
  assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
72
  self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
73
+ self.det_model = det_model
74
+
75
+ if self.det_model == 'dlib':
76
+ # standard 5 landmarks for FFHQ faces with 1024 x 1024
77
+ self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
78
+ [337.91089109, 488.38613861], [437.95049505, 493.51485149],
79
+ [513.58415842, 678.5049505]])
80
+ self.face_template = self.face_template / (1024 // face_size)
81
+ elif self.template_3points:
82
  self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
83
  else:
84
  # standard 5 landmarks for FFHQ faces with 512 x 512
 
90
  # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
91
  # [198.22603, 372.82502], [313.91018, 372.75659]])
92
 
 
93
  self.face_template = self.face_template * (face_size / 512.0)
94
  if self.crop_ratio[0] > 1:
95
  self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
 
109
  self.pad_input_imgs = []
110
 
111
  if device is None:
112
+ # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
113
+ self.device = get_device()
114
  else:
115
  self.device = device
116
 
117
  # init face detection model
118
+ if self.det_model == 'dlib':
119
+ self.face_detector, self.shape_predictor_5 = self.init_dlib(dlib_model_url['face_detector'], dlib_model_url['shape_predictor_5'])
120
+ else:
121
+ self.face_detector = init_detection_model(det_model, half=False, device=self.device)
122
 
123
  # init face parsing model
124
  self.use_parse = use_parse
 
141
  img = img[:, :, 0:3]
142
 
143
  self.input_img = img
144
+ self.is_gray = is_gray(img, threshold=10)
145
  if self.is_gray:
146
  print('Grayscale input: True')
147
 
 
149
  f = 512.0/min(self.input_img.shape[:2])
150
  self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
151
 
152
+ def init_dlib(self, detection_path, landmark5_path):
153
+ """Initialize the dlib detectors and predictors."""
154
+ try:
155
+ import dlib
156
+ except ImportError:
157
+ print('Please install dlib by running:' 'conda install -c conda-forge dlib')
158
+ detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
159
+ landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
160
+ face_detector = dlib.cnn_face_detection_model_v1(detection_path)
161
+ shape_predictor_5 = dlib.shape_predictor(landmark5_path)
162
+ return face_detector, shape_predictor_5
163
+
164
+ def get_face_landmarks_5_dlib(self,
165
+ only_keep_largest=False,
166
+ scale=1):
167
+ det_faces = self.face_detector(self.input_img, scale)
168
+
169
+ if len(det_faces) == 0:
170
+ print('No face detected. Try to increase upsample_num_times.')
171
+ return 0
172
+ else:
173
+ if only_keep_largest:
174
+ print('Detect several faces and only keep the largest.')
175
+ face_areas = []
176
+ for i in range(len(det_faces)):
177
+ face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
178
+ det_faces[i].rect.bottom() - det_faces[i].rect.top())
179
+ face_areas.append(face_area)
180
+ largest_idx = face_areas.index(max(face_areas))
181
+ self.det_faces = [det_faces[largest_idx]]
182
+ else:
183
+ self.det_faces = det_faces
184
+
185
+ if len(self.det_faces) == 0:
186
+ return 0
187
+
188
+ for face in self.det_faces:
189
+ shape = self.shape_predictor_5(self.input_img, face.rect)
190
+ landmark = np.array([[part.x, part.y] for part in shape.parts()])
191
+ self.all_landmarks_5.append(landmark)
192
+
193
+ return len(self.all_landmarks_5)
194
+
195
+
196
  def get_face_landmarks_5(self,
197
  only_keep_largest=False,
198
  only_center_face=False,
199
  resize=None,
200
  blur_ratio=0.01,
201
  eye_dist_threshold=None):
202
+ if self.det_model == 'dlib':
203
+ return self.get_face_landmarks_5_dlib(only_keep_largest)
204
+
205
  if resize is None:
206
  scale = 1
207
  input_img = self.input_img
208
  else:
209
  h, w = self.input_img.shape[0:2]
210
  scale = resize / min(h, w)
211
+ # scale = max(1, scale) # always scale up; comment this out for HD images, e.g., AIGC faces.
212
  h, w = int(h * scale), int(w * scale)
213
  interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
214
  input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
215
 
216
  with torch.no_grad():
217
+ bboxes = self.face_detector.detect_faces(input_img)
218
 
219
  if bboxes is None or bboxes.shape[0] == 0:
220
  return 0
 
361
  torch.save(inverse_affine, save_path)
362
 
363
 
364
+ def add_restored_face(self, restored_face, input_face=None):
365
  if self.is_gray:
366
+ restored_face = bgr2gray(restored_face) # convert img into grayscale
367
+ if input_face is not None:
368
+ restored_face = adain_npy(restored_face, input_face) # transfer the color
369
+ self.restored_faces.append(restored_face)
370
 
371
 
372
  def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
CodeFormer/facelib/utils/misc.py CHANGED
@@ -7,13 +7,13 @@ import torch
7
  from torch.hub import download_url_to_file, get_dir
8
  from urllib.parse import urlparse
9
  # from basicsr.utils.download_util import download_file_from_google_drive
10
- # import gdown
11
-
12
 
13
  ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14
 
15
 
16
  def download_pretrained_models(file_ids, save_path_root):
 
 
17
  os.makedirs(save_path_root, exist_ok=True)
18
 
19
  for file_name, file_id in file_ids.items():
@@ -23,7 +23,7 @@ def download_pretrained_models(file_ids, save_path_root):
23
  user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
24
  if user_response.lower() == 'y':
25
  print(f'Covering {file_name} to {save_path}')
26
- # gdown.download(file_url, save_path, quiet=False)
27
  # download_file_from_google_drive(file_id, save_path)
28
  elif user_response.lower() == 'n':
29
  print(f'Skipping {file_name}')
@@ -31,7 +31,7 @@ def download_pretrained_models(file_ids, save_path_root):
31
  raise ValueError('Wrong input. Only accepts Y/N.')
32
  else:
33
  print(f'Downloading {file_name} to {save_path}')
34
- # gdown.download(file_url, save_path, quiet=False)
35
  # download_file_from_google_drive(file_id, save_path)
36
 
37
 
@@ -172,3 +172,31 @@ def bgr2gray(img, out_channel=3):
172
  if out_channel == 3:
173
  gray = gray[:,:,np.newaxis].repeat(3, axis=2)
174
  return gray
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from torch.hub import download_url_to_file, get_dir
8
  from urllib.parse import urlparse
9
  # from basicsr.utils.download_util import download_file_from_google_drive
 
 
10
 
11
  ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
 
13
 
14
  def download_pretrained_models(file_ids, save_path_root):
15
+ import gdown
16
+
17
  os.makedirs(save_path_root, exist_ok=True)
18
 
19
  for file_name, file_id in file_ids.items():
 
23
  user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
24
  if user_response.lower() == 'y':
25
  print(f'Covering {file_name} to {save_path}')
26
+ gdown.download(file_url, save_path, quiet=False)
27
  # download_file_from_google_drive(file_id, save_path)
28
  elif user_response.lower() == 'n':
29
  print(f'Skipping {file_name}')
 
31
  raise ValueError('Wrong input. Only accepts Y/N.')
32
  else:
33
  print(f'Downloading {file_name} to {save_path}')
34
+ gdown.download(file_url, save_path, quiet=False)
35
  # download_file_from_google_drive(file_id, save_path)
36
 
37
 
 
172
  if out_channel == 3:
173
  gray = gray[:,:,np.newaxis].repeat(3, axis=2)
174
  return gray
175
+
176
+
177
+ def calc_mean_std(feat, eps=1e-5):
178
+ """
179
+ Args:
180
+ feat (numpy): 3D [w h c]s
181
+ """
182
+ size = feat.shape
183
+ assert len(size) == 3, 'The input feature should be 3D tensor.'
184
+ c = size[2]
185
+ feat_var = feat.reshape(-1, c).var(axis=0) + eps
186
+ feat_std = np.sqrt(feat_var).reshape(1, 1, c)
187
+ feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
188
+ return feat_mean, feat_std
189
+
190
+
191
+ def adain_npy(content_feat, style_feat):
192
+ """Adaptive instance normalization for numpy.
193
+
194
+ Args:
195
+ content_feat (numpy): The input feature.
196
+ style_feat (numpy): The reference feature.
197
+ """
198
+ size = content_feat.shape
199
+ style_mean, style_std = calc_mean_std(style_feat)
200
+ content_mean, content_std = calc_mean_std(content_feat)
201
+ normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
202
+ return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
CodeFormer/inference_codeformer.py CHANGED
@@ -1,4 +1,3 @@
1
- # Modified by Shangchen Zhou from: https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
2
  import os
3
  import cv2
4
  import argparse
@@ -7,8 +6,9 @@ import torch
7
  from torchvision.transforms.functional import normalize
8
  from basicsr.utils import imwrite, img2tensor, tensor2img
9
  from basicsr.utils.download_util import load_file_from_url
 
10
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
11
- import torch.nn.functional as F
12
 
13
  from basicsr.utils.registry import ARCH_REGISTRY
14
 
@@ -17,51 +17,104 @@ pretrain_model_url = {
17
  }
18
 
19
  def set_realesrgan():
20
- if not torch.cuda.is_available(): # CPU
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  import warnings
22
- warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
23
- 'If you really want to use it, please modify the corresponding codes.',
 
24
  category=RuntimeWarning)
25
- bg_upsampler = None
26
- else:
27
- from basicsr.archs.rrdbnet_arch import RRDBNet
28
- from basicsr.utils.realesrgan_utils import RealESRGANer
29
- model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
30
- bg_upsampler = RealESRGANer(
31
- scale=2,
32
- model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
33
- model=model,
34
- tile=args.bg_tile,
35
- tile_pad=40,
36
- pre_pad=0,
37
- half=True) # need to set False in CPU mode
38
- return bg_upsampler
39
 
40
  if __name__ == '__main__':
41
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
42
  parser = argparse.ArgumentParser()
43
 
44
- parser.add_argument('--w', type=float, default=0.5, help='Balance the quality and fidelity')
45
- parser.add_argument('--upscale', type=int, default=2, help='The final upsampling scale of the image. Default: 2')
46
- parser.add_argument('--test_path', type=str, default='./inputs/cropped_faces')
47
- parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces')
48
- parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face')
 
 
 
 
 
 
49
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
50
  # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
51
- parser.add_argument('--detection_model', type=str, default='retinaface_resnet50')
52
- parser.add_argument('--draw_box', action='store_true')
53
- parser.add_argument('--bg_upsampler', type=str, default='None', help='background upsampler. Optional: realesrgan')
54
- parser.add_argument('--face_upsample', action='store_true', help='face upsampler after enhancement.')
 
55
  parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
 
 
56
 
57
  args = parser.parse_args()
58
 
59
  # ------------------------ input & output ------------------------
60
- if args.test_path.endswith('/'): # solve when path ends with /
61
- args.test_path = args.test_path[:-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- w = args.w
64
- result_root = f'results/{os.path.basename(args.test_path)}_{w}'
 
 
65
 
66
  # ------------------ set up background upsampler ------------------
67
  if args.bg_upsampler == 'realesrgan':
@@ -109,19 +162,27 @@ if __name__ == '__main__':
109
  device=device)
110
 
111
  # -------------------- start to processing ---------------------
112
- # scan all the jpg and png images
113
- for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
114
  # clean all the intermediate results to process the next image
115
  face_helper.clean_all()
116
 
117
- img_name = os.path.basename(img_path)
118
- print(f'Processing: {img_name}')
119
- basename, ext = os.path.splitext(img_name)
120
- img = cv2.imread(img_path, cv2.IMREAD_COLOR)
 
 
 
 
 
 
121
 
122
  if args.has_aligned:
123
  # the input faces are already cropped and aligned
124
  img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
 
 
 
125
  face_helper.cropped_faces = [img]
126
  else:
127
  face_helper.read_image(img)
@@ -150,7 +211,7 @@ if __name__ == '__main__':
150
  restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
151
 
152
  restored_face = restored_face.astype('uint8')
153
- face_helper.add_restored_face(restored_face)
154
 
155
  # paste_back
156
  if not args.has_aligned:
@@ -178,12 +239,36 @@ if __name__ == '__main__':
178
  save_face_name = f'{basename}.png'
179
  else:
180
  save_face_name = f'{basename}_{idx:02d}.png'
 
 
181
  save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
182
  imwrite(restored_face, save_restore_path)
183
 
184
  # save restored img
185
  if not args.has_aligned and restored_img is not None:
 
 
186
  save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
187
  imwrite(restored_img, save_restore_path)
188
 
189
- print(f'\nAll results are saved in {result_root}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import cv2
3
  import argparse
 
6
  from torchvision.transforms.functional import normalize
7
  from basicsr.utils import imwrite, img2tensor, tensor2img
8
  from basicsr.utils.download_util import load_file_from_url
9
+ from basicsr.utils.misc import gpu_is_available, get_device
10
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
11
+ from facelib.utils.misc import is_gray
12
 
13
  from basicsr.utils.registry import ARCH_REGISTRY
14
 
 
17
  }
18
 
19
  def set_realesrgan():
20
+ from basicsr.archs.rrdbnet_arch import RRDBNet
21
+ from basicsr.utils.realesrgan_utils import RealESRGANer
22
+
23
+ use_half = False
24
+ if torch.cuda.is_available(): # set False in CPU/MPS mode
25
+ no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16
26
+ if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]:
27
+ use_half = True
28
+
29
+ model = RRDBNet(
30
+ num_in_ch=3,
31
+ num_out_ch=3,
32
+ num_feat=64,
33
+ num_block=23,
34
+ num_grow_ch=32,
35
+ scale=2,
36
+ )
37
+ upsampler = RealESRGANer(
38
+ scale=2,
39
+ model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
40
+ model=model,
41
+ tile=args.bg_tile,
42
+ tile_pad=40,
43
+ pre_pad=0,
44
+ half=use_half
45
+ )
46
+
47
+ if not gpu_is_available(): # CPU
48
  import warnings
49
+ warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
50
+ 'The unoptimized RealESRGAN is slow on CPU. '
51
+ 'If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.',
52
  category=RuntimeWarning)
53
+ return upsampler
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  if __name__ == '__main__':
56
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
57
+ device = get_device()
58
  parser = argparse.ArgumentParser()
59
 
60
+ parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
61
+ help='Input image, video or folder. Default: inputs/whole_imgs')
62
+ parser.add_argument('-o', '--output_path', type=str, default=None,
63
+ help='Output folder. Default: results/<input_name>_<w>')
64
+ parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5,
65
+ help='Balance the quality and fidelity. Default: 0.5')
66
+ parser.add_argument('-s', '--upscale', type=int, default=2,
67
+ help='The final upsampling scale of the image. Default: 2')
68
+ parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
69
+ parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
70
+ parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
71
  # large det_model: 'YOLOv5l', 'retinaface_resnet50'
72
  # small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
73
+ parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
74
+ help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
75
+ Default: retinaface_resnet50')
76
+ parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan')
77
+ parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False')
78
  parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
79
+ parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
80
+ parser.add_argument('--save_video_fps', type=float, default=None, help='Frame rate for saving video. Default: None')
81
 
82
  args = parser.parse_args()
83
 
84
  # ------------------------ input & output ------------------------
85
+ w = args.fidelity_weight
86
+ input_video = False
87
+ if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
88
+ input_img_list = [args.input_path]
89
+ result_root = f'results/test_img_{w}'
90
+ elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
91
+ from basicsr.utils.video_util import VideoReader, VideoWriter
92
+ input_img_list = []
93
+ vidreader = VideoReader(args.input_path)
94
+ image = vidreader.get_frame()
95
+ while image is not None:
96
+ input_img_list.append(image)
97
+ image = vidreader.get_frame()
98
+ audio = vidreader.get_audio()
99
+ fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
100
+ video_name = os.path.basename(args.input_path)[:-4]
101
+ result_root = f'results/{video_name}_{w}'
102
+ input_video = True
103
+ vidreader.close()
104
+ else: # input img folder
105
+ if args.input_path.endswith('/'): # solve when path ends with /
106
+ args.input_path = args.input_path[:-1]
107
+ # scan all the jpg and png images
108
+ input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
109
+ result_root = f'results/{os.path.basename(args.input_path)}_{w}'
110
+
111
+ if not args.output_path is None: # set output path
112
+ result_root = args.output_path
113
 
114
+ test_img_num = len(input_img_list)
115
+ if test_img_num == 0:
116
+ raise FileNotFoundError('No input image/video is found...\n'
117
+ '\tNote that --input_path for video should end with .mp4|.mov|.avi')
118
 
119
  # ------------------ set up background upsampler ------------------
120
  if args.bg_upsampler == 'realesrgan':
 
162
  device=device)
163
 
164
  # -------------------- start to processing ---------------------
165
+ for i, img_path in enumerate(input_img_list):
 
166
  # clean all the intermediate results to process the next image
167
  face_helper.clean_all()
168
 
169
+ if isinstance(img_path, str):
170
+ img_name = os.path.basename(img_path)
171
+ basename, ext = os.path.splitext(img_name)
172
+ print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
173
+ img = cv2.imread(img_path, cv2.IMREAD_COLOR)
174
+ else: # for video processing
175
+ basename = str(i).zfill(6)
176
+ img_name = f'{video_name}_{basename}' if input_video else basename
177
+ print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
178
+ img = img_path
179
 
180
  if args.has_aligned:
181
  # the input faces are already cropped and aligned
182
  img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
183
+ face_helper.is_gray = is_gray(img, threshold=10)
184
+ if face_helper.is_gray:
185
+ print('Grayscale input: True')
186
  face_helper.cropped_faces = [img]
187
  else:
188
  face_helper.read_image(img)
 
211
  restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
212
 
213
  restored_face = restored_face.astype('uint8')
214
+ face_helper.add_restored_face(restored_face, cropped_face)
215
 
216
  # paste_back
217
  if not args.has_aligned:
 
239
  save_face_name = f'{basename}.png'
240
  else:
241
  save_face_name = f'{basename}_{idx:02d}.png'
242
+ if args.suffix is not None:
243
+ save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
244
  save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
245
  imwrite(restored_face, save_restore_path)
246
 
247
  # save restored img
248
  if not args.has_aligned and restored_img is not None:
249
+ if args.suffix is not None:
250
+ basename = f'{basename}_{args.suffix}'
251
  save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
252
  imwrite(restored_img, save_restore_path)
253
 
254
+ # save enhanced video
255
+ if input_video:
256
+ print('Video Saving...')
257
+ # load images
258
+ video_frames = []
259
+ img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g')))
260
+ for img_path in img_list:
261
+ img = cv2.imread(img_path)
262
+ video_frames.append(img)
263
+ # write images to video
264
+ height, width = video_frames[0].shape[:2]
265
+ if args.suffix is not None:
266
+ video_name = f'{video_name}_{args.suffix}.png'
267
+ save_restore_path = os.path.join(result_root, f'{video_name}.mp4')
268
+ vidwriter = VideoWriter(save_restore_path, height, width, fps, audio)
269
+
270
+ for f in video_frames:
271
+ vidwriter.write_frame(f)
272
+ vidwriter.close()
273
+
274
+ print(f'\nAll results are saved in {result_root}')
README.md CHANGED
@@ -9,4 +9,4 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
9
  pinned: false
10
  ---
11
 
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -16,9 +16,9 @@ from torchvision.transforms.functional import normalize
16
  from basicsr.utils import imwrite, img2tensor, tensor2img
17
  from basicsr.utils.download_util import load_file_from_url
18
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
19
- from facelib.utils.misc import is_gray
20
  from basicsr.archs.rrdbnet_arch import RRDBNet
21
  from basicsr.utils.realesrgan_utils import RealESRGANer
 
22
 
23
  from basicsr.utils.registry import ARCH_REGISTRY
24
 
@@ -166,9 +166,7 @@ def inference(image, face_align, background_enhance, face_upsample, upscale, cod
166
  # face restoration for each cropped face
167
  for idx, cropped_face in enumerate(face_helper.cropped_faces):
168
  # prepare data
169
- cropped_face_t = img2tensor(
170
- cropped_face / 255.0, bgr2rgb=True, float32=True
171
- )
172
  normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
173
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
174
 
@@ -182,12 +180,10 @@ def inference(image, face_align, background_enhance, face_upsample, upscale, cod
182
  torch.cuda.empty_cache()
183
  except RuntimeError as error:
184
  print(f"Failed inference for CodeFormer: {error}")
185
- restored_face = tensor2img(
186
- cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
187
- )
188
 
189
  restored_face = restored_face.astype("uint8")
190
- face_helper.add_restored_face(restored_face)
191
 
192
  # paste_back
193
  if not has_aligned:
@@ -264,6 +260,12 @@ If you have any questions, please feel free to reach me out at <b>shangchenzhou@
264
  td {
265
  padding-right: 0px !important;
266
  }
 
 
 
 
 
 
267
  </style>
268
 
269
  <table>
@@ -302,5 +304,5 @@ demo = gr.Interface(
302
  )
303
 
304
  DEBUG = os.getenv('DEBUG') == '1'
305
- demo.launch(debug=DEBUG)
306
- # demo.launch(debug=DEBUG, share=True)
 
16
  from basicsr.utils import imwrite, img2tensor, tensor2img
17
  from basicsr.utils.download_util import load_file_from_url
18
  from facelib.utils.face_restoration_helper import FaceRestoreHelper
 
19
  from basicsr.archs.rrdbnet_arch import RRDBNet
20
  from basicsr.utils.realesrgan_utils import RealESRGANer
21
+ from facelib.utils.misc import is_gray
22
 
23
  from basicsr.utils.registry import ARCH_REGISTRY
24
 
 
166
  # face restoration for each cropped face
167
  for idx, cropped_face in enumerate(face_helper.cropped_faces):
168
  # prepare data
169
+ cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
 
 
170
  normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
171
  cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
172
 
 
180
  torch.cuda.empty_cache()
181
  except RuntimeError as error:
182
  print(f"Failed inference for CodeFormer: {error}")
183
+ restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
 
 
184
 
185
  restored_face = restored_face.astype("uint8")
186
+ face_helper.add_restored_face(restored_face, cropped_face)
187
 
188
  # paste_back
189
  if not has_aligned:
 
260
  td {
261
  padding-right: 0px !important;
262
  }
263
+
264
+ .gradio-container-4-37-2 .prose table, .gradio-container-4-37-2 .prose tr, .gradio-container-4-37-2 .prose td, .gradio-container-4-37-2 .prose th {
265
+ border: 0px solid #ffffff;
266
+ border-bottom: 0px solid #ffffff;
267
+ }
268
+
269
  </style>
270
 
271
  <table>
 
304
  )
305
 
306
  DEBUG = os.getenv('DEBUG') == '1'
307
+ # demo.launch(debug=DEBUG)
308
+ demo.launch(debug=DEBUG, share=True)