File size: 3,577 Bytes
45bcca5
 
 
 
c8910cc
45bcca5
 
 
fed7f36
 
 
f2c9715
45bcca5
 
c8910cc
f2c9715
 
45bcca5
 
c8910cc
 
f2c9715
 
 
fed7f36
 
 
c8910cc
f2c9715
45bcca5
 
fed7f36
f2c9715
 
 
 
45bcca5
 
fed7f36
 
f2c9715
45bcca5
 
de28386
 
f2c9715
de28386
 
 
f2c9715
de28386
 
 
f2c9715
 
 
de28386
f2c9715
 
 
de28386
f2c9715
de28386
f2c9715
de28386
f2c9715
de28386
f2c9715
de28386
 
 
 
f2c9715
de28386
f2c9715
de28386
 
 
 
4872e62
 
 
 
f2c9715
 
 
 
 
 
 
 
 
 
4872e62
 
 
 
 
 
 
 
 
 
de28386
4872e62
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python

from __future__ import annotations

import pathlib

import gradio as gr
import numpy as np

from model import Model

DESCRIPTION = "# [Self-Distilled StyleGAN](https://github.com/self-distilled-stylegan/self-distilled-internet-photos)"


def get_sample_image_url(name: str) -> str:
    sample_image_dir = "https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/samples"
    return f"{sample_image_dir}/{name}.jpg"


def get_sample_image_markdown(name: str) -> str:
    url = get_sample_image_url(name)
    size = name.split("_")[1]
    truncation_type = "_".join(name.split("_")[2:])
    return f"""
    - size: {size}x{size}
    - seed: 0-99
    - truncation: 0.7
    - truncation type: {truncation_type}
    ![sample images]({url})"""


def get_cluster_center_image_url(model_name: str) -> str:
    cluster_center_image_dir = (
        "https://huggingface.co/spaces/hysts/Self-Distilled-StyleGAN/resolve/main/cluster_center_images"
    )
    return f"{cluster_center_image_dir}/{model_name}.jpg"


def get_cluster_center_image_markdown(model_name: str) -> str:
    url = get_cluster_center_image_url(model_name)
    return f"![cluster center images]({url})"


model = Model()

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Tabs():
        with gr.TabItem("App"):
            with gr.Row():
                with gr.Column():
                    with gr.Group():
                        model_name = gr.Dropdown(label="Model", choices=model.MODEL_NAMES, value=model.MODEL_NAMES[0])
                        seed = gr.Slider(label="Seed", minimum=0, maximum=np.iinfo(np.uint32).max, step=1, value=0)
                        psi = gr.Slider(label="Truncation psi", minimum=0, maximum=2, step=0.05, value=0.7)
                        truncation_type = gr.Dropdown(
                            label="Truncation Type", choices=model.TRUNCATION_TYPES, value=model.TRUNCATION_TYPES[0]
                        )
                        run_button = gr.Button("Run")
                with gr.Column():
                    result = gr.Image(label="Result", elem_id="result")

        with gr.TabItem("Sample Images"):
            with gr.Row():
                paths = sorted(pathlib.Path("samples").glob("*"))
                names = [path.stem for path in paths]
                model_name2 = gr.Dropdown(label="Type", choices=names, value="dogs_1024_multimodal_lpips")
            with gr.Row():
                text = get_sample_image_markdown(model_name2.value)
                sample_images = gr.Markdown(text)

        with gr.TabItem("Cluster Center Images"):
            with gr.Row():
                model_name3 = gr.Dropdown(label="Model", choices=model.MODEL_NAMES, value=model.MODEL_NAMES[0])
            with gr.Row():
                text = get_cluster_center_image_markdown(model_name3.value)
                cluster_center_images = gr.Markdown(value=text)

    model_name.change(
        fn=model.set_model,
        inputs=model_name,
    )
    run_button.click(
        fn=model.set_model_and_generate_image,
        inputs=[
            model_name,
            seed,
            psi,
            truncation_type,
        ],
        outputs=result,
    )
    model_name2.change(
        fn=get_sample_image_markdown,
        inputs=model_name2,
        outputs=sample_images,
    )
    model_name3.change(
        fn=get_cluster_center_image_markdown,
        inputs=model_name3,
        outputs=cluster_center_images,
    )

if __name__ == "__main__":
    demo.queue(max_size=10).launch()