import gradio as gr import legacy import dnnlib import numpy as np import torch import find_direction import generator import psp_wrapper psp_encoder_path = "./pretrained/e4e_ffhq_encode.pt" landmarks_path = "./pretrained/shape_predictor_68_face_landmarks.dat" e4e_embedder = psp_wrapper.psp_encoder(psp_encoder_path, landmarks_path) G_ffhq_path = "./pretrained/ffhq.pkl" G_metfaces_path = "./pretrained/metfaces.pkl" device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") with dnnlib.util.open_url(G_ffhq_path) as f: G_ffhq = legacy.load_network_pkl(f)['G_ema'].to(device) with dnnlib.util.open_url(G_metfaces_path) as f: G_metfaces = legacy.load_network_pkl(f)['G_ema'].to(device) G_dict = {"FFHQ": G_ffhq, "MetFaces": G_metfaces} DESCRIPTION = '''# StyleMC: Multi-Channel Based Fast Text-Guided Image Generation and Manipulation ''' FOOTER = 'This space is built by Catlab Team.' direction_map = {} direction_list = [] def add_direction(prompt, stylegan_type, id_loss_w): new_dir_name = prompt+" "+stylegan_type+" w_id_loss"+str(id_loss_w) if (prompt != None) and (new_dir_name not in direction_list): print("adding direction with id:", new_dir_name) direction = find_direction.find_direction(G_dict[stylegan_type], prompt) print(f"new direction calculated with {stylegan_type} and id loss weight = {id_loss_w}") direction_list.append(new_dir_name) direction_map[new_dir_name]={"direction":direction, "stylegan_type":stylegan_type} return gr.Radio.update(choices=direction_list, value=None, visible=True) def generate_output_image(image_path, direction_id, change_power): direction = direction_map[direction_id]["direction"] G=G_dict["FFHQ"] w = e4e_embedder.get_w(image_path) s = generator.w_to_s(GIn=G, wsIn=w) output_image = generator.generate_from_style( GIn=G, styles=s, styles_direction=direction, change_power=change_power, outdir='.' ) return output_image with gr.Blocks(css="style.css") as demo: gr.Markdown(DESCRIPTION) with gr.Box(): gr.Markdown('''### Step 1) Finding a global manipulation direction\ - Please enter the target **text prompt** and **identity loss weight** to find global manipulation direction.''') with gr.Row(): with gr.Column(): style_gan_type = gr.Radio(["FFHQ", "MetFaces"], value = "FFHQ", label="StyleGAN Type", interactive=True) with gr.Column(): identity_loss_weight = gr.Slider( 0.1, 10, value=0.5, step=0.1,label="Identity Loss Weight",interactive=True ) with gr.Row(): with gr.Column(): with gr.Row(): text = gr.Textbox( label="Enter your text prompt", show_label=False, max_lines=1, placeholder="Enter your text prompt" ).style(container=False) find_direction_btn = gr.Button("Find Direction").style(full_width=False) with gr.Box(): gr.Markdown('''### Step 2) Text-guided manipulation\ - Please upload an image:\ - You can select any of the previously found **directions** and set the **manipulation strength** to manipulate the image.''') with gr.Row(): direction_radio = gr.Radio(direction_list, label="List of Directions") with gr.Row(): manipulation_strength = gr.Slider( 0.1, 25, value=10, step=0.1, label="Manipulation Strength",interactive=True ) with gr.Row(): with gr.Column(): with gr.Row(): input_image = gr.Image(label="Input Image", type="filepath") with gr.Row(): generate_btn = gr.Button("Generate") with gr.Column(): with gr.Row(): generated_image = gr.Image(label="Generated Image",type="pil",interactive=False) find_direction_btn.click(add_direction, inputs=[text, style_gan_type, identity_loss_weight], outputs=direction_radio) generate_btn.click(generate_output_image, inputs=[input_image, direction_radio,manipulation_strength], outputs=generated_image) demo.launch(debug=True)