littlelittlecloud commited on
Commit
7779efa
1 Parent(s): e3192e0

split model into small ones

Browse files
.gitattributes CHANGED
@@ -34,3 +34,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  cat.png filter=lfs diff=lfs merge=lfs -text
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  cat.png filter=lfs diff=lfs merge=lfs -text
37
+ autoencoder_kl.ckpt filter=lfs diff=lfs merge=lfs -text
38
+ clip_encoder.ckpt filter=lfs diff=lfs merge=lfs -text
AutoencoderKL.cs ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using TorchSharp;
2
+
3
+ public class AutoencoderKL
4
+ {
5
+ private readonly torch.jit.ScriptModule _model;
6
+ private readonly float _scale;
7
+ public torch.Device Device {get;}
8
+
9
+ public AutoencoderKL(string modelPath, torch.Device device, float scale = 0.18215f)
10
+ {
11
+ _model = TorchSharp.torch.jit.load(modelPath);
12
+ Device = device;
13
+ _model.to(Device);
14
+ _model.eval();
15
+ _scale = scale;
16
+ }
17
+
18
+ public torch.Tensor Forward(torch.Tensor tokenTensor)
19
+ {
20
+ var context = torch.enable_grad(false);
21
+ tokenTensor = 1.0f / _scale * tokenTensor;
22
+ return (torch.Tensor)_model.forward(tokenTensor);
23
+ }
24
+ }
ClipEnocder.cs ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ using TorchSharp;
2
+
3
+ public class ClipEncoder
4
+ {
5
+ private readonly torch.jit.ScriptModule _model;
6
+ public torch.Device Device {get;}
7
+
8
+ public ClipEncoder(string modelPath, torch.Device device)
9
+ {
10
+ _model = TorchSharp.torch.jit.load(modelPath);
11
+ Device = device;
12
+ _model.to(Device);
13
+ _model.eval();
14
+ }
15
+
16
+ public torch.Tensor Forward(torch.Tensor tokenTensor)
17
+ {
18
+ return (torch.Tensor)_model.forward(tokenTensor);
19
+ }
20
+ }
DDIMSampler.cs CHANGED
@@ -1,3 +1,4 @@
 
1
  using TorchSharp;
2
 
3
  public class DDIMSampler
@@ -15,17 +16,20 @@ public class DDIMSampler
15
  public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
16
  {
17
  var gap = DDIMSampler.TIME_STEPS / steps;
 
 
18
  using(var context = torch.enable_grad(false))
19
  {
20
  for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
21
  {
22
- var t_cur = torch.full(1, i, dtype: torch.ScalarType.Int64, device: _device);
23
- var t_prev = torch.full(1, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device);
24
  (var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
25
  var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
26
  e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
27
  var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
28
  img = _model.QSample(pred_x0, t_prev, e_t);
 
29
  }
30
 
31
  return img;
 
1
+ using System;
2
  using TorchSharp;
3
 
4
  public class DDIMSampler
 
16
  public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f)
17
  {
18
  var gap = DDIMSampler.TIME_STEPS / steps;
19
+ var batch = img.shape[0];
20
+
21
  using(var context = torch.enable_grad(false))
22
  {
23
  for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap)
24
  {
25
+ var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: _device);
26
+ var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device);
27
  (var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur);
28
  var model_output = e_t_uncond + scale * (e_t - e_t_uncond);
29
  e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output);
30
  var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output);
31
  img = _model.QSample(pred_x0, t_prev, e_t);
32
+ Console.WriteLine(img);
33
  }
34
 
35
  return img;
DDPM.cs CHANGED
@@ -1,3 +1,4 @@
 
1
  using TorchSharp;
2
 
3
  public class DDPM
@@ -21,16 +22,6 @@ public class DDPM
21
  return (res[0], res[1]);
22
  }
23
 
24
- public torch.Tensor DecodeImage(torch.Tensor img)
25
- {
26
- return _model.invoke<torch.Tensor>("decode_image", img);
27
- }
28
-
29
- public torch.Tensor ClipEncoder(torch.Tensor tokenTensor)
30
- {
31
- return _model.invoke<torch.Tensor>("clip_encoder", tokenTensor);
32
- }
33
-
34
  public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v)
35
  {
36
  return _model.invoke<torch.Tensor>("q_sample",z, t, v);
 
1
+ using System;
2
  using TorchSharp;
3
 
4
  public class DDPM
 
22
  return (res[0], res[1]);
23
  }
24
 
 
 
 
 
 
 
 
 
 
 
25
  public torch.Tensor QSample(torch.Tensor z, torch.Tensor t, torch.Tensor v)
