File size: 1,480 Bytes
884f26e
2cd83c7
 
884f26e
 
 
110c7d2
2cd83c7
 
 
 
884f26e
 
 
 
 
 
 
 
 
 
110c7d2
fb3e824
884f26e
 
 
 
 
 
 
 
 
 
 
 
2cd83c7
884f26e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from PIL import Image as PILImage
import numpy as np
from depth_map_generator import generate_depth_map
from depth_segmentation import segment_image_by_depth

def on_process(image, num_segments):
    # Convert the NumPy array to a PIL Image
    if isinstance(image, np.ndarray):
        image = PILImage.fromarray(image)

    # Save the uploaded image
    original_img_path = "original.jpg"
    image.save(original_img_path)
    
    # Generate the depth map
    depth_map_path = generate_depth_map(original_img_path)
    
    # Perform depth segmentation
    segmented_images = segment_image_by_depth(original_img_path, depth_map_path, num_segments)
    
    # Return the segmented images and file paths for download
    return segmented_images, segmented_images

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# Depth Map Segmentation Tool")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Original Image")
            num_segments = gr.Slider(label="Number of Segments", minimum=3, maximum=5, step=1, value=3)
            process_button = gr.Button("Process Image")
        
        with gr.Column():
            output_gallery = gr.Gallery(label="Segmented Images").style(grid=[3, 2])
            download_button = gr.Files(label="Download All", file_count="multiple")
    
    process_button.click(on_process, [image_input, num_segments], [output_gallery, download_button])

demo.launch()