Spaces:
Running
on
Zero
Running
on
Zero
zhiweili
commited on
Commit
•
8ae56d4
1
Parent(s):
f93286f
init commit
Browse files- app.py +10 -0
- app_img2img.py +110 -0
- checkpoints/selfie_multiclass_256x256.tflite +3 -0
- croper.py +108 -0
- requirements.txt +8 -0
- segment_utils.py +88 -0
app.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from app_img2img import create_demo as create_demo_face
|
4 |
+
|
5 |
+
with gr.Blocks(css="style.css") as demo:
|
6 |
+
with gr.Tabs():
|
7 |
+
with gr.Tab(label="Face"):
|
8 |
+
create_demo_face()
|
9 |
+
|
10 |
+
demo.launch()
|
app_img2img.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import time
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
from segment_utils import(
|
8 |
+
segment_image,
|
9 |
+
restore_result,
|
10 |
+
)
|
11 |
+
from diffusers import (
|
12 |
+
StableDiffusionXLImg2ImgPipeline
|
13 |
+
)
|
14 |
+
|
15 |
+
BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0"
|
16 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
17 |
+
|
18 |
+
DEFAULT_EDIT_PROMPT = "a beautiful hollywood woman,photo,detailed,8k,high quality,highly detailed,high resolution"
|
19 |
+
DEFAULT_NEGATIVE_PROMPT = "nude, nudity, nsfw, nipple, Bare-chested, palm hand, hands, fingers, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, cloned face, disfigured"
|
20 |
+
|
21 |
+
DEFAULT_CATEGORY = "face"
|
22 |
+
|
23 |
+
basepipeline = StableDiffusionXLImg2ImgPipeline.from_pretrained(
|
24 |
+
BASE_MODEL,
|
25 |
+
torch_dtype=torch.float16,
|
26 |
+
variant="fp16",
|
27 |
+
use_safetensors=True,
|
28 |
+
)
|
29 |
+
|
30 |
+
basepipeline = basepipeline.to(DEVICE)
|
31 |
+
|
32 |
+
|
33 |
+
@spaces.GPU(duration=15)
|
34 |
+
def image_to_image(
|
35 |
+
input_image: Image,
|
36 |
+
edit_prompt: str,
|
37 |
+
seed: int,
|
38 |
+
num_steps: int,
|
39 |
+
guidance_scale: float,
|
40 |
+
):
|
41 |
+
run_task_time = 0
|
42 |
+
time_cost_str = ''
|
43 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
44 |
+
generator = torch.Generator(device=DEVICE).manual_seed(seed)
|
45 |
+
generated_image = basepipeline(
|
46 |
+
generator=generator,
|
47 |
+
prompt=edit_prompt,
|
48 |
+
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
49 |
+
image=input_image,
|
50 |
+
guidance_scale=guidance_scale,
|
51 |
+
num_inference_steps = num_steps,
|
52 |
+
).images[0]
|
53 |
+
|
54 |
+
run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
|
55 |
+
|
56 |
+
return generated_image, time_cost_str
|
57 |
+
|
58 |
+
def get_time_cost(run_task_time, time_cost_str):
|
59 |
+
now_time = int(time.time()*1000)
|
60 |
+
if run_task_time == 0:
|
61 |
+
time_cost_str = 'start'
|
62 |
+
else:
|
63 |
+
if time_cost_str != '':
|
64 |
+
time_cost_str += f'-->'
|
65 |
+
time_cost_str += f'{now_time - run_task_time}'
|
66 |
+
run_task_time = now_time
|
67 |
+
return run_task_time, time_cost_str
|
68 |
+
|
69 |
+
def create_demo() -> gr.Blocks:
|
70 |
+
with gr.Blocks() as demo:
|
71 |
+
croper = gr.State()
|
72 |
+
with gr.Row():
|
73 |
+
with gr.Column():
|
74 |
+
edit_prompt = gr.Textbox(lines=1, label="Edit Prompt", value=DEFAULT_EDIT_PROMPT)
|
75 |
+
generate_size = gr.Number(label="Generate Size", value=1024)
|
76 |
+
category = gr.Textbox(label="Category", value=DEFAULT_CATEGORY, visible=False)
|
77 |
+
with gr.Column():
|
78 |
+
num_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Num Steps")
|
79 |
+
guidance_scale = gr.Slider(minimum=0, maximum=30, value=15, step=0.5, label="Guidance Scale")
|
80 |
+
mask_expansion = gr.Number(label="Mask Expansion", value=300, visible=False)
|
81 |
+
with gr.Column():
|
82 |
+
mask_dilation = gr.Slider(minimum=0, maximum=10, value=2, step=1, label="Mask Dilation")
|
83 |
+
seed = gr.Number(label="Seed", value=8)
|
84 |
+
g_btn = gr.Button("Edit Image")
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
with gr.Column():
|
88 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
89 |
+
with gr.Column():
|
90 |
+
restored_image = gr.Image(label="Restored Image", type="pil", interactive=False)
|
91 |
+
with gr.Column():
|
92 |
+
origin_area_image = gr.Image(label="Origin Area Image", type="pil", interactive=False)
|
93 |
+
generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
|
94 |
+
generated_cost = gr.Textbox(label="Time cost by step (ms):", visible=True, interactive=False)
|
95 |
+
|
96 |
+
g_btn.click(
|
97 |
+
fn=segment_image,
|
98 |
+
inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
|
99 |
+
outputs=[origin_area_image, croper],
|
100 |
+
).success(
|
101 |
+
fn=image_to_image,
|
102 |
+
inputs=[origin_area_image, edit_prompt,seed, num_steps, guidance_scale],
|
103 |
+
outputs=[generated_image, generated_cost],
|
104 |
+
).success(
|
105 |
+
fn=restore_result,
|
106 |
+
inputs=[croper, category, generated_image],
|
107 |
+
outputs=[restored_image],
|
108 |
+
)
|
109 |
+
|
110 |
+
return demo
|
checkpoints/selfie_multiclass_256x256.tflite
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
|
3 |
+
size 16371837
|
croper.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import PIL
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
class Croper:
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
input_image: PIL.Image,
|
10 |
+
target_mask: np.ndarray,
|
11 |
+
mask_size: int = 256,
|
12 |
+
mask_expansion: int = 20,
|
13 |
+
):
|
14 |
+
self.input_image = input_image
|
15 |
+
self.target_mask = target_mask
|
16 |
+
self.mask_size = mask_size
|
17 |
+
self.mask_expansion = mask_expansion
|
18 |
+
|
19 |
+
def corp_mask_image(self):
|
20 |
+
target_mask = self.target_mask
|
21 |
+
input_image = self.input_image
|
22 |
+
mask_expansion = self.mask_expansion
|
23 |
+
original_width, original_height = input_image.size
|
24 |
+
mask_indices = np.where(target_mask)
|
25 |
+
start_y = np.min(mask_indices[0])
|
26 |
+
end_y = np.max(mask_indices[0])
|
27 |
+
start_x = np.min(mask_indices[1])
|
28 |
+
end_x = np.max(mask_indices[1])
|
29 |
+
mask_height = end_y - start_y
|
30 |
+
mask_width = end_x - start_x
|
31 |
+
# choose the max side length
|
32 |
+
max_side_length = max(mask_height, mask_width)
|
33 |
+
# expand the mask area
|
34 |
+
height_diff = (max_side_length - mask_height) // 2
|
35 |
+
width_diff = (max_side_length - mask_width) // 2
|
36 |
+
start_y = start_y - mask_expansion - height_diff
|
37 |
+
if start_y < 0:
|
38 |
+
start_y = 0
|
39 |
+
end_y = end_y + mask_expansion + height_diff
|
40 |
+
if end_y > original_height:
|
41 |
+
end_y = original_height
|
42 |
+
start_x = start_x - mask_expansion - width_diff
|
43 |
+
if start_x < 0:
|
44 |
+
start_x = 0
|
45 |
+
end_x = end_x + mask_expansion + width_diff
|
46 |
+
if end_x > original_width:
|
47 |
+
end_x = original_width
|
48 |
+
expanded_height = end_y - start_y
|
49 |
+
expanded_width = end_x - start_x
|
50 |
+
expanded_max_side_length = max(expanded_height, expanded_width)
|
51 |
+
# calculate the crop area
|
52 |
+
crop_mask = target_mask[start_y:end_y, start_x:end_x]
|
53 |
+
crop_mask_start_y = (expanded_max_side_length - expanded_height) // 2
|
54 |
+
crop_mask_end_y = crop_mask_start_y + expanded_height
|
55 |
+
crop_mask_start_x = (expanded_max_side_length - expanded_width) // 2
|
56 |
+
crop_mask_end_x = crop_mask_start_x + expanded_width
|
57 |
+
# create a square mask
|
58 |
+
square_mask = np.zeros((expanded_max_side_length, expanded_max_side_length), dtype=target_mask.dtype)
|
59 |
+
square_mask[crop_mask_start_y:crop_mask_end_y, crop_mask_start_x:crop_mask_end_x] = crop_mask
|
60 |
+
square_mask_image = Image.fromarray((square_mask * 255).astype(np.uint8))
|
61 |
+
|
62 |
+
crop_image = input_image.crop((start_x, start_y, end_x, end_y))
|
63 |
+
square_image = Image.new("RGB", (expanded_max_side_length, expanded_max_side_length))
|
64 |
+
square_image.paste(crop_image, (crop_mask_start_x, crop_mask_start_y))
|
65 |
+
|
66 |
+
self.origin_start_x = start_x
|
67 |
+
self.origin_start_y = start_y
|
68 |
+
self.origin_end_x = end_x
|
69 |
+
self.origin_end_y = end_y
|
70 |
+
|
71 |
+
self.square_start_x = crop_mask_start_x
|
72 |
+
self.square_start_y = crop_mask_start_y
|
73 |
+
self.square_end_x = crop_mask_end_x
|
74 |
+
self.square_end_y = crop_mask_end_y
|
75 |
+
|
76 |
+
self.square_length = expanded_max_side_length
|
77 |
+
self.square_mask_image = square_mask_image
|
78 |
+
self.square_image = square_image
|
79 |
+
self.corp_mask = crop_mask
|
80 |
+
|
81 |
+
mask_size = self.mask_size
|
82 |
+
self.resized_square_mask_image = square_mask_image.resize((mask_size, mask_size))
|
83 |
+
self.resized_square_image = square_image.resize((mask_size, mask_size))
|
84 |
+
|
85 |
+
return self.resized_square_mask_image
|
86 |
+
|
87 |
+
def restore_result(self, generated_image):
|
88 |
+
square_length = self.square_length
|
89 |
+
generated_image = generated_image.resize((square_length, square_length))
|
90 |
+
square_mask_image = self.square_mask_image
|
91 |
+
cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
|
92 |
+
cropped_square_mask_image = square_mask_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
|
93 |
+
|
94 |
+
restored_image = self.input_image.copy()
|
95 |
+
restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y), cropped_square_mask_image)
|
96 |
+
|
97 |
+
return restored_image
|
98 |
+
|
99 |
+
def restore_result_v2(self, generated_image):
|
100 |
+
square_length = self.square_length
|
101 |
+
generated_image = generated_image.resize((square_length, square_length))
|
102 |
+
cropped_generated_image = generated_image.crop((self.square_start_x, self.square_start_y, self.square_end_x, self.square_end_y))
|
103 |
+
|
104 |
+
restored_image = self.input_image.copy()
|
105 |
+
restored_image.paste(cropped_generated_image, (self.origin_start_x, self.origin_start_y))
|
106 |
+
|
107 |
+
return restored_image
|
108 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
diffusers
|
5 |
+
transformers
|
6 |
+
accelerate
|
7 |
+
mediapipe
|
8 |
+
spaces
|
segment_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import mediapipe as mp
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from mediapipe.tasks import python
|
6 |
+
from mediapipe.tasks.python import vision
|
7 |
+
from scipy.ndimage import binary_dilation
|
8 |
+
from croper import Croper
|
9 |
+
|
10 |
+
segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
|
11 |
+
base_options = python.BaseOptions(model_asset_path=segment_model)
|
12 |
+
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
13 |
+
segmenter = vision.ImageSegmenter.create_from_options(options)
|
14 |
+
|
15 |
+
def restore_result(croper, category, generated_image):
|
16 |
+
square_length = croper.square_length
|
17 |
+
generated_image = generated_image.resize((square_length, square_length))
|
18 |
+
|
19 |
+
cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
|
20 |
+
cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)
|
21 |
+
|
22 |
+
restored_image = croper.input_image.copy()
|
23 |
+
restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)
|
24 |
+
|
25 |
+
return restored_image
|
26 |
+
|
27 |
+
def segment_image(input_image, category, generate_size, mask_expansion, mask_dilation):
|
28 |
+
mask_size = int(generate_size)
|
29 |
+
mask_expansion = int(mask_expansion)
|
30 |
+
|
31 |
+
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
32 |
+
segmentation_result = segmenter.segment(image)
|
33 |
+
category_mask = segmentation_result.category_mask
|
34 |
+
category_mask_np = category_mask.numpy_view()
|
35 |
+
|
36 |
+
if category == "hair":
|
37 |
+
target_mask = get_hair_mask(category_mask_np, mask_dilation)
|
38 |
+
elif category == "clothes":
|
39 |
+
target_mask = get_clothes_mask(category_mask_np, mask_dilation)
|
40 |
+
elif category == "face":
|
41 |
+
target_mask = get_face_mask(category_mask_np, mask_dilation)
|
42 |
+
else:
|
43 |
+
target_mask = get_face_mask(category_mask_np, mask_dilation)
|
44 |
+
|
45 |
+
croper = Croper(input_image, target_mask, mask_size, mask_expansion)
|
46 |
+
croper.corp_mask_image()
|
47 |
+
origin_area_image = croper.resized_square_image
|
48 |
+
|
49 |
+
return origin_area_image, croper
|
50 |
+
|
51 |
+
def get_face_mask(category_mask_np, dilation=1):
|
52 |
+
face_skin_mask = category_mask_np == 3
|
53 |
+
if dilation > 0:
|
54 |
+
face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation)
|
55 |
+
|
56 |
+
return face_skin_mask
|
57 |
+
|
58 |
+
def get_clothes_mask(category_mask_np, dilation=1):
|
59 |
+
body_skin_mask = category_mask_np == 2
|
60 |
+
clothes_mask = category_mask_np == 4
|
61 |
+
combined_mask = np.logical_or(body_skin_mask, clothes_mask)
|
62 |
+
combined_mask = binary_dilation(combined_mask, iterations=4)
|
63 |
+
if dilation > 0:
|
64 |
+
combined_mask = binary_dilation(combined_mask, iterations=dilation)
|
65 |
+
return combined_mask
|
66 |
+
|
67 |
+
def get_hair_mask(category_mask_np, dilation=1):
|
68 |
+
hair_mask = category_mask_np == 1
|
69 |
+
if dilation > 0:
|
70 |
+
hair_mask = binary_dilation(hair_mask, iterations=dilation)
|
71 |
+
return hair_mask
|
72 |
+
|
73 |
+
def get_restore_mask_image(croper, category, generated_image):
|
74 |
+
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image))
|
75 |
+
segmentation_result = segmenter.segment(image)
|
76 |
+
category_mask = segmentation_result.category_mask
|
77 |
+
category_mask_np = category_mask.numpy_view()
|
78 |
+
|
79 |
+
if category == "hair":
|
80 |
+
target_mask = get_hair_mask(category_mask_np, 0)
|
81 |
+
elif category == "clothes":
|
82 |
+
target_mask = get_clothes_mask(category_mask_np, 0)
|
83 |
+
elif category == "face":
|
84 |
+
target_mask = get_face_mask(category_mask_np, 0)
|
85 |
+
|
86 |
+
combined_mask = np.logical_or(target_mask, croper.corp_mask)
|
87 |
+
mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
|
88 |
+
return mask_image
|