radames's picture
add inversion
d9778ff
from argparse import Namespace
import time
import torch
import torchvision.transforms as transforms
import dlib
import numpy as np
from PIL import Image
from pixel2style2pixel.utils.common import tensor2im
from pixel2style2pixel.models.psp import pSp
from pixel2style2pixel.scripts.align_all_parallel import align_face
class InversionModel:
def __init__(self, checkpoint_path: str, dlib_path: str) -> None:
self.dlib_path = dlib_path
self.dlib_predictor = dlib.shape_predictor(dlib_path)
self.tranform_image = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
ckpt = torch.load(checkpoint_path, map_location="cpu")
opts = ckpt["opts"]
opts["checkpoint_path"] = checkpoint_path
opts["learn_in_w"] = False
opts["output_size"] = 1024
self.opts = Namespace(**opts)
self.net = pSp(self.opts)
self.net.eval()
self.net.cuda()
print("Model successfully loaded!")
def run_alignment(self, image_path: str):
aligned_image = align_face(filepath=image_path, predictor=self.dlib_predictor)
print("Aligned image has shape: {}".format(aligned_image.size))
return aligned_image
def inference(self, image_path: str):
input_image = self.run_alignment(image_path)
input_image = input_image.resize((256, 256))
transformed_image = self.tranform_image(input_image)
with torch.no_grad():
tic = time.time()
result_image, latents = self.net(
transformed_image.unsqueeze(0).to("cuda").float(),
return_latents=True,
randomize_noise=False,
)
toc = time.time()
print("Inference took {:.4f} seconds.".format(toc - tic))
res_image = tensor2im(result_image[0])
return (
res_image,
{
"w1": latents.cpu().detach().numpy(),
"w1_initial": latents.cpu().detach().numpy(),
},
)