Spaces:
Runtime error
Runtime error
from argparse import Namespace | |
import time | |
import torch | |
import torchvision.transforms as transforms | |
import dlib | |
import numpy as np | |
from PIL import Image | |
from pixel2style2pixel.utils.common import tensor2im | |
from pixel2style2pixel.models.psp import pSp | |
from pixel2style2pixel.scripts.align_all_parallel import align_face | |
class InversionModel: | |
def __init__(self, checkpoint_path: str, dlib_path: str) -> None: | |
self.dlib_path = dlib_path | |
self.dlib_predictor = dlib.shape_predictor(dlib_path) | |
self.tranform_image = transforms.Compose( | |
[ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
ckpt = torch.load(checkpoint_path, map_location="cpu") | |
opts = ckpt["opts"] | |
opts["checkpoint_path"] = checkpoint_path | |
opts["learn_in_w"] = False | |
opts["output_size"] = 1024 | |
self.opts = Namespace(**opts) | |
self.net = pSp(self.opts) | |
self.net.eval() | |
self.net.cuda() | |
print("Model successfully loaded!") | |
def run_alignment(self, image_path: str): | |
aligned_image = align_face(filepath=image_path, predictor=self.dlib_predictor) | |
print("Aligned image has shape: {}".format(aligned_image.size)) | |
return aligned_image | |
def inference(self, image_path: str): | |
input_image = self.run_alignment(image_path) | |
input_image = input_image.resize((256, 256)) | |
transformed_image = self.tranform_image(input_image) | |
with torch.no_grad(): | |
tic = time.time() | |
result_image, latents = self.net( | |
transformed_image.unsqueeze(0).to("cuda").float(), | |
return_latents=True, | |
randomize_noise=False, | |
) | |
toc = time.time() | |
print("Inference took {:.4f} seconds.".format(toc - tic)) | |
res_image = tensor2im(result_image[0]) | |
return ( | |
res_image, | |
{ | |
"w1": latents.cpu().detach().numpy(), | |
"w1_initial": latents.cpu().detach().numpy(), | |
}, | |
) | |