DualStyleGAN / dualstylegan.py
hysts's picture
hysts HF staff
Update
811cb03
from __future__ import annotations
import argparse
import os
import pathlib
import shlex
import subprocess
import sys
from typing import Callable
import dlib
import huggingface_hub
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torchvision.transforms as T
if os.getenv("SYSTEM") == "spaces" and not torch.cuda.is_available():
with open("patch") as f:
subprocess.run(shlex.split("patch -p1"), cwd="DualStyleGAN", stdin=f)
app_dir = pathlib.Path(__file__).parent
submodule_dir = app_dir / "DualStyleGAN"
sys.path.insert(0, submodule_dir.as_posix())
from model.dualstylegan import DualStyleGAN
from model.encoder.align_all_parallel import align_face
from model.encoder.psp import pSp
class Model:
def __init__(self):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.landmark_model = self._create_dlib_landmark_model()
self.encoder = self._load_encoder()
self.transform = self._create_transform()
self.style_types = [
"cartoon",
"caricature",
"anime",
"arcane",
"comic",
"pixar",
"slamdunk",
]
self.generator_dict = {style_type: self._load_generator(style_type) for style_type in self.style_types}
self.exstyle_dict = {style_type: self._load_exstylecode(style_type) for style_type in self.style_types}
@staticmethod
def _create_dlib_landmark_model():
path = huggingface_hub.hf_hub_download(
"public-data/dlib_face_landmark_model", "shape_predictor_68_face_landmarks.dat"
)
return dlib.shape_predictor(path)
def _load_encoder(self) -> nn.Module:
ckpt_path = huggingface_hub.hf_hub_download("public-data/DualStyleGAN", "models/encoder.pt")
ckpt = torch.load(ckpt_path, map_location="cpu")
opts = ckpt["opts"]
opts["device"] = self.device.type
opts["checkpoint_path"] = ckpt_path
opts = argparse.Namespace(**opts)
model = pSp(opts)
model.to(self.device)
model.eval()
return model
@staticmethod
def _create_transform() -> Callable:
transform = T.Compose(
[
T.Resize(256),
T.CenterCrop(256),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
return transform
def _load_generator(self, style_type: str) -> nn.Module:
model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
ckpt_path = huggingface_hub.hf_hub_download("public-data/DualStyleGAN", f"models/{style_type}/generator.pt")
ckpt = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(ckpt["g_ema"])
model.to(self.device)
model.eval()
return model
@staticmethod
def _load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
if style_type in ["cartoon", "caricature", "anime"]:
filename = "refined_exstyle_code.npy"
else:
filename = "exstyle_code.npy"
path = huggingface_hub.hf_hub_download("public-data/DualStyleGAN", f"models/{style_type}/{filename}")
exstyles = np.load(path, allow_pickle=True).item()
return exstyles
def detect_and_align_face(self, image: str) -> np.ndarray:
image = align_face(filepath=image, predictor=self.landmark_model)
return image
@staticmethod
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
def postprocess(self, tensor: torch.Tensor) -> np.ndarray:
tensor = self.denormalize(tensor)
return tensor.cpu().numpy().transpose(1, 2, 0)
@torch.inference_mode()
def reconstruct_face(self, image: np.ndarray) -> tuple[np.ndarray, torch.Tensor]:
image = PIL.Image.fromarray(image)
input_data = self.transform(image).unsqueeze(0).to(self.device)
img_rec, instyle = self.encoder(
input_data,
randomize_noise=False,
return_latents=True,
z_plus_latent=True,
return_z_plus_latent=True,
resize=False,
)
img_rec = torch.clamp(img_rec.detach(), -1, 1)
img_rec = self.postprocess(img_rec[0])
return img_rec, instyle
@torch.inference_mode()
def generate(
self,
style_type: str,
style_id: int,
structure_weight: float,
color_weight: float,
structure_only: bool,
instyle: torch.Tensor,
) -> np.ndarray:
generator = self.generator_dict[style_type]
exstyles = self.exstyle_dict[style_type]
style_id = int(style_id)
stylename = list(exstyles.keys())[style_id]
latent = torch.tensor(exstyles[stylename]).to(self.device)
if structure_only:
latent[0, 7:18] = instyle[0, 7:18]
exstyle = generator.generator.style(
latent.reshape(latent.shape[0] * latent.shape[1], latent.shape[2])
).reshape(latent.shape)
img_gen, _ = generator(
[instyle],
exstyle,
z_plus_latent=True,
truncation=0.7,
truncation_latent=0,
use_res=True,
interp_weights=[structure_weight] * 7 + [color_weight] * 11,
)
img_gen = torch.clamp(img_gen.detach(), -1, 1)
img_gen = self.postprocess(img_gen[0])
return img_gen