zhiweili commited on
Commit
e99c825
1 Parent(s): 28e87e1

add app_gfp

Browse files
Files changed (6) hide show
  1. README.md +1 -0
  2. app.py +10 -0
  3. app_gfp.py +109 -0
  4. croper.py +108 -0
  5. requirements.txt +11 -0
  6. segment_utils.py +88 -0
README.md CHANGED
@@ -10,4 +10,5 @@ pinned: false
10
  license: mit
11
  ---
12
 
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  license: mit
11
  ---
12
 
13
+ Modified from: https://huggingface.co/spaces/turboedit/turbo_edit
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from app_base import create_demo as create_demo_face
4
+
5
+ with gr.Blocks(css="style.css") as demo:
6
+ with gr.Tabs():
7
+ with gr.Tab(label="Face"):
8
+ create_demo_face()
9
+
10
+ demo.launch()
app_gfp.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import spaces
4
+ import cv2
5
+ import gradio as gr
6
+
7
+ from gfpgan.utils import GFPGANer
8
+
9
+ os.system("pip freeze")
10
+ # download weights
11
+ if not os.path.exists('GFPGANv1.2.pth'):
12
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth -P .")
13
+ if not os.path.exists('GFPGANv1.3.pth'):
14
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth -P .")
15
+ if not os.path.exists('GFPGANv1.4.pth'):
16
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -P .")
17
+ if not os.path.exists('RestoreFormer.pth'):
18
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth -P .")
19
+ if not os.path.exists('CodeFormer.pth'):
20
+ os.system("wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth -P .")
21
+
22
+ @spaces.GPU(duration=10)
23
+ def enhance(
24
+ img_path:str,
25
+ version:str='1.4',
26
+ scale:int=2,
27
+ ):
28
+ run_task_time = 0
29
+ time_cost_str = ''
30
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
31
+ extension = os.path.splitext(os.path.basename(img_path))[1]
32
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
33
+ if len(img.shape) == 3 and img.shape[2] == 4:
34
+ img_mode = 'RGBA'
35
+ elif len(img.shape) == 2: # for gray inputs
36
+ img_mode = None
37
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
38
+ else:
39
+ img_mode = None
40
+
41
+ h, w = img.shape[0:2]
42
+ if h < 300:
43
+ img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4)
44
+
45
+ if version == 'v1.2':
46
+ face_enhancer = GFPGANer(model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2)
47
+ elif version == 'v1.3':
48
+ face_enhancer = GFPGANer(model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2)
49
+ elif version == 'v1.4':
50
+ face_enhancer = GFPGANer(model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2)
51
+ elif version == 'RestoreFormer':
52
+ face_enhancer = GFPGANer(model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2)
53
+ elif version == 'CodeFormer':
54
+ face_enhancer = GFPGANer(model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2)
55
+ elif version == 'RealESR-General-x4v3':
56
+ face_enhancer = GFPGANer(model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2)
57
+
58
+ _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=True, paste_back=True)
59
+ if scale != 2:
60
+ interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4
61
+ h, w = img.shape[0:2]
62
+ output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation)
63
+
64
+ if img_mode == 'RGBA': # RGBA images should be saved in png format
65
+ extension = 'png'
66
+ else:
67
+ extension = 'jpg'
68
+ save_path = f'output/out.{extension}'
69
+ cv2.imwrite(save_path, output)
70
+
71
+ output = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
72
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
73
+ return output, save_path, time_cost_str
74
+
75
+
76
+ def get_time_cost(run_task_time, time_cost_str):
77
+ now_time = int(time.time()*1000)
78
+ if run_task_time == 0:
79
+ time_cost_str = 'start'
80
+ else:
81
+ if time_cost_str != '':
82
+ time_cost_str += f'-->'
83
+ time_cost_str += f'{now_time - run_task_time}'
84
+ run_task_time = now_time
85
+ return run_task_time, time_cost_str
86
+
87
+ def create_demo() -> gr.Blocks:
88
+ with gr.Blocks() as demo:
89
+ with gr.Row():
90
+ with gr.Column():
91
+ version = gr.Radio(['v1.2', 'v1.3', 'v1.4', 'RestoreFormer','CodeFormer','RealESR-General-x4v3'], type="value", default='v1.4', label='version')
92
+ scale = gr.Number(label="Rescaling factor", default=2)
93
+ with gr.Column():
94
+ g_btn = gr.Button(label="Enhance")
95
+ with gr.Row():
96
+ with gr.Column():
97
+ input_image = gr.Image(label="Input Image", type="filepath")
98
+ with gr.Column():
99
+ restored_image = gr.Image(label="Restored Image", type="numpy", interactive=False)
100
+ download_path = gr.File(label="Download the output image", interactive=False)
101
+ restored_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
102
+
103
+ g_btn.click(
104
+ fn=enhance,
105
+ inputs=[input_image, version, scale],
106
+ outputs=[restored_image, download_path, restored_cost],
107
+ )
108
+
109
+ return demo
croper.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+
4
+ from PIL import Image
5
+
6
+ class Croper:
7
+ def __init__(
8
+ self,
9
+ input_image: PIL.Image,
10
+ target_mask: np.ndarray,
11
+ mask_size: int = 256,
12
+ mask_expansion: int = 20,
13
+ ):
14
+ self.input_image = input_image
15
+ self.target_mask = target_mask
16
+ self.mask_size = mask_size
17
+ self.mask_expansion = mask_expansion
18
+
19
+ def corp_mask_image(self):
20
+ target_mask = self.target_mask
21
+ input_image = self.input_image
22
+ mask_expansion = self.mask_expansion
23
+ original_width, original_height = input_image.size
24
+ mask_indices = np.where(target_mask)
25
+ start_y = np.min(mask_indices[0])
26
+ end_y = np.max(mask_indices[0])
27
+ start_x = np.min(mask_indices[1])
28
+ end_x = np.max(mask_indices[1])
29
+ mask_height = end_y - start_y
30
+ mask_width = end_x - start_x
31
+ # choose the max side length
32
+ max_side_length = max(mask_height, mask_width)
33
+ # expand the mask area
34
+ height_diff = (max_side_length - mask_height) // 2
35
+ width_diff = (max_side_length - mask_width) // 2
36
+ start_y = start_y - mask_expansion - height_diff
37
+ if start_y < 0:
38
+ start_y = 0
39
+ end_y = end_y + mask_expansion + height_diff
40
+ if end_y > original_height:
41
+ end_y = original_height
42
+ start_x = start_x - mask_expansion - width_diff
43
+ if start_x < 0:
44
+ start_x = 0
45
+ end_x = end_x + mask_expansion + width_diff
46
+ if end_x > original_width:
47
+ end_x = original_width
48
+ expanded_height = end_y - start_y
49
+ expanded_width = end_x - start_x
50
+ expanded_max_side_length = max(expanded_height, expanded_width)
51
+ # calculate the crop area
52
+ crop_mask = target_mask[start_y:end_y, start_x:end_x]
53
+ crop_mask_start_y = (expanded_max_side_length - expanded_height) // 2
54
+ crop_mask_end_y = crop_mask_start_y + expanded_height
55
+ crop_mask_start_x = (expanded_max_side_length - expanded_width) // 2
56
+ crop_mask_end_x = crop_mask_start_x + expanded_width
57
+ # create a square mask
58
+ square_mask = np.zeros((expanded_max_side_length, expanded_max_side_length), dtype=target_mask.dtype)
59
+ square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
60
+ square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
61
+
62
+ crop_image = input_image.crop((start_x, start_y, end_x, end_y))
63
+ square_image = Image.new("RGB", (expanded_max_side_length, expanded_max_side_length))
64
+ square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
65
+
66
+ self.origin_start_x = start_x
67
+ self.origin_start_y = start_y
68
+ self.origin_end_x = end_x
69
+ self.origin_end_y = end_y
70
+
71
+ self.square_start_x = crop_mask_start_x
72
+ self.square_start_y = crop_mask_start_y
73
+ self.square_end_x = crop_mask_end_x
74
+ self.square_end_y = crop_mask_end_y
75
+
76
+ self.square_length = expanded_max_side_length
77
+ self.square_mask_image = square_mask_image
78
+ self.square_image = square_image
79
+ self.corp_mask = crop_mask
80
+
81
+ mask_size = self.mask_size
82
+ self.resized_square_mask_image = square_mask_image.resize((mask_size, mask_size))
83
+ self.resized_square_image = square_image.resize((mask_size, mask_size))
84
+
85
+ return self.resized_square_mask_image
86
+
87
+ def restore_result(self, generated_image):
88
+ square_length = self.square_length
89
+ generated_image = generated_image.resize((square_length, square_length))
90
+ square_mask_image = self.square_mask_image
91
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
92
+ cropped_square_mask_image = square_mask_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
93
+
94
+ restored_image = self.input_image.copy()
95
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y), cropped_square_mask_image)
96
+
97
+ return restored_image
98
+
99
+ def restore_result_v2(self, generated_image):
100
+ square_length = self.square_length
101
+ generated_image = generated_image.resize((square_length, square_length))
102
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
103
+
104
+ restored_image = self.input_image.copy()
105
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y))
106
+
107
+ return restored_image
108
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ml-collections
2
+ gradio
3
+ torch
4
+ diffusers
5
+ transformers
6
+ accelerate
7
+ mediapipe
8
+ spaces
9
+ sentencepiece
10
+ compel
11
+ gfpgan
segment_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import mediapipe as mp
3
+
4
+ from PIL import Image
5
+ from mediapipe.tasks import python
6
+ from mediapipe.tasks.python import vision
7
+ from scipy.ndimage import binary_dilation
8
+ from croper import Croper
9
+
10
+ segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
11
+ base_options = python.BaseOptions(model_asset_path=segment_model)
12
+ options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
13
+ segmenter = vision.ImageSegmenter.create_from_options(options)
14
+
15
+ def restore_result(croper, category, generated_image):
16
+ square_length = croper.square_length
17
+ generated_image = generated_image.resize((square_length, square_length))
18
+
19
+ cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
20
+ cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)
21
+
22
+ restored_image = croper.input_image.copy()
23
+ restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
24
+
25
+ return restored_image
26
+
27
+ def segment_image(input_image, category, generate_size, mask_expansion, mask_dilation):
28
+ mask_size = int(generate_size)
29
+ mask_expansion = int(mask_expansion)
30
+
31
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
32
+ segmentation_result = segmenter.segment(image)
33
+ category_mask = segmentation_result.category_mask
34
+ category_mask_np = category_mask.numpy_view()
35
+
36
+ if category == "hair":
37
+ target_mask = get_hair_mask(category_mask_np, mask_dilation)
38
+ elif category == "clothes":
39
+ target_mask = get_clothes_mask(category_mask_np, mask_dilation)
40
+ elif category == "face":
41
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
42
+ else:
43
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
44
+
45
+ croper = Croper(input_image, target_mask, mask_size, mask_expansion)
46
+ croper.corp_mask_image()
47
+ origin_area_image = croper.resized_square_image
48
+
49
+ return origin_area_image, croper
50
+
51
+ def get_face_mask(category_mask_np, dilation=1):
52
+ face_skin_mask = category_mask_np == 3
53
+ if dilation > 0:
54
+ face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation)
55
+
56
+ return face_skin_mask
57
+
58
+ def get_clothes_mask(category_mask_np, dilation=1):
59
+ body_skin_mask = category_mask_np == 2
60
+ clothes_mask = category_mask_np == 4
61
+ combined_mask = np.logical_or(body_skin_mask, clothes_mask)
62
+ combined_mask = binary_dilation(combined_mask, iterations=4)
63
+ if dilation > 0:
64
+ combined_mask = binary_dilation(combined_mask, iterations=dilation)
65
+ return combined_mask
66
+
67
+ def get_hair_mask(category_mask_np, dilation=1):
68
+ hair_mask = category_mask_np == 1
69
+ if dilation > 0:
70
+ hair_mask = binary_dilation(hair_mask, iterations=dilation)
71
+ return hair_mask
72
+
73
+ def get_restore_mask_image(croper, category, generated_image):
74
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image))
75
+ segmentation_result = segmenter.segment(image)
76
+ category_mask = segmentation_result.category_mask
77
+ category_mask_np = category_mask.numpy_view()
78
+
79
+ if category == "hair":
80
+ target_mask = get_hair_mask(category_mask_np, 0)
81
+ elif category == "clothes":
82
+ target_mask = get_clothes_mask(category_mask_np, 0)
83
+ elif category == "face":
84
+ target_mask = get_face_mask(category_mask_np, 0)
85
+
86
+ combined_mask = np.logical_or(target_mask, croper.corp_mask)
87
+ mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
88
+ return mask_image