vwxyzjn commited on
Commit
4e85d86
1 Parent(s): 06c1c03

pushing model

Browse files
.gitattributes CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  videos/BeamRider-v5__cleanba_ppo_envpool_impala_atari_wrapper__2__ec54eea3-e012-4a37-b5e0-2d1f851254c8-eval/0.mp4 filter=lfs diff=lfs merge=lfs -text
36
  replay.mp4 filter=lfs diff=lfs merge=lfs -text
37
  cleanba_ppo_envpool_impala_atari_wrapper.cleanrl_model filter=lfs diff=lfs merge=lfs -text
 
 
35
  videos/BeamRider-v5__cleanba_ppo_envpool_impala_atari_wrapper__2__ec54eea3-e012-4a37-b5e0-2d1f851254c8-eval/0.mp4 filter=lfs diff=lfs merge=lfs -text
36
  replay.mp4 filter=lfs diff=lfs merge=lfs -text
37
  cleanba_ppo_envpool_impala_atari_wrapper.cleanrl_model filter=lfs diff=lfs merge=lfs -text
38
+ videos/BeamRider-v5__cleanba_ppo_envpool_impala_atari_wrapper__2__6f600660-2e9e-43ae-be8d-4fdd4ef75e94-eval/0.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -16,7 +16,7 @@ model-index:
16
  type: BeamRider-v5
17
  metrics:
18
  - type: mean_reward
19
- value: 34678.80 +/- 8978.15
20
  name: mean_reward
21
  verified: false
22
  ---
@@ -46,7 +46,7 @@ curl -OL https://huggingface.co/cleanrl/BeamRider-v5-cleanba_ppo_envpool_impala_
46
  curl -OL https://huggingface.co/cleanrl/BeamRider-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed2/raw/main/pyproject.toml
47
  curl -OL https://huggingface.co/cleanrl/BeamRider-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed2/raw/main/poetry.lock
48
  poetry install --all-extras
49
- python cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --learner-device-ids 1 2 3 --track --save-model --upload-model --hf-entity cleanrl --env-id BeamRider-v5 --seed 2
50
  ```
51
 
52
  # Hyperparameters
@@ -59,6 +59,7 @@ python cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --learner-devic
59
  'batch_size': 15360,
60
  'capture_video': False,
61
  'clip_coef': 0.1,
 
62
  'cuda': True,
63
  'distributed': True,
64
  'ent_coef': 0.01,
@@ -99,7 +100,7 @@ python cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --learner-devic
99
  'upload_model': True,
100
  'vf_coef': 0.5,
101
  'wandb_entity': None,
102
- 'wandb_project_name': 'cleanRL',
103
  'world_size': 2}
104
  ```
105
 
 
16
  type: BeamRider-v5
17
  metrics:
18
  - type: mean_reward
19
+ value: 38510.80 +/- 14083.44
20
  name: mean_reward
21
  verified: false
22
  ---
 
46
  curl -OL https://huggingface.co/cleanrl/BeamRider-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed2/raw/main/pyproject.toml
47
  curl -OL https://huggingface.co/cleanrl/BeamRider-v5-cleanba_ppo_envpool_impala_atari_wrapper-seed2/raw/main/poetry.lock
48
  poetry install --all-extras
49
+ python cleanba_ppo_envpool_impala_atari_wrapper.py --distributed --learner-device-ids 1 2 3 --track --wandb-project-name cleanba --save-model --upload-model --hf-entity cleanrl --env-id BeamRider-v5 --seed 2
50
  ```
51
 
52
  # Hyperparameters
 
59
  'batch_size': 15360,
60
  'capture_video': False,
61
  'clip_coef': 0.1,
62
+ 'concurrency': True,
63
  'cuda': True,
64
  'distributed': True,
65
  'ent_coef': 0.01,
 
100
  'upload_model': True,
101
  'vf_coef': 0.5,
102
  'wandb_entity': None,
103
+ 'wandb_project_name': 'cleanba',
104
  'world_size': 2}
105
  ```
106
 
