|
|
|
|
|
import os |
|
import sys |
|
import time |
|
from typing import List, Optional, Tuple |
|
import copy |
|
|
|
import numpy as np |
|
import PIL |
|
import torch |
|
|
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
stylegan2_dir = os.path.join(CURRENT_DIR, "stylegan2") |
|
sys.path.insert(0, stylegan2_dir) |
|
import dnnlib |
|
import legacy |
|
from . import utils |
|
|
|
def load_model( |
|
network_pkl: str = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqdog.pkl", |
|
device: torch.device = torch.device("cuda"), |
|
fp16: bool = True, |
|
) -> torch.nn.Module: |
|
""" |
|
Loads a pretrained StyleGAN2-ADA generator network from a pickle file. |
|
|
|
Args: |
|
network_pkl (str): The URL or local path to the network pickle file. |
|
device (torch.device): The device to use for the computation. |
|
fp16 (bool): Whether to use half-precision floating point format for the network weights. |
|
|
|
Returns: |
|
The pretrained generator network. |
|
""" |
|
print('Loading networks from "%s"...' % network_pkl) |
|
with dnnlib.util.open_url(network_pkl) as f: |
|
chkpt = legacy.load_network_pkl(f, force_fp16=fp16) |
|
G = chkpt["G_ema"].to(device).eval() |
|
for param in G.parameters(): |
|
param.requires_grad_(False) |
|
|
|
|
|
|
|
G.__setattr__("activations", None) |
|
|
|
|
|
def hook(module, input, output): |
|
G.activations = output |
|
|
|
|
|
for i, (name, module) in enumerate(G.synthesis.named_children()): |
|
if i == 6: |
|
print("Registering hook for:", name) |
|
module.register_forward_hook(hook) |
|
|
|
return G |
|
|
|
|
|
def register_hook(G): |
|
|
|
|
|
G.__setattr__("activations", None) |
|
|
|
|
|
def hook(module, input, output): |
|
G.activations = output |
|
|
|
|
|
for i, (name, module) in enumerate(G.synthesis.named_children()): |
|
if i == 6: |
|
print("Registering hook for:", name) |
|
module.register_forward_hook(hook) |
|
return G |
|
|
|
|
|
def generate_W( |
|
_G: torch.nn.Module, |
|
seed: int = 0, |
|
network_pkl: Optional[str] = None, |
|
truncation_psi: float = 1.0, |
|
truncation_cutoff: Optional[int] = None, |
|
device: torch.device = torch.device("cuda"), |
|
) -> np.ndarray: |
|
""" |
|
Generates a latent code tensor in W+ space from a pretrained StyleGAN2-ADA generator network. |
|
|
|
Args: |
|
_G (torch.nn.Module): The generator network, with underscore to avoid streamlit cache error |
|
seed (int): The random seed to use for generating the latent code. |
|
network_pkl (Optional[str]): The path to the network pickle file. If None, the default network will be used. |
|
truncation_psi (float): The truncation psi value to use for the mapping network. |
|
truncation_cutoff (Optional[int]): The number of layers to use for the truncation trick. If None, all layers will be used. |
|
device (torch.device): The device to use for the computation. |
|
|
|
Returns: |
|
The W+ latent as a numpy array of shape [1, num_layers, 512]. |
|
""" |
|
G = _G |
|
torch.manual_seed(seed) |
|
z = torch.randn(1, G.z_dim).to(device) |
|
num_layers = G.synthesis.num_ws |
|
if truncation_cutoff == -1: |
|
truncation_cutoff = None |
|
elif truncation_cutoff is not None: |
|
truncation_cutoff = min(num_layers, truncation_cutoff) |
|
W = G.mapping( |
|
z, |
|
None, |
|
truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
) |
|
return W.cpu().numpy() |
|
|
|
|
|
def forward_G( |
|
G: torch.nn.Module, |
|
W: torch.Tensor, |
|
device: torch.device, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
Forward pass through the generator network. |
|
|
|
Args: |
|
G (torch.nn.Module): The generator network. |
|
W (torch.Tensor): The latent code tensor of shape [batch_size, latent_dim, 512]. |
|
device (torch.device): The device to use for the computation. |
|
|
|
Returns: |
|
A tuple containing the generated image tensor of shape [batch_size, 3, height, width] |
|
and the feature maps tensor of shape [batch_size, num_channels, height, width]. |
|
""" |
|
register_hook(G) |
|
|
|
if not isinstance(W, torch.Tensor): |
|
W = torch.from_numpy(W).to(device) |
|
|
|
img = G.synthesis(W, noise_mode="const", force_fp32=True) |
|
|
|
return img, G.activations[0] |
|
|
|
|
|
def generate_image( |
|
W, |
|
_G: Optional[torch.nn.Module] = None, |
|
network_pkl: Optional[str] = None, |
|
class_idx=None, |
|
device=torch.device("cuda"), |
|
) -> Tuple[PIL.Image.Image, torch.Tensor]: |
|
""" |
|
Generates an image using a pretrained generator network. |
|
|
|
Args: |
|
W (torch.Tensor): A tensor of latent codes of shape [batch_size, latent_dim, 512]. |
|
_G (Optional[torch.nn.Module]): The generator network. If None, the network will be loaded from `network_pkl`. |
|
network_pkl (Optional[str]): The path to the network pickle file. If None, the default network will be used. |
|
class_idx (Optional[int]): The class index to use for conditional generation. If None, unconditional generation will be used. |
|
device (str): The device to use for the computation. |
|
|
|
Returns: |
|
A tuple containing the generated image as a PIL Image object and the feature maps tensor of shape [batch_size, num_channels, height, width]. |
|
""" |
|
if _G is None: |
|
assert network_pkl is not None |
|
_G = load_model(network_pkl, device) |
|
G = _G |
|
|
|
|
|
label = torch.zeros([1, G.c_dim], device=device) |
|
if G.c_dim != 0: |
|
if class_idx is None: |
|
raise Exception( |
|
"Must specify class label with --class when using a conditional network" |
|
) |
|
label[:, class_idx] = 1 |
|
else: |
|
if class_idx is not None: |
|
print("warn: --class=lbl ignored when running on an unconditional network") |
|
|
|
|
|
img, features = forward_G(G, W, device) |
|
|
|
img = utils.tensor_to_PIL(img) |
|
|
|
return img, features |
|
|
|
|
|
def drag_gan( |
|
W, |
|
G, |
|
handle_points, |
|
target_points, |
|
mask, |
|
max_iters=1000, |
|
r1=3, |
|
r2=12, |
|
lam=20, |
|
d=2, |
|
lr=2e-3, |
|
): |
|
|
|
handle_points0 = copy.deepcopy(handle_points) |
|
handle_points = torch.stack(handle_points) |
|
handle_points0 = torch.stack(handle_points0) |
|
target_points = torch.stack(target_points) |
|
|
|
device = torch.device("cuda") |
|
|
|
img, F0 = forward_G(G, W, device) |
|
|
|
target_resolution = img.shape[-1] |
|
F0_resized = torch.nn.functional.interpolate( |
|
F0, |
|
size=(target_resolution, target_resolution), |
|
mode="bilinear", |
|
align_corners=True, |
|
).detach() |
|
|
|
W = torch.from_numpy(W).to(device).float() |
|
W.requires_grad_(False) |
|
|
|
|
|
W_layers_to_optimize = W[:, :6].clone() |
|
W_layers_to_optimize.requires_grad_(True) |
|
|
|
optimizer = torch.optim.Adam([W_layers_to_optimize], lr=lr) |
|
|
|
for _ in range(max_iters): |
|
start = time.perf_counter() |
|
if torch.allclose(handle_points, target_points, atol=d): |
|
break |
|
|
|
optimizer.zero_grad() |
|
W_combined = torch.cat([W_layers_to_optimize, W[:, 6:].detach()], dim=1) |
|
|
|
img, F = forward_G(G, W_combined, device) |
|
F_resized = torch.nn.functional.interpolate( |
|
F, |
|
size=(target_resolution, target_resolution), |
|
mode="bilinear", |
|
align_corners=True, |
|
) |
|
|
|
|
|
loss = motion_supervison(handle_points, target_points, F_resized, r1, device) |
|
|
|
|
|
|
|
|
|
loss.backward() |
|
optimizer.step() |
|
|
|
print( |
|
f"Loss: {loss.item():0.2f}\tTime: {(time.perf_counter() - start) * 1000:.0f}ms" |
|
) |
|
|
|
with torch.no_grad(): |
|
img, F = forward_G(G, W_combined, device) |
|
handle_points = point_tracking(F_resized, F0_resized, handle_points, handle_points0, r2, device) |
|
|
|
|
|
|
|
W_out = torch.cat([W_layers_to_optimize, W[:, 6:]], dim=1).detach().cpu().numpy() |
|
|
|
img = utils.tensor_to_PIL(img) |
|
yield img, W_out, handle_points |
|
|
|
|
|
def motion_supervison(handle_points, target_points, F, r1, device): |
|
loss = 0 |
|
n = len(handle_points) |
|
for i in range(n): |
|
target2handle = target_points[i] - handle_points[i] |
|
d_i = target2handle / (torch.norm(target2handle) + 1e-7) |
|
if torch.norm(d_i) > torch.norm(target2handle): |
|
d_i = target2handle |
|
|
|
mask = utils.create_circular_mask( |
|
F.shape[2], F.shape[3], center=handle_points[i].tolist(), radius=r1 |
|
).to(device) |
|
|
|
coordinates = torch.nonzero(mask).float() |
|
|
|
|
|
shifted_coordinates = coordinates + d_i[None] |
|
|
|
h, w = F.shape[2], F.shape[3] |
|
|
|
|
|
F_qi = F[:, :, mask] |
|
|
|
|
|
normalized_shifted_coordinates = shifted_coordinates.clone() |
|
normalized_shifted_coordinates[:, 0] = ( |
|
2.0 * shifted_coordinates[:, 0] / (h - 1) |
|
) - 1 |
|
normalized_shifted_coordinates[:, 1] = ( |
|
2.0 * shifted_coordinates[:, 1] / (w - 1) |
|
) - 1 |
|
|
|
normalized_shifted_coordinates = normalized_shifted_coordinates.unsqueeze( |
|
0 |
|
).unsqueeze( |
|
0 |
|
) |
|
normalized_shifted_coordinates = normalized_shifted_coordinates.flip( |
|
-1 |
|
) |
|
normalized_shifted_coordinates = normalized_shifted_coordinates.clamp(-1, 1) |
|
|
|
|
|
F_qi_plus_di = torch.nn.functional.grid_sample( |
|
F, normalized_shifted_coordinates, mode="bilinear", align_corners=True |
|
) |
|
|
|
F_qi_plus_di = F_qi_plus_di.squeeze(2) |
|
|
|
loss += torch.nn.functional.l1_loss(F_qi.detach(), F_qi_plus_di) |
|
return loss |
|
|
|
|
|
def point_tracking( |
|
F: torch.Tensor, |
|
F0: torch.Tensor, |
|
handle_points: torch.Tensor, |
|
handle_points0: torch.Tensor, |
|
r2: int = 3, |
|
device: torch.device = torch.device("cuda"), |
|
) -> torch.Tensor: |
|
|
|
n = handle_points.shape[0] |
|
new_handle_points = torch.zeros_like(handle_points) |
|
|
|
for i in range(n): |
|
|
|
patch = utils.create_square_mask( |
|
F.shape[2], F.shape[3], center=handle_points[i].tolist(), radius=r2 |
|
).to(device) |
|
|
|
|
|
patch_coordinates = torch.nonzero(patch) |
|
|
|
|
|
F_qi = F[:, :, patch_coordinates[:, 0], patch_coordinates[:, 1]] |
|
|
|
f_i = F0[:, :, handle_points0[i][0].long(), handle_points0[i][1].long()] |
|
|
|
|
|
distances = torch.norm(F_qi - f_i[:, :, None], p=1, dim=1) |
|
|
|
|
|
min_index = torch.argmin(distances) |
|
new_handle_points[i] = patch_coordinates[min_index] |
|
|
|
return new_handle_points |
|
|