File size: 2,166 Bytes
afff12f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470c68c
 
afff12f
 
fd93f98
afff12f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import gradio as gr
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.python._framework_bindings import image as image_module
_Image = image_module.Image
from mediapipe.python._framework_bindings import image_frame
_ImageFormat = image_frame.ImageFormat

# Constants for colors
BG_COLOR = (0, 0, 0, 255)  # gray with full opacity
MASK_COLOR = (255, 255, 255, 255)  # white with full opacity

# Create the options that will be used for ImageSegmenter
base_options = python.BaseOptions(model_asset_path='emirhan.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options,
                                       output_category_mask=True)

# Function to segment hair and generate mask
def segment_hair(image):
    rgba_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
    rgba_image[:, :, 3] = 0  # Set alpha channel to empty

    # Create MP Image object from numpy array
    mp_image = _Image(image_format=_ImageFormat.SRGBA, data=rgba_image)

    # Create the image segmenter
    with vision.ImageSegmenter.create_from_options(options) as segmenter:
        # Retrieve the masks for the segmented image
        segmentation_result = segmenter.segment(mp_image)
        category_mask = segmentation_result.category_mask

        # Generate solid color images for showing the output segmentation mask.
        image_data = mp_image.numpy_view()
        fg_image = np.zeros(image_data.shape, dtype=np.uint8)
        fg_image[:] = MASK_COLOR
        bg_image = np.zeros(image_data.shape, dtype=np.uint8)
        bg_image[:] = BG_COLOR

        condition = np.stack((category_mask.numpy_view(),) * 4, axis=-1) > 0.2
        output_image = np.where(condition, fg_image, bg_image)

        return cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)

# Gradio interface
iface = gr.Interface(
    fn=segment_hair,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Image(type="numpy"),
    title="Hair Segmentation",
    description="Upload an image to segment the hair and generate a mask.",
    examples=["example.jpeg"]
)

if __name__ == "__main__":
    iface.launch()