zhiweili commited on
Commit
012d7e8
1 Parent(s): bf2ad73

only convert mask area to gray

Browse files
Files changed (3) hide show
  1. app_haircolor.py +2 -2
  2. inversion_run_adapter.py +0 -7
  3. segment_utils.py +27 -0
app_haircolor.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
 
6
  from PIL import Image
7
  from segment_utils import(
8
- segment_image,
9
  restore_result,
10
  )
11
  from enhance_utils import enhance_image
@@ -111,7 +111,7 @@ def create_demo() -> gr.Blocks:
111
  generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
112
 
113
  g_btn.click(
114
- fn=segment_image,
115
  inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
116
  outputs=[origin_area_image, croper],
117
  ).success(
 
5
 
6
  from PIL import Image
7
  from segment_utils import(
8
+ segment_image_with_gray,
9
  restore_result,
10
  )
11
  from enhance_utils import enhance_image
 
111
  generated_image = gr.Image(label="Generated Image", type="pil", interactive=False)
112
 
113
  g_btn.click(
114
+ fn=segment_image_with_gray,
115
  inputs=[input_image, category, generate_size, mask_expansion, mask_dilation],
116
  outputs=[origin_area_image, croper],
117
  ).success(
inversion_run_adapter.py CHANGED
@@ -177,9 +177,6 @@ def run(
177
  adapter_conditioning_scale=conditioning_scale,
178
  )
179
 
180
- # convert to grayscale
181
- input_image = convert_to_grayscale(input_image)
182
-
183
  x_0_image = input_image
184
  x_0 = encode_image(x_0_image, pipeline)
185
  x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
@@ -283,7 +280,3 @@ def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=N
283
  return timesteps, num_inference_steps
284
 
285
  return timesteps, num_inference_steps - t_start
286
-
287
- def convert_to_grayscale(pil_image: Image):
288
- gray_image = pil_image.convert('L')
289
- return Image.merge('RGB', (gray_image, gray_image, gray_image))
 
177
  adapter_conditioning_scale=conditioning_scale,
178
  )
179
 
 
 
 
180
  x_0_image = input_image
181
  x_0 = encode_image(x_0_image, pipeline)
182
  x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
 
280
  return timesteps, num_inference_steps
281
 
282
  return timesteps, num_inference_steps - t_start
 
 
 
 
segment_utils.py CHANGED
@@ -58,6 +58,33 @@ def segment_image(input_image, category, input_size, mask_expansion, mask_dilati
58
 
59
  return origin_area_image, croper
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def get_face_mask(category_mask_np, dilation=1):
62
  face_skin_mask = category_mask_np == 3
63
  if dilation > 0:
 
58
 
59
  return origin_area_image, croper
60
 
61
+ def segment_image_with_gray(input_image, category, input_size, mask_expansion, mask_dilation):
62
+ mask_size = int(input_size)
63
+ mask_expansion = int(mask_expansion)
64
+
65
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
66
+ segmentation_result = segmenter.segment(image)
67
+ category_mask = segmentation_result.category_mask
68
+ category_mask_np = category_mask.numpy_view()
69
+
70
+ if category == "hair":
71
+ target_mask = get_hair_mask(category_mask_np, mask_dilation)
72
+ elif category == "clothes":
73
+ target_mask = get_clothes_mask(category_mask_np, mask_dilation)
74
+ elif category == "face":
75
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
76
+ else:
77
+ target_mask = get_face_mask(category_mask_np, mask_dilation)
78
+
79
+ croper = Croper(input_image, target_mask, mask_size, mask_expansion)
80
+ croper.corp_mask_image()
81
+ origin_area_image = croper.resized_square_image
82
+ mask_image = croper.resized_square_mask_image
83
+ gray_area_image = origin_area_image.convert('L')
84
+ origin_area_image.paste(gray_area_image, (0, 0), mask_image)
85
+
86
+ return origin_area_image, croper
87
+
88
  def get_face_mask(category_mask_np, dilation=1):
89
  face_skin_mask = category_mask_np == 3
90
  if dilation > 0: