segment-hair / app.py
emirhanno's picture
bugfix
fd93f98
raw
history blame contribute delete
No virus
2.17 kB
import gradio as gr
import cv2
import numpy as np
import mediapipe as mp
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from mediapipe.python._framework_bindings import image as image_module
_Image = image_module.Image
from mediapipe.python._framework_bindings import image_frame
_ImageFormat = image_frame.ImageFormat
# Constants for colors
BG_COLOR = (0, 0, 0, 255) # gray with full opacity
MASK_COLOR = (255, 255, 255, 255) # white with full opacity
# Create the options that will be used for ImageSegmenter
base_options = python.BaseOptions(model_asset_path='emirhan.tflite')
options = vision.ImageSegmenterOptions(base_options=base_options,
output_category_mask=True)
# Function to segment hair and generate mask
def segment_hair(image):
rgba_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
rgba_image[:, :, 3] = 0 # Set alpha channel to empty
# Create MP Image object from numpy array
mp_image = _Image(image_format=_ImageFormat.SRGBA, data=rgba_image)
# Create the image segmenter
with vision.ImageSegmenter.create_from_options(options) as segmenter:
# Retrieve the masks for the segmented image
segmentation_result = segmenter.segment(mp_image)
category_mask = segmentation_result.category_mask
# Generate solid color images for showing the output segmentation mask.
image_data = mp_image.numpy_view()
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
fg_image[:] = MASK_COLOR
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
bg_image[:] = BG_COLOR
condition = np.stack((category_mask.numpy_view(),) * 4, axis=-1) > 0.2
output_image = np.where(condition, fg_image, bg_image)
return cv2.cvtColor(output_image, cv2.COLOR_RGBA2RGB)
# Gradio interface
iface = gr.Interface(
fn=segment_hair,
inputs=gr.Image(type="numpy"),
outputs=gr.Image(type="numpy"),
title="Hair Segmentation",
description="Upload an image to segment the hair and generate a mask.",
examples=["example.jpeg"]
)
if __name__ == "__main__":
iface.launch()