Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Based on: https://github.com/crowsonkb/k-diffusion | |
""" | |
import random | |
import numpy as np | |
import torch as th | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# from piq import LPIPS | |
from utils.ssim import SSIM | |
from modules.diffusion.karras.random_utils import get_generator | |
def mean_flat(tensor): | |
""" | |
Take the mean over all non-batch dimensions. | |
""" | |
return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
def append_dims(x, target_dims): | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError( | |
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
) | |
return x[(...,) + (None,) * dims_to_append] | |
def append_zero(x): | |
return th.cat([x, x.new_zeros([1])]) | |
def get_weightings(weight_schedule, snrs, sigma_data): | |
if weight_schedule == "snr": | |
weightings = snrs | |
elif weight_schedule == "snr+1": | |
weightings = snrs + 1 | |
elif weight_schedule == "karras": | |
weightings = snrs + 1.0 / sigma_data**2 | |
elif weight_schedule == "truncated-snr": | |
weightings = th.clamp(snrs, min=1.0) | |
elif weight_schedule == "uniform": | |
weightings = th.ones_like(snrs) | |
else: | |
raise NotImplementedError() | |
return weightings | |
class KarrasDenoiser: | |
def __init__( | |
self, | |
sigma_data: float = 0.5, | |
sigma_max=80.0, | |
sigma_min=0.002, | |
rho=7.0, | |
weight_schedule="karras", | |
distillation=False, | |
loss_norm="l2", | |
): | |
self.sigma_data = sigma_data | |
self.sigma_max = sigma_max | |
self.sigma_min = sigma_min | |
self.weight_schedule = weight_schedule | |
self.distillation = distillation | |
self.loss_norm = loss_norm | |
# if loss_norm == "lpips": | |
# self.lpips_loss = LPIPS(replace_pooling=True, reduction="none") | |
if loss_norm == "ssim": | |
self.ssim_loss = SSIM() | |
self.rho = rho | |
self.num_timesteps = 40 | |
def get_snr(self, sigmas): | |
return sigmas**-2 | |
def get_sigmas(self, sigmas): | |
return sigmas | |
def get_scalings(self, sigma): | |
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) | |
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 | |
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 | |
return c_skip, c_out, c_in | |
def get_scalings_for_boundary_condition(self, sigma): | |
c_skip = self.sigma_data**2 / ( | |
(sigma - self.sigma_min) ** 2 + self.sigma_data**2 | |
) | |
c_out = ( | |
(sigma - self.sigma_min) | |
* self.sigma_data | |
/ (sigma**2 + self.sigma_data**2) ** 0.5 | |
) | |
c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 | |
return c_skip, c_out, c_in | |
def training_losses(self, model, x_start, sigmas, condition=None, noise=None): | |
if noise is None: | |
noise = th.randn_like(x_start) | |
terms = {} | |
dims = x_start.ndim | |
x_t = x_start + noise * append_dims(sigmas, dims) | |
model_output, denoised = self.denoise(model, x_t, sigmas, condition) | |
snrs = self.get_snr(sigmas) | |
weights = append_dims( | |
get_weightings(self.weight_schedule, snrs, self.sigma_data), dims | |
) | |
# terms["xs_mse"] = mean_flat((denoised - x_start) ** 2) | |
terms["mse"] = mean_flat(weights * (denoised - x_start) ** 2) | |
# terms["mae"] = mean_flat(weights * th.abs(denoised - x_start)) | |
# terms["mse"] = nn.MSELoss(reduction="none")(denoised, x_start) | |
# if "vb" in terms: | |
# terms["loss"] = terms["mse"] + terms["vb"] | |
# else: | |
terms["loss"] = terms["mse"] | |
return terms | |
def consistency_losses( | |
self, | |
model, | |
x_start, | |
num_scales, | |
# model_kwargs=None, | |
condition=None, | |
target_model=None, | |
teacher_model=None, | |
teacher_diffusion=None, | |
noise=None, | |
): | |
if noise is None: | |
noise = th.randn_like(x_start) | |
dims = x_start.ndim | |
def denoise_fn(x, t): | |
return self.denoise(model, x, t, condition)[1] | |
if target_model: | |
def target_denoise_fn(x, t): | |
return self.denoise(target_model, x, t, condition)[1] | |
else: | |
raise NotImplementedError("Must have a target model") | |
if teacher_model: | |
def teacher_denoise_fn(x, t): | |
return teacher_diffusion.denoise(teacher_model, x, t, condition)[1] | |
def heun_solver(samples, t, next_t, x0): | |
x = samples | |
if teacher_model is None: | |
denoiser = x0 | |
else: | |
denoiser = teacher_denoise_fn(x, t) | |
d = (x - denoiser) / append_dims(t, dims) | |
samples = x + d * append_dims(next_t - t, dims) | |
if teacher_model is None: | |
denoiser = x0 | |
else: | |
denoiser = teacher_denoise_fn(samples, next_t) | |
next_d = (samples - denoiser) / append_dims(next_t, dims) | |
samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims) | |
return samples | |
def euler_solver(samples, t, next_t, x0): | |
x = samples | |
if teacher_model is None: | |
denoiser = x0 | |
else: | |
denoiser = teacher_denoise_fn(x, t) | |
d = (x - denoiser) / append_dims(t, dims) | |
samples = x + d * append_dims(next_t - t, dims) | |
return samples | |
indices = th.randint( | |
0, num_scales - 1, (x_start.shape[0],), device=x_start.device | |
) | |
t = self.sigma_max ** (1 / self.rho) + indices / (num_scales - 1) * ( | |
self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
) | |
t = t**self.rho | |
t2 = self.sigma_max ** (1 / self.rho) + (indices + 1) / (num_scales - 1) * ( | |
self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
) | |
t2 = t2**self.rho | |
x_t = x_start + noise * append_dims(t, dims) | |
dropout_state = th.get_rng_state() | |
distiller = denoise_fn(x_t, t) | |
if teacher_model is None: | |
x_t2 = euler_solver(x_t, t, t2, x_start).detach() | |
else: | |
x_t2 = heun_solver(x_t, t, t2, x_start).detach() | |
th.set_rng_state(dropout_state) | |
distiller_target = target_denoise_fn(x_t2, t2) | |
distiller_target = distiller_target.detach() | |
snrs = self.get_snr(t) | |
weights = get_weightings(self.weight_schedule, snrs, self.sigma_data) | |
if self.loss_norm == "l1": | |
diffs = th.abs(distiller - distiller_target) | |
loss = mean_flat(diffs) * weights | |
elif self.loss_norm == "l2": | |
# diffs = (distiller - distiller_target) ** 2 | |
loss = F.mse_loss(distiller, distiller_target) | |
# loss = mean_flat(diffs) * weights | |
elif self.loss_norm == "ssim": | |
loss = self.ssim_loss(distiller, distiller_target) * weights | |
# elif self.loss_norm == "l2-32": | |
# distiller = F.interpolate(distiller, size=32, mode="bilinear") | |
# distiller_target = F.interpolate( | |
# distiller_target, | |
# size=32, | |
# mode="bilinear", | |
# ) | |
# diffs = (distiller - distiller_target) ** 2 | |
# loss = mean_flat(diffs) * weights | |
# elif self.loss_norm == "lpips": | |
# if x_start.shape[-1] < 256: | |
# distiller = F.interpolate(distiller, size=224, mode="bilinear") | |
# distiller_target = F.interpolate( | |
# distiller_target, size=224, mode="bilinear" | |
# ) | |
# loss = ( | |
# self.lpips_loss( | |
# (distiller + 1) / 2.0, | |
# (distiller_target + 1) / 2.0, | |
# ) | |
# * weights | |
# ) | |
else: | |
raise ValueError(f"Unknown loss norm {self.loss_norm}") | |
terms = {} | |
terms["loss"] = loss | |
return terms | |
# def progdist_losses( | |
# self, | |
# model, | |
# x_start, | |
# num_scales, | |
# model_kwargs=None, | |
# teacher_model=None, | |
# teacher_diffusion=None, | |
# noise=None, | |
# ): | |
# if model_kwargs is None: | |
# model_kwargs = {} | |
# if noise is None: | |
# noise = th.randn_like(x_start) | |
# dims = x_start.ndim | |
# def denoise_fn(x, t): | |
# return self.denoise(model, x, t, **model_kwargs)[1] | |
# @th.no_grad() | |
# def teacher_denoise_fn(x, t): | |
# return teacher_diffusion.denoise(teacher_model, x, t, **model_kwargs)[1] | |
# @th.no_grad() | |
# def euler_solver(samples, t, next_t): | |
# x = samples | |
# denoiser = teacher_denoise_fn(x, t) | |
# d = (x - denoiser) / append_dims(t, dims) | |
# samples = x + d * append_dims(next_t - t, dims) | |
# return samples | |
# @th.no_grad() | |
# def euler_to_denoiser(x_t, t, x_next_t, next_t): | |
# denoiser = x_t - append_dims(t, dims) * (x_next_t - x_t) / append_dims( | |
# next_t - t, dims | |
# ) | |
# return denoiser | |
# indices = th.randint(0, num_scales, (x_start.shape[0],), device=x_start.device) | |
# t = self.sigma_max ** (1 / self.rho) + indices / num_scales * ( | |
# self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
# ) | |
# t = t**self.rho | |
# t2 = self.sigma_max ** (1 / self.rho) + (indices + 0.5) / num_scales * ( | |
# self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
# ) | |
# t2 = t2**self.rho | |
# t3 = self.sigma_max ** (1 / self.rho) + (indices + 1) / num_scales * ( | |
# self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
# ) | |
# t3 = t3**self.rho | |
# x_t = x_start + noise * append_dims(t, dims) | |
# denoised_x = denoise_fn(x_t, t) | |
# x_t2 = euler_solver(x_t, t, t2).detach() | |
# x_t3 = euler_solver(x_t2, t2, t3).detach() | |
# target_x = euler_to_denoiser(x_t, t, x_t3, t3).detach() | |
# snrs = self.get_snr(t) | |
# weights = get_weightings(self.weight_schedule, snrs, self.sigma_data) | |
# if self.loss_norm == "l1": | |
# diffs = th.abs(denoised_x - target_x) | |
# loss = mean_flat(diffs) * weights | |
# elif self.loss_norm == "l2": | |
# diffs = (denoised_x - target_x) ** 2 | |
# loss = mean_flat(diffs) * weights | |
# elif self.loss_norm == "lpips": | |
# if x_start.shape[-1] < 256: | |
# denoised_x = F.interpolate(denoised_x, size=224, mode="bilinear") | |
# target_x = F.interpolate(target_x, size=224, mode="bilinear") | |
# loss = ( | |
# self.lpips_loss( | |
# (denoised_x + 1) / 2.0, | |
# (target_x + 1) / 2.0, | |
# ) | |
# * weights | |
# ) | |
# else: | |
# raise ValueError(f"Unknown loss norm {self.loss_norm}") | |
# terms = {} | |
# terms["loss"] = loss | |
# return terms | |
def denoise(self, model, x_t, sigmas, condition): | |
if not self.distillation: | |
c_skip, c_out, c_in = [ | |
append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas) | |
] | |
else: | |
c_skip, c_out, c_in = [ | |
append_dims(x, x_t.ndim) | |
for x in self.get_scalings_for_boundary_condition(sigmas) | |
] | |
rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) | |
# rescaled_t = rescaled_t[:, None] | |
model_output = model(c_in * x_t, rescaled_t, condition) | |
denoised = c_out * model_output + c_skip * x_t | |
return model_output, denoised | |
def karras_sample( | |
diffusion, | |
model, | |
shape, | |
steps, | |
clip_denoised=True, | |
progress=True, | |
callback=None, | |
# model_kwargs=None, | |
condition=None, | |
device=None, | |
sigma_min=0.002, | |
sigma_max=80, # higher for highres? | |
rho=7.0, | |
sampler="heun", | |
s_churn=0.0, | |
s_tmin=0.0, | |
s_tmax=float("inf"), | |
s_noise=1.0, | |
generator=None, | |
ts=None, | |
): | |
if generator is None: | |
generator = get_generator("dummy") | |
if sampler == "progdist": | |
sigmas = get_sigmas_karras(steps + 1, sigma_min, sigma_max, rho, device=device) | |
else: | |
sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device) | |
th.manual_seed(42) | |
x_T = generator.randn(*shape, device=device) * sigma_max | |
sigmas = sigmas.unsqueeze(-1) | |
sample_fn = { | |
"heun": sample_heun, | |
"dpm": sample_dpm, | |
"ancestral": sample_euler_ancestral, | |
"onestep": sample_onestep, | |
"progdist": sample_progdist, | |
"euler": sample_euler, | |
"multistep": stochastic_iterative_sampler, | |
}[sampler] | |
if sampler in ["heun", "dpm"]: | |
sampler_args = dict( | |
s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise | |
) | |
elif sampler == "multistep": | |
sampler_args = dict( | |
ts=ts, t_min=sigma_min, t_max=sigma_max, rho=diffusion.rho, steps=steps | |
) | |
else: | |
sampler_args = {} | |
def denoiser(x_t, sigma): | |
_, denoised = diffusion.denoise(model, x_t, sigma, condition) | |
if clip_denoised: | |
denoised = denoised.clamp(-1, 1) | |
return denoised | |
x_0 = sample_fn( | |
denoiser, | |
x_T, | |
sigmas, | |
generator, | |
progress=progress, | |
callback=callback, | |
**sampler_args, | |
) | |
return x_0.clamp(-1, 1) | |
def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): | |
"""Constructs the noise schedule of Karras et al. (2022).""" | |
ramp = th.linspace(0, 1, n) | |
min_inv_rho = sigma_min ** (1 / rho) | |
max_inv_rho = sigma_max ** (1 / rho) | |
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho | |
return append_zero(sigmas).to(device) | |
def to_d(x, sigma, denoised): | |
"""Converts a denoiser output to a Karras ODE derivative.""" | |
return (x - denoised) / append_dims(sigma, x.ndim) | |
def get_ancestral_step(sigma_from, sigma_to): | |
"""Calculates the noise level (sigma_down) to step down to and the amount | |
of noise to add (sigma_up) when doing an ancestral sampling step.""" | |
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 | |
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | |
return sigma_down, sigma_up | |
def sample_euler_ancestral(model, x, sigmas, generator, progress=False, callback=None): | |
"""Ancestral sampling with Euler method steps.""" | |
s_in = x.new_ones([x.shape[0]]) | |
indices = range(len(sigmas) - 1) | |
if progress: | |
from tqdm.auto import tqdm | |
indices = tqdm(indices) | |
for i in indices: | |
denoised = model(x, sigmas[i] * s_in) | |
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) | |
if callback is not None: | |
callback( | |
{ | |
"x": x, | |
"i": i, | |
"sigma": sigmas[i], | |
"sigma_hat": sigmas[i], | |
"denoised": denoised, | |
} | |
) | |
d = to_d(x, sigmas[i], denoised) | |
# Euler method | |
dt = sigma_down - sigmas[i] | |
x = x + d * dt | |
x = x + generator.randn_like(x) * sigma_up | |
return x | |
def sample_midpoint_ancestral(model, x, ts, generator, progress=False, callback=None): | |
"""Ancestral sampling with midpoint method steps.""" | |
s_in = x.new_ones([x.shape[0]]) | |
step_size = 1 / len(ts) | |
if progress: | |
from tqdm.auto import tqdm | |
ts = tqdm(ts) | |
for tn in ts: | |
dn = model(x, tn * s_in) | |
dn_2 = model(x + (step_size / 2) * dn, (tn + step_size / 2) * s_in) | |
x = x + step_size * dn_2 | |
if callback is not None: | |
callback({"x": x, "tn": tn, "dn": dn, "dn_2": dn_2}) | |
return x | |
def sample_heun( | |
denoiser, | |
x, | |
sigmas, | |
generator, | |
progress=False, | |
callback=None, | |
s_churn=0.0, | |
s_tmin=0.0, | |
s_tmax=float("inf"), | |
s_noise=1.0, | |
): | |
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" | |
s_in = x.new_ones([x.shape[0]]) | |
indices = range(len(sigmas) - 1) | |
if progress: | |
from tqdm.auto import tqdm | |
indices = tqdm(indices) | |
for i in indices: | |
gamma = ( | |
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) | |
if s_tmin <= sigmas[i] <= s_tmax | |
else 0.0 | |
) | |
eps = generator.randn_like(x) * s_noise | |
sigma_hat = sigmas[i] * (gamma + 1) | |
if gamma > 0: | |
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 | |
denoised = denoiser(x, sigma_hat * s_in) | |
d = to_d(x, sigma_hat, denoised) | |
if callback is not None: | |
callback( | |
{ | |
"x": x, | |
"i": i, | |
"sigma": sigmas[i], | |
"sigma_hat": sigma_hat, | |
"denoised": denoised, | |
} | |
) | |
dt = sigmas[i + 1] - sigma_hat | |
if sigmas[i + 1] == 0: | |
# Euler method | |
x = x + d * dt | |
else: | |
# Heun's method | |
x_2 = x + d * dt | |
denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in) | |
d_2 = to_d(x_2, sigmas[i + 1], denoised_2) | |
d_prime = (d + d_2) / 2 | |
x = x + d_prime * dt | |
return x | |
def sample_euler( | |
denoiser, | |
x, | |
sigmas, | |
generator, | |
progress=False, | |
callback=None, | |
): | |
"""Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" | |
s_in = x.new_ones([x.shape[0]]) | |
indices = range(len(sigmas) - 1) | |
if progress: | |
from tqdm.auto import tqdm | |
indices = tqdm(indices) | |
for i in indices: | |
sigma = sigmas[i] | |
denoised = denoiser(x, sigma * s_in) | |
d = to_d(x, sigma, denoised) | |
if callback is not None: | |
callback( | |
{ | |
"x": x, | |
"i": i, | |
"sigma": sigmas[i], | |
"denoised": denoised, | |
} | |
) | |
dt = sigmas[i + 1] - sigma | |
x = x + d * dt | |
return x | |
def sample_dpm( | |
denoiser, | |
x, | |
sigmas, | |
generator, | |
progress=False, | |
callback=None, | |
s_churn=0.0, | |
s_tmin=0.0, | |
s_tmax=float("inf"), | |
s_noise=1.0, | |
): | |
"""A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" | |
s_in = x.new_ones([x.shape[0]]) | |
indices = range(len(sigmas) - 1) | |
if progress: | |
from tqdm.auto import tqdm | |
indices = tqdm(indices) | |
for i in indices: | |
gamma = ( | |
min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) | |
if s_tmin <= sigmas[i] <= s_tmax | |
else 0.0 | |
) | |
eps = generator.randn_like(x) * s_noise | |
sigma_hat = sigmas[i] * (gamma + 1) | |
if gamma > 0: | |
x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 | |
denoised = denoiser(x, sigma_hat * s_in) | |
d = to_d(x, sigma_hat, denoised) | |
if callback is not None: | |
callback( | |
{ | |
"x": x, | |
"i": i, | |
"sigma": sigmas[i], | |
"sigma_hat": sigma_hat, | |
"denoised": denoised, | |
} | |
) | |
# Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule | |
sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 | |
dt_1 = sigma_mid - sigma_hat | |
dt_2 = sigmas[i + 1] - sigma_hat | |
x_2 = x + d * dt_1 | |
denoised_2 = denoiser(x_2, sigma_mid * s_in) | |
d_2 = to_d(x_2, sigma_mid, denoised_2) | |
x = x + d_2 * dt_2 | |
return x | |
def sample_onestep( | |
distiller, | |
x, | |
sigmas, | |
generator=None, | |
progress=False, | |
callback=None, | |
): | |
"""Single-step generation from a distilled model.""" | |
s_in = x.new_ones([x.shape[0]]) | |
return distiller(x, sigmas[0] * s_in) | |
def stochastic_iterative_sampler( | |
distiller, | |
x, | |
sigmas, | |
generator, | |
ts, | |
progress=False, | |
callback=None, | |
t_min=0.002, | |
t_max=80.0, | |
rho=7.0, | |
steps=40, | |
): | |
t_max_rho = t_max ** (1 / rho) | |
t_min_rho = t_min ** (1 / rho) | |
s_in = x.new_ones([x.shape[0]]) | |
for i in range(len(ts) - 1): | |
t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
x0 = distiller(x, t * s_in) | |
next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
next_t = np.clip(next_t, t_min, t_max) | |
x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2) | |
return x | |
def sample_progdist( | |
denoiser, | |
x, | |
sigmas, | |
generator=None, | |
progress=False, | |
callback=None, | |
): | |
s_in = x.new_ones([x.shape[0]]) | |
sigmas = sigmas[:-1] # skip the zero sigma | |
indices = range(len(sigmas) - 1) | |
if progress: | |
from tqdm.auto import tqdm | |
indices = tqdm(indices) | |
for i in indices: | |
sigma = sigmas[i] | |
denoised = denoiser(x, sigma * s_in) | |
d = to_d(x, sigma, denoised) | |
if callback is not None: | |
callback( | |
{ | |
"x": x, | |
"i": i, | |
"sigma": sigma, | |
"denoised": denoised, | |
} | |
) | |
dt = sigmas[i + 1] - sigma | |
x = x + d * dt | |
return x | |
# @th.no_grad() | |
# def iterative_colorization( | |
# distiller, | |
# images, | |
# x, | |
# ts, | |
# t_min=0.002, | |
# t_max=80.0, | |
# rho=7.0, | |
# steps=40, | |
# generator=None, | |
# ): | |
# def obtain_orthogonal_matrix(): | |
# vector = np.asarray([0.2989, 0.5870, 0.1140]) | |
# vector = vector / np.linalg.norm(vector) | |
# matrix = np.eye(3) | |
# matrix[:, 0] = vector | |
# matrix = np.linalg.qr(matrix)[0] | |
# if np.sum(matrix[:, 0]) < 0: | |
# matrix = -matrix | |
# return matrix | |
# Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32) | |
# mask = th.zeros(*x.shape[1:], device=dist_util.dev()) | |
# mask[0, ...] = 1.0 | |
# def replacement(x0, x1): | |
# x0 = th.einsum("bchw,cd->bdhw", x0, Q) | |
# x1 = th.einsum("bchw,cd->bdhw", x1, Q) | |
# x_mix = x0 * mask + x1 * (1.0 - mask) | |
# x_mix = th.einsum("bdhw,cd->bchw", x_mix, Q) | |
# return x_mix | |
# t_max_rho = t_max ** (1 / rho) | |
# t_min_rho = t_min ** (1 / rho) | |
# s_in = x.new_ones([x.shape[0]]) | |
# images = replacement(images, th.zeros_like(images)) | |
# for i in range(len(ts) - 1): | |
# t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
# x0 = distiller(x, t * s_in) | |
# x0 = th.clamp(x0, -1.0, 1.0) | |
# x0 = replacement(images, x0) | |
# next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
# next_t = np.clip(next_t, t_min, t_max) | |
# x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2) | |
# return x, images | |
# @th.no_grad() | |
# def iterative_inpainting( | |
# distiller, | |
# images, | |
# x, | |
# ts, | |
# t_min=0.002, | |
# t_max=80.0, | |
# rho=7.0, | |
# steps=40, | |
# generator=None, | |
# ): | |
# from PIL import Image, ImageDraw, ImageFont | |
# image_size = x.shape[-1] | |
# # create a blank image with a white background | |
# img = Image.new("RGB", (image_size, image_size), color="white") | |
# # get a drawing context for the image | |
# draw = ImageDraw.Draw(img) | |
# # load a font | |
# font = ImageFont.truetype("arial.ttf", 250) | |
# # draw the letter "C" in black | |
# draw.text((50, 0), "S", font=font, fill=(0, 0, 0)) | |
# # convert the image to a numpy array | |
# img_np = np.array(img) | |
# img_np = img_np.transpose(2, 0, 1) | |
# img_th = th.from_numpy(img_np).to(dist_util.dev()) | |
# mask = th.zeros(*x.shape, device=dist_util.dev()) | |
# mask = mask.reshape(-1, 7, 3, image_size, image_size) | |
# mask[::2, :, img_th > 0.5] = 1.0 | |
# mask[1::2, :, img_th < 0.5] = 1.0 | |
# mask = mask.reshape(-1, 3, image_size, image_size) | |
# def replacement(x0, x1): | |
# x_mix = x0 * mask + x1 * (1 - mask) | |
# return x_mix | |
# t_max_rho = t_max ** (1 / rho) | |
# t_min_rho = t_min ** (1 / rho) | |
# s_in = x.new_ones([x.shape[0]]) | |
# images = replacement(images, -th.ones_like(images)) | |
# for i in range(len(ts) - 1): | |
# t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
# x0 = distiller(x, t * s_in) | |
# x0 = th.clamp(x0, -1.0, 1.0) | |
# x0 = replacement(images, x0) | |
# next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
# next_t = np.clip(next_t, t_min, t_max) | |
# x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2) | |
# return x, images | |
# @th.no_grad() | |
# def iterative_superres( | |
# distiller, | |
# images, | |
# x, | |
# ts, | |
# t_min=0.002, | |
# t_max=80.0, | |
# rho=7.0, | |
# steps=40, | |
# generator=None, | |
# ): | |
# patch_size = 8 | |
# def obtain_orthogonal_matrix(): | |
# vector = np.asarray([1] * patch_size**2) | |
# vector = vector / np.linalg.norm(vector) | |
# matrix = np.eye(patch_size**2) | |
# matrix[:, 0] = vector | |
# matrix = np.linalg.qr(matrix)[0] | |
# if np.sum(matrix[:, 0]) < 0: | |
# matrix = -matrix | |
# return matrix | |
# Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32) | |
# image_size = x.shape[-1] | |
# def replacement(x0, x1): | |
# x0_flatten = ( | |
# x0.reshape(-1, 3, image_size, image_size) | |
# .reshape( | |
# -1, | |
# 3, | |
# image_size // patch_size, | |
# patch_size, | |
# image_size // patch_size, | |
# patch_size, | |
# ) | |
# .permute(0, 1, 2, 4, 3, 5) | |
# .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2) | |
# ) | |
# x1_flatten = ( | |
# x1.reshape(-1, 3, image_size, image_size) | |
# .reshape( | |
# -1, | |
# 3, | |
# image_size // patch_size, | |
# patch_size, | |
# image_size // patch_size, | |
# patch_size, | |
# ) | |
# .permute(0, 1, 2, 4, 3, 5) | |
# .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2) | |
# ) | |
# x0 = th.einsum("bcnd,de->bcne", x0_flatten, Q) | |
# x1 = th.einsum("bcnd,de->bcne", x1_flatten, Q) | |
# x_mix = x0.new_zeros(x0.shape) | |
# x_mix[..., 0] = x0[..., 0] | |
# x_mix[..., 1:] = x1[..., 1:] | |
# x_mix = th.einsum("bcne,de->bcnd", x_mix, Q) | |
# x_mix = ( | |
# x_mix.reshape( | |
# -1, | |
# 3, | |
# image_size // patch_size, | |
# image_size // patch_size, | |
# patch_size, | |
# patch_size, | |
# ) | |
# .permute(0, 1, 2, 4, 3, 5) | |
# .reshape(-1, 3, image_size, image_size) | |
# ) | |
# return x_mix | |
# def average_image_patches(x): | |
# x_flatten = ( | |
# x.reshape(-1, 3, image_size, image_size) | |
# .reshape( | |
# -1, | |
# 3, | |
# image_size // patch_size, | |
# patch_size, | |
# image_size // patch_size, | |
# patch_size, | |
# ) | |
# .permute(0, 1, 2, 4, 3, 5) | |
# .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2) | |
# ) | |
# x_flatten[..., :] = x_flatten.mean(dim=-1, keepdim=True) | |
# return ( | |
# x_flatten.reshape( | |
# -1, | |
# 3, | |
# image_size // patch_size, | |
# image_size // patch_size, | |
# patch_size, | |
# patch_size, | |
# ) | |
# .permute(0, 1, 2, 4, 3, 5) | |
# .reshape(-1, 3, image_size, image_size) | |
# ) | |
# t_max_rho = t_max ** (1 / rho) | |
# t_min_rho = t_min ** (1 / rho) | |
# s_in = x.new_ones([x.shape[0]]) | |
# images = average_image_patches(images) | |
# for i in range(len(ts) - 1): | |
# t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
# x0 = distiller(x, t * s_in) | |
# x0 = th.clamp(x0, -1.0, 1.0) | |
# x0 = replacement(images, x0) | |
# next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
# next_t = np.clip(next_t, t_min, t_max) | |
# x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2) | |
# return x, images | |