|
import gradio as gr |
|
from PIL import Image as PILImage |
|
import numpy as np |
|
from depth_map_generator import generate_depth_map |
|
from depth_segmentation import segment_image_by_depth |
|
|
|
def on_process(image, num_segments): |
|
|
|
if isinstance(image, np.ndarray): |
|
image = PILImage.fromarray(image) |
|
|
|
|
|
original_img_path = "original.jpg" |
|
image.save(original_img_path) |
|
|
|
|
|
depth_map_path = generate_depth_map(original_img_path) |
|
|
|
|
|
segmented_images = segment_image_by_depth(original_img_path, depth_map_path, num_segments) |
|
|
|
|
|
return segmented_images, segmented_images |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Depth Map Segmentation Tool") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Upload Original Image") |
|
num_segments = gr.Slider(label="Number of Segments", minimum=3, maximum=5, step=1, value=3) |
|
process_button = gr.Button("Process Image") |
|
|
|
with gr.Column(): |
|
output_gallery = gr.Gallery(label="Segmented Images").style(grid=[3, 2]) |
|
download_button = gr.Files(label="Download All", file_count="multiple") |
|
|
|
process_button.click(on_process, [image_input, num_segments], [output_gallery, download_button]) |
|
|
|
demo.launch() |
|
|