cleanba_ppo_envpool_impala_atari_wrapper.cleanrl_model CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:33f41a8906bb907d438ed1b563cdff59832891d28df06d2c8cb64b4918be0020
3
- size 4369294
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fc08a5ce8cf660de0a2f8e04b495bf70504b33b161b66700c1807275eb9df801
3
+ size 4369307
cleanba_ppo_envpool_impala_atari_wrapper.py CHANGED
@@ -1,4 +1,3 @@
1
- # docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/ppo/#ppo_atari_envpool_async_jax_scan_impalanet_machadopy
2
  import argparse
3
  import os
4
  import random
@@ -26,7 +25,7 @@ import numpy as np
26
  import optax
27
  from flax.linen.initializers import constant, orthogonal
28
  from flax.training.train_state import TrainState
29
- from torch.utils.tensorboard import SummaryWriter
30
 
31
 
32
  def parse_args():
@@ -47,7 +46,7 @@ def parse_args():
47
  parser.add_argument("--wandb-entity", type=str, default=None,
48
  help="the entity (team) of wandb's project")
49
  parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
50
- help="weather to capture videos of the agent performances (check out `videos` folder)")
51
  parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
52
  help="whether to save model into the `runs/{run_name}` folder")
53
  parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
@@ -97,6 +96,8 @@ def parse_args():
97
  help="the device ids that learner workers will use")
98
  parser.add_argument("--distributed", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
99
  help="whether to use `jax.distirbuted`")
 
 
100
  parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
101
  help="whether to call block_until_ready() for profiling")
102
  parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
@@ -213,7 +214,7 @@ class AgentParams:
213
 
214
  @partial(jax.jit, static_argnums=(3))
