Upload 3 files
Browse files- config.json +9 -3
- config.yaml +14 -9
- model.safetensors +2 -2
config.json
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
{
|
2 |
"n_obs_steps": 5,
|
3 |
"n_action_pred_token": 7,
|
4 |
-
"
|
5 |
"input_shapes": {
|
6 |
"observation.image": [
|
7 |
3,
|
@@ -33,16 +33,22 @@
|
|
33 |
"pretrained_backbone_weights": null,
|
34 |
"use_group_norm": true,
|
35 |
"spatial_softmax_num_keypoints": 32,
|
36 |
-
"
|
37 |
"vqvae_groups": 2,
|
38 |
"vqvae_n_embed": 16,
|
39 |
"vqvae_embedding_dim": 256,
|
|
|
40 |
"gpt_block_size": 500,
|
|
|
41 |
"gpt_output_dim": 512,
|
42 |
"gpt_n_layer": 8,
|
43 |
"gpt_n_head": 8,
|
44 |
"gpt_hidden_dim": 512,
|
45 |
"dropout": 0.1,
|
46 |
"mlp_hidden_dim": 1024,
|
47 |
-
"offset_loss_weight": 10000.0
|
|
|
|
|
|
|
|
|
48 |
}
|
|
|
1 |
{
|
2 |
"n_obs_steps": 5,
|
3 |
"n_action_pred_token": 7,
|
4 |
+
"action_chunk_size": 5,
|
5 |
"input_shapes": {
|
6 |
"observation.image": [
|
7 |
3,
|
|
|
33 |
"pretrained_backbone_weights": null,
|
34 |
"use_group_norm": true,
|
35 |
"spatial_softmax_num_keypoints": 32,
|
36 |
+
"n_vqvae_training_steps": 20000,
|
37 |
"vqvae_groups": 2,
|
38 |
"vqvae_n_embed": 16,
|
39 |
"vqvae_embedding_dim": 256,
|
40 |
+
"vqvae_enc_hidden_dim": 128,
|
41 |
"gpt_block_size": 500,
|
42 |
+
"gpt_input_dim": 512,
|
43 |
"gpt_output_dim": 512,
|
44 |
"gpt_n_layer": 8,
|
45 |
"gpt_n_head": 8,
|
46 |
"gpt_hidden_dim": 512,
|
47 |
"dropout": 0.1,
|
48 |
"mlp_hidden_dim": 1024,
|
49 |
+
"offset_loss_weight": 10000.0,
|
50 |
+
"primary_code_loss_weight": 5.0,
|
51 |
+
"secondary_code_loss_weight": 0.5,
|
52 |
+
"bet_softmax_temperature": 0.1,
|
53 |
+
"sequentially_select": false
|
54 |
}
|
config.yaml
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
device: cuda
|
2 |
use_amp: false
|
3 |
seed: 100000
|
@@ -11,7 +12,8 @@ training:
|
|
11 |
eval_freq: 20000
|
12 |
save_freq: 20000
|
13 |
log_freq: 250
|
14 |
-
|
|
|
15 |
batch_size: 64
|
16 |
grad_clip_norm: 10
|
17 |
lr: 0.0001
|
@@ -23,7 +25,7 @@ training:
|
|
23 |
adam_eps: 1.0e-08
|
24 |
adam_weight_decay: 1.0e-06
|
25 |
vqvae_lr: 0.001
|
26 |
-
|
27 |
bet_weight_decay: 0.0002
|
28 |
bet_learning_rate: 5.5e-05
|
29 |
bet_betas:
|
@@ -71,13 +73,16 @@ fps: 10
|
|
71 |
env:
|
72 |
name: pusht
|
73 |
task: PushT-v0
|
74 |
-
from_pixels: true
|
75 |
-
pixels_only: false
|
76 |
image_size: 96
|
77 |
-
episode_length: 300
|
78 |
-
fps: ${fps}
|
79 |
state_dim: 2
|
80 |
action_dim: 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
override_dataset_stats:
|
82 |
observation.image:
|
83 |
mean:
|
@@ -106,7 +111,7 @@ policy:
|
|
106 |
name: vqbet
|
107 |
n_obs_steps: 5
|
108 |
n_action_pred_token: 7
|
109 |
-
|
110 |
input_shapes:
|
111 |
observation.image:
|
112 |
- 3
|
@@ -130,7 +135,7 @@ policy:
|
|
130 |
pretrained_backbone_weights: null
|
131 |
use_group_norm: true
|
132 |
spatial_softmax_num_keypoints: 32
|
133 |
-
discretize_step: ${training.
|
134 |
vqvae_groups: 2
|
135 |
vqvae_n_embed: 16
|
136 |
vqvae_embedding_dim: 256
|
@@ -141,10 +146,10 @@ policy:
|
|
141 |
gpt_n_layer: 8
|
142 |
gpt_n_head: 8
|
143 |
gpt_hidden_dim: 512
|
144 |
-
gpt_num_obs_mode: 2
|
145 |
dropout: 0.1
|
146 |
mlp_hidden_dim: 1024
|
147 |
offset_loss_weight: 10000.0
|
148 |
primary_code_loss_weight: 5.0
|
149 |
secondary_code_loss_weight: 0.5
|
150 |
bet_softmax_temperature: 0.01
|
|
|
|
1 |
+
resume: false
|
2 |
device: cuda
|
3 |
use_amp: false
|
4 |
seed: 100000
|
|
|
12 |
eval_freq: 20000
|
13 |
save_freq: 20000
|
14 |
log_freq: 250
|
15 |
+
save_checkpoint: true
|
16 |
+
num_workers: 4
|
17 |
batch_size: 64
|
18 |
grad_clip_norm: 10
|
19 |
lr: 0.0001
|
|
|
25 |
adam_eps: 1.0e-08
|
26 |
adam_weight_decay: 1.0e-06
|
27 |
vqvae_lr: 0.001
|
28 |
+
n_vqvae_training_steps: 20000
|
29 |
bet_weight_decay: 0.0002
|
30 |
bet_learning_rate: 5.5e-05
|
31 |
bet_betas:
|
|
|
73 |
env:
|
74 |
name: pusht
|
75 |
task: PushT-v0
|
|
|
|
|
76 |
image_size: 96
|
|
|
|
|
77 |
state_dim: 2
|
78 |
action_dim: 2
|
79 |
+
fps: ${fps}
|
80 |
+
episode_length: 300
|
81 |
+
gym:
|
82 |
+
obs_type: pixels_agent_pos
|
83 |
+
render_mode: rgb_array
|
84 |
+
visualization_width: 384
|
85 |
+
visualization_height: 384
|
86 |
override_dataset_stats:
|
87 |
observation.image:
|
88 |
mean:
|
|
|
111 |
name: vqbet
|
112 |
n_obs_steps: 5
|
113 |
n_action_pred_token: 7
|
114 |
+
action_chunk_size: 5
|
115 |
input_shapes:
|
116 |
observation.image:
|
117 |
- 3
|
|
|
135 |
pretrained_backbone_weights: null
|
136 |
use_group_norm: true
|
137 |
spatial_softmax_num_keypoints: 32
|
138 |
+
discretize_step: ${training.n_vqvae_training_steps}
|
139 |
vqvae_groups: 2
|
140 |
vqvae_n_embed: 16
|
141 |
vqvae_embedding_dim: 256
|
|
|
146 |
gpt_n_layer: 8
|
147 |
gpt_n_head: 8
|
148 |
gpt_hidden_dim: 512
|
|
|
149 |
dropout: 0.1
|
150 |
mlp_hidden_dim: 1024
|
151 |
offset_loss_weight: 10000.0
|
152 |
primary_code_loss_weight: 5.0
|
153 |
secondary_code_loss_weight: 0.5
|
154 |
bet_softmax_temperature: 0.01
|
155 |
+
sequentially_select: false
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:df4f83724946d68ed33205379f7d82a9fb152577c330d5de86e8b270761f37f3
|
3 |
+
size 158154222
|