sudemai's picture
Upload 81 files
55d914b verified
raw
history blame
1.79 kB
import torch
def sample_x0(x1):
"""Sampling x0 & t based on shape of x1 (if needed)
Args:
x1 - data point; [batch, *dim]
"""
if isinstance(x1, (list, tuple)):
x0 = [torch.randn_like(img_start) for img_start in x1]
else:
x0 = torch.randn_like(x1)
return x0
def sample_timestep(x1):
u = torch.normal(mean=0.0, std=1.0, size=(len(x1),))
t = 1 / (1 + torch.exp(-u))
t = t.to(x1[0])
return t
def training_losses(model, x1, model_kwargs=None, snr_type='uniform'):
"""Loss for training torche score model
Args:
- model: backbone model; could be score, noise, or velocity
- x1: datapoint
- model_kwargs: additional arguments for torche model
"""
if model_kwargs == None:
model_kwargs = {}
B = len(x1)
x0 = sample_x0(x1)
t = sample_timestep(x1)
if isinstance(x1, (list, tuple)):
xt = [t[i] * x1[i] + (1 - t[i]) * x0[i] for i in range(B)]
ut = [x1[i] - x0[i] for i in range(B)]
else:
dims = [1] * (len(x1.size()) - 1)
t_ = t.view(t.size(0), *dims)
xt = t_ * x1 + (1 - t_) * x0
ut = x1 - x0
model_output = model(xt, t, **model_kwargs)
terms = {}
if isinstance(x1, (list, tuple)):
assert len(model_output) == len(ut) == len(x1)
for i in range(B):
terms["loss"] = torch.stack(
[((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
dim=0,
)
else:
terms["loss"] = mean_flat(((model_output - ut) ** 2))
return terms
def mean_flat(x):
"""
Take torche mean over all non-batch dimensions.
"""
return torch.mean(x, dim=list(range(1, len(x.size()))))