zhiweili commited on
Commit
8ae56d4
1 Parent(s): f93286f

init commit

Browse files
app.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from app_img2img import create_demo as create_demo_face
4
+
5
+ with gr.Blocks(css="style.css") as demo:
6
+ with gr.Tabs():
7
+ with gr.Tab(label="Face"):
8
+ create_demo_face()
9
+
10
+ demo.launch()
app_img2img.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import time
4
+ import torch
5
+
6
+ from PIL import Image
7
+ from segment_utils import(
8
+ segment_image,
9
+ restore_result,
10
+ )
11
+ from diffusers import (
12
+ StableDiffusionXLImg2ImgPipeline
13
+ )
14
+
15
+ BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
16
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
17
+
18
+ DEFAULT_EDIT_PROMPT = "a beautiful hollywood woman,photo,detailed,8k,high quality,highly detailed,high resolution"
19
+ DEFAULT_NEGATIVE_PROMPT = "nude, nudity, nsfw, nipple, Bare-chested, palm hand, hands, fingers, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, cloned face, disfigured"
20
+
21
+ DEFAULT_CATEGORY = "face"
22
+
23
+ basepipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
24
+ BASE_MODEL,
25
+ torch_dtype=torch.float16,
26
+ variant="fp16",
27
+ use_safetensors=True,
28
+ )
29
+
30
+ basepipeline = basepipeline.to(DEVICE)
31
+
32
+
33
+ @spaces.GPU(duration=15)
34
+ def image_to_image(
35
+ input_image: Image,
36
+ edit_prompt: str,
37
+ seed: int,
38
+ num_steps: int,
39
+ guidance_scale: float,
40
+ ):
41
+ run_task_time = 0
42
+ time_cost_str = ''
43
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
44
+ generator = torch.Generator(device=DEVICE).manual_seed(seed)
45
+ generated_image = basepipeline(
46
+ generator=generator,
47
+ prompt=edit_prompt,
48
+ negative_prompt=DEFAULT_NEGATIVE_PROMPT,
49
+ image=input_image,
50
+ guidance_scale=guidance_scale,
51
+ num_inference_steps = num_steps,
52
+ ).images[0]
53
+
54
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
55
+
56
+ return generated_image, time_cost_str
57
+
58
+ def get_time_cost(run_task_time, time_cost_str):
59
+ now_time = int(time.time()*1000)
60
+ if run_task_time == 0:
61
+ time_cost_str = 'start'
62
+ else:
63
+ if time_cost_str != '':
64
+ time_cost_str += f'-->'
65
+ time_cost_str += f'{now_time - run_task_time}'
66
+ run_task_time = now_time
67
+ return run_task_time, time_cost_str
68
+
69
+ def create_demo() -> gr.Blocks:
70
+ with gr.Blocks() as demo:
71
+ croper = gr.State()
72
+ with gr.Row():
73
+ with gr.Column():
74
+ edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
75
+ generate_size = gr.Number(label="Generate Size", value=1024)
76
+ category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
77
+ with gr.Column():
78
+ num_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Num Steps")
79
+ guidance_scale = gr.Slider(minimum=0, maximum=30, value=15, step=0.5, label="Guidance Scale")
80
+ mask_expansion = gr.Number(label="Mask Expansion", value=300, visible=False)
81
+ with gr.Column():
82
+ mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
83
+ seed = gr.Number(label="Seed", value=8)
84
+ g_btn = gr.Button("Edit Image")
85
+
86
+ with gr.Row():
87
+ with gr.Column():
88
+ input_image = gr.Image(label="Input Image", type="pil")
89
+ with gr.Column():
90
+ restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
91
+ with gr.Column():
92
+ origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
93
+ generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
94
+ generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
95
+
96
+ g_btn.click(
97
+ fn=segment_image,
98
+ inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
99
+ outputs=[origin_area_image, croper],
100
+ ).success(
101
+ fn=image_to_image,
102
+ inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale],
103
+ outputs=[generated_image, generated_cost],
104
+ ).success(
105
+ fn=restore_result,
106
+ inputs=[croper, category, generated_image],
107
+ outputs=[restored_image],
108
+ )
109
+
110
+ return demo
checkpoints/selfie_multiclass_256x256.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
3
+ size 16371837
croper.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import numpy as np
3
+
4
+ from PIL import Image
5
+
6
+ class Croper:
7
+ def __init__(
8
+ self,
9
+ input_image: PIL.Image,
10
+ target_mask: np.ndarray,
11
+ mask_size: int = 256,
12
+ mask_expansion: int = 20,
13
+ ):
14
+ self.input_image = input_image
15
+ self.target_mask = target_mask
16
+ self.mask_size = mask_size
17
+ self.mask_expansion = mask_expansion
18
+
19
+ def corp_mask_image(self):
20
+ target_mask = self.target_mask
21
+ input_image = self.input_image
22
+ mask_expansion = self.mask_expansion
23
+ original_width, original_height = input_image.size
24
+ mask_indices = np.where(target_mask)
25
+ start_y = np.min(mask_indices[0])
26
+ end_y = np.max(mask_indices[0])
27
+ start_x = np.min(mask_indices[1])
28
+ end_x = np.max(mask_indices[1])
29
+ mask_height = end_y - start_y
30
+ mask_width = end_x - start_x
31
+ # choose the max side length
32
+ max_side_length = max(mask_height, mask_width)
33
+ # expand the mask area
34
+ height_diff = (max_side_length - mask_height) // 2
35
+ width_diff = (max_side_length - mask_width) // 2
36
+ start_y = start_y - mask_expansion - height_diff
37
+ if start_y < 0:
38
+ start_y = 0
39
+ end_y = end_y + mask_expansion + height_diff
40
+ if end_y > original_height:
41
+ end_y = original_height
42
+ start_x = start_x - mask_expansion - width_diff
43
+ if start_x < 0:
44
+ start_x = 0
45
+ end_x = end_x + mask_expansion + width_diff
46
+ if end_x > original_width:
47
+ end_x = original_width
48
+ expanded_height = end_y - start_y
49
+ expanded_width = end_x - start_x
50
+ expanded_max_side_length = max(expanded_height, expanded_width)
51
+ # calculate the crop area
52
+ crop_mask = target_mask[start_y:end_y, start_x:end_x]
53
+ crop_mask_start_y = (expanded_max_side_length - expanded_height) // 2
54
+ crop_mask_end_y = crop_mask_start_y + expanded_height
55
+ crop_mask_start_x = (expanded_max_side_length - expanded_width) // 2
56
+ crop_mask_end_x = crop_mask_start_x + expanded_width
57
+ # create a square mask
58
+ square_mask = np.zeros((expanded_max_side_length, expanded_max_side_length), dtype=target_mask.dtype)
59
+ square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
60
+ square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
61
+
62
+ crop_image = input_image.crop((start_x, start_y, end_x, end_y))
63
+ square_image = Image.new("RGB", (expanded_max_side_length, expanded_max_side_length))
64
+ square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
65
+
66
+ self.origin_start_x = start_x
67
+ self.origin_start_y = start_y
68
+ self.origin_end_x = end_x
69
+ self.origin_end_y = end_y
70
+
71
+ self.square_start_x = crop_mask_start_x
72
+ self.square_start_y = crop_mask_start_y
73
+ self.square_end_x = crop_mask_end_x
74
+ self.square_end_y = crop_mask_end_y
75
+
76
+ self.square_length = expanded_max_side_length
77
+ self.square_mask_image = square_mask_image
78
+ self.square_image = square_image
79
+ self.corp_mask = crop_mask
80
+
81
+ mask_size = self.mask_size
82
+ self.resized_square_mask_image = square_mask_image.resize((mask_size, mask_size))
83
+ self.resized_square_image = square_image.resize((mask_size, mask_size))
84
+
85
+ return self.resized_square_mask_image
86
+
87
+ def restore_result(self, generated_image):
88
+ square_length = self.square_length
89
+ generated_image = generated_image.resize((square_length, square_length))
90
+ square_mask_image = self.square_mask_image
91
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
92
+ cropped_square_mask_image = square_mask_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
93
+
94
+ restored_image = self.input_image.copy()
95
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y), cropped_square_mask_image)
96
+
97
+ return restored_image
98
+
99
+ def restore_result_v2(self, generated_image):
100
+ square_length = self.square_length
101
+ generated_image = generated_image.resize((square_length, square_length))
102
+ cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
103
+
104
+ restored_image = self.input_image.copy()
105
+ restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y))
106
+
107
+ return restored_image
108
+
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ diffusers
5
+ transformers
6
+ accelerate
7
+ mediapipe
8
+ spaces
segment_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import mediapipe as mp
3
+
4
+ from PIL import Image
5
+ from mediapipe.tasks import python
6
+ from mediapipe.tasks.python import vision
7
+ from scipy.ndimage import binary_dilation
8
+ from croper import Croper
9
+
10
+ segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
11
+ base_options = python.BaseOptions(model_asset_path=segment_model)
12
+ options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
13
+ segmenter = vision.ImageSegmenter.create_from_options(options)
14
+
15
+ def restore_result(croper, category, generated_image):
16
+ square_length = croper.square_length
17
+ generated_image = generated_image.resize((square_length, square_length))
18
+
19
+ cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
20
+ cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)
21
+
22
+ restored_image = croper.input_image.copy()
23
+ restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
24
+
25
+ return restored_image
26
+
27
+ def segment_image(input_image, category, generate_size, mask_expansion, mask_dilation):
28
+ mask_size = int(generate_size)
29
+ mask_expansion = int(mask_expansion)
30
+
31
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
32
+ segmentation_result = segmenter.segment(image)
33
+ category_mask = segmentation_result.category_mask
34
+ category_mask_np = category_mask.numpy_view()
35
+
36
+ if category == "hair":
37
+ target_mask = get_hair_mask(category_mask_np, mask_dilation)
38
+ elif category == "clothes":
39
+ target_mask = get_clothes_mask(category_mask_np, mask_dilation)
40
+ elif category == "face":
41
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
42
+ else:
43
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
44
+
45
+ croper = Croper(input_image, target_mask, mask_size, mask_expansion)
46
+ croper.corp_mask_image()
47
+ origin_area_image = croper.resized_square_image
48
+
49
+ return origin_area_image, croper
50
+
51
+ def get_face_mask(category_mask_np, dilation=1):
52
+ face_skin_mask = category_mask_np == 3
53
+ if dilation > 0:
54
+ face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation)
55
+
56
+ return face_skin_mask
57
+
58
+ def get_clothes_mask(category_mask_np, dilation=1):
59
+ body_skin_mask = category_mask_np == 2
60
+ clothes_mask = category_mask_np == 4
61
+ combined_mask = np.logical_or(body_skin_mask, clothes_mask)
62
+ combined_mask = binary_dilation(combined_mask, iterations=4)
63
+ if dilation > 0:
64
+ combined_mask = binary_dilation(combined_mask, iterations=dilation)
65
+ return combined_mask
66
+
67
+ def get_hair_mask(category_mask_np, dilation=1):
68
+ hair_mask = category_mask_np == 1
69
+ if dilation > 0:
70
+ hair_mask = binary_dilation(hair_mask, iterations=dilation)
71
+ return hair_mask
72
+
73
+ def get_restore_mask_image(croper, category, generated_image):
74
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image))
75
+ segmentation_result = segmenter.segment(image)
76
+ category_mask = segmentation_result.category_mask
77
+ category_mask_np = category_mask.numpy_view()
78
+
79
+ if category == "hair":
80
+ target_mask = get_hair_mask(category_mask_np, 0)
81
+ elif category == "clothes":
82
+ target_mask = get_clothes_mask(category_mask_np, 0)
83
+ elif category == "face":
84
+ target_mask = get_face_mask(category_mask_np, 0)
85
+
86
+ combined_mask = np.logical_or(target_mask, croper.corp_mask)
87
+ mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
88
+ return mask_image