File size: 4,240 Bytes
338da8d
2c27d98
 
8ed2153
 
432f235
 
8ed2153
 
432f235
 
8ed2153
14f8384
6cfa606
 
 
 
 
 
4240411
6cfa606
 
 
d91aa80
 
6cfa606
 
14f8384
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ed2153
 
2c27d98
903b52c
2c27d98
 
 
 
 
 
338da8d
2c27d98
338da8d
 
6cfa606
338da8d
2c27d98
8ed2153
 
 
903b52c
8ed2153
 
903b52c
 
 
36c070e
903b52c
 
 
d91aa80
903b52c
d91aa80
 
 
 
 
 
903b52c
d91aa80
 
 
 
 
903b52c
 
 
 
6cfa606
903b52c
2c27d98
 
 
 
 
 
 
4f8bfe3
2c27d98
 
8ed2153
 
903b52c
 
e004917
36c070e
e004917
8ed2153
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import json
import os

import numpy as np
import torch
import gradio as gr

from config import PipelineConfig
from src.pipeline import FashionPipeline, PipelineOutput


config = PipelineConfig()
fashion_pipeline = FashionPipeline(config, device=torch.device('cuda'))


def process(
    input_image: np.ndarray,
    prompt: str,
    negative_prompt: str,
    generate_from_mask: bool,
    num_inference_steps: int,
    guidance_scale: float,
    conditioning_scale: float,
    target_image_size: int,
    max_image_size: int,
    seed: int,
):
    output: PipelineOutput = fashion_pipeline(
        control_image=input_image,
        prompt=prompt,
        negative_prompt=negative_prompt,
        generate_from_mask=generate_from_mask,
        num_inference_steps=num_inference_steps,
        guidance_scale=float(guidance_scale),
        conditioning_scale=float(conditioning_scale),
        target_image_size=target_image_size,
        max_image_size=max_image_size,
        seed=seed,
    )
    return [
        output.generated_image,
        output.control_mask,
    ]


def read_content(file_path: str) -> str:
    """Read the content of target file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()

    return content


image_dir = 'examples/images'
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
with open('examples/prompts.json', 'r') as f:
    prompts_list = json.load(f).values()
examples = [[image, prompt[0], prompt[1]] for image, prompt in zip(image_list, prompts_list)]


block = gr.Blocks().queue()
with block:
    with gr.Row():
        gr.HTML(read_content('header.html'))
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(type='numpy')
            prompt = gr.Textbox(label='Prompt')
            negative_prompt = gr.Textbox(label='Negative Prompt')
            with gr.Row():
                generate_from_mask = gr.Checkbox(label='Input image is already a control mask', value=False)
            run_button = gr.Button(value='Run')
            with gr.Accordion('Advanced options', open=False):
                target_image_size = gr.Slider(
                    label='Image target size:',
                    minimum=512,
                    maximum=2048,
                    value=768,
                    step=64,
                )
                max_image_size = gr.Slider(
                    label='Image max size:',
                    minimum=512,
                    maximum=2048,
                    value=1024,
                    step=64,
                )
                num_inference_steps = gr.Slider(label='Number of steps', minimum=1, maximum=100, value=20, step=1)
                guidance_scale = gr.Slider(label='Guidance scale', minimum=0.1, maximum=30.0, value=9.0, step=0.1)
                conditioning_scale = gr.Slider(label='Conditioning scale', minimum=0.0, maximum=5.0, value=1.0, step=0.1)
                seed = gr.Slider(label='Seed', minimum=0, maximum=config.max_seed, step=1, value=0)

            gr.Examples(examples=examples, inputs=[input_image, prompt, negative_prompt], label='Examples - Input Images', examples_per_page=12)

            gr.HTML(
                """
                    <div class="footer">
                    
                        <p> This repo based on Unet from <a style="text-decoration: underline;" href="https://huggingface.co/spaces/wildoctopus/cloth-segmentation">cloth-segmentation</a>
        It's uses pre-trained U2NET to extract Upper body(red), Lower body(green), Full body(blue) masks, and then
        run StableDiffusionXLControlNetPipeline with trained controlnet_baseline to generate image conditioned on this masks.
                        </p>
                """)

        with gr.Column():
            generated_output = gr.Image(label='Generated', type='numpy', elem_id='generated')
            mask_output = gr.Image(label='Mask', type='numpy', elem_id='mask')

    ips = [input_image, prompt, negative_prompt, generate_from_mask, num_inference_steps, guidance_scale, conditioning_scale, target_image_size, max_image_size, seed]
    run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output])


block.launch()