Spaces:
Running
Running
import cv2 | |
import numpy as np | |
import os | |
import torch | |
from skimage import transform as trans | |
from basicsr.utils import imwrite | |
try: | |
import dlib | |
except ImportError: | |
print('Please install dlib before testing face restoration.' 'Reference: https://github.com/davisking/dlib') | |
class FaceRestorationHelper(object): | |
"""Helper for the face restoration pipeline.""" | |
def __init__(self, upscale_factor, face_size=512): | |
self.upscale_factor = upscale_factor | |
self.face_size = (face_size, face_size) | |
# standard 5 landmarks for FFHQ faces with 1024 x 1024 | |
self.face_template = np.array([[686.77227723, 488.62376238], [586.77227723, 493.59405941], | |
[337.91089109, 488.38613861], [437.95049505, 493.51485149], | |
[513.58415842, 678.5049505]]) | |
self.face_template = self.face_template / (1024 // face_size) | |
# for estimation the 2D similarity transformation | |
self.similarity_trans = trans.SimilarityTransform() | |
self.all_landmarks_5 = [] | |
self.all_landmarks_68 = [] | |
self.affine_matrices = [] | |
self.inverse_affine_matrices = [] | |
self.cropped_faces = [] | |
self.restored_faces = [] | |
self.save_png = True | |
def init_dlib(self, detection_path, landmark5_path, landmark68_path): | |
"""Initialize the dlib detectors and predictors.""" | |
self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) | |
self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) | |
self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) | |
def free_dlib_gpu_memory(self): | |
del self.face_detector | |
del self.shape_predictor_5 | |
del self.shape_predictor_68 | |
def read_input_image(self, img_path): | |
# self.input_img is Numpy array, (h, w, c) with RGB order | |
self.input_img = dlib.load_rgb_image(img_path) | |
def detect_faces(self, img_path, upsample_num_times=1, only_keep_largest=False): | |
""" | |
Args: | |
img_path (str): Image path. | |
upsample_num_times (int): Upsamples the image before running the | |
face detector | |
Returns: | |
int: Number of detected faces. | |
""" | |
self.read_input_image(img_path) | |
det_faces = self.face_detector(self.input_img, upsample_num_times) | |
if len(det_faces) == 0: | |
print('No face detected. Try to increase upsample_num_times.') | |
else: | |
if only_keep_largest: | |
print('Detect several faces and only keep the largest.') | |
face_areas = [] | |
for i in range(len(det_faces)): | |
face_area = (det_faces[i].rect.right() - det_faces[i].rect.left()) * ( | |
det_faces[i].rect.bottom() - det_faces[i].rect.top()) | |
face_areas.append(face_area) | |
largest_idx = face_areas.index(max(face_areas)) | |
self.det_faces = [det_faces[largest_idx]] | |
else: | |
self.det_faces = det_faces | |
return len(self.det_faces) | |
def get_face_landmarks_5(self): | |
for face in self.det_faces: | |
shape = self.shape_predictor_5(self.input_img, face.rect) | |
landmark = np.array([[part.x, part.y] for part in shape.parts()]) | |
self.all_landmarks_5.append(landmark) | |
return len(self.all_landmarks_5) | |
def get_face_landmarks_68(self): | |
"""Get 68 densemarks for cropped images. | |
Should only have one face at most in the cropped image. | |
""" | |
num_detected_face = 0 | |
for idx, face in enumerate(self.cropped_faces): | |
# face detection | |
det_face = self.face_detector(face, 1) # TODO: can we remove it? | |
if len(det_face) == 0: | |
print(f'Cannot find faces in cropped image with index {idx}.') | |
self.all_landmarks_68.append(None) | |
else: | |
if len(det_face) > 1: | |
print('Detect several faces in the cropped face. Use the ' | |
' largest one. Note that it will also cause overlap ' | |
'during paste_faces_to_input_image.') | |
face_areas = [] | |
for i in range(len(det_face)): | |
face_area = (det_face[i].rect.right() - det_face[i].rect.left()) * ( | |
det_face[i].rect.bottom() - det_face[i].rect.top()) | |
face_areas.append(face_area) | |
largest_idx = face_areas.index(max(face_areas)) | |
face_rect = det_face[largest_idx].rect | |
else: | |
face_rect = det_face[0].rect | |
shape = self.shape_predictor_68(face, face_rect) | |
landmark = np.array([[part.x, part.y] for part in shape.parts()]) | |
self.all_landmarks_68.append(landmark) | |
num_detected_face += 1 | |
return num_detected_face | |
def warp_crop_faces(self, save_cropped_path=None, save_inverse_affine_path=None): | |
"""Get affine matrix, warp and cropped faces. | |
Also get inverse affine matrix for post-processing. | |
""" | |
for idx, landmark in enumerate(self.all_landmarks_5): | |
# use 5 landmarks to get affine matrix | |
self.similarity_trans.estimate(landmark, self.face_template) | |
affine_matrix = self.similarity_trans.params[0:2, :] | |
self.affine_matrices.append(affine_matrix) | |
# warp and crop faces | |
cropped_face = cv2.warpAffine(self.input_img, affine_matrix, self.face_size) | |
self.cropped_faces.append(cropped_face) | |
# save the cropped face | |
if save_cropped_path is not None: | |
path, ext = os.path.splitext(save_cropped_path) | |
if self.save_png: | |
save_path = f'{path}_{idx:02d}.png' | |
else: | |
save_path = f'{path}_{idx:02d}{ext}' | |
imwrite(cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) | |
# get inverse affine matrix | |
self.similarity_trans.estimate(self.face_template, landmark * self.upscale_factor) | |
inverse_affine = self.similarity_trans.params[0:2, :] | |
self.inverse_affine_matrices.append(inverse_affine) | |
# save inverse affine matrices | |
if save_inverse_affine_path is not None: | |
path, _ = os.path.splitext(save_inverse_affine_path) | |
save_path = f'{path}_{idx:02d}.pth' | |
torch.save(inverse_affine, save_path) | |
def add_restored_face(self, face): | |
self.restored_faces.append(face) | |
def paste_faces_to_input_image(self, save_path): | |
# operate in the BGR order | |
input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) | |
h, w, _ = input_img.shape | |
h_up, w_up = h * self.upscale_factor, w * self.upscale_factor | |
# simply resize the background | |
upsample_img = cv2.resize(input_img, (w_up, h_up)) | |
assert len(self.restored_faces) == len( | |
self.inverse_affine_matrices), ('length of restored_faces and affine_matrices are different.') | |
for restored_face, inverse_affine in zip(self.restored_faces, self.inverse_affine_matrices): | |
inv_restored = cv2.warpAffine(restored_face, inverse_affine, (w_up, h_up)) | |
mask = np.ones((*self.face_size, 3), dtype=np.float32) | |
inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) | |
# remove the black borders | |
inv_mask_erosion = cv2.erode(inv_mask, np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), | |
np.uint8)) | |
inv_restored_remove_border = inv_mask_erosion * inv_restored | |
total_face_area = np.sum(inv_mask_erosion) // 3 | |
# compute the fusion edge based on the area of face | |
w_edge = int(total_face_area**0.5) // 20 | |
erosion_radius = w_edge * 2 | |
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) | |
blur_size = w_edge * 2 | |
inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) | |
upsample_img = inv_soft_mask * inv_restored_remove_border + (1 - inv_soft_mask) * upsample_img | |
if self.save_png: | |
save_path = save_path.replace('.jpg', '.png').replace('.jpeg', '.png') | |
imwrite(upsample_img.astype(np.uint8), save_path) | |
def clean_all(self): | |
self.all_landmarks_5 = [] | |
self.all_landmarks_68 = [] | |
self.restored_faces = [] | |
self.affine_matrices = [] | |
self.cropped_faces = [] | |
self.inverse_affine_matrices = [] | |