zhiweili
commited on
Commit
•
012d7e8
1
Parent(s):
bf2ad73
only convert mask area to gray
Browse files- app_haircolor.py +2 -2
- inversion_run_adapter.py +0 -7
- segment_utils.py +27 -0
app_haircolor.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
|
6 |
from PIL import Image
|
7 |
from segment_utils import(
|
8 |
-
|
9 |
restore_result,
|
10 |
)
|
11 |
from enhance_utils import enhance_image
|
@@ -111,7 +111,7 @@ def create_demo() -> gr.Blocks:
|
|
111 |
generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
|
112 |
|
113 |
g_btn.click(
|
114 |
-
fn=
|
115 |
inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
|
116 |
outputs=[origin_area_image, croper],
|
117 |
).success(
|
|
|
5 |
|
6 |
from PIL import Image
|
7 |
from segment_utils import(
|
8 |
+
segment_image_with_gray,
|
9 |
restore_result,
|
10 |
)
|
11 |
from enhance_utils import enhance_image
|
|
|
111 |
generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
|
112 |
|
113 |
g_btn.click(
|
114 |
+
fn=segment_image_with_gray,
|
115 |
inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
|
116 |
outputs=[origin_area_image, croper],
|
117 |
).success(
|
inversion_run_adapter.py
CHANGED
@@ -177,9 +177,6 @@ def run(
|
|
177 |
adapter_conditioning_scale=conditioning_scale,
|
178 |
)
|
179 |
|
180 |
-
# convert to grayscale
|
181 |
-
input_image = convert_to_grayscale(input_image)
|
182 |
-
|
183 |
x_0_image = input_image
|
184 |
x_0 = encode_image(x_0_image, pipeline)
|
185 |
x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
|
@@ -283,7 +280,3 @@ def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=N
|
|
283 |
return timesteps, num_inference_steps
|
284 |
|
285 |
return timesteps, num_inference_steps - t_start
|
286 |
-
|
287 |
-
def convert_to_grayscale(pil_image: Image):
|
288 |
-
gray_image = pil_image.convert('L')
|
289 |
-
return Image.merge('RGB', (gray_image, gray_image, gray_image))
|
|
|
177 |
adapter_conditioning_scale=conditioning_scale,
|
178 |
)
|
179 |
|
|
|
|
|
|
|
180 |
x_0_image = input_image
|
181 |
x_0 = encode_image(x_0_image, pipeline)
|
182 |
x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
|
|
|
280 |
return timesteps, num_inference_steps
|
281 |
|
282 |
return timesteps, num_inference_steps - t_start
|
|
|
|
|
|
|
|
segment_utils.py
CHANGED
@@ -58,6 +58,33 @@ def segment_image(input_image, category, input_size, mask_expansion, mask_dilati
|
|
58 |
|
59 |
return origin_area_image, croper
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
def get_face_mask(category_mask_np, dilation=1):
|
62 |
face_skin_mask = category_mask_np == 3
|
63 |
if dilation > 0:
|
|
|
58 |
|
59 |
return origin_area_image, croper
|
60 |
|
61 |
+
def segment_image_with_gray(input_image, category, input_size, mask_expansion, mask_dilation):
|
62 |
+
mask_size = int(input_size)
|
63 |
+
mask_expansion = int(mask_expansion)
|
64 |
+
|
65 |
+
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
66 |
+
segmentation_result = segmenter.segment(image)
|
67 |
+
category_mask = segmentation_result.category_mask
|
68 |
+
category_mask_np = category_mask.numpy_view()
|
69 |
+
|
70 |
+
if category == "hair":
|
71 |
+
target_mask = get_hair_mask(category_mask_np, mask_dilation)
|
72 |
+
elif category == "clothes":
|
73 |
+
target_mask = get_clothes_mask(category_mask_np, mask_dilation)
|
74 |
+
elif category == "face":
|
75 |
+
target_mask = get_face_mask(category_mask_np, mask_dilation)
|
76 |
+
else:
|
77 |
+
target_mask = get_face_mask(category_mask_np, mask_dilation)
|
78 |
+
|
79 |
+
croper = Croper(input_image, target_mask, mask_size, mask_expansion)
|
80 |
+
croper.corp_mask_image()
|
81 |
+
origin_area_image = croper.resized_square_image
|
82 |
+
mask_image = croper.resized_square_mask_image
|
83 |
+
gray_area_image = origin_area_image.convert('L')
|
84 |
+
origin_area_image.paste(gray_area_image, (0, 0), mask_image)
|
85 |
+
|
86 |
+
return origin_area_image, croper
|
87 |
+
|
88 |
def get_face_mask(category_mask_np, dilation=1):
|
89 |
face_skin_mask = category_mask_np == 3
|
90 |
if dilation > 0:
|