XiaoYun Zhang
clean up
9c54c90
raw
history blame contribute delete
No virus
1.42 kB
using System;
using TorchSharp;
public class DDIMSampler
{
private readonly DDPM _model;
private const int TIME_STEPS = 1000;
private readonly torch.Device _device;
public DDIMSampler(DDPM model, float scale = 9.0f)
{
_model = model;
_device = model.Device;
}
public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
{
var gap = DDIMSampler.TIME_STEPS / steps;
var batch = img.shape[0];
using(var context = torch.enable_grad(false))
{
for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
{
var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: _device);
var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device);
(var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
img = _model.QSample(pred_x0, t_prev, e_t);
Console.WriteLine(img);
}
return img;
}
}
}