Spaces:
Runtime error
Runtime error
File size: 4,240 Bytes
338da8d 2c27d98 8ed2153 432f235 8ed2153 432f235 8ed2153 14f8384 6cfa606 4240411 6cfa606 d91aa80 6cfa606 14f8384 8ed2153 2c27d98 903b52c 2c27d98 338da8d 2c27d98 338da8d 6cfa606 338da8d 2c27d98 8ed2153 903b52c 8ed2153 903b52c 36c070e 903b52c d91aa80 903b52c d91aa80 903b52c d91aa80 903b52c 6cfa606 903b52c 2c27d98 4f8bfe3 2c27d98 8ed2153 903b52c e004917 36c070e e004917 8ed2153 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
import json
import os
import numpy as np
import torch
import gradio as gr
from config import PipelineConfig
from src.pipeline import FashionPipeline, PipelineOutput
config = PipelineConfig()
fashion_pipeline = FashionPipeline(config, device=torch.device('cuda'))
def process(
input_image: np.ndarray,
prompt: str,
negative_prompt: str,
generate_from_mask: bool,
num_inference_steps: int,
guidance_scale: float,
conditioning_scale: float,
target_image_size: int,
max_image_size: int,
seed: int,
):
output: PipelineOutput = fashion_pipeline(
control_image=input_image,
prompt=prompt,
negative_prompt=negative_prompt,
generate_from_mask=generate_from_mask,
num_inference_steps=num_inference_steps,
guidance_scale=float(guidance_scale),
conditioning_scale=float(conditioning_scale),
target_image_size=target_image_size,
max_image_size=max_image_size,
seed=seed,
)
return [
output.generated_image,
output.control_mask,
]
def read_content(file_path: str) -> str:
"""Read the content of target file."""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
image_dir = 'examples/images'
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
with open('examples/prompts.json', 'r') as f:
prompts_list = json.load(f).values()
examples = [[image, prompt[0], prompt[1]] for image, prompt in zip(image_list, prompts_list)]
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.HTML(read_content('header.html'))
with gr.Row():
with gr.Column():
input_image = gr.Image(type='numpy')
prompt = gr.Textbox(label='Prompt')
negative_prompt = gr.Textbox(label='Negative Prompt')
with gr.Row():
generate_from_mask = gr.Checkbox(label='Input image is already a control mask', value=False)
run_button = gr.Button(value='Run')
with gr.Accordion('Advanced options', open=False):
target_image_size = gr.Slider(
label='Image target size:',
minimum=512,
maximum=2048,
value=768,
step=64,
)
max_image_size = gr.Slider(
label='Image max size:',
minimum=512,
maximum=2048,
value=1024,
step=64,
)
num_inference_steps = gr.Slider(label='Number of steps', minimum=1, maximum=100, value=20, step=1)
guidance_scale = gr.Slider(label='Guidance scale', minimum=0.1, maximum=30.0, value=9.0, step=0.1)
conditioning_scale = gr.Slider(label='Conditioning scale', minimum=0.0, maximum=5.0, value=1.0, step=0.1)
seed = gr.Slider(label='Seed', minimum=0, maximum=config.max_seed, step=1, value=0)
gr.Examples(examples=examples, inputs=[input_image, prompt, negative_prompt], label='Examples - Input Images', examples_per_page=12)
gr.HTML(
"""
<div class="footer">
<p> This repo based on Unet from <a style="text-decoration: underline;" href="https://huggingface.co/spaces/wildoctopus/cloth-segmentation">cloth-segmentation</a>
It's uses pre-trained U2NET to extract Upper body(red), Lower body(green), Full body(blue) masks, and then
run StableDiffusionXLControlNetPipeline with trained controlnet_baseline to generate image conditioned on this masks.
</p>
""")
with gr.Column():
generated_output = gr.Image(label='Generated', type='numpy', elem_id='generated')
mask_output = gr.Image(label='Mask', type='numpy', elem_id='mask')
ips = [input_image, prompt, negative_prompt, generate_from_mask, num_inference_steps, guidance_scale, conditioning_scale, target_image_size, max_image_size, seed]
run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output])
block.launch()
|