File size: 4,871 Bytes
4069116
e9e2aab
b5e8b97
 
e9e2aab
b5e8b97
 
1c472db
 
 
 
 
 
 
 
 
 
4069116
b5e8b97
 
1c472db
 
 
 
 
 
 
 
 
b5e8b97
 
 
 
 
 
6c139d0
286253d
b5e8b97
4069116
 
 
b6090a0
4069116
 
 
 
1c472db
 
 
 
 
 
 
4069116
1c472db
 
e9e2aab
 
1c472db
 
 
 
6c139d0
1c472db
 
 
 
 
 
 
 
 
 
 
 
 
 
4ad4ba1
9cf1e91
1c472db
 
 
 
 
 
 
 
 
 
 
339d9cc
1c472db
 
339d9cc
1c472db
 
74880fc
1c472db
 
4ad4ba1
 
9cf1e91
1c472db
 
 
 
f302484
1c472db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import gradio as gr
import legacy
import dnnlib
import numpy as np
import torch

import find_direction
import generator
import psp_wrapper


psp_encoder_path = "./pretrained/e4e_ffhq_encode.pt"
landmarks_path = "./pretrained/shape_predictor_68_face_landmarks.dat"
e4e_embedder = psp_wrapper.psp_encoder(psp_encoder_path, landmarks_path)
G_ffhq_path = "./pretrained/ffhq.pkl"
G_metfaces_path = "./pretrained/metfaces.pkl"
direction_folder = "./assets/directions/"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

with dnnlib.util.open_url(G_ffhq_path) as f:
    G_ffhq = legacy.load_network_pkl(f)['G_ema'].to(device)

with dnnlib.util.open_url(G_metfaces_path) as f:
    G_metfaces = legacy.load_network_pkl(f)['G_ema'].to(device)

G_dict = {"FFHQ": G_ffhq, "MetFaces": G_metfaces}



DESCRIPTION = '''# <a href="https://github.com/catlab-team/stylemc"> StyleMC:</a> Multi-Channel Based Fast Text-Guided Image Generation and Manipulation
'''
FOOTER = 'This space is built by <a href = "https://github.com/catlab-team">Catlab Team</a>.'

direction_map = {}
direction_list = []

directions = [f for f in os.listdir(direction_folder) if f.endswith(".npz")]
for d in directions:
    with np.load(direction_folder + d) as data:
        dir_name = d.split(".npz")[0]
        direction_list.append(data["s"])
        direction_map[dir_name] = {"direction": data["s"], "stylegan_type": "FFHQ"} 


def add_direction(prompt, stylegan_type, id_loss_w):
    new_dir_name = prompt+" "+stylegan_type+" w_id_loss"+str(id_loss_w)
    if (prompt != None) and (new_dir_name not in direction_list):
        print("adding direction with id:", new_dir_name)
        direction = find_direction.find_direction(G_dict[stylegan_type], prompt)
        print(f"new direction calculated with {stylegan_type} and id loss weight = {id_loss_w}")
        direction_list.append(new_dir_name)
        direction_map[new_dir_name] = {"direction":direction, "stylegan_type":stylegan_type}

    return gr.Radio.update(choices=direction_list, value=None, visible=True)


def generate_output_image(image_path, direction_id, change_power):
    direction = direction_map[direction_id]["direction"]
    G=G_dict["FFHQ"]

    w = e4e_embedder.get_w(image_path)
    s = generator.w_to_s(GIn=G, wsIn=w)
    output_image = generator.generate_from_style(
        GIn=G, 
        styles=s, 
        styles_direction=direction,
        change_power=change_power,
        outdir='.'
    )
    return output_image
  
with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)

    with gr.Box():
        gr.Markdown('''### Step 1) Finding a global manipulation direction <br />
            - Please enter the target **text prompt** and **identity loss weight** to find global manipulation direction.''')
        with gr.Row():
            with gr.Column():
                style_gan_type = gr.Radio(["FFHQ", "MetFaces"], value = "FFHQ", label="StyleGAN Type", interactive=True)
            with gr.Column():
                identity_loss_weight = gr.Slider(
                    0.1, 10, value=0.5, step=0.1,label="Identity Loss Weight",interactive=True
                )
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    text = gr.Textbox(
                        label="Enter your text prompt",
                        show_label=False,
                        max_lines=1,
                        placeholder="Enter your text prompt"
                    ).style(container=False)

                    find_direction_btn = gr.Button("Find Direction").style(full_width=False)

    with gr.Box():
        gr.Markdown('''### Step 2) Text-guided manipulation <br />
            - Please upload an image. <br />
            - You can select any of the previously found **directions** and set the **manipulation strength** to manipulate the image.''')
        with gr.Row():
            direction_radio = gr.Radio(direction_list, label="List of Directions")
        with gr.Row():
            manipulation_strength = gr.Slider(
                0.1, 25, value=10, step=0.1, label="Manipulation Strength",interactive=True
            )
        with gr.Row():
            with gr.Column():
                with gr.Row():
                    input_image = gr.Image(label="Input Image", type="filepath")
                with gr.Row():
                    generate_btn = gr.Button("Generate")
            with gr.Column():
                with gr.Row():
                    generated_image = gr.Image(label="Generated Image",type="pil",interactive=False)

    find_direction_btn.click(add_direction, inputs=[text, style_gan_type, identity_loss_weight], outputs=direction_radio)
    generate_btn.click(generate_output_image, inputs=[input_image, direction_radio,manipulation_strength], outputs=generated_image)

demo.launch(debug=True)