Spaces:
Runtime error
Runtime error
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. | |
# | |
# This work is licensed under the Creative Commons Attribution-NonCommercial | |
# 4.0 International License. To view a copy of this license, visit | |
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to | |
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. | |
"""Minimal script for generating an image using pre-trained StyleGAN generator.""" | |
import os | |
import pickle | |
import numpy as np | |
import PIL.Image | |
import dnnlib | |
import dnnlib.tflib as tflib | |
import config | |
def main(): | |
# Initialize TensorFlow. | |
tflib.init_tf() | |
# Load pre-trained network. | |
url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl | |
with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: | |
_G, _D, Gs = pickle.load(f) | |
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. | |
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. | |
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. | |
# Print network details. | |
Gs.print_layers() | |
# Pick latent vector. | |
rnd = np.random.RandomState(5) | |
latents = rnd.randn(1, Gs.input_shape[1]) | |
# Generate image. | |
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) | |
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) | |
# Save image. | |
os.makedirs(config.result_dir, exist_ok=True) | |
png_filename = os.path.join(config.result_dir, 'example.png') | |
PIL.Image.fromarray(images[0], 'RGB').save(png_filename) | |
if __name__ == "__main__": | |
main() | |