File size: 4,234 Bytes
45bcca5
 
 
 
c8910cc
45bcca5
 
 
fed7f36
 
 
de28386
45bcca5
 
c8910cc
fed7f36
c8910cc
45bcca5
 
c8910cc
 
 
 
fed7f36
 
 
 
c8910cc
fed7f36
45bcca5
 
fed7f36
 
 
45bcca5
 
fed7f36
 
 
45bcca5
 
de28386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python

from __future__ import annotations

import pathlib

import gradio as gr
import numpy as np

from model import Model

DESCRIPTION = '# [Self-Distilled StyleGAN](https://github.com/self-distilled-stylegan/self-distilled-internet-photos)'


def get_sample_image_url(name: str) -> str:
    sample_image_dir = 'https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/samples'
    return f'{sample_image_dir}/{name}.jpg'


def get_sample_image_markdown(name: str) -> str:
    url = get_sample_image_url(name)
    size = name.split('_')[1]
    truncation_type = '_'.join(name.split('_')[2:])
    return f'''
    - size: {size}x{size}
    - seed: 0-99
    - truncation: 0.7
    - truncation type: {truncation_type}
    ![sample images]({url})'''


def get_cluster_center_image_url(model_name: str) -> str:
    cluster_center_image_dir = 'https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/cluster_center_images'
    return f'{cluster_center_image_dir}/{model_name}.jpg'


def get_cluster_center_image_markdown(model_name: str) -> str:
    url = get_cluster_center_image_url(model_name)
    return f'![cluster center images]({url})'


model = Model()

with gr.Blocks(css='style.css') as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Tabs():
        with gr.TabItem('App'):
            with gr.Row():
                with gr.Column():
                    with gr.Group():
                        model_name = gr.Dropdown(label='Model',
                                                 choices=model.MODEL_NAMES,
                                                 value=model.MODEL_NAMES[0])
                        seed = gr.Slider(label='Seed',
                                         minimum=0,
                                         maximum=np.iinfo(np.uint32).max,
                                         step=1,
                                         value=0)
                        psi = gr.Slider(label='Truncation psi',
                                        minimum=0,
                                        maximum=2,
                                        step=0.05,
                                        value=0.7)
                        truncation_type = gr.Dropdown(
                            label='Truncation Type',
                            choices=model.TRUNCATION_TYPES,
                            value=model.TRUNCATION_TYPES[0])
                        run_button = gr.Button('Run')
                with gr.Column():
                    result = gr.Image(label='Result', elem_id='result')

        with gr.TabItem('Sample Images'):
            with gr.Row():
                paths = sorted(pathlib.Path('samples').glob('*'))
                names = [path.stem for path in paths]
                model_name2 = gr.Dropdown(label='Type',
                                          choices=names,
                                          value='dogs_1024_multimodal_lpips')
            with gr.Row():
                text = get_sample_image_markdown(model_name2.value)
                sample_images = gr.Markdown(text)

        with gr.TabItem('Cluster Center Images'):
            with gr.Row():
                model_name3 = gr.Dropdown(label='Model',
                                          choices=model.MODEL_NAMES,
                                          value=model.MODEL_NAMES[0])
            with gr.Row():
                text = get_cluster_center_image_markdown(model_name3.value)
                cluster_center_images = gr.Markdown(value=text)

    model_name.change(fn=model.set_model, inputs=model_name)
    run_button.click(fn=model.set_model_and_generate_image,
                     inputs=[
                         model_name,
                         seed,
                         psi,
                         truncation_type,
                     ],
                     outputs=result)
    model_name2.change(fn=get_sample_image_markdown,
                       inputs=model_name2,
                       outputs=sample_images)
    model_name3.change(fn=get_cluster_center_image_markdown,
                       inputs=model_name3,
                       outputs=cluster_center_images)

demo.queue(max_size=10).launch()