shivakanthsujit commited on
Commit
8cc1b26
1 Parent(s): 42c1704

Init Commit

Browse files
Files changed (4) hide show
  1. README.md +3 -2
  2. config.yaml +97 -15
  3. model/env_stats.pickle +2 -2
  4. model/model.pth +2 -2
README.md CHANGED
@@ -6,8 +6,9 @@ tags:
6
  - reinforcement-learning
7
  - mbrl-lib
8
  ---
9
- # **OneDTransitionRewardModel** Agent playing **mbrl-continuous-cartpole**
10
- This is a trained model of a **OneDTransitionRewardModel** agent playing **mbrl-continuous-cartpole**
 
11
  using [MBRL-Lib](https://github.com/facebookresearch/mbrl-lib).
12
 
13
  ## Usage (with MBRL-Lib)
 
6
  - reinforcement-learning
7
  - mbrl-lib
8
  ---
9
+ # **OneDTransitionRewardModel w/ SACAgent** Agent playing **mbrl-continuous-cartpole**
10
+ This is a trained model of a **OneDTransitionRewardModel w/ SACAgent** agent
11
+ playing **mbrl-continuous-cartpole**
12
  using [MBRL-Lib](https://github.com/facebookresearch/mbrl-lib).
13
 
14
  ## Usage (with MBRL-Lib)
config.yaml CHANGED
@@ -1,20 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  dynamics_model:
2
  _target_: mbrl.models.GaussianMLP
3
- num_layers: 3
4
- ensemble_size: 5
5
- device: cpu
 
 
6
  hid_size: 200
7
  in_size: 5
8
- out_size: 4
9
- deterministic: false
10
- propagation_method: fixed_model
11
- activation_fn_cfg:
12
- _target_: torch.nn.LeakyReLU
13
- negative_slope: 0.01
14
- algorithm:
15
- learned_rewards: false
16
- target_is_delta: true
17
- normalize: true
18
  overrides:
19
- model_batch_size: 32
20
- validation_ratio: 0.05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ action_optimizer:
2
+ _target_: mbrl.planning.CEMOptimizer
3
+ alpha: 0.1
4
+ clipped_normal: false
5
+ device: cpu:0
6
+ elite_ratio: 0.1
7
+ lower_bound: ???
8
+ num_iterations: 5
9
+ population_size: 350
10
+ return_mean_elites: true
11
+ upper_bound: ???
12
+ algorithm:
13
+ agent:
14
+ _target_: mbrl.third_party.pytorch_sac_pranz24.sac.SAC
15
+ action_space:
16
+ _target_: gym.env.Box
17
+ high:
18
+ - 1.0
19
+ low:
20
+ - -1.0
21
+ shape:
22
+ - 1
23
+ args:
24
+ alpha: 0.2
25
+ automatic_entropy_tuning: true
26
+ device: cpu:0
27
+ gamma: 0.99
28
+ hidden_size: 256
29
+ lr: 0.0003
30
+ policy: Gaussian
31
+ target_entropy: -0.05
32
+ target_update_interval: 4
33
+ tau: 0.005
34
+ num_inputs: 4
35
+ freq_train_model: 200
36
+ initial_exploration_steps: 5000
37
+ learned_rewards: true
38
+ name: mbpo
39
+ normalize: true
40
+ normalize_double_precision: true
41
+ num_eval_episodes: 1
42
+ random_initial_explore: false
43
+ real_data_ratio: 0.0
44
+ sac_samples_action: true
45
+ target_is_delta: true
46
+ debug_mode: false
47
+ device: cpu:0
48
  dynamics_model:
49
  _target_: mbrl.models.GaussianMLP
50
+ activation_fn_cfg:
51
+ _target_: torch.nn.SiLU
52
+ deterministic: false
53
+ device: cpu:0
54
+ ensemble_size: 7
55
  hid_size: 200
56
  in_size: 5
57
+ learn_logvar_bounds: false
58
+ num_layers: 4
59
+ out_size: 5
60
+ propagation_method: random_model
61
+ experiment: default
62
+ log_frequency_agent: 1000
 
 
 
 
63
  overrides:
64
+ cem_alpha: 0.1
65
+ cem_clipped_normal: false
66
+ cem_elite_ratio: 0.1
67
+ cem_num_iters: 5
68
+ cem_population_size: 350
69
+ effective_model_rollouts_per_step: 400
70
+ env: cartpole_continuous
71
+ epoch_length: 200
72
+ freq_train_model: 200
73
+ model_batch_size: 256
74
+ model_lr: 0.001
75
+ model_wd: 5.0e-05
76
+ num_elites: 5
77
+ num_epochs_to_retain_sac_buffer: 1
78
+ num_sac_updates_per_step: 20
79
+ num_steps: 5000
80
+ patience: 5
81
+ planning_horizon: 15
82
+ rollout_schedule:
83
+ - 1
84
+ - 15
85
+ - 1
86
+ - 1
87
+ sac_alpha: 0.2
88
+ sac_automatic_entropy_tuning: true
89
+ sac_batch_size: 256
90
+ sac_gamma: 0.99
91
+ sac_hidden_size: 256
92
+ sac_lr: 0.0003
93
+ sac_policy: Gaussian
94
+ sac_target_entropy: -0.05
95
+ sac_target_update_interval: 4
96
+ sac_tau: 0.005
97
+ sac_updates_every_steps: 1
98
+ trial_length: 200
99
+ validation_ratio: 0.2
100
+ root_dir: ./logs
101
+ save_video: false
102
+ seed: 0
model/env_stats.pickle CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:69554307f755e9042cf296a7c89d744651bea8ae81d226f8e3150b50f8e8ac01
3
- size 238
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:41c1ac53edc417a20114e12671aea2434e2e9b5125ebfef999e87c267d9fb5c8
3
+ size 278
model/model.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ae823dd889cf4dc2f9c541ba0811c9f9ccb70575c7b768dfc01f35a5cc1073b9
3
- size 1667439
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a51a85c2bd7f4b05886da8820eb36cc4032084cd2acf14c0ae1d579ccfa9b2dc
3
+ size 3470565