yyk19 commited on
Commit
bc1f1f4
1 Parent(s): 9c6e0e6

transfer the ema checkpoint to the clean version.

Browse files
Files changed (5) hide show
  1. app.py +4 -1
  2. config.yaml +2 -2
  3. config_cuda.yaml +88 -0
  4. model.ckpt +3 -0
  5. transfer.py +14 -0
app.py CHANGED
@@ -66,8 +66,11 @@ def process_multi_wrapper_only_show_rendered(rendered_txt_0, rendered_txt_1, ren
66
  shared_eta, shared_a_prompt, shared_n_prompt,
67
  only_show_rendered_image=True)
68
 
 
 
 
69
  cfg = OmegaConf.load("config.yaml")
70
- model = load_model_from_config(cfg, "model_states.pt", verbose=True)
71
 
72
  ddim_sampler = DDIMSampler(model)
73
  render_tool = Render_Text(model)
 
66
  shared_eta, shared_a_prompt, shared_n_prompt,
67
  only_show_rendered_image=True)
68
 
69
+ # cfg = OmegaConf.load("config.yaml")
70
+ # model = load_model_from_config(cfg, "model_states.pt", verbose=True)
71
+
72
  cfg = OmegaConf.load("config.yaml")
73
+ model = load_model_from_config(cfg, "model.ckpt", verbose=True)
74
 
75
  ddim_sampler = DDIMSampler(model)
76
  render_tool = Render_Text(model)
config.yaml CHANGED
@@ -18,7 +18,7 @@ model:
18
  scale_factor: 0.18215
19
  only_mid_control: False
20
  sd_locked: True
21
- use_ema: False #True #False #True #False
22
 
23
  control_stage_config:
24
  target: cldm.cldm.ControlNet
@@ -85,4 +85,4 @@ model:
85
  params:
86
  freeze: True
87
  layer: "penultimate"
88
- device: "cpu"
 
18
  scale_factor: 0.18215
19
  only_mid_control: False
20
  sd_locked: True
21
+ use_ema: False #True #TODO: specify
22
 
23
  control_stage_config:
24
  target: cldm.cldm.ControlNet
 
85
  params:
86
  freeze: True
87
  layer: "penultimate"
88
+ device: "cpu" #TODO: specify
config_cuda.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ base_learning_rate: 1.0e-6 #1.0e-5 #1.0e-4
3
+ target: cldm.cldm.ControlLDM
4
+ params:
5
+ linear_start: 0.00085
6
+ linear_end: 0.0120
7
+ num_timesteps_cond: 1
8
+ log_every_t: 200
9
+ timesteps: 1000
10
+ first_stage_key: "jpg"
11
+ cond_stage_key: "txt"
12
+ control_key: "hint"
13
+ image_size: 64
14
+ channels: 4
15
+ cond_stage_trainable: false
16
+ conditioning_key: crossattn
17
+ monitor: #val/loss_simple_ema
18
+ scale_factor: 0.18215
19
+ only_mid_control: False
20
+ sd_locked: True
21
+ use_ema: True #TODO: specify
22
+
23
+ control_stage_config:
24
+ target: cldm.cldm.ControlNet
25
+ params:
26
+ use_checkpoint: True
27
+ image_size: 32 # unused
28
+ in_channels: 4
29
+ hint_channels: 3
30
+ model_channels: 320
31
+ attention_resolutions: [ 4, 2, 1 ]
32
+ num_res_blocks: 2
33
+ channel_mult: [ 1, 2, 4, 4 ]
34
+ num_head_channels: 64 # need to fix for flash-attn
35
+ use_spatial_transformer: True
36
+ use_linear_in_transformer: True
37
+ transformer_depth: 1
38
+ context_dim: 1024
39
+ legacy: False
40
+
41
+ unet_config:
42
+ target: cldm.cldm.ControlledUnetModel
43
+ params:
44
+ use_checkpoint: True
45
+ image_size: 32 # unused
46
+ in_channels: 4
47
+ out_channels: 4
48
+ model_channels: 320
49
+ attention_resolutions: [ 4, 2, 1 ]
50
+ num_res_blocks: 2
51
+ channel_mult: [ 1, 2, 4, 4 ]
52
+ num_head_channels: 64 # need to fix for flash-attn
53
+ use_spatial_transformer: True
54
+ use_linear_in_transformer: True
55
+ transformer_depth: 1
56
+ context_dim: 1024
57
+ legacy: False
58
+
59
+ first_stage_config:
60
+ target: ldm.models.autoencoder.AutoencoderKL
61
+ params:
62
+ embed_dim: 4
63
+ monitor: val/rec_loss
64
+ ddconfig:
65
+ #attn_type: "vanilla-xformers"
66
+ double_z: true
67
+ z_channels: 4
68
+ resolution: 256
69
+ in_channels: 3
70
+ out_ch: 3
71
+ ch: 128
72
+ ch_mult:
73
+ - 1
74
+ - 2
75
+ - 4
76
+ - 4
77
+ num_res_blocks: 2
78
+ attn_resolutions: []
79
+ dropout: 0.0
80
+ lossconfig:
81
+ target: torch.nn.Identity
82
+
83
+ cond_stage_config:
84
+ target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
85
+ params:
86
+ freeze: True
87
+ layer: "penultimate"
88
+ # device: "cpu" #TODO: specify
model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5f82f4af7d69b0ffdff6bf3d1b8dc6b13bbf81e28ea0fbacbf68824d2c1f652
3
+ size 8129070351
transfer.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf import OmegaConf
2
+ from scripts.rendertext_tool import Render_Text, load_model_from_config
3
+ import torch
4
+ cfg = OmegaConf.load("config_cuda.yaml")
5
+ model = load_model_from_config(cfg, "model_states.pt", verbose=True)
6
+
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ with model.ema_scope("store ema weights"):
9
+ file_content = {
10
+ 'state_dict': model.state_dict()
11
+ }
12
+ torch.save(file_content, "model.ckpt")
13
+ print("has stored the transfered ckpt.")
14
+ print("trial ends!")