File size: 3,098 Bytes
a1b524b
 
 
 
 
 
 
 
 
 
dda8135
a1b524b
26fa884
dda8135
a1b524b
 
 
dda8135
a1b524b
 
 
 
 
 
52cfbaf
a1b524b
 
 
 
52cfbaf
 
a1b524b
 
 
26fa884
 
dda8135
26fa884
52cfbaf
dda8135
26fa884
 
 
dda8135
26fa884
dda8135
26fa884
dda8135
26fa884
dda8135
52cfbaf
26fa884
 
dda8135
 
26fa884
52cfbaf
dda8135
26fa884
 
 
 
dda8135
 
26fa884
 
dda8135
26fa884
dda8135
26fa884
dda8135
26fa884
 
dda8135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26fa884
52cfbaf
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python

from __future__ import annotations

import pathlib

import gradio as gr

from model import Model

DESCRIPTION = """# [HairCLIP](https://github.com/wty-ustc/HairCLIP)

<center><img id="teaser" src="https://raw.githubusercontent.com/wty-ustc/HairCLIP/main/assets/teaser.png" alt="teaser"></center>
"""


def load_hairstyle_list() -> list[str]:
    with open("HairCLIP/mapper/hairstyle_list.txt") as f:
        lines = [line.strip() for line in f.readlines()]
        lines = [line[:-10] for line in lines]
    return lines


def set_example_image(example: list) -> dict:
    return gr.Image(value=example[0])


def update_step2_components(choice: str) -> tuple[dict, dict]:
    return (
        gr.Dropdown(visible=choice in ["hairstyle", "both"]),
        gr.Textbox(visible=choice in ["color", "both"]),
    )


model = Model()

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    with gr.Group():
        gr.Markdown("## Step 1")
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    input_image = gr.Image(label="Input Image", type="filepath")
                with gr.Row():
                    preprocess_button = gr.Button("Preprocess")
            with gr.Column():
                aligned_face = gr.Image(label="Aligned Face", type="pil", interactive=False)
            with gr.Column():
                reconstructed_face = gr.Image(label="Reconstructed Face", type="numpy")
                latent = gr.State()

        with gr.Row():
            paths = sorted(pathlib.Path("images").glob("*.jpg"))
            gr.Examples(examples=[[path.as_posix()] for path in paths], inputs=input_image)

    with gr.Group():
        gr.Markdown("## Step 2")
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    editing_type = gr.Radio(
                        label="Editing Type", choices=["hairstyle", "color", "both"], value="both", type="value"
                    )
                with gr.Row():
                    hairstyles = load_hairstyle_list()
                    hairstyle_index = gr.Dropdown(label="Hairstyle", choices=hairstyles, value="afro", type="index")
                with gr.Row():
                    color_description = gr.Textbox(label="Color", value="red")
                with gr.Row():
                    run_button = gr.Button("Run")

            with gr.Column():
                result = gr.Image(label="Result")

    preprocess_button.click(fn=model.detect_and_align_face, inputs=input_image, outputs=aligned_face)
    aligned_face.change(fn=model.reconstruct_face, inputs=aligned_face, outputs=[reconstructed_face, latent])
    editing_type.change(fn=update_step2_components, inputs=editing_type, outputs=[hairstyle_index, color_description])
    run_button.click(
        fn=model.generate,
        inputs=[
            editing_type,
            hairstyle_index,
            color_description,
            latent,
        ],
        outputs=result,
    )

if __name__ == "__main__":
    demo.queue(max_size=10).launch()