File size: 4,682 Bytes
8ae56d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8219169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ae56d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import mediapipe as mp

from PIL import Image
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from scipy.ndimage import binary_dilation
from croper import Croper

segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
base_options = python.BaseOptions(model_asset_path=segment_model)
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
segmenter = vision.ImageSegmenter.create_from_options(options)

def restore_result(croper, category, generated_image):
    square_length = croper.square_length
    generated_image = generated_image.resize((square_length, square_length))

    cropped_generated_image = generated_image.crop((croper.square_start_x, croper.square_start_y, croper.square_end_x, croper.square_end_y))
    cropped_square_mask_image = get_restore_mask_image(croper, category, cropped_generated_image)

    restored_image = croper.input_image.copy()
    restored_image.paste(cropped_generated_image, (croper.origin_start_x, croper.origin_start_y), cropped_square_mask_image)

    return restored_image

def segment_image(input_image, category, generate_size, mask_expansion, mask_dilation):
    mask_size = int(generate_size)
    mask_expansion = int(mask_expansion)

    image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
    segmentation_result = segmenter.segment(image)
    category_mask = segmentation_result.category_mask
    category_mask_np = category_mask.numpy_view()

    if category == "hair":
        target_mask = get_hair_mask(category_mask_np, mask_dilation)
    elif category == "clothes":
        target_mask = get_clothes_mask(category_mask_np, mask_dilation)
    elif category == "face":
        target_mask = get_face_mask(category_mask_np, mask_dilation)
    else:
        target_mask = get_face_mask(category_mask_np, mask_dilation)
    
    croper = Croper(input_image, target_mask, mask_size, mask_expansion)
    croper.corp_mask_image()
    origin_area_image = croper.resized_square_image

    return origin_area_image, croper

def segment_image_withmask(input_image, category, generate_size, mask_expansion, mask_dilation):
    mask_size = int(generate_size)
    mask_expansion = int(mask_expansion)

    image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
    segmentation_result = segmenter.segment(image)
    category_mask = segmentation_result.category_mask
    category_mask_np = category_mask.numpy_view()

    if category == "hair":
        target_mask = get_hair_mask(category_mask_np, mask_dilation)
    elif category == "clothes":
        target_mask = get_clothes_mask(category_mask_np, mask_dilation)
    elif category == "face":
        target_mask = get_face_mask(category_mask_np, mask_dilation)
    else:
        target_mask = get_face_mask(category_mask_np, mask_dilation)
    
    croper = Croper(input_image, target_mask, mask_size, mask_expansion)
    croper.corp_mask_image()
    origin_area_image = croper.resized_square_image
    mask_image = croper.resized_square_mask_image

    return origin_area_image, mask_image, croper

def get_face_mask(category_mask_np, dilation=1):
    face_skin_mask = category_mask_np == 3
    if dilation > 0:
        face_skin_mask = binary_dilation(face_skin_mask, iterations=dilation)

    return face_skin_mask

def get_clothes_mask(category_mask_np, dilation=1):
    body_skin_mask = category_mask_np == 2
    clothes_mask = category_mask_np == 4
    combined_mask = np.logical_or(body_skin_mask, clothes_mask)
    combined_mask = binary_dilation(combined_mask, iterations=4)
    if dilation > 0:
        combined_mask = binary_dilation(combined_mask, iterations=dilation)
    return combined_mask

def get_hair_mask(category_mask_np, dilation=1):
    hair_mask = category_mask_np == 1
    if dilation > 0:
        hair_mask = binary_dilation(hair_mask, iterations=dilation)
    return hair_mask

def get_restore_mask_image(croper, category, generated_image):
    image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(generated_image))
    segmentation_result = segmenter.segment(image)
    category_mask = segmentation_result.category_mask
    category_mask_np = category_mask.numpy_view()

    if category == "hair":
        target_mask = get_hair_mask(category_mask_np, 0)
    elif category == "clothes":
        target_mask = get_clothes_mask(category_mask_np, 0)
    elif category == "face":
        target_mask = get_face_mask(category_mask_np, 0)
    
    combined_mask = np.logical_or(target_mask, croper.corp_mask)
    mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
    return mask_image