PTI / scripts /pti_styleclip.py
ucalyptus's picture
simp
2d7efb8
import glob
from argparse import Namespace
from configs import paths_config
from models.StyleCLIP.mapper.scripts.inference import run
from scripts.run_pti import run_PTI
meta_data = {
'afro': ['afro', False, False, True],
'angry': ['angry', False, False, True],
'Beyonce': ['beyonce', False, False, False],
'bobcut': ['bobcut', False, False, True],
'bowlcut': ['bowlcut', False, False, True],
'curly hair': ['curly_hair', False, False, True],
'Hilary Clinton': ['hilary_clinton', False, False, False],
'Jhonny Depp': ['depp', False, False, False],
'mohawk': ['mohawk', False, False, True],
'purple hair': ['purple_hair', False, False, False],
'surprised': ['surprised', False, False, True],
'Taylor Swift': ['taylor_swift', False, False, False],
'trump': ['trump', False, False, False],
'Mark Zuckerberg': ['zuckerberg', False, False, False]
}
def styleclip_edit(use_multi_id_G, run_id, use_wandb, edit_types):
images_dir = paths_config.input_data_path
pretrained_mappers = paths_config.style_clip_pretrained_mappers
data_dir_name = paths_config.input_data_id
if run_id == '':
run_id = run_PTI(run_name='', use_wandb=use_wandb, use_multi_id_training=False)
images = glob.glob(f"{images_dir}/*.jpeg")
w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
for image_name in images:
image_name = image_name.split(".")[0].split("/")[-1]
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}'
latent_path = f'{embedding_dir}/0.pt'
for edit_type in set(meta_data.keys()).intersection(edit_types):
edit_id = meta_data[edit_type][0]
args = {
"exp_dir": f'{paths_config.styleclip_output_dir}',
"checkpoint_path": f"{pretrained_mappers}/{edit_id}.pt",
"couple_outputs": False,
"mapper_type": "LevelsMapper",
"no_coarse_mapper": meta_data[edit_type][1],
"no_medium_mapper": meta_data[edit_type][2],
"no_fine_mapper": meta_data[edit_type][3],
"stylegan_size": 1024,
"test_batch_size": 1,
"latents_test_path": latent_path,
"test_workers": 1,
"run_id": run_id,
"image_name": image_name,
'edit_name': edit_type,
"data_dir_name": data_dir_name
}
run(Namespace(**args), run_id, image_name, use_multi_id_G)