using System; using TorchSharp; public class DDPM : IDisposable { private torch.jit.ScriptModule _model; public torch.Device Device {get;} public DDPM(string modelPath, torch.Device device) { _model = TorchSharp.torch.jit.load(modelPath); Device = device; _model.to(Device); _model.eval(); } public (torch.Tensor e_T_Uncondition, torch.Tensor e_T) DiffusionModel(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, torch.Tensor t) { var x_in = torch.cat(new[] { img, img }); var condition_in = torch.cat(new[] { unconditional_condition, condition }); var t_in = torch.cat(new[] { t, t }); var res = _model.invoke("diffusion_model", x_in, t_in, condition_in).chunk(2); return (res[0], res[1]); } public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v) { return _model.invoke("q_sample",z, t, v); } public torch.Tensor PredictEPSFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v) { return _model.invoke("predict_eps_from_z_and_v", z, t, v); } public torch.Tensor PredictStartFromZANDV(torch.Tensor z, torch.Tensor t, torch.Tensor v) { return _model.invoke("predict_start_from_z_and_v", z, t, v); } public void Dispose() { _model.Dispose(); _model = null; } }