Spaces:
Runtime error
Runtime error
from ipywidgets import fixed | |
import gradio as gr | |
from skimage import img_as_ubyte | |
from config import Config | |
from decomposition import get_or_compute | |
from models import get_instrumented_model | |
import imageio | |
from PIL import Image | |
import ipywidgets as widgets | |
import numpy as np | |
import PIL | |
import torch | |
from IPython.utils import io | |
import nltk | |
nltk.download('wordnet') | |
# @title Load Model | |
selected_model = 'lookbook' | |
# Load model | |
# Speed up computation | |
torch.autograd.set_grad_enabled(False) | |
torch.backends.cudnn.benchmark = True | |
# Specify model to use | |
config = Config( | |
model='StyleGAN2', | |
layer='style', | |
output_class=selected_model, | |
components=80, | |
use_w=True, | |
batch_size=5_000, # style layer quite small | |
) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
inst = get_instrumented_model(config.model, config.output_class, | |
config.layer, torch.device(device), use_w=config.use_w) | |
path_to_components = get_or_compute(config, inst) | |
model = inst.model | |
comps = np.load(path_to_components) | |
lst = comps.files | |
latent_dirs = [] | |
latent_stdevs = [] | |
load_activations = False | |
for item in lst: | |
if load_activations: | |
if item == 'act_comp': | |
for i in range(comps[item].shape[0]): | |
latent_dirs.append(comps[item][i]) | |
if item == 'act_stdev': | |
for i in range(comps[item].shape[0]): | |
latent_stdevs.append(comps[item][i]) | |
else: | |
if item == 'lat_comp': | |
for i in range(comps[item].shape[0]): | |
latent_dirs.append(comps[item][i]) | |
if item == 'lat_stdev': | |
for i in range(comps[item].shape[0]): | |
latent_stdevs.append(comps[item][i]) | |
def mix_w(w1, w2, content, style): | |
for i in range(0, 5): | |
w2[i] = w1[i] * (1 - content) + w2[i] * content | |
for i in range(5, 16): | |
w2[i] = w1[i] * (1 - style) + w2[i] * style | |
return w2 | |
def display_sample_pytorch(seed, truncation, directions, distances, scale, start, end, w=None, disp=True, save=None, noise_spec=None): | |
# blockPrint() | |
model.truncation = truncation | |
if w is None: | |
w = model.sample_latent(1, seed=seed).detach().cpu().numpy() | |
w = [w]*model.get_max_latents() # one per layer | |
else: | |
w = [np.expand_dims(x, 0) for x in w] | |
for l in range(start, end): | |
for i in range(len(directions)): | |
w[l] = w[l] + directions[i] * distances[i] * scale | |
torch.cuda.empty_cache() | |
# save image and display | |
out = model.sample_np(w) | |
final_im = Image.fromarray( | |
(out * 255).astype(np.uint8)).resize((500, 500), Image.LANCZOS) | |
if save is not None: | |
if disp == False: | |
print(save) | |
final_im.save(f'out/{seed}_{save:05}.png') | |
if disp: | |
display(final_im) | |
return final_im | |
# @title Demo UI | |
def generate_image(seed1, seed2, content, style, truncation, c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer): | |
seed1 = int(seed1) | |
seed2 = int(seed2) | |
scale = 1 | |
params = {'c0': c0, | |
'c1': c1, | |
'c2': c2, | |
'c3': c3, | |
'c4': c4, | |
'c5': c5, | |
'c6': c6} | |
param_indexes = {'c0': 0, | |
'c1': 1, | |
'c2': 2, | |
'c3': 3, | |
'c4': 4, | |
'c5': 5, | |
'c6': 6} | |
directions = [] | |
distances = [] | |
for k, v in params.items(): | |
directions.append(latent_dirs[param_indexes[k]]) | |
distances.append(v) | |
w1 = model.sample_latent(1, seed=seed1).detach().cpu().numpy() | |
w1 = [w1]*model.get_max_latents() # one per layer | |
im1 = model.sample_np(w1) | |
w2 = model.sample_latent(1, seed=seed2).detach().cpu().numpy() | |
w2 = [w2]*model.get_max_latents() # one per layer | |
im2 = model.sample_np(w2) | |
combined_im = np.concatenate([im1, im2], axis=1) | |
input_im = Image.fromarray((combined_im * 255).astype(np.uint8)) | |
mixed_w = mix_w(w1, w2, content, style) | |
return input_im, display_sample_pytorch(seed1, truncation, directions, distances, scale, int(start_layer), int(end_layer), w=mixed_w, disp=False) | |
truncation = gr.inputs.Slider( | |
minimum=0, maximum=1, default=0.5, label="Truncation") | |
start_layer = gr.inputs.Number(default=3, label="Start Layer") | |
end_layer = gr.inputs.Number(default=14, label="End Layer") | |
seed1 = gr.inputs.Number(default=0, label="Seed 1") | |
seed2 = gr.inputs.Number(default=0, label="Seed 2") | |
content = gr.inputs.Slider( | |
label="Structure", minimum=0, maximum=1, default=0.5) | |
style = gr.inputs.Slider(label="Style", minimum=0, maximum=1, default=0.5) | |
slider_max_val = 20 | |
slider_min_val = -20 | |
slider_step = 1 | |
c0 = gr.inputs.Slider(label="Sleeve & Size", | |
minimum=slider_min_val, maximum=slider_max_val, default=0) | |
c1 = gr.inputs.Slider(label="Dress - Jacket", | |
minimum=slider_min_val, maximum=slider_max_val, default=0) | |
c2 = gr.inputs.Slider( | |
label="Female Coat", minimum=slider_min_val, maximum=slider_max_val, default=0) | |
c3 = gr.inputs.Slider(label="Coat", minimum=slider_min_val, | |
maximum=slider_max_val, default=0) | |
c4 = gr.inputs.Slider(label="Graphics", minimum=slider_min_val, | |
maximum=slider_max_val, default=0) | |
c5 = gr.inputs.Slider(label="Dark", minimum=slider_min_val, | |
maximum=slider_max_val, default=0) | |
c6 = gr.inputs.Slider(label="Less Cleavage", | |
minimum=slider_min_val, maximum=slider_max_val, default=0) | |
scale = 1 | |
inputs = [seed1, seed2, content, style, truncation, | |
c0, c1, c2, c3, c4, c5, c6, start_layer, end_layer] | |
description = "Change the seed number to generate different parent design." | |
gr.Interface(generate_image, inputs, [ | |
"image", "image"], description=description, live=True, title="").launch() | |