Spaces:
Running
on
Zero
Running
on
Zero
""" | |
https://arxiv.org/abs/2312.07537 | |
""" | |
import math | |
import torch | |
import torch.fft as fft | |
import torch.nn.functional as F | |
def freq_mix_3d(x, noise, LPF): | |
""" | |
Noise reinitialization. | |
Args: | |
x: diffused latent | |
noise: randomly sampled noise | |
LPF: low pass filter | |
""" | |
# FFT | |
x_freq = fft.fftn(x, dim=(-3, -2, -1)) | |
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1)) | |
noise_freq = fft.fftn(noise, dim=(-3, -2, -1)) | |
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1)) | |
# frequency mix | |
HPF = 1 - LPF | |
x_freq_low = x_freq * LPF | |
noise_freq_high = noise_freq * HPF | |
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain | |
# IFFT | |
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1)) | |
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real | |
return x_mixed | |
def get_freq_filter(shape, device, filter_type, n, d_s, d_t): | |
""" | |
Form the frequency filter for noise reinitialization. | |
Args: | |
shape: shape of latent (B, C, T, H, W) | |
filter_type: type of the freq filter | |
n: (only for butterworth) order of the filter, larger n ~ ideal, smaller n ~ gaussian | |
d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
""" | |
if filter_type == "gaussian": | |
return gaussian_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) | |
elif filter_type == "ideal": | |
return ideal_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) | |
elif filter_type == "box": | |
return box_low_pass_filter(shape=shape, d_s=d_s, d_t=d_t).to(device) | |
elif filter_type == "butterworth": | |
return butterworth_low_pass_filter(shape=shape, n=n, d_s=d_s, d_t=d_t).to(device) | |
else: | |
raise NotImplementedError | |
def gaussian_low_pass_filter(shape, d_s=0.25, d_t=0.25): | |
""" | |
Compute the gaussian low pass filter mask. | |
Args: | |
shape: shape of the filter (volume) | |
d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
""" | |
T, H, W = shape[-3], shape[-2], shape[-1] | |
mask = torch.zeros(shape) | |
if d_s==0 or d_t==0: | |
return mask | |
for t in range(T): | |
for h in range(H): | |
for w in range(W): | |
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) | |
mask[..., t,h,w] = math.exp(-1/(2*d_s**2) * d_square) | |
return mask | |
def butterworth_low_pass_filter(shape, n=4, d_s=0.25, d_t=0.25): | |
""" | |
Compute the butterworth low pass filter mask. | |
Args: | |
shape: shape of the filter (volume) | |
n: order of the filter, larger n ~ ideal, smaller n ~ gaussian | |
d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
""" | |
T, H, W = shape[-3], shape[-2], shape[-1] | |
mask = torch.zeros(shape) | |
if d_s==0 or d_t==0: | |
return mask | |
for t in range(T): | |
for h in range(H): | |
for w in range(W): | |
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) | |
mask[..., t,h,w] = 1 / (1 + (d_square / d_s**2)**n) | |
return mask | |
def ideal_low_pass_filter(shape, d_s=0.25, d_t=0.25): | |
""" | |
Compute the ideal low pass filter mask. | |
Args: | |
shape: shape of the filter (volume) | |
d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
""" | |
T, H, W = shape[-3], shape[-2], shape[-1] | |
mask = torch.zeros(shape) | |
if d_s==0 or d_t==0: | |
return mask | |
for t in range(T): | |
for h in range(H): | |
for w in range(W): | |
d_square = (((d_s/d_t)*(2*t/T-1))**2 + (2*h/H-1)**2 + (2*w/W-1)**2) | |
mask[..., t,h,w] = 1 if d_square <= d_s*2 else 0 | |
return mask | |
def box_low_pass_filter(shape, d_s=0.25, d_t=0.25): | |
""" | |
Compute the ideal low pass filter mask (approximated version). | |
Args: | |
shape: shape of the filter (volume) | |
d_s: normalized stop frequency for spatial dimensions (0.0-1.0) | |
d_t: normalized stop frequency for temporal dimension (0.0-1.0) | |
""" | |
T, H, W = shape[-3], shape[-2], shape[-1] | |
mask = torch.zeros(shape) | |
if d_s==0 or d_t==0: | |
return mask | |
threshold_s = round(int(H // 2) * d_s) | |
threshold_t = round(T // 2 * d_t) | |
cframe, crow, ccol = T // 2, H // 2, W //2 | |
mask[..., cframe - threshold_t:cframe + threshold_t, crow - threshold_s:crow + threshold_s, ccol - threshold_s:ccol + threshold_s] = 1.0 | |
return mask | |
def init_filter(video_length, height, width, filter_params_method="gaussian", filter_params_n=4, filter_params_d_s=0.25, filter_params_d_t=0.25, num_channels_latents=4, device='cpu'): | |
# initialize frequency filter for noise reinitialization | |
batch_size = 1 | |
num_channels_latents = num_channels_latents | |
filter_shape = [ | |
batch_size, | |
num_channels_latents, | |
video_length, | |
height, | |
width, | |
] | |
freq_filter = get_freq_filter( | |
filter_shape, | |
device=device, | |
filter_type=filter_params_method, | |
n=filter_params_n if filter_params_method=="butterworth" else None, | |
d_s=filter_params_d_s, | |
d_t=filter_params_d_t | |
) | |
return freq_filter | |
def FFTInit(noisy_latent, noise): | |
dtype = noisy_latent.dtype | |
freq_filter = init_filter( | |
video_length=noisy_latent.shape[2], | |
height=noisy_latent.shape[3], | |
width=noisy_latent.shape[4], | |
device=noisy_latent.device | |
) | |
# make it float32 to accept any kinds of resolution | |
latents = freq_mix_3d(noisy_latent.to(dtype=torch.float32), noise.to(dtype=torch.float32), LPF=freq_filter) | |
latents = latents.to(dtype) | |
return latents |