File size: 6,666 Bytes
ff2b8e3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
"""SeFa."""
import os
import argparse
from tqdm import tqdm
import numpy as np
import torch
from models import parse_gan_type
from utils import to_tensor
from utils import postprocess
from utils import load_generator
from utils import factorize_weight
from utils import HtmlPageVisualizer
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser(
description='Discover semantics from the pre-trained weight.')
parser.add_argument('model_name', type=str,
help='Name to the pre-trained model.')
parser.add_argument('--save_dir', type=str, default='results',
help='Directory to save the visualization pages. '
'(default: %(default)s)')
parser.add_argument('-L', '--layer_idx', type=str, default='all',
help='Indices of layers to interpret. '
'(default: %(default)s)')
parser.add_argument('-N', '--num_samples', type=int, default=5,
help='Number of samples used for visualization. '
'(default: %(default)s)')
parser.add_argument('-K', '--num_semantics', type=int, default=5,
help='Number of semantic boundaries corresponding to '
'the top-k eigen values. (default: %(default)s)')
parser.add_argument('--start_distance', type=float, default=-3.0,
help='Start point for manipulation on each semantic. '
'(default: %(default)s)')
parser.add_argument('--end_distance', type=float, default=3.0,
help='Ending point for manipulation on each semantic. '
'(default: %(default)s)')
parser.add_argument('--step', type=int, default=11,
help='Manipulation step on each semantic. '
'(default: %(default)s)')
parser.add_argument('--viz_size', type=int, default=256,
help='Size of images to visualize on the HTML page. '
'(default: %(default)s)')
parser.add_argument('--trunc_psi', type=float, default=0.7,
help='Psi factor used for truncation. This is '
'particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
parser.add_argument('--trunc_layers', type=int, default=8,
help='Number of layers to perform truncation. This is '
'particularly applicable to StyleGAN (v1/v2). '
'(default: %(default)s)')
parser.add_argument('--seed', type=int, default=0,
help='Seed for sampling. (default: %(default)s)')
parser.add_argument('--gpu_id', type=str, default='0',
help='GPU(s) to use. (default: %(default)s)')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id
os.makedirs(args.save_dir, exist_ok=True)
# Factorize weights.
generator = load_generator(args.model_name)
gan_type = parse_gan_type(generator)
layers, boundaries, values = factorize_weight(generator, args.layer_idx)
# Set random seed.
np.random.seed(args.seed)
torch.manual_seed(args.seed)
# Prepare codes.
codes = torch.randn(args.num_samples, generator.z_space_dim).cuda()
if gan_type == 'pggan':
codes = generator.layer0.pixel_norm(codes)
elif gan_type in ['stylegan', 'stylegan2']:
codes = generator.mapping(codes)['w']
codes = generator.truncation(codes,
trunc_psi=args.trunc_psi,
trunc_layers=args.trunc_layers)
codes = codes.detach().cpu().numpy()
# Generate visualization pages.
distances = np.linspace(args.start_distance,args.end_distance, args.step)
num_sam = args.num_samples
num_sem = args.num_semantics
vizer_1 = HtmlPageVisualizer(num_rows=num_sem * (num_sam + 1),
num_cols=args.step + 1,
viz_size=args.viz_size)
vizer_2 = HtmlPageVisualizer(num_rows=num_sam * (num_sem + 1),
num_cols=args.step + 1,
viz_size=args.viz_size)
headers = [''] + [f'Distance {d:.2f}' for d in distances]
vizer_1.set_headers(headers)
vizer_2.set_headers(headers)
for sem_id in range(num_sem):
value = values[sem_id]
vizer_1.set_cell(sem_id * (num_sam + 1), 0,
text=f'Semantic {sem_id:03d}<br>({value:.3f})',
highlight=True)
for sam_id in range(num_sam):
vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, 0,
text=f'Sample {sam_id:03d}')
for sam_id in range(num_sam):
vizer_2.set_cell(sam_id * (num_sem + 1), 0,
text=f'Sample {sam_id:03d}',
highlight=True)
for sem_id in range(num_sem):
value = values[sem_id]
vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, 0,
text=f'Semantic {sem_id:03d}<br>({value:.3f})')
for sam_id in tqdm(range(num_sam), desc='Sample ', leave=False):
code = codes[sam_id:sam_id + 1]
for sem_id in tqdm(range(num_sem), desc='Semantic ', leave=False):
boundary = boundaries[sem_id:sem_id + 1]
for col_id, d in enumerate(distances, start=1):
temp_code = code.copy()
if gan_type == 'pggan':
temp_code += boundary * d
image = generator(to_tensor(temp_code))['image']
elif gan_type in ['stylegan', 'stylegan2']:
temp_code[:, layers, :] += boundary * d
image = generator.synthesis(to_tensor(temp_code))['image']
image = postprocess(image)[0]
vizer_1.set_cell(sem_id * (num_sam + 1) + sam_id + 1, col_id,
image=image)
vizer_2.set_cell(sam_id * (num_sem + 1) + sem_id + 1, col_id,
image=image)
prefix = (f'{args.model_name}_'
f'N{num_sam}_K{num_sem}_L{args.layer_idx}_seed{args.seed}')
vizer_1.save(os.path.join(args.save_dir, f'{prefix}_sample_first.html'))
vizer_2.save(os.path.join(args.save_dir, f'{prefix}_semantic_first.html'))
if __name__ == '__main__':
main()
|