Spaces:
Running
on
L4
Running
on
L4
update.
Browse files- CodeFormer/.gitignore → .gitignore +5 -4
- CodeFormer/basicsr/utils/misc.py +25 -2
- CodeFormer/basicsr/version.py +2 -2
- CodeFormer/facelib/utils/face_restoration_helper.py +77 -12
- CodeFormer/facelib/utils/misc.py +32 -4
- CodeFormer/inference_codeformer.py +126 -41
- README.md +1 -1
- app.py +12 -10
CodeFormer/.gitignore → .gitignore
RENAMED
@@ -5,9 +5,9 @@ version.py
|
|
5 |
|
6 |
# ignored files with suffix
|
7 |
*.html
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
*.pt
|
12 |
*.gif
|
13 |
*.pth
|
@@ -122,7 +122,8 @@ venv.bak/
|
|
122 |
.mypy_cache/
|
123 |
|
124 |
# project
|
125 |
-
results/
|
|
|
126 |
dlib/
|
127 |
*.pth
|
128 |
*_old*
|
|
|
5 |
|
6 |
# ignored files with suffix
|
7 |
*.html
|
8 |
+
*.png
|
9 |
+
*.jpeg
|
10 |
+
*.jpg
|
11 |
*.pt
|
12 |
*.gif
|
13 |
*.pth
|
|
|
122 |
.mypy_cache/
|
123 |
|
124 |
# project
|
125 |
+
CodeFormer/results/
|
126 |
+
output/
|
127 |
dlib/
|
128 |
*.pth
|
129 |
*_old*
|
CodeFormer/basicsr/utils/misc.py
CHANGED
@@ -1,13 +1,36 @@
|
|
1 |
-
import numpy as np
|
2 |
import os
|
|
|
3 |
import random
|
4 |
import time
|
5 |
import torch
|
|
|
6 |
from os import path as osp
|
7 |
|
8 |
from .dist_util import master_only
|
9 |
from .logger import get_root_logger
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def set_random_seed(seed):
|
13 |
"""Set random seeds."""
|
@@ -131,4 +154,4 @@ def sizeof_fmt(size, suffix='B'):
|
|
131 |
if abs(size) < 1024.0:
|
132 |
return f'{size:3.1f} {unit}{suffix}'
|
133 |
size /= 1024.0
|
134 |
-
return f'{size:3.1f} Y{suffix}'
|
|
|
|
|
1 |
import os
|
2 |
+
import re
|
3 |
import random
|
4 |
import time
|
5 |
import torch
|
6 |
+
import numpy as np
|
7 |
from os import path as osp
|
8 |
|
9 |
from .dist_util import master_only
|
10 |
from .logger import get_root_logger
|
11 |
|
12 |
+
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
13 |
+
torch.__version__)[0][:3])] >= [1, 12, 0]
|
14 |
+
|
15 |
+
def gpu_is_available():
|
16 |
+
if IS_HIGH_VERSION:
|
17 |
+
if torch.backends.mps.is_available():
|
18 |
+
return True
|
19 |
+
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
|
20 |
+
|
21 |
+
def get_device(gpu_id=None):
|
22 |
+
if gpu_id is None:
|
23 |
+
gpu_str = ''
|
24 |
+
elif isinstance(gpu_id, int):
|
25 |
+
gpu_str = f':{gpu_id}'
|
26 |
+
else:
|
27 |
+
raise TypeError('Input should be int value.')
|
28 |
+
|
29 |
+
if IS_HIGH_VERSION:
|
30 |
+
if torch.backends.mps.is_available():
|
31 |
+
return torch.device('mps'+gpu_str)
|
32 |
+
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
|
33 |
+
|
34 |
|
35 |
def set_random_seed(seed):
|
36 |
"""Set random seeds."""
|
|
|
154 |
if abs(size) < 1024.0:
|
155 |
return f'{size:3.1f} {unit}{suffix}'
|
156 |
size /= 1024.0
|
157 |
+
return f'{size:3.1f} Y{suffix}'
|
CodeFormer/basicsr/version.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# GENERATED VERSION FILE
|
2 |
-
# TIME:
|
3 |
__version__ = '1.3.2'
|
4 |
-
__gitsha__ = '
|
5 |
version_info = (1, 3, 2)
|
|
|
1 |
# GENERATED VERSION FILE
|
2 |
+
# TIME: Sat Sep 21 15:31:46 2024
|
3 |
__version__ = '1.3.2'
|
4 |
+
__gitsha__ = '1.3.2'
|
5 |
version_info = (1, 3, 2)
|
CodeFormer/facelib/utils/face_restoration_helper.py
CHANGED
@@ -6,8 +6,14 @@ from torchvision.transforms.functional import normalize
|
|
6 |
|
7 |
from facelib.detection import init_detection_model
|
8 |
from facelib.parsing import init_parsing_model
|
9 |
-
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray
|
|
|
|
|
10 |
|
|
|
|
|
|
|
|
|
11 |
|
12 |
def get_largest_face(det_faces, h, w):
|
13 |
|
@@ -64,8 +70,15 @@ class FaceRestoreHelper(object):
|
|
64 |
self.crop_ratio = crop_ratio # (h, w)
|
65 |
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
66 |
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
70 |
else:
|
71 |
# standard 5 landmarks for FFHQ faces with 512 x 512
|
@@ -77,7 +90,6 @@ class FaceRestoreHelper(object):
|
|
77 |
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
78 |
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
79 |
|
80 |
-
|
81 |
self.face_template = self.face_template * (face_size / 512.0)
|
82 |
if self.crop_ratio[0] > 1:
|
83 |
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
@@ -97,12 +109,16 @@ class FaceRestoreHelper(object):
|
|
97 |
self.pad_input_imgs = []
|
98 |
|
99 |
if device is None:
|
100 |
-
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
101 |
else:
|
102 |
self.device = device
|
103 |
|
104 |
# init face detection model
|
105 |
-
self.
|
|
|
|
|
|
|
106 |
|
107 |
# init face parsing model
|
108 |
self.use_parse = use_parse
|
@@ -125,7 +141,7 @@ class FaceRestoreHelper(object):
|
|
125 |
img = img[:, :, 0:3]
|
126 |
|
127 |
self.input_img = img
|
128 |
-
self.is_gray = is_gray(img, threshold=
|
129 |
if self.is_gray:
|
130 |
print('Grayscale input: True')
|
131 |
|
@@ -133,25 +149,72 @@ class FaceRestoreHelper(object):
|
|
133 |
f = 512.0/min(self.input_img.shape[:2])
|
134 |
self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
def get_face_landmarks_5(self,
|
137 |
only_keep_largest=False,
|
138 |
only_center_face=False,
|
139 |
resize=None,
|
140 |
blur_ratio=0.01,
|
141 |
eye_dist_threshold=None):
|
|
|
|
|
|
|
142 |
if resize is None:
|
143 |
scale = 1
|
144 |
input_img = self.input_img
|
145 |
else:
|
146 |
h, w = self.input_img.shape[0:2]
|
147 |
scale = resize / min(h, w)
|
148 |
-
scale = max(1, scale) # always scale up
|
149 |
h, w = int(h * scale), int(w * scale)
|
150 |
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
151 |
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
152 |
|
153 |
with torch.no_grad():
|
154 |
-
bboxes = self.
|
155 |
|
156 |
if bboxes is None or bboxes.shape[0] == 0:
|
157 |
return 0
|
@@ -298,10 +361,12 @@ class FaceRestoreHelper(object):
|
|
298 |
torch.save(inverse_affine, save_path)
|
299 |
|
300 |
|
301 |
-
def add_restored_face(self,
|
302 |
if self.is_gray:
|
303 |
-
|
304 |
-
|
|
|
|
|
305 |
|
306 |
|
307 |
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
|
|
6 |
|
7 |
from facelib.detection import init_detection_model
|
8 |
from facelib.parsing import init_parsing_model
|
9 |
+
from facelib.utils.misc import img2tensor, imwrite, is_gray, bgr2gray, adain_npy
|
10 |
+
from basicsr.utils.download_util import load_file_from_url
|
11 |
+
from basicsr.utils.misc import get_device
|
12 |
|
13 |
+
dlib_model_url = {
|
14 |
+
'face_detector': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/mmod_human_face_detector-4cb19393.dat',
|
15 |
+
'shape_predictor_5': 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/shape_predictor_5_face_landmarks-c4b1e980.dat'
|
16 |
+
}
|
17 |
|
18 |
def get_largest_face(det_faces, h, w):
|
19 |
|
|
|
70 |
self.crop_ratio = crop_ratio # (h, w)
|
71 |
assert (self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1), 'crop ration only supports >=1'
|
72 |
self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0]))
|
73 |
+
self.det_model = det_model
|
74 |
+
|
75 |
+
if self.det_model == 'dlib':
|
76 |
+
# standard 5 landmarks for FFHQ faces with 1024 x 1024
|
77 |
+
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941],
|
78 |
+
[337.91089109, 488.38613861], [437.95049505, 493.51485149],
|
79 |
+
[513.58415842, 678.5049505]])
|
80 |
+
self.face_template = self.face_template / (1024 // face_size)
|
81 |
+
elif self.template_3points:
|
82 |
self.face_template = np.array([[192, 240], [319, 240], [257, 371]])
|
83 |
else:
|
84 |
# standard 5 landmarks for FFHQ faces with 512 x 512
|
|
|
90 |
# self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894],
|
91 |
# [198.22603, 372.82502], [313.91018, 372.75659]])
|
92 |
|
|
|
93 |
self.face_template = self.face_template * (face_size / 512.0)
|
94 |
if self.crop_ratio[0] > 1:
|
95 |
self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2
|
|
|
109 |
self.pad_input_imgs = []
|
110 |
|
111 |
if device is None:
|
112 |
+
# self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
113 |
+
self.device = get_device()
|
114 |
else:
|
115 |
self.device = device
|
116 |
|
117 |
# init face detection model
|
118 |
+
if self.det_model == 'dlib':
|
119 |
+
self.face_detector, self.shape_predictor_5 = self.init_dlib(dlib_model_url['face_detector'], dlib_model_url['shape_predictor_5'])
|
120 |
+
else:
|
121 |
+
self.face_detector = init_detection_model(det_model, half=False, device=self.device)
|
122 |
|
123 |
# init face parsing model
|
124 |
self.use_parse = use_parse
|
|
|
141 |
img = img[:, :, 0:3]
|
142 |
|
143 |
self.input_img = img
|
144 |
+
self.is_gray = is_gray(img, threshold=10)
|
145 |
if self.is_gray:
|
146 |
print('Grayscale input: True')
|
147 |
|
|
|
149 |
f = 512.0/min(self.input_img.shape[:2])
|
150 |
self.input_img = cv2.resize(self.input_img, (0,0), fx=f, fy=f, interpolation=cv2.INTER_LINEAR)
|
151 |
|
152 |
+
def init_dlib(self, detection_path, landmark5_path):
|
153 |
+
"""Initialize the dlib detectors and predictors."""
|
154 |
+
try:
|
155 |
+
import dlib
|
156 |
+
except ImportError:
|
157 |
+
print('Please install dlib by running:' 'conda install -c conda-forge dlib')
|
158 |
+
detection_path = load_file_from_url(url=detection_path, model_dir='weights/dlib', progress=True, file_name=None)
|
159 |
+
landmark5_path = load_file_from_url(url=landmark5_path, model_dir='weights/dlib', progress=True, file_name=None)
|
160 |
+
face_detector = dlib.cnn_face_detection_model_v1(detection_path)
|
161 |
+
shape_predictor_5 = dlib.shape_predictor(landmark5_path)
|
162 |
+
return face_detector, shape_predictor_5
|
163 |
+
|
164 |
+
def get_face_landmarks_5_dlib(self,
|
165 |
+
only_keep_largest=False,
|
166 |
+
scale=1):
|
167 |
+
det_faces = self.face_detector(self.input_img, scale)
|
168 |
+
|
169 |
+
if len(det_faces) == 0:
|
170 |
+
print('No face detected. Try to increase upsample_num_times.')
|
171 |
+
return 0
|
172 |
+
else:
|
173 |
+
if only_keep_largest:
|
174 |
+
print('Detect several faces and only keep the largest.')
|
175 |
+
face_areas = []
|
176 |
+
for i in range(len(det_faces)):
|
177 |
+
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * (
|
178 |
+
det_faces[i].rect.bottom() - det_faces[i].rect.top())
|
179 |
+
face_areas.append(face_area)
|
180 |
+
largest_idx = face_areas.index(max(face_areas))
|
181 |
+
self.det_faces = [det_faces[largest_idx]]
|
182 |
+
else:
|
183 |
+
self.det_faces = det_faces
|
184 |
+
|
185 |
+
if len(self.det_faces) == 0:
|
186 |
+
return 0
|
187 |
+
|
188 |
+
for face in self.det_faces:
|
189 |
+
shape = self.shape_predictor_5(self.input_img, face.rect)
|
190 |
+
landmark = np.array([[part.x, part.y] for part in shape.parts()])
|
191 |
+
self.all_landmarks_5.append(landmark)
|
192 |
+
|
193 |
+
return len(self.all_landmarks_5)
|
194 |
+
|
195 |
+
|
196 |
def get_face_landmarks_5(self,
|
197 |
only_keep_largest=False,
|
198 |
only_center_face=False,
|
199 |
resize=None,
|
200 |
blur_ratio=0.01,
|
201 |
eye_dist_threshold=None):
|
202 |
+
if self.det_model == 'dlib':
|
203 |
+
return self.get_face_landmarks_5_dlib(only_keep_largest)
|
204 |
+
|
205 |
if resize is None:
|
206 |
scale = 1
|
207 |
input_img = self.input_img
|
208 |
else:
|
209 |
h, w = self.input_img.shape[0:2]
|
210 |
scale = resize / min(h, w)
|
211 |
+
# scale = max(1, scale) # always scale up; comment this out for HD images, e.g., AIGC faces.
|
212 |
h, w = int(h * scale), int(w * scale)
|
213 |
interp = cv2.INTER_AREA if scale < 1 else cv2.INTER_LINEAR
|
214 |
input_img = cv2.resize(self.input_img, (w, h), interpolation=interp)
|
215 |
|
216 |
with torch.no_grad():
|
217 |
+
bboxes = self.face_detector.detect_faces(input_img)
|
218 |
|
219 |
if bboxes is None or bboxes.shape[0] == 0:
|
220 |
return 0
|
|
|
361 |
torch.save(inverse_affine, save_path)
|
362 |
|
363 |
|
364 |
+
def add_restored_face(self, restored_face, input_face=None):
|
365 |
if self.is_gray:
|
366 |
+
restored_face = bgr2gray(restored_face) # convert img into grayscale
|
367 |
+
if input_face is not None:
|
368 |
+
restored_face = adain_npy(restored_face, input_face) # transfer the color
|
369 |
+
self.restored_faces.append(restored_face)
|
370 |
|
371 |
|
372 |
def paste_faces_to_input_image(self, save_path=None, upsample_img=None, draw_box=False, face_upsampler=None):
|
CodeFormer/facelib/utils/misc.py
CHANGED
@@ -7,13 +7,13 @@ import torch
|
|
7 |
from torch.hub import download_url_to_file, get_dir
|
8 |
from urllib.parse import urlparse
|
9 |
# from basicsr.utils.download_util import download_file_from_google_drive
|
10 |
-
# import gdown
|
11 |
-
|
12 |
|
13 |
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
14 |
|
15 |
|
16 |
def download_pretrained_models(file_ids, save_path_root):
|
|
|
|
|
17 |
os.makedirs(save_path_root, exist_ok=True)
|
18 |
|
19 |
for file_name, file_id in file_ids.items():
|
@@ -23,7 +23,7 @@ def download_pretrained_models(file_ids, save_path_root):
|
|
23 |
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
|
24 |
if user_response.lower() == 'y':
|
25 |
print(f'Covering {file_name} to {save_path}')
|
26 |
-
|
27 |
# download_file_from_google_drive(file_id, save_path)
|
28 |
elif user_response.lower() == 'n':
|
29 |
print(f'Skipping {file_name}')
|
@@ -31,7 +31,7 @@ def download_pretrained_models(file_ids, save_path_root):
|
|
31 |
raise ValueError('Wrong input. Only accepts Y/N.')
|
32 |
else:
|
33 |
print(f'Downloading {file_name} to {save_path}')
|
34 |
-
|
35 |
# download_file_from_google_drive(file_id, save_path)
|
36 |
|
37 |
|
@@ -172,3 +172,31 @@ def bgr2gray(img, out_channel=3):
|
|
172 |
if out_channel == 3:
|
173 |
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
174 |
return gray
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
from torch.hub import download_url_to_file, get_dir
|
8 |
from urllib.parse import urlparse
|
9 |
# from basicsr.utils.download_util import download_file_from_google_drive
|
|
|
|
|
10 |
|
11 |
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
12 |
|
13 |
|
14 |
def download_pretrained_models(file_ids, save_path_root):
|
15 |
+
import gdown
|
16 |
+
|
17 |
os.makedirs(save_path_root, exist_ok=True)
|
18 |
|
19 |
for file_name, file_id in file_ids.items():
|
|
|
23 |
user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
|
24 |
if user_response.lower() == 'y':
|
25 |
print(f'Covering {file_name} to {save_path}')
|
26 |
+
gdown.download(file_url, save_path, quiet=False)
|
27 |
# download_file_from_google_drive(file_id, save_path)
|
28 |
elif user_response.lower() == 'n':
|
29 |
print(f'Skipping {file_name}')
|
|
|
31 |
raise ValueError('Wrong input. Only accepts Y/N.')
|
32 |
else:
|
33 |
print(f'Downloading {file_name} to {save_path}')
|
34 |
+
gdown.download(file_url, save_path, quiet=False)
|
35 |
# download_file_from_google_drive(file_id, save_path)
|
36 |
|
37 |
|
|
|
172 |
if out_channel == 3:
|
173 |
gray = gray[:,:,np.newaxis].repeat(3, axis=2)
|
174 |
return gray
|
175 |
+
|
176 |
+
|
177 |
+
def calc_mean_std(feat, eps=1e-5):
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
feat (numpy): 3D [w h c]s
|
181 |
+
"""
|
182 |
+
size = feat.shape
|
183 |
+
assert len(size) == 3, 'The input feature should be 3D tensor.'
|
184 |
+
c = size[2]
|
185 |
+
feat_var = feat.reshape(-1, c).var(axis=0) + eps
|
186 |
+
feat_std = np.sqrt(feat_var).reshape(1, 1, c)
|
187 |
+
feat_mean = feat.reshape(-1, c).mean(axis=0).reshape(1, 1, c)
|
188 |
+
return feat_mean, feat_std
|
189 |
+
|
190 |
+
|
191 |
+
def adain_npy(content_feat, style_feat):
|
192 |
+
"""Adaptive instance normalization for numpy.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
content_feat (numpy): The input feature.
|
196 |
+
style_feat (numpy): The reference feature.
|
197 |
+
"""
|
198 |
+
size = content_feat.shape
|
199 |
+
style_mean, style_std = calc_mean_std(style_feat)
|
200 |
+
content_mean, content_std = calc_mean_std(content_feat)
|
201 |
+
normalized_feat = (content_feat - np.broadcast_to(content_mean, size)) / np.broadcast_to(content_std, size)
|
202 |
+
return normalized_feat * np.broadcast_to(style_std, size) + np.broadcast_to(style_mean, size)
|
CodeFormer/inference_codeformer.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
# Modified by Shangchen Zhou from: https://github.com/TencentARC/GFPGAN/blob/master/inference_gfpgan.py
|
2 |
import os
|
3 |
import cv2
|
4 |
import argparse
|
@@ -7,8 +6,9 @@ import torch
|
|
7 |
from torchvision.transforms.functional import normalize
|
8 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
9 |
from basicsr.utils.download_util import load_file_from_url
|
|
|
10 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
11 |
-
|
12 |
|
13 |
from basicsr.utils.registry import ARCH_REGISTRY
|
14 |
|
@@ -17,51 +17,104 @@ pretrain_model_url = {
|
|
17 |
}
|
18 |
|
19 |
def set_realesrgan():
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
import warnings
|
22 |
-
warnings.warn('
|
23 |
-
'
|
|
|
24 |
category=RuntimeWarning)
|
25 |
-
|
26 |
-
else:
|
27 |
-
from basicsr.archs.rrdbnet_arch import RRDBNet
|
28 |
-
from basicsr.utils.realesrgan_utils import RealESRGANer
|
29 |
-
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
|
30 |
-
bg_upsampler = RealESRGANer(
|
31 |
-
scale=2,
|
32 |
-
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
33 |
-
model=model,
|
34 |
-
tile=args.bg_tile,
|
35 |
-
tile_pad=40,
|
36 |
-
pre_pad=0,
|
37 |
-
half=True) # need to set False in CPU mode
|
38 |
-
return bg_upsampler
|
39 |
|
40 |
if __name__ == '__main__':
|
41 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
42 |
parser = argparse.ArgumentParser()
|
43 |
|
44 |
-
parser.add_argument('--
|
45 |
-
|
46 |
-
parser.add_argument('--
|
47 |
-
|
48 |
-
parser.add_argument('
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
|
50 |
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
|
51 |
-
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50'
|
52 |
-
|
53 |
-
|
54 |
-
parser.add_argument('--
|
|
|
55 |
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
|
|
|
|
|
56 |
|
57 |
args = parser.parse_args()
|
58 |
|
59 |
# ------------------------ input & output ------------------------
|
60 |
-
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
|
66 |
# ------------------ set up background upsampler ------------------
|
67 |
if args.bg_upsampler == 'realesrgan':
|
@@ -109,19 +162,27 @@ if __name__ == '__main__':
|
|
109 |
device=device)
|
110 |
|
111 |
# -------------------- start to processing ---------------------
|
112 |
-
|
113 |
-
for img_path in sorted(glob.glob(os.path.join(args.test_path, '*.[jp][pn]g'))):
|
114 |
# clean all the intermediate results to process the next image
|
115 |
face_helper.clean_all()
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
if args.has_aligned:
|
123 |
# the input faces are already cropped and aligned
|
124 |
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
|
|
125 |
face_helper.cropped_faces = [img]
|
126 |
else:
|
127 |
face_helper.read_image(img)
|
@@ -150,7 +211,7 @@ if __name__ == '__main__':
|
|
150 |
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
151 |
|
152 |
restored_face = restored_face.astype('uint8')
|
153 |
-
face_helper.add_restored_face(restored_face)
|
154 |
|
155 |
# paste_back
|
156 |
if not args.has_aligned:
|
@@ -178,12 +239,36 @@ if __name__ == '__main__':
|
|
178 |
save_face_name = f'{basename}.png'
|
179 |
else:
|
180 |
save_face_name = f'{basename}_{idx:02d}.png'
|
|
|
|
|
181 |
save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
|
182 |
imwrite(restored_face, save_restore_path)
|
183 |
|
184 |
# save restored img
|
185 |
if not args.has_aligned and restored_img is not None:
|
|
|
|
|
186 |
save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
|
187 |
imwrite(restored_img, save_restore_path)
|
188 |
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import cv2
|
3 |
import argparse
|
|
|
6 |
from torchvision.transforms.functional import normalize
|
7 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
8 |
from basicsr.utils.download_util import load_file_from_url
|
9 |
+
from basicsr.utils.misc import gpu_is_available, get_device
|
10 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
11 |
+
from facelib.utils.misc import is_gray
|
12 |
|
13 |
from basicsr.utils.registry import ARCH_REGISTRY
|
14 |
|
|
|
17 |
}
|
18 |
|
19 |
def set_realesrgan():
|
20 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
21 |
+
from basicsr.utils.realesrgan_utils import RealESRGANer
|
22 |
+
|
23 |
+
use_half = False
|
24 |
+
if torch.cuda.is_available(): # set False in CPU/MPS mode
|
25 |
+
no_half_gpu_list = ['1650', '1660'] # set False for GPUs that don't support f16
|
26 |
+
if not True in [gpu in torch.cuda.get_device_name(0) for gpu in no_half_gpu_list]:
|
27 |
+
use_half = True
|
28 |
+
|
29 |
+
model = RRDBNet(
|
30 |
+
num_in_ch=3,
|
31 |
+
num_out_ch=3,
|
32 |
+
num_feat=64,
|
33 |
+
num_block=23,
|
34 |
+
num_grow_ch=32,
|
35 |
+
scale=2,
|
36 |
+
)
|
37 |
+
upsampler = RealESRGANer(
|
38 |
+
scale=2,
|
39 |
+
model_path="https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth",
|
40 |
+
model=model,
|
41 |
+
tile=args.bg_tile,
|
42 |
+
tile_pad=40,
|
43 |
+
pre_pad=0,
|
44 |
+
half=use_half
|
45 |
+
)
|
46 |
+
|
47 |
+
if not gpu_is_available(): # CPU
|
48 |
import warnings
|
49 |
+
warnings.warn('Running on CPU now! Make sure your PyTorch version matches your CUDA.'
|
50 |
+
'The unoptimized RealESRGAN is slow on CPU. '
|
51 |
+
'If you want to disable it, please remove `--bg_upsampler` and `--face_upsample` in command.',
|
52 |
category=RuntimeWarning)
|
53 |
+
return upsampler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
if __name__ == '__main__':
|
56 |
+
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
57 |
+
device = get_device()
|
58 |
parser = argparse.ArgumentParser()
|
59 |
|
60 |
+
parser.add_argument('-i', '--input_path', type=str, default='./inputs/whole_imgs',
|
61 |
+
help='Input image, video or folder. Default: inputs/whole_imgs')
|
62 |
+
parser.add_argument('-o', '--output_path', type=str, default=None,
|
63 |
+
help='Output folder. Default: results/<input_name>_<w>')
|
64 |
+
parser.add_argument('-w', '--fidelity_weight', type=float, default=0.5,
|
65 |
+
help='Balance the quality and fidelity. Default: 0.5')
|
66 |
+
parser.add_argument('-s', '--upscale', type=int, default=2,
|
67 |
+
help='The final upsampling scale of the image. Default: 2')
|
68 |
+
parser.add_argument('--has_aligned', action='store_true', help='Input are cropped and aligned faces. Default: False')
|
69 |
+
parser.add_argument('--only_center_face', action='store_true', help='Only restore the center face. Default: False')
|
70 |
+
parser.add_argument('--draw_box', action='store_true', help='Draw the bounding box for the detected faces. Default: False')
|
71 |
# large det_model: 'YOLOv5l', 'retinaface_resnet50'
|
72 |
# small det_model: 'YOLOv5n', 'retinaface_mobile0.25'
|
73 |
+
parser.add_argument('--detection_model', type=str, default='retinaface_resnet50',
|
74 |
+
help='Face detector. Optional: retinaface_resnet50, retinaface_mobile0.25, YOLOv5l, YOLOv5n, dlib. \
|
75 |
+
Default: retinaface_resnet50')
|
76 |
+
parser.add_argument('--bg_upsampler', type=str, default='None', help='Background upsampler. Optional: realesrgan')
|
77 |
+
parser.add_argument('--face_upsample', action='store_true', help='Face upsampler after enhancement. Default: False')
|
78 |
parser.add_argument('--bg_tile', type=int, default=400, help='Tile size for background sampler. Default: 400')
|
79 |
+
parser.add_argument('--suffix', type=str, default=None, help='Suffix of the restored faces. Default: None')
|
80 |
+
parser.add_argument('--save_video_fps', type=float, default=None, help='Frame rate for saving video. Default: None')
|
81 |
|
82 |
args = parser.parse_args()
|
83 |
|
84 |
# ------------------------ input & output ------------------------
|
85 |
+
w = args.fidelity_weight
|
86 |
+
input_video = False
|
87 |
+
if args.input_path.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
|
88 |
+
input_img_list = [args.input_path]
|
89 |
+
result_root = f'results/test_img_{w}'
|
90 |
+
elif args.input_path.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
|
91 |
+
from basicsr.utils.video_util import VideoReader, VideoWriter
|
92 |
+
input_img_list = []
|
93 |
+
vidreader = VideoReader(args.input_path)
|
94 |
+
image = vidreader.get_frame()
|
95 |
+
while image is not None:
|
96 |
+
input_img_list.append(image)
|
97 |
+
image = vidreader.get_frame()
|
98 |
+
audio = vidreader.get_audio()
|
99 |
+
fps = vidreader.get_fps() if args.save_video_fps is None else args.save_video_fps
|
100 |
+
video_name = os.path.basename(args.input_path)[:-4]
|
101 |
+
result_root = f'results/{video_name}_{w}'
|
102 |
+
input_video = True
|
103 |
+
vidreader.close()
|
104 |
+
else: # input img folder
|
105 |
+
if args.input_path.endswith('/'): # solve when path ends with /
|
106 |
+
args.input_path = args.input_path[:-1]
|
107 |
+
# scan all the jpg and png images
|
108 |
+
input_img_list = sorted(glob.glob(os.path.join(args.input_path, '*.[jpJP][pnPN]*[gG]')))
|
109 |
+
result_root = f'results/{os.path.basename(args.input_path)}_{w}'
|
110 |
+
|
111 |
+
if not args.output_path is None: # set output path
|
112 |
+
result_root = args.output_path
|
113 |
|
114 |
+
test_img_num = len(input_img_list)
|
115 |
+
if test_img_num == 0:
|
116 |
+
raise FileNotFoundError('No input image/video is found...\n'
|
117 |
+
'\tNote that --input_path for video should end with .mp4|.mov|.avi')
|
118 |
|
119 |
# ------------------ set up background upsampler ------------------
|
120 |
if args.bg_upsampler == 'realesrgan':
|
|
|
162 |
device=device)
|
163 |
|
164 |
# -------------------- start to processing ---------------------
|
165 |
+
for i, img_path in enumerate(input_img_list):
|
|
|
166 |
# clean all the intermediate results to process the next image
|
167 |
face_helper.clean_all()
|
168 |
|
169 |
+
if isinstance(img_path, str):
|
170 |
+
img_name = os.path.basename(img_path)
|
171 |
+
basename, ext = os.path.splitext(img_name)
|
172 |
+
print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
|
173 |
+
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
|
174 |
+
else: # for video processing
|
175 |
+
basename = str(i).zfill(6)
|
176 |
+
img_name = f'{video_name}_{basename}' if input_video else basename
|
177 |
+
print(f'[{i+1}/{test_img_num}] Processing: {img_name}')
|
178 |
+
img = img_path
|
179 |
|
180 |
if args.has_aligned:
|
181 |
# the input faces are already cropped and aligned
|
182 |
img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR)
|
183 |
+
face_helper.is_gray = is_gray(img, threshold=10)
|
184 |
+
if face_helper.is_gray:
|
185 |
+
print('Grayscale input: True')
|
186 |
face_helper.cropped_faces = [img]
|
187 |
else:
|
188 |
face_helper.read_image(img)
|
|
|
211 |
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
212 |
|
213 |
restored_face = restored_face.astype('uint8')
|
214 |
+
face_helper.add_restored_face(restored_face, cropped_face)
|
215 |
|
216 |
# paste_back
|
217 |
if not args.has_aligned:
|
|
|
239 |
save_face_name = f'{basename}.png'
|
240 |
else:
|
241 |
save_face_name = f'{basename}_{idx:02d}.png'
|
242 |
+
if args.suffix is not None:
|
243 |
+
save_face_name = f'{save_face_name[:-4]}_{args.suffix}.png'
|
244 |
save_restore_path = os.path.join(result_root, 'restored_faces', save_face_name)
|
245 |
imwrite(restored_face, save_restore_path)
|
246 |
|
247 |
# save restored img
|
248 |
if not args.has_aligned and restored_img is not None:
|
249 |
+
if args.suffix is not None:
|
250 |
+
basename = f'{basename}_{args.suffix}'
|
251 |
save_restore_path = os.path.join(result_root, 'final_results', f'{basename}.png')
|
252 |
imwrite(restored_img, save_restore_path)
|
253 |
|
254 |
+
# save enhanced video
|
255 |
+
if input_video:
|
256 |
+
print('Video Saving...')
|
257 |
+
# load images
|
258 |
+
video_frames = []
|
259 |
+
img_list = sorted(glob.glob(os.path.join(result_root, 'final_results', '*.[jp][pn]g')))
|
260 |
+
for img_path in img_list:
|
261 |
+
img = cv2.imread(img_path)
|
262 |
+
video_frames.append(img)
|
263 |
+
# write images to video
|
264 |
+
height, width = video_frames[0].shape[:2]
|
265 |
+
if args.suffix is not None:
|
266 |
+
video_name = f'{video_name}_{args.suffix}.png'
|
267 |
+
save_restore_path = os.path.join(result_root, f'{video_name}.mp4')
|
268 |
+
vidwriter = VideoWriter(save_restore_path, height, width, fps, audio)
|
269 |
+
|
270 |
+
for f in video_frames:
|
271 |
+
vidwriter.write_frame(f)
|
272 |
+
vidwriter.close()
|
273 |
+
|
274 |
+
print(f'\nAll results are saved in {result_root}')
|
README.md
CHANGED
@@ -9,4 +9,4 @@ app_file: app.py
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
@@ -16,9 +16,9 @@ from torchvision.transforms.functional import normalize
|
|
16 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
17 |
from basicsr.utils.download_util import load_file_from_url
|
18 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
19 |
-
from facelib.utils.misc import is_gray
|
20 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
21 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
|
|
22 |
|
23 |
from basicsr.utils.registry import ARCH_REGISTRY
|
24 |
|
@@ -166,9 +166,7 @@ def inference(image, face_align, background_enhance, face_upsample, upscale, cod
|
|
166 |
# face restoration for each cropped face
|
167 |
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
168 |
# prepare data
|
169 |
-
cropped_face_t = img2tensor(
|
170 |
-
cropped_face / 255.0, bgr2rgb=True, float32=True
|
171 |
-
)
|
172 |
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
173 |
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
174 |
|
@@ -182,12 +180,10 @@ def inference(image, face_align, background_enhance, face_upsample, upscale, cod
|
|
182 |
torch.cuda.empty_cache()
|
183 |
except RuntimeError as error:
|
184 |
print(f"Failed inference for CodeFormer: {error}")
|
185 |
-
restored_face = tensor2img(
|
186 |
-
cropped_face_t, rgb2bgr=True, min_max=(-1, 1)
|
187 |
-
)
|
188 |
|
189 |
restored_face = restored_face.astype("uint8")
|
190 |
-
face_helper.add_restored_face(restored_face)
|
191 |
|
192 |
# paste_back
|
193 |
if not has_aligned:
|
@@ -264,6 +260,12 @@ If you have any questions, please feel free to reach me out at <b>shangchenzhou@
|
|
264 |
td {
|
265 |
padding-right: 0px !important;
|
266 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
</style>
|
268 |
|
269 |
<table>
|
@@ -302,5 +304,5 @@ demo = gr.Interface(
|
|
302 |
)
|
303 |
|
304 |
DEBUG = os.getenv('DEBUG') == '1'
|
305 |
-
demo.launch(debug=DEBUG)
|
306 |
-
|
|
|
16 |
from basicsr.utils import imwrite, img2tensor, tensor2img
|
17 |
from basicsr.utils.download_util import load_file_from_url
|
18 |
from facelib.utils.face_restoration_helper import FaceRestoreHelper
|
|
|
19 |
from basicsr.archs.rrdbnet_arch import RRDBNet
|
20 |
from basicsr.utils.realesrgan_utils import RealESRGANer
|
21 |
+
from facelib.utils.misc import is_gray
|
22 |
|
23 |
from basicsr.utils.registry import ARCH_REGISTRY
|
24 |
|
|
|
166 |
# face restoration for each cropped face
|
167 |
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
168 |
# prepare data
|
169 |
+
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
|
|
|
|
170 |
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
171 |
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
172 |
|
|
|
180 |
torch.cuda.empty_cache()
|
181 |
except RuntimeError as error:
|
182 |
print(f"Failed inference for CodeFormer: {error}")
|
183 |
+
restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1))
|
|
|
|
|
184 |
|
185 |
restored_face = restored_face.astype("uint8")
|
186 |
+
face_helper.add_restored_face(restored_face, cropped_face)
|
187 |
|
188 |
# paste_back
|
189 |
if not has_aligned:
|
|
|
260 |
td {
|
261 |
padding-right: 0px !important;
|
262 |
}
|
263 |
+
|
264 |
+
.gradio-container-4-37-2 .prose table, .gradio-container-4-37-2 .prose tr, .gradio-container-4-37-2 .prose td, .gradio-container-4-37-2 .prose th {
|
265 |
+
border: 0px solid #ffffff;
|
266 |
+
border-bottom: 0px solid #ffffff;
|
267 |
+
}
|
268 |
+
|
269 |
</style>
|
270 |
|
271 |
<table>
|
|
|
304 |
)
|
305 |
|
306 |
DEBUG = os.getenv('DEBUG') == '1'
|
307 |
+
# demo.launch(debug=DEBUG)
|
308 |
+
demo.launch(debug=DEBUG, share=True)
|