26
  {
27
  return _model.invoke<torch.Tensor>("q_sample",z, t, v);
Program.cs CHANGED
@@ -8,7 +8,8 @@ torchvision.io.DefaultImager = new torchvision.io.SkiaImager();
8
  var device = TorchSharp.torch.device("cuda:0");
9
  var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
10
  var ddimSampler = new DDIMSampler(ddpm);
11
-
 
12
  var start_token = 49406;
13
  var end_token = 49407;
14
  var dictionary = new Dictionary<string, long>(){
@@ -20,7 +21,7 @@ var dictionary = new Dictionary<string, long>(){
20
  {"green", 1901},
21
  };
22
 
23
- var batch = 1;
24
 
25
  var prompt = "a wild cute green cat";
26
  var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
@@ -29,17 +30,21 @@ tokens = tokens.Append(end_token).ToList();
29
  tokens = tokens.Concat(Enumerable.Repeat<long>(0, 77 - tokens.Count)).ToList();
30
  var uncontional_tokens = new[]{start_token, end_token}.Concat(Enumerable.Repeat(0, 75)).ToList();
31
  var tokenTensor = torch.tensor(tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
32
- tokenTensor = tokenTensor.reshape((long)batch, -1);
33
  var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
34
- unconditional_tokenTensor = unconditional_tokenTensor.reshape((long)batch, -1);
35
  var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
36
- var t = torch.ones(batch, dtype: torch.ScalarType.Int32, device: device);
37
- var condition = ddpm.ClipEncoder(tokenTensor);
38
- var unconditional_condition = ddpm.ClipEncoder(unconditional_tokenTensor);
39
  var ddim_steps = 50;
40
  img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
41
- var decoded_images = (torch.Tensor)ddpm.DecodeImage(img);
42
  decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
43
- var image = decoded_images[0];
44
- image = (image * 255.0).to(torch.ScalarType.Byte).cpu();
45
- torchvision.io.write_image(image, $"0.png", torchvision.ImageFormat.Png);
 
 
 
 
 
8
  var device = TorchSharp.torch.device("cuda:0");
9
  var ddpm = new DDPM("ddim_v_sampler.ckpt", device);
10
  var ddimSampler = new DDIMSampler(ddpm);
11
+ var autoencoderKL = new AutoencoderKL("autoencoder_kl.ckpt", device);
12
+ var clipEncoder = new ClipEncoder("clip_encoder.ckpt", device);
13
  var start_token = 49406;
14
  var end_token = 49407;
15
  var dictionary = new Dictionary<string, long>(){
 
21
  {"green", 1901},
22
  };
23
 
24
+ var batch = 2;
25
 
26
  var prompt = "a wild cute green cat";
27
  var tokens = prompt.Split(' ').Select(x => dictionary[x]).ToList();
 
30
  tokens = tokens.Concat(Enumerable.Repeat<long>(0, 77 - tokens.Count)).ToList();
31
  var uncontional_tokens = new[]{start_token, end_token}.Concat(Enumerable.Repeat(0, 75)).ToList();
32
  var tokenTensor = torch.tensor(tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
33
+ tokenTensor = tokenTensor.repeat(batch, 1);
34
  var unconditional_tokenTensor = torch.tensor(uncontional_tokens.ToArray(), dtype: torch.ScalarType.Int64, device: device);
35
+ unconditional_tokenTensor = unconditional_tokenTensor.repeat(batch, 1);
36
  var img = torch.randn(batch, 4, 96, 96, dtype: torch.ScalarType.Float32, device: device);
37
+ var t = torch.full(new[]{batch, 1L}, value: batch, dtype: torch.ScalarType.Int32, device: device);
38
+ var condition = clipEncoder.Forward(tokenTensor);
39
+ var unconditional_condition = clipEncoder.Forward(unconditional_tokenTensor);
40
  var ddim_steps = 50;
41
  img = ddimSampler.Sample(img, condition, unconditional_condition, ddim_steps);
42
+ var decoded_images = (torch.Tensor)autoencoderKL.Forward(img);
43
  decoded_images = torch.clamp((decoded_images + 1.0) / 2.0, 0.0, 1.0);
44
+
45
+ for(int i = 0; i!= batch; ++i)
46
+ {
47
+ var image = decoded_images[i];
48
+ image = (image * 255.0).to(torch.ScalarType.Byte).cpu();
49
+ torchvision.io.write_image(image, $"{i}.png", torchvision.ImageFormat.Png);
50
+ }
autoencoder_kl.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f5b15ed1a0f81a0ec4a274ac368a5f4fb84f0ce7c3676e683de527e69a59840
3
+ size 334940269
clip_encoder.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef9706f02a78b2cf93acff22f3036bc3e629d0a5b595c640ada1f73788826f37
3
+ size 1416615515
ddim_v_sampler.ckpt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:22b16b2fc18c3b20c0eb74ed49a8f1834388fbfd84a49110340943f22fd30fa1
3
- size 5216915007
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffa5c521f78e160bb4907a197f8308fa498f21bc3738ff49aded45afe9dbc47d
3
+ size 3465251643