import os from PIL import Image import torch import gradio as gr os.system("pip install dlib") import torch torch.backends.cudnn.benchmark = True from torchvision import transforms, utils from util import * from PIL import Image import math import random import numpy as np from torch import nn, autograd, optim from torch.nn import functional as F from tqdm import tqdm import lpips from model import * from e4e_projection import projection as e4e_projection from copy import deepcopy import imageio os.makedirs('inversion_codes', exist_ok=True) os.makedirs('style_images', exist_ok=True) os.makedirs('style_images_aligned', exist_ok=True) os.makedirs('models', exist_ok=True) os.system("wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2") os.system("bzip2 -dk shape_predictor_68_face_landmarks.dat.bz2") os.system("mv shape_predictor_68_face_landmarks.dat models/dlibshape_predictor_68_face_landmarks.dat") device = 'cpu' os.system("gdown https://drive.google.com/uc?id=1_cTsjqzD_X9DK3t3IZE53huKgnzj_btZ") latent_dim = 512 # Load original generator original_generator = Generator(1024, latent_dim, 8, 2).to(device) ckpt = torch.load('stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage) original_generator.load_state_dict(ckpt["g_ema"], strict=False) mean_latent = original_generator.mean_latent(10000) # to be finetuned generator generatorjojo = deepcopy(original_generator) generatordisney = deepcopy(original_generator) generatorjinx = deepcopy(original_generator) transform = transforms.Compose( [ transforms.Resize((1024, 1024)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ] ) os.system("gdown https://drive.google.com/uc?id=1jtCg8HQ6RlTmLdnbT2PfW1FJ2AYkWqsK") os.system("cp e4e_ffhq_encode.pt models/e4e_ffhq_encode.pt") os.system("gdown https://drive.google.com/uc?id=1-8E0PFT37v5fZs-61oIrFbNpE28Unp2y") ckptjojo = torch.load('jojo.pt', map_location=lambda storage, loc: storage) generatorjojo.load_state_dict(ckptjojo["g"], strict=False) os.system("gdown https://drive.google.com/uc?id=1Bnh02DjfvN_Wm8c4JdOiNV4q9J7Z_tsi") ckptdisney = torch.load('disney_preserve_color.pt', map_location=lambda storage, loc: storage) generatordisney.load_state_dict(ckptdisney["g"], strict=False) os.system("gdown https://drive.google.com/uc?id=1jElwHxaYPod5Itdy18izJk49K1nl4ney") ckptjinx = torch.load('arcane_jinx_preserve_color.pt', map_location=lambda storage, loc: storage) generatorjinx.load_state_dict(ckptjinx["g"], strict=False) def inference(img, model): aligned_face = align_face(img) my_w = e4e_projection(aligned_face, "test.pt", device).unsqueeze(0) if model == 'JoJo': with torch.no_grad(): my_sample = generatorjojo(my_w, input_is_latent=True) elif model == 'Disney': with torch.no_grad(): my_sample = generatordisney(my_w, input_is_latent=True) else: with torch.no_grad(): my_sample = generatorjinx(my_w, input_is_latent=True) npimage = my_sample[0].permute(1, 2, 0).detach().numpy() imageio.imwrite('filename.jpeg', npimage) return 'filename.jpeg' title = "JoJoGAN" description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below." article = "
JoJoGAN: One Shot Face Stylization| Github Repo Pytorch
samples from repo:
" examples=[['iu.jpeg','Jinx']] gr.Interface(inference, [gr.inputs.Image(type="filepath",shape=(256,256)),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,enable_queue=True,allow_flagging=False,examples=examples).launch()