Spaces:
Running
Running
File size: 4,767 Bytes
896437a 1c6ea49 896437a 1c6ea49 896437a 1c6ea49 896437a 1c6ea49 896437a 1c6ea49 7906bcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import PIL
import torch
import gradio as gr
import os
from process import load_seg_model, get_palette, generate_mask
device = 'cpu'
def read_content(file_path: str) -> str:
"""read the content of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
def initialize_and_load_models():
checkpoint_path = 'model/cloth_segm.pth'
net = load_seg_model(checkpoint_path, device=device)
return net
net = initialize_and_load_models()
palette = get_palette(4)
def run(img):
cloth_seg = generate_mask(img, net=net, palette=palette, device=device)
return cloth_seg
# Define input and output interfaces
input_image = gr.inputs.Image(label="Input Image", type="pil")
# Define the Gradio interface
cloth_seg_image = gr.outputs.Image(label="Cloth Segmentation", type="pil")
title = "Demo for Cloth Segmentation"
description = "An app for Cloth Segmentation"
inputs = [input_image]
outputs = [cloth_seg_image]
css = '''
.container {max-width: 1150px;margin: auto;padding-top: 1.5rem}
#image_upload{min-height:400px}
#image_upload [data-testid="image"], #image_upload [data-testid="image"] > div{min-height: 400px}
#mask_radio .gr-form{background:transparent; border: none}
#word_mask{margin-top: .75em !important}
#word_mask textarea:disabled{opacity: 0.3}
.footer {margin-bottom: 45px;margin-top: 35px;text-align: center;border-bottom: 1px solid #e5e5e5}
.footer>p {font-size: .8rem; display: inline-block; padding: 0 10px;transform: translateY(10px);background: white}
.dark .footer {border-color: #303030}
.dark .footer>p {background: #0b0f19}
.acknowledgments h4{margin: 1.25em 0 .25em 0;font-weight: bold;font-size: 115%}
#image_upload .touch-none{display: flex}
@keyframes spin {
from {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
#share-btn-container {
display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem;
}
#share-btn {
all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;
}
#share-btn * {
all: unset;
}
#share-btn-container div:nth-child(-n+2){
width: auto !important;
min-height: 0px !important;
}
#share-btn-container .wrap {
display: none !important;
}
'''
example={}
image_dir='input'
image_list=[os.path.join(image_dir,file) for file in os.listdir(image_dir)]
image_list.sort()
image_blocks = gr.Blocks(css=css)
with image_blocks as demo:
gr.HTML(read_content("header.html"))
with gr.Group():
with gr.Box():
with gr.Row():
with gr.Column():
image = gr.Image(source='upload', elem_id="image_upload", type="pil", label="Input Image")
with gr.Column():
image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
with gr.Row():
with gr.Column():
gr.Examples(image_list, inputs=[image],label="Examples - Input Images",examples_per_page=12)
with gr.Column():
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
btn = gr.Button("Run!").style(
margin=False,
rounded=(False, True, True, False),
full_width=True,
)
btn.click(fn=run, inputs=[image], outputs=[image_out])
gr.HTML(
"""
<div class="footer">
<p>Model by <a href="" style="text-decoration: underline;" target="_blank">WildOctopus</a> - Gradio Demo by 🤗 Hugging Face
</p>
</div>
<div class="acknowledgments">
<p><h4>ACKNOWLEDGEMENTS</h4></p>
<p>
U2net model is from original u2net repo. Thanks to <a href="https://github.com/xuebinqin/U-2-Net" style="text-decoration: underline;" target="_blank">Xuebin Qin</a> for amazing repo.</p>
<p>Codes are modified from <a href="https://github.com/levindabhi/cloth-segmentation" style="text-decoration: underline;" target="_blank">levindabhi/cloth-segmentation</a>
</p>
"""
)
image_blocks.launch() |