File size: 5,599 Bytes
4fb3c5e
 
 
 
 
 
 
 
 
 
 
70d5056
4fb3c5e
 
 
ea424ac
4fb3c5e
ea424ac
4fb3c5e
 
 
 
 
7d37aeb
 
ea424ac
 
7d37aeb
 
ea424ac
4fb3c5e
 
ea424ac
 
 
 
7d37aeb
 
ea424ac
 
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
1489344
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
cb229bd
 
 
 
 
 
 
 
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1489344
4fb3c5e
 
 
 
 
 
 
 
cb229bd
4fb3c5e
 
 
 
 
ea424ac
4fb3c5e
 
ea424ac
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb229bd
 
 
4fb3c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python

from __future__ import annotations

import argparse

import gradio as gr

from model import Model

TITLE = '# Anime Face Generation with [Diffusers](https://github.com/huggingface/diffusers)'
DESCRIPTION = 'Expected execution time on Hugging Face Spaces: 5s (DDIM, 20 steps), 6s (PNDM, 20 steps), 247s (DDPM, 1000 steps)'
FOOTER = '<img id="visitor-badge" src="https://visitor-badge.glitch.me/badge?page_id=hysts.diffusers-anime-faces" alt="visitor badge" />'


def get_sample_image_url(file_name: str) -> str:
    sample_image_dir = 'https://huggingface.co/spaces/hysts/diffusers-anime-faces/resolve/main/samples'
    return f'{sample_image_dir}/{file_name}'


def get_sample_image_markdown(name: str) -> str:
    model_name = name.split()[0]
    if name == 'ddpm-128-exp000 (DDPM)':
        scheduler = 'DDPM'
        steps = 1000
        file_name = f'{model_name}.png'
    elif name == 'ddpm-128-exp000 (DDIM, 20 steps)':
        scheduler = 'DDIM'
        steps = 20
        file_name = f'{model_name}-ddim-20steps.png'
    else:
        raise ValueError
    url = get_sample_image_url(file_name)
    text = f'''
            - size: 128x128
            - seed: 0-99
            - scheduler: {scheduler}
            - steps: {steps}

            ![sample images]({url})'''
    return text


def update_scheduler_type(name: str) -> dict:
    visible = name != 'DDPM'
    if name == 'PNDM':
        minimum = 4
        maximum = 100
    else:
        minimum = 1
        maximum = 200
    return gr.Slider.update(visible=visible,
                            minimum=minimum,
                            maximum=maximum,
                            value=20)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', type=str, default='cpu')
    args = parser.parse_args()

    model = Model(args.device)

    with gr.Blocks(css='style.css') as demo:
        gr.Markdown(TITLE)

        with gr.Tabs():
            with gr.TabItem('Simple Mode'):
                run_button_simple = gr.Button('Generate')
                result_simple = gr.Image(show_label=False,
                                         elem_id='result-grid')

            with gr.TabItem('Advanced Mode'):
                gr.Markdown(DESCRIPTION)

                with gr.Row():
                    with gr.Column():
                        with gr.Group():
                            model_name = gr.Dropdown(
                                model.MODEL_NAMES,
                                value=model.MODEL_NAMES[0],
                                label='Model',
                                interactive=False)
                            scheduler_type = gr.Radio(
                                choices=['DDPM', 'DDIM', 'PNDM'],
                                value='DDIM',
                                label='Scheduler')
                            num_steps = gr.Slider(1,
                                                  200,
                                                  step=1,
                                                  value=20,
                                                  label='Number of Steps')
                            seed = gr.Slider(0,
                                             100000,
                                             step=1,
                                             value=1234,
                                             label='Seed')
                            run_button = gr.Button('Run')
                    with gr.Column():
                        result = gr.Image(show_label=False, elem_id='result')

            with gr.TabItem('Sample Images'):
                with gr.Row():
                    model_name2 = gr.Dropdown([
                        'ddpm-128-exp000 (DDPM)',
                        'ddpm-128-exp000 (DDIM, 20 steps)',
                    ],
                                              value='ddpm-128-exp000 (DDPM)',
                                              label='Model')
                with gr.Row():
                    text = get_sample_image_markdown(model_name2.value)
                    sample_images = gr.Markdown(text)

        gr.Markdown(FOOTER)

        model_name.change(fn=model.set_pipeline,
                          inputs=[
                              model_name,
                              scheduler_type,
                          ],
                          outputs=None)
        scheduler_type.change(fn=update_scheduler_type,
                              inputs=scheduler_type,
                              outputs=num_steps,
                              queue=False)
        scheduler_type.change(fn=model.set_pipeline,
                              inputs=[
                                  model_name,
                                  scheduler_type,
                              ],
                              outputs=None)
        run_button_simple.click(fn=model.run_simple,
                                inputs=None,
                                outputs=result_simple)
        run_button.click(fn=model.run,
                         inputs=[
                             model_name,
                             scheduler_type,
                             num_steps,
                             seed,
                         ],
                         outputs=result)
        model_name2.change(fn=get_sample_image_markdown,
                           inputs=model_name2,
                           outputs=sample_images)

    demo.launch(enable_queue=True, share=False)


if __name__ == '__main__':
    main()