dragynir's picture
add gpu select
755f543
import json
import os
import argparse
import numpy as np
import torch
import gradio as gr
from config import PipelineConfig
from src.pipeline import FashionPipeline, PipelineOutput
config = PipelineConfig()
device = torch.device('cuda') if torch.cuda.is_available() else 'cpu'
fashion_pipeline = FashionPipeline(config, device=device)
def process(
input_image: np.ndarray,
prompt: str,
negative_prompt: str,
generate_from_mask: bool,
num_inference_steps: int,
guidance_scale: float,
conditioning_scale: float,
target_image_size: int,
max_image_size: int,
seed: int,
):
output: PipelineOutput = fashion_pipeline(
control_image=input_image,
prompt=prompt,
negative_prompt=negative_prompt,
generate_from_mask=generate_from_mask,
num_inference_steps=num_inference_steps,
guidance_scale=float(guidance_scale),
conditioning_scale=float(conditioning_scale),
target_image_size=target_image_size,
max_image_size=max_image_size,
seed=seed,
)
return [
output.generated_image,
output.control_mask,
]
def read_content(file_path: str) -> str:
"""Read the content of target file."""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
image_dir = 'examples/images'
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
with open('examples/prompts.json', 'r') as f:
prompts_list = json.load(f).values()
examples = [[image, prompt[0], prompt[1]] for image, prompt in zip(image_list, prompts_list)]
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.HTML(read_content('header.html'))
with gr.Row():
with gr.Column():
input_image = gr.Image(type='numpy')
prompt = gr.Textbox(label='Prompt')
negative_prompt = gr.Textbox(label='Negative Prompt')
with gr.Row():
generate_from_mask = gr.Checkbox(label='Input image is already a control mask', value=False)
run_button = gr.Button(value='Run')
with gr.Accordion('Advanced options', open=False):
target_image_size = gr.Slider(
label='Image target size:',
minimum=512,
maximum=2048,
value=768,
step=64,
)
max_image_size = gr.Slider(
label='Image max size:',
minimum=512,
maximum=2048,
value=1024,
step=64,
)
num_inference_steps = gr.Slider(label='Number of steps', minimum=1, maximum=100, value=20, step=1)
guidance_scale = gr.Slider(label='Guidance scale', minimum=0.1, maximum=30.0, value=9.0, step=0.1)
conditioning_scale = gr.Slider(label='Conditioning scale', minimum=0.0, maximum=5.0, value=1.0, step=0.1)
seed = gr.Slider(label='Seed', minimum=0, maximum=config.max_seed, step=1, value=0)
gr.Examples(examples=examples, inputs=[input_image, prompt, negative_prompt], label='Examples - Input Images', examples_per_page=12)
gr.HTML(
"""
<div class="footer">
<p> This repo based on Unet from <a style="text-decoration: underline;" href="https://huggingface.co/spaces/wildoctopus/cloth-segmentation">cloth-segmentation</a>
It's uses pre-trained U2NET to extract Upper body(red), Lower body(green), Full body(blue) masks, and then
run StableDiffusionXLControlNetPipeline with trained controlnet_baseline to generate image conditioned on this masks.
</p>
""")
with gr.Column():
generated_output = gr.Image(label='Generated', type='numpy', elem_id='generated')
mask_output = gr.Image(label='Mask', type='numpy', elem_id='mask')
ips = [input_image, prompt, negative_prompt, generate_from_mask, num_inference_steps, guidance_scale, conditioning_scale, target_image_size, max_image_size, seed]
run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output])
parser = argparse.ArgumentParser()
parser.add_argument(
'--share',
'-s',
action="store_true",
default=False,
help='Create public link for the app.'
)
args = parser.parse_args()
block.launch(share=args.share)