Spaces:
Runtime error
Runtime error
import json | |
import os | |
import argparse | |
import numpy as np | |
import torch | |
import gradio as gr | |
from config import PipelineConfig | |
from src.pipeline import FashionPipeline, PipelineOutput | |
config = PipelineConfig() | |
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu' | |
fashion_pipeline = FashionPipeline(config, device=device) | |
def process( | |
input_image: np.ndarray, | |
prompt: str, | |
negative_prompt: str, | |
generate_from_mask: bool, | |
num_inference_steps: int, | |
guidance_scale: float, | |
conditioning_scale: float, | |
target_image_size: int, | |
max_image_size: int, | |
seed: int, | |
): | |
output: PipelineOutput = fashion_pipeline( | |
control_image=input_image, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
generate_from_mask=generate_from_mask, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=float(guidance_scale), | |
conditioning_scale=float(conditioning_scale), | |
target_image_size=target_image_size, | |
max_image_size=max_image_size, | |
seed=seed, | |
) | |
return [ | |
output.generated_image, | |
output.control_mask, | |
] | |
def read_content(file_path: str) -> str: | |
"""Read the content of target file.""" | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
image_dir = 'examples/images' | |
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)] | |
with open('examples/prompts.json', 'r') as f: | |
prompts_list = json.load(f).values() | |
examples = [[image, prompt[0], prompt[1]] for image, prompt in zip(image_list, prompts_list)] | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.HTML(read_content('header.html')) | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type='numpy') | |
prompt = gr.Textbox(label='Prompt') | |
negative_prompt = gr.Textbox(label='Negative Prompt') | |
with gr.Row(): | |
generate_from_mask = gr.Checkbox(label='Input image is already a control mask', value=False) | |
run_button = gr.Button(value='Run') | |
with gr.Accordion('Advanced options', open=False): | |
target_image_size = gr.Slider( | |
label='Image target size:', | |
minimum=512, | |
maximum=2048, | |
value=768, | |
step=64, | |
) | |
max_image_size = gr.Slider( | |
label='Image max size:', | |
minimum=512, | |
maximum=2048, | |
value=1024, | |
step=64, | |
) | |
num_inference_steps = gr.Slider(label='Number of steps', minimum=1, maximum=100, value=20, step=1) | |
guidance_scale = gr.Slider(label='Guidance scale', minimum=0.1, maximum=30.0, value=9.0, step=0.1) | |
conditioning_scale = gr.Slider(label='Conditioning scale', minimum=0.0, maximum=5.0, value=1.0, step=0.1) | |
seed = gr.Slider(label='Seed', minimum=0, maximum=config.max_seed, step=1, value=0) | |
gr.Examples(examples=examples, inputs=[input_image, prompt, negative_prompt], label='Examples - Input Images', examples_per_page=12) | |
gr.HTML( | |
""" | |
<div class="footer"> | |
<p> This repo based on Unet from <a style="text-decoration: underline;" href="https://huggingface.co/spaces/wildoctopus/cloth-segmentation">cloth-segmentation</a> | |
It's uses pre-trained U2NET to extract Upper body(red), Lower body(green), Full body(blue) masks, and then | |
run StableDiffusionXLControlNetPipeline with trained controlnet_baseline to generate image conditioned on this masks. | |
</p> | |
""") | |
with gr.Column(): | |
generated_output = gr.Image(label='Generated', type='numpy', elem_id='generated') | |
mask_output = gr.Image(label='Mask', type='numpy', elem_id='mask') | |
ips = [input_image, prompt, negative_prompt, generate_from_mask, num_inference_steps, guidance_scale, conditioning_scale, target_image_size, max_image_size, seed] | |
run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output]) | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--share', | |
'-s', | |
action="store_true", | |
default=False, | |
help='Create public link for the app.' | |
) | |
args = parser.parse_args() | |
block.launch(share=args.share) | |