Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# NVIDIA CORPORATION and its licensors retain all intellectual property | |
# and proprietary rights in and to this software, related documentation | |
# and any modifications thereto. Any use, reproduction, disclosure or | |
# distribution of this software and related documentation without an express | |
# license agreement from NVIDIA CORPORATION is strictly prohibited. | |
import sys | |
import copy | |
import traceback | |
import numpy as np | |
import torch | |
import torch.fft | |
import torch.nn | |
import matplotlib.cm | |
import dnnlib | |
import torch.nn.functional as F | |
from torch_utils import misc | |
from torch_utils.ops import upfirdn2d | |
from training.networks import Generator | |
import legacy # pylint: disable=import-error | |
#---------------------------------------------------------------------------- | |
class CapturedException(Exception): | |
def __init__(self, msg=None): | |
if msg is None: | |
_type, value, _traceback = sys.exc_info() | |
assert value is not None | |
if isinstance(value, CapturedException): | |
msg = str(value) | |
else: | |
msg = traceback.format_exc() | |
assert isinstance(msg, str) | |
super().__init__(msg) | |
#---------------------------------------------------------------------------- | |
class CaptureSuccess(Exception): | |
def __init__(self, out): | |
super().__init__() | |
self.out = out | |
#---------------------------------------------------------------------------- | |
def _sinc(x): | |
y = (x * np.pi).abs() | |
z = torch.sin(y) / y.clamp(1e-30, float('inf')) | |
return torch.where(y < 1e-30, torch.ones_like(x), z) | |
def _lanczos_window(x, a): | |
x = x.abs() / a | |
return torch.where(x < 1, _sinc(x), torch.zeros_like(x)) | |
#---------------------------------------------------------------------------- | |
def _construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): | |
assert a <= amax < aflt | |
mat = torch.as_tensor(mat).to(torch.float32) | |
# Construct 2D filter taps in input & output coordinate spaces. | |
taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) | |
yi, xi = torch.meshgrid(taps, taps) | |
xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) | |
# Convolution of two oriented 2D sinc filters. | |
fi = _sinc(xi * cutoff_in) * _sinc(yi * cutoff_in) | |
fo = _sinc(xo * cutoff_out) * _sinc(yo * cutoff_out) | |
f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real | |
# Convolution of two oriented 2D Lanczos windows. | |
wi = _lanczos_window(xi, a) * _lanczos_window(yi, a) | |
wo = _lanczos_window(xo, a) * _lanczos_window(yo, a) | |
w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real | |
# Construct windowed FIR filter. | |
f = f * w | |
# Finalize. | |
c = (aflt - amax) * up | |
f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] | |
f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) | |
f = f / f.sum([0,2], keepdim=True) / (up ** 2) | |
f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] | |
return f | |
#---------------------------------------------------------------------------- | |
def _apply_affine_transformation(x, mat, up=4, **filter_kwargs): | |
_N, _C, H, W = x.shape | |
mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) | |
# Construct filter. | |
f = _construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) | |
assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 | |
p = f.shape[0] // 2 | |
# Construct sampling grid. | |
theta = mat.inverse() | |
theta[:2, 2] *= 2 | |
theta[0, 2] += 1 / up / W | |
theta[1, 2] += 1 / up / H | |
theta[0, :] *= W / (W + p / up * 2) | |
theta[1, :] *= H / (H + p / up * 2) | |
theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) | |
g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) | |
# Resample image. | |
y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) | |
z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) | |
# Form mask. | |
m = torch.zeros_like(y) | |
c = p * 2 + 1 | |
m[:, :, c:-c, c:-c] = 1 | |
m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) | |
return z, m | |
#---------------------------------------------------------------------------- | |
def set_random_seed(seed): | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
class Renderer: | |
def __init__(self): | |
self._device = torch.device('cuda') | |
self._pkl_data = dict() # {pkl: dict | CapturedException, ...} | |
self._networks = dict() # {cache_key: torch.nn.Module, ...} | |
self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} | |
self._cmaps = dict() # {name: torch.Tensor, ...} | |
self._is_timing = False | |
self._start_event = torch.cuda.Event(enable_timing=True) | |
self._end_event = torch.cuda.Event(enable_timing=True) | |
self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...} | |
def render(self, **args): | |
self._is_timing = True | |
self._start_event.record(torch.cuda.current_stream(self._device)) | |
res = dnnlib.EasyDict() | |
try: | |
self._render_impl(res, **args) | |
except: | |
res.error = CapturedException() | |
self._end_event.record(torch.cuda.current_stream(self._device)) | |
if 'error' in res: | |
res.error = str(res.error) | |
if self._is_timing: | |
self._end_event.synchronize() | |
res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3 | |
self._is_timing = False | |
return res | |
def get_network(self, pkl, key, **tweak_kwargs): | |
data = self._pkl_data.get(pkl, None) | |
if data is None: | |
print(f'Loading "{pkl}"... ', end='', flush=True) | |
try: | |
with dnnlib.util.open_url(pkl, verbose=False) as f: | |
data = legacy.load_network_pkl(f) | |
print('Done.') | |
except: | |
data = CapturedException() | |
print('Failed!') | |
self._pkl_data[pkl] = data | |
self._ignore_timing() | |
if isinstance(data, CapturedException): | |
raise data | |
orig_net = data[key] | |
cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items()))) | |
net = self._networks.get(cache_key, None) | |
if net is None: | |
try: | |
net = copy.deepcopy(orig_net) | |
net = self._tweak_network(net, **tweak_kwargs) | |
net.to(self._device) | |
except: | |
net = CapturedException() | |
self._networks[cache_key] = net | |
self._ignore_timing() | |
if isinstance(net, CapturedException): | |
raise net | |
return net | |
def get_camera_traj(self, gen, pitch, yaw, fov=12, batch_size=1, model_name='FFHQ512'): | |
range_u, range_v = gen.C.range_u, gen.C.range_v | |
if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option? | |
yaw, pitch = 0.5 * yaw, 0.3 * pitch | |
pitch = pitch + np.pi/2 | |
u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) | |
v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) | |
else: | |
u = (yaw + 1) / 2 | |
v = (pitch + 1) / 2 | |
cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=self._device, fov=fov) | |
return cam | |
def _tweak_network(self, net): | |
# Print diagnostics. | |
#for name, value in misc.named_params_and_buffers(net): | |
# if name.endswith('.magnitude_ema'): | |
# value = value.rsqrt().numpy() | |
# print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}') | |
# if name.endswith('.weight') and value.ndim == 4: | |
# value = value.square().mean([1,2,3]).sqrt().numpy() | |
# print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}') | |
return net | |
def _get_pinned_buf(self, ref): | |
key = (tuple(ref.shape), ref.dtype) | |
buf = self._pinned_bufs.get(key, None) | |
if buf is None: | |
buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory() | |
self._pinned_bufs[key] = buf | |
return buf | |
def to_device(self, buf): | |
return self._get_pinned_buf(buf).copy_(buf).to(self._device) | |
def to_cpu(self, buf): | |
return self._get_pinned_buf(buf).copy_(buf).clone() | |
def _ignore_timing(self): | |
self._is_timing = False | |
def _apply_cmap(self, x, name='viridis'): | |
cmap = self._cmaps.get(name, None) | |
if cmap is None: | |
cmap = matplotlib.cm.get_cmap(name) | |
cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3] | |
cmap = self.to_device(torch.from_numpy(cmap)) | |
self._cmaps[name] = cmap | |
hi = cmap.shape[0] - 1 | |
x = (x * hi + 0.5).clamp(0, hi).to(torch.int64) | |
x = torch.nn.functional.embedding(x, cmap) | |
return x | |
def _render_impl(self, res, | |
pkl = None, | |
w0_seeds = [[0, 1]], | |
stylemix_idx = [], | |
stylemix_seed = 0, | |
trunc_psi = 1, | |
trunc_cutoff = 0, | |
random_seed = 0, | |
noise_mode = 'const', | |
force_fp32 = False, | |
layer_name = None, | |
sel_channels = 3, | |
base_channel = 0, | |
img_scale_db = 0, | |
img_normalize = False, | |
fft_show = False, | |
fft_all = True, | |
fft_range_db = 50, | |
fft_beta = 8, | |
input_transform = None, | |
untransform = False, | |
camera = None, | |
output_lowres = False, | |
**unused, | |
): | |
# Dig up network details. | |
_G = self.get_network(pkl, 'G_ema') | |
try: | |
G = Generator(*_G.init_args, **_G.init_kwargs).to(self._device) | |
misc.copy_params_and_buffers(_G, G, require_all=False) | |
except Exception: | |
G = _G | |
G.eval() | |
res.img_resolution = G.img_resolution | |
res.num_ws = G.num_ws | |
res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers()) | |
res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform')) | |
# Set input transform. | |
if res.has_input_transform: | |
m = np.eye(3) | |
try: | |
if input_transform is not None: | |
m = np.linalg.inv(np.asarray(input_transform)) | |
except np.linalg.LinAlgError: | |
res.error = CapturedException() | |
G.synthesis.input.transform.copy_(torch.from_numpy(m)) | |
# Generate random latents. | |
all_seeds = [seed for seed, _weight in w0_seeds] + [stylemix_seed] | |
all_seeds = list(set(all_seeds)) | |
all_zs = np.zeros([len(all_seeds), G.z_dim], dtype=np.float32) | |
all_cs = np.zeros([len(all_seeds), G.c_dim], dtype=np.float32) | |
for idx, seed in enumerate(all_seeds): | |
rnd = np.random.RandomState(seed) | |
all_zs[idx] = rnd.randn(G.z_dim) | |
if G.c_dim > 0: | |
all_cs[idx, rnd.randint(G.c_dim)] = 1 | |
# Run mapping network. | |
w_avg = G.mapping.w_avg | |
all_zs = self.to_device(torch.from_numpy(all_zs)) | |
all_cs = self.to_device(torch.from_numpy(all_cs)) | |
all_ws = G.mapping(z=all_zs, c=all_cs, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) - w_avg | |
all_ws = dict(zip(all_seeds, all_ws)) | |
# Calculate final W. | |
w = torch.stack([all_ws[seed] * weight for seed, weight in w0_seeds]).sum(dim=0, keepdim=True) | |
stylemix_idx = [idx for idx in stylemix_idx if 0 <= idx < G.num_ws] | |
if len(stylemix_idx) > 0: | |
w[:, stylemix_idx] = all_ws[stylemix_seed][np.newaxis, stylemix_idx] | |
w += w_avg | |
# Run synthesis network. | |
synthesis_kwargs = dnnlib.EasyDict(noise_mode=noise_mode, force_fp32=force_fp32) | |
set_random_seed(random_seed) | |
if hasattr(G.synthesis, 'C'): | |
synthesis_kwargs.update({'camera_matrices': camera}) | |
out, out_lowres, layers = self.run_synthesis_net(G.synthesis, w, capture_layer=layer_name, **synthesis_kwargs) | |
# Update layer list. | |
cache_key = (G.synthesis, tuple(sorted(synthesis_kwargs.items()))) | |
if cache_key not in self._net_layers: | |
self._net_layers = dict() | |
if layer_name is not None: | |
torch.manual_seed(random_seed) | |
_out, _out2, layers = self.run_synthesis_net(G.synthesis, w, **synthesis_kwargs) | |
self._net_layers[cache_key] = layers | |
res.layers = self._net_layers[cache_key] | |
# Untransform. | |
if untransform and res.has_input_transform: | |
out, _mask = _apply_affine_transformation(out.to(torch.float32), G.synthesis.input.transform, amax=6) # Override amax to hit the fast path in upfirdn2d. | |
# Select channels and compute statistics. | |
if output_lowres and out_lowres is not None: | |
out = torch.cat([out, F.interpolate(out_lowres, out.size(-1), mode='nearest')], -1) | |
out = out[0].to(torch.float32) | |
if sel_channels > out.shape[0]: | |
sel_channels = 1 | |
base_channel = max(min(base_channel, out.shape[0] - sel_channels), 0) | |
sel = out[base_channel : base_channel + sel_channels] | |
res.stats = torch.stack([ | |
out.mean(), sel.mean(), | |
out.std(), sel.std(), | |
out.norm(float('inf')), sel.norm(float('inf')), | |
]) | |
res.stats = self.to_cpu(res.stats).numpy() # move to cpu | |
# Scale and convert to uint8. | |
img = sel | |
if img_normalize: | |
img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8) | |
img = img * (10 ** (img_scale_db / 20)) | |
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) | |
res.image = img | |
# FFT. | |
if fft_show: | |
sig = out if fft_all else sel | |
sig = sig.to(torch.float32) | |
sig = sig - sig.mean(dim=[1,2], keepdim=True) | |
sig = sig * torch.kaiser_window(sig.shape[1], periodic=False, beta=fft_beta, device=self._device)[None, :, None] | |
sig = sig * torch.kaiser_window(sig.shape[2], periodic=False, beta=fft_beta, device=self._device)[None, None, :] | |
fft = torch.fft.fftn(sig, dim=[1,2]).abs().square().sum(dim=0) | |
fft = fft.roll(shifts=[fft.shape[0] // 2, fft.shape[1] // 2], dims=[0,1]) | |
fft = (fft / fft.mean()).log10() * 10 # dB | |
fft = self._apply_cmap((fft / fft_range_db + 1) / 2) | |
res.image = torch.cat([img.expand_as(fft), fft], dim=1) | |
res.image = self.to_cpu(res.image).numpy() # move to cpu | |
def run_synthesis_net(self, net, *args, capture_layer=None, **kwargs): # => out, layers | |
submodule_names = {mod: name for name, mod in net.named_modules()} | |
unique_names = set() | |
layers = [] | |
def module_hook(module, _inputs, outputs): | |
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] | |
outputs = [out for out in outputs if isinstance(out, torch.Tensor) and out.ndim in [4, 5]] | |
for idx, out in enumerate(outputs): | |
if out.ndim == 5: # G-CNN => remove group dimension. | |
out = out.mean(2) | |
name = submodule_names[module] | |
if name == '': | |
name = 'output' | |
if len(outputs) > 1: | |
name += f':{idx}' | |
if name in unique_names: | |
suffix = 2 | |
while f'{name}_{suffix}' in unique_names: | |
suffix += 1 | |
name += f'_{suffix}' | |
unique_names.add(name) | |
shape = [int(x) for x in out.shape] | |
dtype = str(out.dtype).split('.')[-1] | |
layers.append(dnnlib.EasyDict(name=name, shape=shape, dtype=dtype)) | |
if name == capture_layer: | |
raise CaptureSuccess(out) | |
hooks = [] | |
hooks = [module.register_forward_hook(module_hook) for module in net.modules()] | |
try: | |
if 'camera_matrices' in kwargs: | |
kwargs['camera_matrices'] = self.get_camera_traj(net, *kwargs['camera_matrices']) | |
out = net(*args, **kwargs) | |
out_lowres = None | |
if isinstance(out, dict): | |
if 'img_nerf' in out: | |
out_lowres = out['img_nerf'] | |
out = out['img'] | |
except CaptureSuccess as e: | |
out = e.out | |
out_lowres = None | |
for hook in hooks: | |
hook.remove() | |
return out, out_lowres, layers | |
#---------------------------------------------------------------------------- | |