domain-expansion / generate_aligned.py
alvan
Added gradio space for domain expansion
560a1b9
# Copyright 2023 Adobe Research. All rights reserved.
# To view a copy of the license, visit LICENSE.md.
import sys
sys.path.append('..')
import argparse
from pathlib import Path
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import dnnlib
import legacy
from expansion_utils import io_utils, latent_operations
def generate_images(
ckpt,
num_samples,
truncation_psi
):
device = torch.device('cuda')
with dnnlib.util.open_url(ckpt) as f:
snapshot_dict = legacy.load_network_pkl(f)
G = snapshot_dict['G_ema'].to(device)
latent_basis = snapshot_dict['latent_basis'].to(device)
subspace_distance = snapshot_dict['subspace_distance']
repurposed_dims = snapshot_dict['repurposed_dims'].cpu()
# out_dir = Path(out_dir)
def norm_fn(tensor):
minFrom= tensor.min()
maxFrom= tensor.max()
minTo = 0
maxTo=1
return minTo + (maxTo - minTo) * ((tensor - minFrom) / (maxFrom - minFrom))
topil = T.ToPILImage(mode='RGB')
# norm_fn = T.Normalize(
# # mean=[0.485, 0.456, 0.406],
# # std=[0.229, 0.224, 0.225]
# mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
# std=[1/0.229, 1/0.224, 1/0.225]
# )
all_imgs = []
for i in range(num_samples):
per_sample_imgs = []
z = torch.randn((1, G.z_dim), ).to(device)
w = G.mapping(z, None, truncation_psi=truncation_psi)
base_w, edit_ws = latent_operations.project_to_subspaces(w, latent_basis, repurposed_dims, step_size=subspace_distance, mean=G.mapping.w_avg)
edit_ws = edit_ws[0] # Single step
base_img = G.synthesis(base_w, noise_mode='const')
per_sample_imgs.append(topil(norm_fn(base_img.squeeze())))
# io_utils.save_images(base_img, out_dir.joinpath('base', f'{i:05d}'))
for idx, (dim_num, edit_w) in enumerate(zip(repurposed_dims, edit_ws)):
# dim_out_dir = out_dir.joinpath(f'dim_{dim_num}')
if idx % 4 == 0:
edit_img = G.synthesis(edit_w, noise_mode='const')
# mean, std = edit_img.mean((0,2)), edit_img.std((0,2))
# norm_fn = T.Normalize(mean, std)
edited_img_pil = topil(
norm_fn(
edit_img.squeeze(),
)
)
per_sample_imgs.append(edited_img_pil)
# io_utils.save_images(edit_img, dim_out_dir.joinpath(f'{i:05d}'))
all_imgs.append(per_sample_imgs)
return all_imgs
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', help='Network pickle filename', required=True)
parser.add_argument('--out_dir', help='Where to save the output images', type=str, required=True, metavar='DIR')
parser.add_argument('--num', help='Number of independant samples', type=int)
parser.add_argument('--truncation_psi', help='Coefficient for truncation', type=float, default=1)
args = parser.parse_args()
with torch.no_grad():
generate_images(args.ckpt, args.out_dir, args.num, args.truncation_psi)