using System; using System.Collections.Generic; using System.IO; using System.Linq; using TorchSharp; torchvision.io.DefaultImager = new torchvision.io.SkiaImager(); var device = TorchSharp.torch.device("cuda:0"); var ddpm_v_sampler = TorchSharp.torch.jit.load("ddim_v_sampler.ckpt"); ddpm_v_sampler.to(device); ddpm_v_sampler.eval(); var start_token = 49406; var end_token = 49407; var dictionary = new Dictionary(){ {"cat", 2368}, {"a", 320}, {"cute", 2242}, {"blue", 1746}, {"wild", 3220}, {"green", 1901}, }; var batch = 1; var prompt = "a wild cute green cat"; var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList(); tokens = tokens.Prepend(start_token).ToList(); tokens = tokens.Append(end_token).ToList(); tokens = tokens.Concat(Enumerable.Repeat(0, 77 - tokens.Count)).ToList(); var uncontional_tokens = new[]{start_token, end_token}.Concat(Enumerable.Repeat(0, 75)).ToList(); var tokenTensor = torch.tensor(tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device); tokenTensor = tokenTensor.reshape((long)batch, -1); var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device); unconditional_tokenTensor = unconditional_tokenTensor.reshape((long)batch, -1); var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device); var t = torch.ones(batch, dtype: torch.ScalarType.Int32, device: device); var condition = ddpm_v_sampler.invoke("clip_encoder", tokenTensor); var unconditional_condition = ddpm_v_sampler.invoke("clip_encoder", unconditional_tokenTensor); Console.WriteLine(condition); var timesteps = 1000; var ddim_steps = 50; int gap = timesteps / ddim_steps; using(var context = torch.enable_grad(false)) { for(var i = timesteps-1; i >=0; i -= gap) { var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: device); var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: device); img = (torch.Tensor)ddpm_v_sampler.invoke("ddim_sampler", img, condition, unconditional_condition, t_cur, t_prev); Console.WriteLine($"step {i}"); } var decoded_images = (torch.Tensor)ddpm_v_sampler.invoke("decode_image", img); decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0); for(int i = 0; i!= batch; ++i) { // c * h * w var image = decoded_images[i]; image = (image * 255.0).to(torch.ScalarType.Byte).cpu(); torchvision.io.write_image(image, $"{i}.png", torchvision.ImageFormat.Png); } }