215
  def get_action_and_value(
216
- params: TrainState,
217
  next_obs: np.ndarray,
218
  key: jax.random.PRNGKey,
219
  action_dim: int,
@@ -281,6 +282,20 @@ def prepare_data(
281
  return b_obs, b_actions, b_logprobs, b_advantages, b_returns
282
 
283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
  def rollout(
285
  key: jax.random.PRNGKey,
286
  args,
@@ -289,7 +304,7 @@ def rollout(
289
  writer,
290
  learner_devices,
291
  ):
292
- envs = make_env(args.env_id, args.seed, args.local_num_envs, args.async_batch_size)()
293
  len_actor_device_ids = len(args.actor_device_ids)
294
  global_step = 0
295
  # TRY NOT TO MODIFY: start the game
@@ -332,9 +347,13 @@ def rollout(
332
  # concurrently with the learning process. It also ensures the actor's policy version is only 1 step
333
  # behind the learner's policy version
334
  params_queue_get_time_start = time.time()
335
- if update != 2:
336
  params = params_queue.get()
337
  actor_policy_version += 1
 
 
 
 
338
  params_queue_get_time.append(time.time() - params_queue_get_time_start)
339
  writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
340
  rollout_time_start = time.time()
@@ -397,18 +416,29 @@ def rollout(
397
  writer.add_scalar("stats/inference_time", inference_time, global_step)
398
  writer.add_scalar("stats/storage_time", storage_time, global_step)
399
  writer.add_scalar("stats/env_send_time", env_send_time, global_step)
 
 
 
 
 
 
 
 
 
 
400
 
401
  payload = (
402
  global_step,
403
  actor_policy_version,
404
  update,
405
  obs,
406
- dones,
407
  values,
408
  actions,
409
  logprobs,
 
410
  env_ids,
411
  rewards,
 
412
  )
413
  if update == 1 or not args.test_actor_learner_throughput:
414
  rollout_queue_put_time_start = time.time()
@@ -717,15 +747,21 @@ if __name__ == "__main__":
717
  actor_policy_version,
718
  update,
719
  obs,
720
- dones,
721
  values,
722
  actions,
723
  logprobs,
 
724
  env_ids,
725
  rewards,
 
726
  ) = rollout_queue.get()
727
  rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
728
  writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
 
 
 
 
 
729
 
730
  data_transfer_time_start = time.time()
731
  b_obs, b_actions, b_logprobs, b_advantages, b_returns = prepare_data(
@@ -780,11 +816,22 @@ if __name__ == "__main__":
780
  break
781
 
782
  if args.save_model and args.local_rank == 0:
 
 
783
  agent_state = flax.jax_utils.unreplicate(agent_state)
784
  model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
785
  with open(model_path, "wb") as f:
786
  f.write(
787
- flax.serialization.to_bytes([ vars(args), [ agent_state.params.network_params, agent_state.params.actor_params, agent_state.params.critic_params, ],])
 
 
 
 
 
 
 
 
 
788
  )
789
  print(f"model saved to {model_path}")
790
  from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate
 
 
1
  import argparse
2
  import os
3
  import random
 
25
  import optax
26
  from flax.linen.initializers import constant, orthogonal
27
  from flax.training.train_state import TrainState
28
+ from tensorboardX import SummaryWriter
29
 
30
 
31
  def parse_args():
 
46
  parser.add_argument("--wandb-entity", type=str, default=None,
47
  help="the entity (team) of wandb's project")
48
  parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
49
+ help="whether to capture videos of the agent performances (check out `videos` folder)")
50
  parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
51
  help="whether to save model into the `runs/{run_name}` folder")
52
  parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
 
96
  help="the device ids that learner workers will use")
97
  parser.add_argument("--distributed", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
98
  help="whether to use `jax.distirbuted`")
99
+ parser.add_argument("--concurrency", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True,
100
+ help="whether to run the actor and learner concurrently")
101
  parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
102
  help="whether to call block_until_ready() for profiling")
103
  parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True,
 
214
 
215
  @partial(jax.jit, static_argnums=(3))
216
  def get_action_and_value(
217
+ params: flax.core.FrozenDict,
218
  next_obs: np.ndarray,
219
  key: jax.random.PRNGKey,
220
  action_dim: int,
 
282
  return b_obs, b_actions, b_logprobs, b_advantages, b_returns
283
 
284
 
285
+ @jax.jit
286
+ def make_bulk_array(
287
+ obs: list,
288
+ values: list,
289
+ actions: list,
290
+ logprobs: list,
291
+ ):
292
+ obs = jnp.asarray(obs)
293
+ values = jnp.asarray(values)
294
+ actions = jnp.asarray(actions)
295
+ logprobs = jnp.asarray(logprobs)
296
+ return obs, values, actions, logprobs
297
+
298
+
299
  def rollout(
300
  key: jax.random.PRNGKey,
301
  args,
 
304
  writer,
305
  learner_devices,
306
  ):
307
+ envs = make_env(args.env_id, args.seed + jax.process_index(), args.local_num_envs, args.async_batch_size)()
308
  len_actor_device_ids = len(args.actor_device_ids)
309
  global_step = 0
310
  # TRY NOT TO MODIFY: start the game
 
347
  # concurrently with the learning process. It also ensures the actor's policy version is only 1 step
348
  # behind the learner's policy version
349
  params_queue_get_time_start = time.time()
350
+ if not args.concurrency:
351
  params = params_queue.get()
352
  actor_policy_version += 1
353
+ else:
354
+ if update != 2:
355
+ params = params_queue.get()
356
+ actor_policy_version += 1
357
  params_queue_get_time.append(time.time() - params_queue_get_time_start)
358
  writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step)
359
  rollout_time_start = time.time()
 
416
  writer.add_scalar("stats/inference_time", inference_time, global_step)
417
  writer.add_scalar("stats/storage_time", storage_time, global_step)
418
  writer.add_scalar("stats/env_send_time", env_send_time, global_step)
419
+ # `make_bulk_array` is actually important. It accumulates the data from the lists
420
+ # into single bulk arrays, which later makes transferring the data to the learner's
421
+ # device slightly faster. See https://wandb.ai/costa-huang/cleanRL/reports/data-transfer-optimization--VmlldzozNjU5MTg1
422
+ if args.learner_device_ids[0] != args.actor_device_ids[0]:
423
+ obs, values, actions, logprobs = make_bulk_array(
424
+ obs,
425
+ values,
426
+ actions,
427
+ logprobs,
428
+ )
429
 
430
  payload = (
431
  global_step,
432
  actor_policy_version,
433
  update,
434
  obs,
 
435
  values,
436
  actions,
437
  logprobs,
438
+ dones,
439
  env_ids,
440
  rewards,
441
+ np.mean(params_queue_get_time),
442
  )
443
  if update == 1 or not args.test_actor_learner_throughput:
444
  rollout_queue_put_time_start = time.time()
 
747
  actor_policy_version,
748
  update,
749
  obs,
 
750
  values,
751
  actions,
752
  logprobs,
753
+ dones,
754
  env_ids,
755
  rewards,
756
+ avg_params_queue_get_time,
757
  ) = rollout_queue.get()
758
  rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start)
759
  writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step)
760
+ writer.add_scalar(
761
+ "stats/rollout_params_queue_get_time_diff",
762
+ np.mean(rollout_queue_get_time) - avg_params_queue_get_time,
763
+ global_step,
764
+ )
765
 
766
  data_transfer_time_start = time.time()
767
  b_obs, b_actions, b_logprobs, b_advantages, b_returns = prepare_data(
 
816
  break
817
 
818
  if args.save_model and args.local_rank == 0:
819
+ if args.distributed:
820
+ jax.distributed.shutdown()
821
  agent_state = flax.jax_utils.unreplicate(agent_state)
822
  model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
823
  with open(model_path, "wb") as f:
824
  f.write(
825
+ flax.serialization.to_bytes(
826
+ [
827
+ vars(args),
828
+ [
829
+ agent_state.params.network_params,
830
+ agent_state.params.actor_params,
831
+ agent_state.params.critic_params,
832
+ ],
833
+ ]
834
+ )
835
  )
836
  print(f"model saved to {model_path}")
837
  from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate
events.out.tfevents.1676644137.ip-26-0-134-150.813450.0 → events.out.tfevents.1678205983.ip-26-0-138-178 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a843352011c6d4d6ea12352653060463d13a2b98e2cdc39aeb8f106dc3f1dc8
3
- size 4754791
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58c2d0f32a4a85f6755842c826d14b1d3fa388c60e2064530b32d088addccd0b
3
+ size 5017750
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -1,178 +1,34 @@
1
  [tool.poetry]
2
- name = "cleanrl"
3
- version = "1.1.0"
4
- description = "High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features"
5
  authors = ["Costa Huang <[email protected]>"]
 
6
  packages = [
7
- { include = "cleanrl" },
8
  { include = "cleanrl_utils" },
9
  ]
10
- keywords = ["reinforcement", "machine", "learning", "research"]
11
- license="MIT"
12
- readme = "README.md"
13
 
14
  [tool.poetry.dependencies]
15
- python = ">=3.7.1,<3.10"
16
- tensorboard = "^2.10.0"
17
- wandb = "^0.13.6"
 
 
 
 
 
 
 
 
18
  gym = "0.23.1"
19
- torch = ">=1.12.1"
20
- stable-baselines3 = "1.2.0"
21
- gymnasium = "^0.26.3"
22
  moviepy = "^1.0.3"
23
- pygame = "2.1.0"
24
- huggingface-hub = "^0.11.1"
25
 
26
- ale-py = {version = "0.7.4", optional = true}
27
- AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2"}
28
- opencv-python = {version = "^4.6.0.66", optional = true}
29
- pybullet = {version = "3.1.8", optional = true}
30
- procgen = {version = "^0.10.7", optional = true}
31
- pytest = {version = "^7.1.3", optional = true}
32
- mujoco = {version = "^2.2", optional = true}
33
- imageio = {version = "^2.14.1", optional = true}
34
- free-mujoco-py = {version = "^2.1.6", optional = true}
35
- mkdocs-material = {version = "^8.4.3", optional = true}
36
- markdown-include = {version = "^0.7.0", optional = true}
37
- jax = {version = "^0.3.17", optional = true}
38
- jaxlib = {version = "^0.3.15", optional = true}
39
- flax = {version = "^0.6.0", optional = true}
40
- optuna = {version = "^3.0.1", optional = true}
41
- optuna-dashboard = {version = "^0.7.2", optional = true}
42
- rich = {version = "<12.0", optional = true}
43
- envpool = {version = "^0.8.1", optional = true}
44
- PettingZoo = {version = "1.18.1", optional = true}
45
- SuperSuit = {version = "3.4.0", optional = true}
46
- multi-agent-ale-py = {version = "0.1.11", optional = true}
47
- boto3 = {version = "^1.24.70", optional = true}
48
- awscli = {version = "^1.25.71", optional = true}
49
- shimmy = {version = "^0.1.0", optional = true}
50
- dm-control = {version = "^1.0.8", optional = true}
51
 
52
  [tool.poetry.group.dev.dependencies]
53
- pre-commit = "^2.20.0"
54
-
55
- [tool.poetry.group.atari]
56
- optional = true
57
- [tool.poetry.group.atari.dependencies]
58
- ale-py = "0.7.4"
59
- AutoROM = {extras = ["accept-rom-license"], version = "^0.4.2"}
60
- opencv-python = "^4.6.0.66"
61
-
62
- [tool.poetry.group.pybullet]
63
- optional = true
64
- [tool.poetry.group.pybullet.dependencies]
65
- pybullet = "3.1.8"
66
-
67
- [tool.poetry.group.procgen]
68
- optional = true
69
- [tool.poetry.group.procgen.dependencies]
70
- procgen = "^0.10.7"
71
-
72
- [tool.poetry.group.pytest]
73
- optional = true
74
- [tool.poetry.group.pytest.dependencies]
75
- pytest = "^7.1.3"
76
-
77
- [tool.poetry.group.mujoco]
78
- optional = true
79
- [tool.poetry.group.mujoco.dependencies]
80
- mujoco = "^2.2"
81
- imageio = "^2.14.1"
82
-
83
- [tool.poetry.group.mujoco_py]
84
- optional = true
85
- [tool.poetry.group.mujoco_py.dependencies]
86
- free-mujoco-py = "^2.1.6"
87
-
88
- [tool.poetry.group.docs]
89
- optional = true
90
- [tool.poetry.group.docs.dependencies]
91
- mkdocs-material = "^8.4.3"
92
- markdown-include = "^0.7.0"
93
-
94
- [tool.poetry.group.jax]
95
- optional = true
96
- [tool.poetry.group.jax.dependencies]
97
- jax = "^0.3.17"
98
- jaxlib = "^0.3.15"
99
- flax = "^0.6.0"
100
-
101
- [tool.poetry.group.optuna]
102
- optional = true
103
- [tool.poetry.group.optuna.dependencies]
104
- optuna = "^3.0.1"
105
- optuna-dashboard = "^0.7.2"
106
- rich = "<12.0"
107
-
108
- [tool.poetry.group.envpool]
109
- optional = true
110
- [tool.poetry.group.envpool.dependencies]
111
- envpool = "^0.8.1"
112
-
113
- [tool.poetry.group.pettingzoo]
114
- optional = true
115
- [tool.poetry.group.pettingzoo.dependencies]
116
- PettingZoo = "1.18.1"
117
- SuperSuit = "3.4.0"
118
- multi-agent-ale-py = "0.1.11"
119
-
120
- [tool.poetry.group.cloud]
121
- optional = true
122
- [tool.poetry.group.cloud.dependencies]
123
- boto3 = "^1.24.70"
124
- awscli = "^1.25.71"
125
-
126
- [tool.poetry.group.isaacgym]
127
- optional = true
128
- [tool.poetry.group.isaacgym.dependencies]
129
- isaacgymenvs = {git = "https://github.com/vwxyzjn/IsaacGymEnvs.git", rev = "poetry"}
130
- isaacgym = {path = "cleanrl/ppo_continuous_action_isaacgym/isaacgym", develop = true}
131
-
132
- [tool.poetry.group.dm_control]
133
- optional = true
134
- [tool.poetry.group.dm_control.dependencies]
135
- shimmy = "^0.1.0"
136
- dm-control = "^1.0.8"
137
- mujoco = "^2.2"
138
 
139
  [build-system]
140
  requires = ["poetry-core"]
141
  build-backend = "poetry.core.masonry.api"
142
-
143
- [tool.poetry.extras]
144
- atari = ["ale-py", "AutoROM", "opencv-python"]
145
- pybullet = ["pybullet"]
146
- procgen = ["procgen"]
147
- plot = ["pandas", "seaborn"]
148
- pytest = ["pytest"]
149
- mujoco = ["mujoco", "imageio"]
150
- mujoco_py = ["free-mujoco-py"]
151
- jax = ["jax", "jaxlib", "flax"]
152
- docs = ["mkdocs-material", "markdown-include"]
153
- envpool = ["envpool"]
154
- optuna = ["optuna", "optuna-dashboard", "rich"]
155
- pettingzoo = ["PettingZoo", "SuperSuit", "multi-agent-ale-py"]
156
- cloud = ["boto3", "awscli"]
157
- dm_control = ["shimmy", "dm-control", "mujoco"]
158
-
159
- # dependencies for algorithm variant (useful when you want to run a specific algorithm)
160
- dqn = []
161
- dqn_atari = ["ale-py", "AutoROM", "opencv-python"]
162
- dqn_jax = ["jax", "jaxlib", "flax"]
163
- dqn_atari_jax = [
164
- "ale-py", "AutoROM", "opencv-python", # atari
165
- "jax", "jaxlib", "flax" # jax
166
- ]
167
- c51 = []
168
- c51_atari = ["ale-py", "AutoROM", "opencv-python"]
169
- c51_jax = ["jax", "jaxlib", "flax"]
170
- c51_atari_jax = [
171
- "ale-py", "AutoROM", "opencv-python", # atari
172
- "jax", "jaxlib", "flax" # jax
173
- ]
174
- ppo_atari_envpool_xla_jax_scan = [
175
- "ale-py", "AutoROM", "opencv-python", # atari
176
- "jax", "jaxlib", "flax", # jax
177
- "envpool", # envpool
178
- ]
 
1
  [tool.poetry]
2
+ name = "cleanba"
3
+ version = "0.1.0"
4
+ description = ""
5
  authors = ["Costa Huang <[email protected]>"]
6
+ readme = "README.md"
7
  packages = [
8
+ { include = "cleanba" },
9
  { include = "cleanrl_utils" },
10
  ]
 
 
 
11
 
12
  [tool.poetry.dependencies]
13
+ python = "^3.8"
14
+ tensorboard = "^2.12.0"
15
+ envpool = "^0.8.1"
16
+ jax = "0.3.25"
17
+ flax = "0.6.0"
18
+ optax = "0.1.3"
19
+ huggingface-hub = "^0.12.0"
20
+ jaxlib = "0.3.25"
21
+ wandb = "^0.13.10"
22
+ tensorboardx = "^2.5.1"
23
+ chex = "0.1.5"
24
  gym = "0.23.1"
25
+ opencv-python = "^4.7.0.68"
 
 
26
  moviepy = "^1.0.3"
 
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  [tool.poetry.group.dev.dependencies]
30
+ pre-commit = "^3.0.4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  [build-system]
33
  requires = ["poetry-core"]
34
  build-backend = "poetry.core.masonry.api"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
replay.mp4 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:37da3d7f3ca25e53d6e8b580e71982bdd7e25405dbe67605f028ea81711a968b
3
- size 1616695
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab3d5ac955377ca6ba5689892187e23d8a4efba0edcb457eca069d68c853ff3a
3
+ size 1702554
videos/{BeamRider-v5__cleanba_ppo_envpool_impala_atari_wrapper__2__ec54eea3-e012-4a37-b5e0-2d1f851254c8-eval → BeamRider-v5__cleanba_ppo_envpool_impala_atari_wrapper__2__6f600660-2e9e-43ae-be8d-4fdd4ef75e94-eval}/0.mp4 RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:37da3d7f3ca25e53d6e8b580e71982bdd7e25405dbe67605f028ea81711a968b
3
- size 1616695
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab3d5ac955377ca6ba5689892187e23d8a4efba0edcb457eca069d68c853ff3a
3
+ size 1702554