sgoodfriend commited on
Commit
c0392b0
1 Parent(s): 76b718a

A2C playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +16 -15
  2. environment.yml +1 -1
  3. pyproject.toml +6 -3
  4. replay.meta.json +1 -1
  5. replay.mp4 +0 -0
  6. rl_algo_impls/a2c/a2c.py +11 -9
  7. rl_algo_impls/a2c/optimize.py +9 -5
  8. rl_algo_impls/dqn/dqn.py +15 -8
  9. rl_algo_impls/dqn/q_net.py +1 -1
  10. rl_algo_impls/huggingface_publish.py +6 -7
  11. rl_algo_impls/hyperparams/a2c.yml +13 -12
  12. rl_algo_impls/hyperparams/dqn.yml +3 -3
  13. rl_algo_impls/hyperparams/ppo.yml +123 -10
  14. rl_algo_impls/hyperparams/vpg.yml +6 -6
  15. rl_algo_impls/optimize.py +61 -28
  16. rl_algo_impls/ppo/ppo.py +27 -16
  17. rl_algo_impls/runner/config.py +12 -3
  18. rl_algo_impls/runner/evaluate.py +5 -6
  19. rl_algo_impls/runner/running_utils.py +15 -23
  20. rl_algo_impls/runner/selfplay_evaluate.py +142 -0
  21. rl_algo_impls/runner/train.py +36 -21
  22. rl_algo_impls/selfplay_enjoy.py +53 -0
  23. rl_algo_impls/shared/actor/__init__.py +1 -1
  24. rl_algo_impls/shared/actor/actor.py +10 -9
  25. rl_algo_impls/shared/actor/categorical.py +3 -3
  26. rl_algo_impls/shared/actor/gaussian.py +3 -3
  27. rl_algo_impls/shared/actor/gridnet.py +4 -4
  28. rl_algo_impls/shared/actor/gridnet_decoder.py +3 -4
  29. rl_algo_impls/shared/actor/make_actor.py +8 -5
  30. rl_algo_impls/shared/actor/multi_discrete.py +3 -3
  31. rl_algo_impls/shared/actor/state_dependent_noise.py +14 -15
  32. rl_algo_impls/shared/algorithm.py +5 -5
  33. rl_algo_impls/shared/callbacks/__init__.py +1 -0
  34. rl_algo_impls/shared/callbacks/eval_callback.py +24 -4
  35. rl_algo_impls/shared/callbacks/microrts_reward_decay_callback.py +36 -0
  36. rl_algo_impls/shared/callbacks/optimize_callback.py +1 -1
  37. rl_algo_impls/shared/callbacks/self_play_callback.py +34 -0
  38. rl_algo_impls/shared/encoder/cnn.py +1 -1
  39. rl_algo_impls/shared/encoder/encoder.py +1 -1
  40. rl_algo_impls/shared/encoder/gridnet_encoder.py +1 -1
  41. rl_algo_impls/shared/encoder/impala_cnn.py +1 -1
  42. rl_algo_impls/shared/encoder/microrts_cnn.py +1 -1
  43. rl_algo_impls/shared/encoder/nature_cnn.py +1 -1
  44. rl_algo_impls/shared/gae.py +1 -1
  45. rl_algo_impls/shared/module/{module.py → utils.py} +0 -0
  46. rl_algo_impls/shared/policy/{on_policy.py → actor_critic.py} +62 -95
  47. rl_algo_impls/shared/policy/actor_critic_network/__init__.py +11 -0
  48. rl_algo_impls/shared/policy/actor_critic_network/connected_trio.py +118 -0
  49. rl_algo_impls/shared/policy/actor_critic_network/network.py +57 -0
  50. rl_algo_impls/shared/policy/actor_critic_network/separate_actor_critic.py +128 -0
README.md CHANGED
@@ -23,17 +23,17 @@ model-index:
23
 
24
  This is a trained model of a **A2C** agent playing **CartPole-v1** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
- All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/7lx79bf0.
27
 
28
  ## Training Results
29
 
30
- This model was trained from 3 trainings of **A2C** agents using different initial seeds. These agents were trained by checking out [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
- | a2c | CartPole-v1 | 1 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/fou5cee8) |
35
- | a2c | CartPole-v1 | 2 | 500 | 0 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/vu1741ee) |
36
- | a2c | CartPole-v1 | 3 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/lcpgx9n6) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
@@ -53,10 +53,10 @@ login`.
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
- [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
- python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/vu1741ee
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -68,11 +68,11 @@ notebook.
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
- commit the agent was trained on: [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
75
- python train.py --algo a2c --env CartPole-v1 --seed 2
76
  ```
77
 
78
  Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -83,7 +83,7 @@ notebook.
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
- This and other models from https://api.wandb.ai/links/sgoodfriend/7lx79bf0 were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone [email protected]:sgoodfriend/rl-algo-impls.git
@@ -113,18 +113,19 @@ env: CartPole-v1
113
  env_hyperparams:
114
  n_envs: 8
115
  env_id: null
116
- eval_params: {}
 
117
  n_timesteps: 500000
118
  policy_hyperparams: {}
119
- seed: 2
120
  use_deterministic_algorithms: true
121
  wandb_entity: null
122
  wandb_group: null
123
  wandb_project_name: rl-algo-impls-benchmarks
124
  wandb_tags:
125
- - benchmark_0511de3
126
- - host_152-67-249-42
127
  - branch_main
128
- - v0.0.8
129
 
130
  ```
 
23
 
24
  This is a trained model of a **A2C** agent playing **CartPole-v1** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
+ All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/ysd5gj7p.
27
 
28
  ## Training Results
29
 
30
+ This model was trained from 3 trainings of **A2C** agents using different initial seeds. These agents were trained by checking out [983cb75](https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
+ | a2c | CartPole-v1 | 1 | 500 | 0 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/y4b3inhl) |
35
+ | a2c | CartPole-v1 | 2 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/16d9dk25) |
36
+ | a2c | CartPole-v1 | 3 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/jh46gegy) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
 
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
+ [983cb75](https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
+ python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/y4b3inhl
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
 
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
+ commit the agent was trained on: [983cb75](https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
75
+ python train.py --algo a2c --env CartPole-v1 --seed 1
76
  ```
77
 
78
  Setup hasn't been completely worked out yet, so you might be best served by using Google
 
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
+ This and other models from https://api.wandb.ai/links/sgoodfriend/ysd5gj7p were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone [email protected]:sgoodfriend/rl-algo-impls.git
 
113
  env_hyperparams:
114
  n_envs: 8
115
  env_id: null
116
+ eval_hyperparams: {}
117
+ microrts_reward_decay_callback: false
118
  n_timesteps: 500000
119
  policy_hyperparams: {}
120
+ seed: 1
121
  use_deterministic_algorithms: true
122
  wandb_entity: null
123
  wandb_group: null
124
  wandb_project_name: rl-algo-impls-benchmarks
125
  wandb_tags:
126
+ - benchmark_983cb75
127
+ - host_129-159-43-75
128
  - branch_main
129
+ - v0.0.9
130
 
131
  ```
environment.yml CHANGED
@@ -4,7 +4,7 @@ channels:
4
  - conda-forge
5
  - nodefaults
6
  dependencies:
7
- - python>=3.8, <3.11
8
  - mamba
9
  - pip
10
  - pytorch
 
4
  - conda-forge
5
  - nodefaults
6
  dependencies:
7
+ - python>=3.8, <3.10
8
  - mamba
9
  - pip
10
  - pytorch
pyproject.toml CHANGED
@@ -1,6 +1,6 @@
1
  [project]
2
  name = "rl_algo_impls"
3
- version = "0.0.8"
4
  description = "Implementations of reinforcement learning algorithms"
5
  authors = [
6
  {name = "Scott Goodfriend", email = "[email protected]"},
@@ -56,14 +56,17 @@ procgen = [
56
  "glfw >= 1.12.0, < 1.13",
57
  "procgen; platform_machine=='x86_64'",
58
  ]
59
- microrts-old = [
60
  "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
61
  "gym-microrts == 0.2.0", # Match ppo-implementation-details
62
  ]
63
- microrts = [
64
  "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
65
  "gym-microrts == 0.3.2",
66
  ]
 
 
 
67
  jupyter = [
68
  "jupyter",
69
  "notebook"
 
1
  [project]
2
  name = "rl_algo_impls"
3
+ version = "0.0.9"
4
  description = "Implementations of reinforcement learning algorithms"
5
  authors = [
6
  {name = "Scott Goodfriend", email = "[email protected]"},
 
56
  "glfw >= 1.12.0, < 1.13",
57
  "procgen; platform_machine=='x86_64'",
58
  ]
59
+ microrts-ppo = [
60
  "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
61
  "gym-microrts == 0.2.0", # Match ppo-implementation-details
62
  ]
63
+ microrts-paper = [
64
  "numpy < 1.24.0", # Support for gym-microrts < 0.6.0
65
  "gym-microrts == 0.3.2",
66
  ]
67
+ microrts = [
68
+ "gym-microrts",
69
+ ]
70
  jupyter = [
71
  "jupyter",
72
  "notebook"
replay.meta.json CHANGED
@@ -1 +1 @@
1
- {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/tmp/tmpljyuzlcc/a2c-CartPole-v1/replay.mp4"]}, "episode": {"r": 500.0, "l": 500, "t": 3.352876}}
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/tmp/tmpc0iptd3_/a2c-CartPole-v1/replay.mp4"]}, "episodes": [{"r": 500.0, "l": 500, "t": 3.551568}]}
replay.mp4 CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
 
rl_algo_impls/a2c/a2c.py CHANGED
@@ -1,23 +1,23 @@
1
  import logging
 
 
 
2
  import numpy as np
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
6
-
7
- from time import perf_counter
8
  from torch.utils.tensorboard.writer import SummaryWriter
9
- from typing import Optional, TypeVar
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
- from rl_algo_impls.shared.callbacks.callback import Callback
13
  from rl_algo_impls.shared.gae import compute_advantages
14
- from rl_algo_impls.shared.policy.on_policy import ActorCritic
15
  from rl_algo_impls.shared.schedule import schedule, update_learning_rate
16
  from rl_algo_impls.shared.stats import log_scalars
17
  from rl_algo_impls.wrappers.vectorable_wrapper import (
18
  VecEnv,
19
- single_observation_space,
20
  single_action_space,
 
21
  )
22
 
23
  A2CSelf = TypeVar("A2CSelf", bound="A2C")
@@ -70,7 +70,7 @@ class A2C(Algorithm):
70
  def learn(
71
  self: A2CSelf,
72
  train_timesteps: int,
73
- callback: Optional[Callback] = None,
74
  total_timesteps: Optional[int] = None,
75
  start_timesteps: int = 0,
76
  ) -> A2CSelf:
@@ -193,8 +193,10 @@ class A2C(Algorithm):
193
  timesteps_elapsed,
194
  )
195
 
196
- if callback:
197
- if not callback.on_step(timesteps_elapsed=rollout_steps):
 
 
198
  logging.info(
199
  f"Callback terminated training at {timesteps_elapsed} timesteps"
200
  )
 
1
  import logging
2
+ from time import perf_counter
3
+ from typing import List, Optional, TypeVar
4
+
5
  import numpy as np
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
 
9
  from torch.utils.tensorboard.writer import SummaryWriter
 
10
 
11
  from rl_algo_impls.shared.algorithm import Algorithm
12
+ from rl_algo_impls.shared.callbacks import Callback
13
  from rl_algo_impls.shared.gae import compute_advantages
14
+ from rl_algo_impls.shared.policy.actor_critic import ActorCritic
15
  from rl_algo_impls.shared.schedule import schedule, update_learning_rate
16
  from rl_algo_impls.shared.stats import log_scalars
17
  from rl_algo_impls.wrappers.vectorable_wrapper import (
18
  VecEnv,
 
19
  single_action_space,
20
+ single_observation_space,
21
  )
22
 
23
  A2CSelf = TypeVar("A2CSelf", bound="A2C")
 
70
  def learn(
71
  self: A2CSelf,
72
  train_timesteps: int,
73
+ callbacks: Optional[List[Callback]] = None,
74
  total_timesteps: Optional[int] = None,
75
  start_timesteps: int = 0,
76
  ) -> A2CSelf:
 
193
  timesteps_elapsed,
194
  )
195
 
196
+ if callbacks:
197
+ if not all(
198
+ c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
199
+ ):
200
  logging.info(
201
  f"Callback terminated training at {timesteps_elapsed} timesteps"
202
  )
rl_algo_impls/a2c/optimize.py CHANGED
@@ -1,10 +1,10 @@
1
- import optuna
2
-
3
  from copy import deepcopy
4
 
5
- from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams
6
- from rl_algo_impls.shared.vec_env import make_eval_env
 
7
  from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
 
8
  from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
9
 
10
 
@@ -16,7 +16,11 @@ def sample_params(
16
  hyperparams = deepcopy(base_hyperparams)
17
 
18
  base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
19
- env = make_eval_env(base_config, base_env_hyperparams, override_n_envs=1)
 
 
 
 
20
 
21
  # env_hyperparams
22
  env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
 
 
 
1
  from copy import deepcopy
2
 
3
+ import optuna
4
+
5
+ from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams
6
  from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
7
+ from rl_algo_impls.shared.vec_env import make_eval_env
8
  from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
9
 
10
 
 
16
  hyperparams = deepcopy(base_hyperparams)
17
 
18
  base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
19
+ env = make_eval_env(
20
+ base_config,
21
+ base_env_hyperparams,
22
+ override_hparams={"n_envs": 1},
23
+ )
24
 
25
  # env_hyperparams
26
  env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
rl_algo_impls/dqn/dqn.py CHANGED
@@ -1,18 +1,19 @@
1
  import copy
2
- import numpy as np
3
  import random
 
 
 
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
-
8
- from collections import deque
9
  from torch.optim import Adam
10
  from torch.utils.tensorboard.writer import SummaryWriter
11
- from typing import NamedTuple, Optional, TypeVar
12
 
13
  from rl_algo_impls.dqn.policy import DQNPolicy
14
  from rl_algo_impls.shared.algorithm import Algorithm
15
- from rl_algo_impls.shared.callbacks.callback import Callback
16
  from rl_algo_impls.shared.schedule import linear_schedule
17
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
18
 
@@ -118,7 +119,7 @@ class DQN(Algorithm):
118
  self.max_grad_norm = max_grad_norm
119
 
120
  def learn(
121
- self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None
122
  ) -> DQNSelf:
123
  self.policy.train(True)
124
  obs = self.env.reset()
@@ -140,8 +141,14 @@ class DQN(Algorithm):
140
  if steps_since_target_update >= self.target_update_interval:
141
  self._update_target()
142
  steps_since_target_update = 0
143
- if callback:
144
- callback.on_step(timesteps_elapsed=rollout_steps)
 
 
 
 
 
 
145
  return self
146
 
147
  def train(self) -> None:
 
1
  import copy
2
+ import logging
3
  import random
4
+ from collections import deque
5
+ from typing import List, NamedTuple, Optional, TypeVar
6
+
7
+ import numpy as np
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
 
 
11
  from torch.optim import Adam
12
  from torch.utils.tensorboard.writer import SummaryWriter
 
13
 
14
  from rl_algo_impls.dqn.policy import DQNPolicy
15
  from rl_algo_impls.shared.algorithm import Algorithm
16
+ from rl_algo_impls.shared.callbacks import Callback
17
  from rl_algo_impls.shared.schedule import linear_schedule
18
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
19
 
 
119
  self.max_grad_norm = max_grad_norm
120
 
121
  def learn(
122
+ self: DQNSelf, total_timesteps: int, callbacks: Optional[List[Callback]] = None
123
  ) -> DQNSelf:
124
  self.policy.train(True)
125
  obs = self.env.reset()
 
141
  if steps_since_target_update >= self.target_update_interval:
142
  self._update_target()
143
  steps_since_target_update = 0
144
+ if callbacks:
145
+ if not all(
146
+ c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
147
+ ):
148
+ logging.info(
149
+ f"Callback terminated training at {timesteps_elapsed} timesteps"
150
+ )
151
+ break
152
  return self
153
 
154
  def train(self) -> None:
rl_algo_impls/dqn/q_net.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
6
  from gym.spaces import Discrete
7
 
8
  from rl_algo_impls.shared.encoder import Encoder
9
- from rl_algo_impls.shared.module.module import mlp
10
 
11
 
12
  class QNetwork(nn.Module):
 
6
  from gym.spaces import Discrete
7
 
8
  from rl_algo_impls.shared.encoder import Encoder
9
+ from rl_algo_impls.shared.module.utils import mlp
10
 
11
 
12
  class QNetwork(nn.Module):
rl_algo_impls/huggingface_publish.py CHANGED
@@ -3,24 +3,23 @@ import os
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
 
5
  import argparse
6
- import requests
7
  import shutil
8
  import subprocess
9
  import tempfile
10
- import wandb
11
- import wandb.apis.public
12
-
13
  from typing import List, Optional
14
 
 
 
15
  from huggingface_hub.hf_api import HfApi, upload_folder
16
  from huggingface_hub.repocard import metadata_save
17
  from pyvirtualdisplay.display import Display
18
 
 
19
  from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
20
  from rl_algo_impls.runner.config import EnvHyperparams
21
  from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
22
- from rl_algo_impls.shared.vec_env import make_eval_env
23
  from rl_algo_impls.shared.callbacks.eval_callback import evaluate
 
24
  from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
25
 
26
 
@@ -134,7 +133,7 @@ def publish(
134
  make_eval_env(
135
  config,
136
  EnvHyperparams(**config.env_hyperparams),
137
- override_n_envs=1,
138
  normalize_load_path=model_path,
139
  ),
140
  os.path.join(repo_dir_path, "replay"),
@@ -144,7 +143,7 @@ def publish(
144
  video_env,
145
  policy,
146
  1,
147
- deterministic=config.eval_params.get("deterministic", True),
148
  )
149
 
150
  api = HfApi()
 
3
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
 
5
  import argparse
 
6
  import shutil
7
  import subprocess
8
  import tempfile
 
 
 
9
  from typing import List, Optional
10
 
11
+ import requests
12
+ import wandb.apis.public
13
  from huggingface_hub.hf_api import HfApi, upload_folder
14
  from huggingface_hub.repocard import metadata_save
15
  from pyvirtualdisplay.display import Display
16
 
17
+ import wandb
18
  from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
19
  from rl_algo_impls.runner.config import EnvHyperparams
20
  from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
 
21
  from rl_algo_impls.shared.callbacks.eval_callback import evaluate
22
+ from rl_algo_impls.shared.vec_env import make_eval_env
23
  from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
24
 
25
 
 
133
  make_eval_env(
134
  config,
135
  EnvHyperparams(**config.env_hyperparams),
136
+ override_hparams={"n_envs": 1},
137
  normalize_load_path=model_path,
138
  ),
139
  os.path.join(repo_dir_path, "replay"),
 
143
  video_env,
144
  policy,
145
  1,
146
+ deterministic=config.eval_hyperparams.get("deterministic", True),
147
  )
148
 
149
  api = HfApi()
rl_algo_impls/hyperparams/a2c.yml CHANGED
@@ -101,31 +101,32 @@ HopperBulletEnv-v0:
101
  CarRacing-v0:
102
  n_timesteps: !!float 4e6
103
  env_hyperparams:
104
- n_envs: 16
105
  frame_stack: 4
106
  normalize: true
107
  normalize_kwargs:
108
  norm_obs: false
109
  norm_reward: true
110
  policy_hyperparams:
111
- use_sde: false
112
- log_std_init: -1.3502584927786276
113
  init_layers_orthogonal: true
114
  activation_fn: tanh
115
  share_features_extractor: false
116
  cnn_flatten_dim: 256
117
  hidden_sizes: [256]
118
  algo_hyperparams:
119
- n_steps: 16
120
- learning_rate: 0.000025630993245026736
121
- learning_rate_decay: linear
122
- gamma: 0.99957617037542
123
- gae_lambda: 0.949455676599436
124
- ent_coef: !!float 1.707983205298309e-7
125
- vf_coef: 0.10428178193833336
126
- max_grad_norm: 0.5406643389792273
127
- normalize_advantage: true
128
  use_rms_prop: false
 
129
 
130
  _atari: &atari-defaults
131
  n_timesteps: !!float 1e7
 
101
  CarRacing-v0:
102
  n_timesteps: !!float 4e6
103
  env_hyperparams:
104
+ n_envs: 4
105
  frame_stack: 4
106
  normalize: true
107
  normalize_kwargs:
108
  norm_obs: false
109
  norm_reward: true
110
  policy_hyperparams:
111
+ use_sde: true
112
+ log_std_init: -4.839609092563
113
  init_layers_orthogonal: true
114
  activation_fn: tanh
115
  share_features_extractor: false
116
  cnn_flatten_dim: 256
117
  hidden_sizes: [256]
118
  algo_hyperparams:
119
+ n_steps: 64
120
+ learning_rate: 0.000018971962220405576
121
+ gamma: 0.9942776405534832
122
+ gae_lambda: 0.9549244758833236
123
+ ent_coef: 0.0000015666550584860516
124
+ ent_coef_decay: linear
125
+ vf_coef: 0.12164696385898476
126
+ max_grad_norm: 2.2574480552177127
127
+ normalize_advantage: false
128
  use_rms_prop: false
129
+ sde_sample_freq: 16
130
 
131
  _atari: &atari-defaults
132
  n_timesteps: !!float 1e7
rl_algo_impls/hyperparams/dqn.yml CHANGED
@@ -15,7 +15,7 @@ CartPole-v1: &cartpole-defaults
15
  gradient_steps: 128
16
  exploration_fraction: 0.16
17
  exploration_final_eps: 0.04
18
- eval_params:
19
  step_freq: !!float 1e4
20
 
21
  CartPole-v0:
@@ -76,7 +76,7 @@ LunarLander-v2:
76
  exploration_fraction: 0.12
77
  exploration_final_eps: 0.1
78
  max_grad_norm: 0.5
79
- eval_params:
80
  step_freq: 25_000
81
 
82
  _atari: &atari-defaults
@@ -97,7 +97,7 @@ _atari: &atari-defaults
97
  gradient_steps: 2
98
  exploration_fraction: 0.1
99
  exploration_final_eps: 0.01
100
- eval_params:
101
  deterministic: false
102
 
103
  PongNoFrameskip-v4:
 
15
  gradient_steps: 128
16
  exploration_fraction: 0.16
17
  exploration_final_eps: 0.04
18
+ eval_hyperparams:
19
  step_freq: !!float 1e4
20
 
21
  CartPole-v0:
 
76
  exploration_fraction: 0.12
77
  exploration_final_eps: 0.1
78
  max_grad_norm: 0.5
79
+ eval_hyperparams:
80
  step_freq: 25_000
81
 
82
  _atari: &atari-defaults
 
97
  gradient_steps: 2
98
  exploration_fraction: 0.1
99
  exploration_final_eps: 0.01
100
+ eval_hyperparams:
101
  deterministic: false
102
 
103
  PongNoFrameskip-v4:
rl_algo_impls/hyperparams/ppo.yml CHANGED
@@ -13,7 +13,7 @@ CartPole-v1: &cartpole-defaults
13
  learning_rate_decay: linear
14
  clip_range: 0.2
15
  clip_range_decay: linear
16
- eval_params:
17
  step_freq: !!float 2.5e4
18
 
19
  CartPole-v0:
@@ -52,7 +52,7 @@ MountainCarContinuous-v0:
52
  gae_lambda: 0.9
53
  max_grad_norm: 5
54
  vf_coef: 0.19
55
- eval_params:
56
  step_freq: 5000
57
 
58
  Acrobot-v1:
@@ -162,7 +162,7 @@ _atari: &atari-defaults
162
  clip_range_decay: linear
163
  vf_coef: 0.5
164
  ent_coef: 0.01
165
- eval_params:
166
  deterministic: false
167
 
168
  _norm-rewards-atari: &norm-rewards-atari-default
@@ -228,7 +228,7 @@ _microrts: &microrts-defaults
228
  clip_range_decay: none
229
  clip_range_vf: 0.1
230
  ppo2_vf_coef_halving: true
231
- eval_params:
232
  deterministic: false # Good idea because MultiCategorical mode isn't great
233
 
234
  _no-mask-microrts: &no-mask-microrts-defaults
@@ -252,15 +252,15 @@ MicrortsRandomEnemyShapedReward3-v1-NoMask:
252
  _microrts_ai: &microrts-ai-defaults
253
  <<: *microrts-defaults
254
  n_timesteps: !!float 100e6
255
- additional_keys_to_log: ["microrts_stats"]
256
  env_hyperparams: &microrts-ai-env-defaults
257
  n_envs: 24
258
  env_type: microrts
259
- make_kwargs:
260
  num_selfplay_envs: 0
261
- max_steps: 2000
262
  render_theme: 2
263
- map_path: maps/16x16/basesWorkers16x16.xml
264
  reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
265
  policy_hyperparams: &microrts-ai-policy-defaults
266
  <<: *microrts-policy-defaults
@@ -278,6 +278,15 @@ _microrts_ai: &microrts-ai-defaults
278
  max_grad_norm: 0.5
279
  clip_range: 0.1
280
  clip_range_vf: 0.1
 
 
 
 
 
 
 
 
 
281
 
282
  MicrortsAttackPassiveEnemySparseReward-v3:
283
  <<: *microrts-ai-defaults
@@ -305,6 +314,18 @@ enc-dec-MicrortsDefeatRandomEnemySparseReward-v3:
305
  actor_head_style: gridnet_decoder
306
  v_hidden_sizes: [128]
307
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  MicrortsDefeatCoacAIShaped-v3: &microrts-coacai-defaults
309
  <<: *microrts-ai-defaults
310
  env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple
@@ -313,6 +334,27 @@ MicrortsDefeatCoacAIShaped-v3: &microrts-coacai-defaults
313
  <<: *microrts-ai-env-defaults
314
  bots:
315
  coacAI: 24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  MicrortsDefeatCoacAIShaped-v3-diverseBots: &microrts-diverse-defaults
318
  <<: *microrts-coacai-defaults
@@ -325,6 +367,7 @@ MicrortsDefeatCoacAIShaped-v3-diverseBots: &microrts-diverse-defaults
325
  workerRushAI: 2
326
 
327
  enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
 
328
  <<: *microrts-diverse-defaults
329
  policy_hyperparams:
330
  <<: *microrts-ai-policy-defaults
@@ -332,6 +375,76 @@ enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
332
  actor_head_style: gridnet_decoder
333
  v_hidden_sizes: [128]
334
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
335
  HalfCheetahBulletEnv-v0: &pybullet-defaults
336
  n_timesteps: !!float 2e6
337
  env_hyperparams: &pybullet-env-defaults
@@ -418,7 +531,7 @@ _procgen: &procgen-defaults
418
  learning_rate: !!float 5e-4
419
  # learning_rate_decay: linear
420
  vf_coef: 0.5
421
- eval_params: &procgen-eval-defaults
422
  ignore_first_episode: true
423
  # deterministic: false
424
  step_freq: !!float 1e5
@@ -466,7 +579,7 @@ _procgen-hard: &procgen-hard-defaults
466
  batch_size: 8192
467
  clip_range_decay: linear
468
  learning_rate_decay: linear
469
- eval_params:
470
  <<: *procgen-eval-defaults
471
  step_freq: !!float 5e5
472
 
 
13
  learning_rate_decay: linear
14
  clip_range: 0.2
15
  clip_range_decay: linear
16
+ eval_hyperparams:
17
  step_freq: !!float 2.5e4
18
 
19
  CartPole-v0:
 
52
  gae_lambda: 0.9
53
  max_grad_norm: 5
54
  vf_coef: 0.19
55
+ eval_hyperparams:
56
  step_freq: 5000
57
 
58
  Acrobot-v1:
 
162
  clip_range_decay: linear
163
  vf_coef: 0.5
164
  ent_coef: 0.01
165
+ eval_hyperparams:
166
  deterministic: false
167
 
168
  _norm-rewards-atari: &norm-rewards-atari-default
 
228
  clip_range_decay: none
229
  clip_range_vf: 0.1
230
  ppo2_vf_coef_halving: true
231
+ eval_hyperparams: &microrts-eval-defaults
232
  deterministic: false # Good idea because MultiCategorical mode isn't great
233
 
234
  _no-mask-microrts: &no-mask-microrts-defaults
 
252
  _microrts_ai: &microrts-ai-defaults
253
  <<: *microrts-defaults
254
  n_timesteps: !!float 100e6
255
+ additional_keys_to_log: ["microrts_stats", "microrts_results"]
256
  env_hyperparams: &microrts-ai-env-defaults
257
  n_envs: 24
258
  env_type: microrts
259
+ make_kwargs: &microrts-ai-env-make-kwargs-defaults
260
  num_selfplay_envs: 0
261
+ max_steps: 4000
262
  render_theme: 2
263
+ map_paths: [maps/16x16/basesWorkers16x16.xml]
264
  reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
265
  policy_hyperparams: &microrts-ai-policy-defaults
266
  <<: *microrts-policy-defaults
 
278
  max_grad_norm: 0.5
279
  clip_range: 0.1
280
  clip_range_vf: 0.1
281
+ eval_hyperparams: &microrts-ai-eval-defaults
282
+ <<: *microrts-eval-defaults
283
+ score_function: mean
284
+ max_video_length: 4000
285
+ env_overrides: &microrts-ai-eval-env-overrides
286
+ make_kwargs:
287
+ <<: *microrts-ai-env-make-kwargs-defaults
288
+ max_steps: 4000
289
+ reward_weight: [1.0, 0, 0, 0, 0, 0]
290
 
291
  MicrortsAttackPassiveEnemySparseReward-v3:
292
  <<: *microrts-ai-defaults
 
314
  actor_head_style: gridnet_decoder
315
  v_hidden_sizes: [128]
316
 
317
+ unet-MicrortsDefeatRandomEnemySparseReward-v3:
318
+ <<: *microrts-random-ai-defaults
319
+ # device: cpu
320
+ policy_hyperparams:
321
+ <<: *microrts-ai-policy-defaults
322
+ actor_head_style: unet
323
+ v_hidden_sizes: [256, 128]
324
+ algo_hyperparams:
325
+ <<: *microrts-ai-algo-defaults
326
+ learning_rate: !!float 2.5e-4
327
+ learning_rate_decay: spike
328
+
329
  MicrortsDefeatCoacAIShaped-v3: &microrts-coacai-defaults
330
  <<: *microrts-ai-defaults
331
  env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple
 
334
  <<: *microrts-ai-env-defaults
335
  bots:
336
  coacAI: 24
337
+ eval_hyperparams: &microrts-coacai-eval-defaults
338
+ <<: *microrts-ai-eval-defaults
339
+ step_freq: !!float 1e6
340
+ n_episodes: 26
341
+ env_overrides: &microrts-coacai-eval-env-overrides
342
+ <<: *microrts-ai-eval-env-overrides
343
+ n_envs: 26
344
+ bots:
345
+ coacAI: 2
346
+ randomBiasedAI: 2
347
+ randomAI: 2
348
+ passiveAI: 2
349
+ workerRushAI: 2
350
+ lightRushAI: 2
351
+ naiveMCTSAI: 2
352
+ mixedBot: 2
353
+ rojo: 2
354
+ izanagi: 2
355
+ tiamat: 2
356
+ droplet: 2
357
+ guidedRojoA3N: 2
358
 
359
  MicrortsDefeatCoacAIShaped-v3-diverseBots: &microrts-diverse-defaults
360
  <<: *microrts-coacai-defaults
 
367
  workerRushAI: 2
368
 
369
  enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
370
+ &microrts-env-dec-diverse-defaults
371
  <<: *microrts-diverse-defaults
372
  policy_hyperparams:
373
  <<: *microrts-ai-policy-defaults
 
375
  actor_head_style: gridnet_decoder
376
  v_hidden_sizes: [128]
377
 
378
+ debug-enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
379
+ <<: *microrts-env-dec-diverse-defaults
380
+ n_timesteps: !!float 1e6
381
+
382
+ unet-MicrortsDefeatCoacAIShaped-v3-diverseBots: &microrts-unet-defaults
383
+ <<: *microrts-diverse-defaults
384
+ policy_hyperparams:
385
+ <<: *microrts-ai-policy-defaults
386
+ actor_head_style: unet
387
+ v_hidden_sizes: [256, 128]
388
+ algo_hyperparams: &microrts-unet-algo-defaults
389
+ <<: *microrts-ai-algo-defaults
390
+ learning_rate: !!float 2.5e-4
391
+ learning_rate_decay: spike
392
+
393
+ Microrts-selfplay-unet: &microrts-selfplay-defaults
394
+ <<: *microrts-unet-defaults
395
+ env_hyperparams: &microrts-selfplay-env-defaults
396
+ <<: *microrts-ai-env-defaults
397
+ make_kwargs: &microrts-selfplay-env-make-kwargs-defaults
398
+ <<: *microrts-ai-env-make-kwargs-defaults
399
+ num_selfplay_envs: 36
400
+ self_play_kwargs:
401
+ num_old_policies: 12
402
+ save_steps: 300000
403
+ swap_steps: 6000
404
+ swap_window_size: 4
405
+ window: 33
406
+ eval_hyperparams: &microrts-selfplay-eval-defaults
407
+ <<: *microrts-coacai-eval-defaults
408
+ env_overrides: &microrts-selfplay-eval-env-overrides
409
+ <<: *microrts-coacai-eval-env-overrides
410
+ self_play_kwargs: {}
411
+
412
+ Microrts-selfplay-unet-winloss: &microrts-selfplay-winloss-defaults
413
+ <<: *microrts-selfplay-defaults
414
+ env_hyperparams:
415
+ <<: *microrts-selfplay-env-defaults
416
+ make_kwargs:
417
+ <<: *microrts-selfplay-env-make-kwargs-defaults
418
+ reward_weight: [1.0, 0, 0, 0, 0, 0]
419
+ algo_hyperparams: &microrts-selfplay-winloss-algo-defaults
420
+ <<: *microrts-unet-algo-defaults
421
+ gamma: 0.999
422
+
423
+ Microrts-selfplay-unet-decay: &microrts-selfplay-decay-defaults
424
+ <<: *microrts-selfplay-defaults
425
+ microrts_reward_decay_callback: true
426
+ algo_hyperparams:
427
+ <<: *microrts-unet-algo-defaults
428
+ gamma_end: 0.999
429
+
430
+ Microrts-selfplay-unet-debug: &microrts-selfplay-debug-defaults
431
+ <<: *microrts-selfplay-decay-defaults
432
+ eval_hyperparams:
433
+ <<: *microrts-selfplay-eval-defaults
434
+ step_freq: !!float 1e5
435
+ env_overrides:
436
+ <<: *microrts-selfplay-eval-env-overrides
437
+ n_envs: 24
438
+ bots:
439
+ coacAI: 12
440
+ randomBiasedAI: 4
441
+ workerRushAI: 4
442
+ lightRushAI: 4
443
+
444
+ Microrts-selfplay-unet-debug-mps:
445
+ <<: *microrts-selfplay-debug-defaults
446
+ device: mps
447
+
448
  HalfCheetahBulletEnv-v0: &pybullet-defaults
449
  n_timesteps: !!float 2e6
450
  env_hyperparams: &pybullet-env-defaults
 
531
  learning_rate: !!float 5e-4
532
  # learning_rate_decay: linear
533
  vf_coef: 0.5
534
+ eval_hyperparams: &procgen-eval-defaults
535
  ignore_first_episode: true
536
  # deterministic: false
537
  step_freq: !!float 1e5
 
579
  batch_size: 8192
580
  clip_range_decay: linear
581
  learning_rate_decay: linear
582
+ eval_hyperparams:
583
  <<: *procgen-eval-defaults
584
  step_freq: !!float 5e5
585
 
rl_algo_impls/hyperparams/vpg.yml CHANGED
@@ -7,7 +7,7 @@ CartPole-v1: &cartpole-defaults
7
  gae_lambda: 1
8
  val_lr: 0.01
9
  train_v_iters: 80
10
- eval_params:
11
  step_freq: !!float 2.5e4
12
 
13
  CartPole-v0:
@@ -52,7 +52,7 @@ MountainCarContinuous-v0:
52
  val_lr: !!float 1e-3
53
  train_v_iters: 80
54
  max_grad_norm: 5
55
- eval_params:
56
  step_freq: 5000
57
 
58
  Acrobot-v1:
@@ -78,7 +78,7 @@ LunarLander-v2:
78
  val_lr: 0.0001
79
  train_v_iters: 80
80
  max_grad_norm: 0.5
81
- eval_params:
82
  deterministic: false
83
 
84
  BipedalWalker-v3:
@@ -96,7 +96,7 @@ BipedalWalker-v3:
96
  val_lr: !!float 1e-4
97
  train_v_iters: 80
98
  max_grad_norm: 0.5
99
- eval_params:
100
  deterministic: false
101
 
102
  CarRacing-v0:
@@ -169,7 +169,7 @@ FrozenLake-v1:
169
  val_lr: 0.01
170
  train_v_iters: 80
171
  max_grad_norm: 0.5
172
- eval_params:
173
  step_freq: !!float 5e4
174
  n_episodes: 10
175
  save_best: true
@@ -193,5 +193,5 @@ _atari: &atari-defaults
193
  train_v_iters: 80
194
  max_grad_norm: 0.5
195
  ent_coef: 0.01
196
- eval_params:
197
  deterministic: false
 
7
  gae_lambda: 1
8
  val_lr: 0.01
9
  train_v_iters: 80
10
+ eval_hyperparams:
11
  step_freq: !!float 2.5e4
12
 
13
  CartPole-v0:
 
52
  val_lr: !!float 1e-3
53
  train_v_iters: 80
54
  max_grad_norm: 5
55
+ eval_hyperparams:
56
  step_freq: 5000
57
 
58
  Acrobot-v1:
 
78
  val_lr: 0.0001
79
  train_v_iters: 80
80
  max_grad_norm: 0.5
81
+ eval_hyperparams:
82
  deterministic: false
83
 
84
  BipedalWalker-v3:
 
96
  val_lr: !!float 1e-4
97
  train_v_iters: 80
98
  max_grad_norm: 0.5
99
+ eval_hyperparams:
100
  deterministic: false
101
 
102
  CarRacing-v0:
 
169
  val_lr: 0.01
170
  train_v_iters: 80
171
  max_grad_norm: 0.5
172
+ eval_hyperparams:
173
  step_freq: !!float 5e4
174
  n_episodes: 10
175
  save_best: true
 
193
  train_v_iters: 80
194
  max_grad_norm: 0.5
195
  ent_coef: 0.01
196
+ eval_hyperparams:
197
  deterministic: false
rl_algo_impls/optimize.py CHANGED
@@ -2,37 +2,44 @@ import dataclasses
2
  import gc
3
  import inspect
4
  import logging
 
 
 
 
5
  import numpy as np
6
  import optuna
7
- import os
8
  import torch
9
- import wandb
10
-
11
- from dataclasses import asdict, dataclass
12
  from optuna.pruners import HyperbandPruner
13
  from optuna.samplers import TPESampler
14
  from optuna.visualization import plot_optimization_history, plot_param_importances
15
  from torch.utils.tensorboard.writer import SummaryWriter
16
- from typing import Callable, List, NamedTuple, Optional, Sequence, Union
17
 
 
18
  from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
19
  from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
20
- from rl_algo_impls.shared.vec_env import make_env, make_eval_env
21
  from rl_algo_impls.runner.running_utils import (
 
22
  base_parser,
23
- load_hyperparams,
24
- set_seeds,
25
  get_device,
26
- make_policy,
27
- ALGOS,
28
  hparam_dict,
 
 
 
 
 
 
 
29
  )
30
  from rl_algo_impls.shared.callbacks.optimize_callback import (
31
  Evaluation,
32
  OptimizeCallback,
33
  evaluation,
34
  )
 
35
  from rl_algo_impls.shared.stats import EpisodesStats
 
 
 
36
 
37
 
38
  @dataclass
@@ -195,29 +202,38 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
195
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
196
  )
197
  device = get_device(config, env)
198
- policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
 
 
 
199
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
200
 
201
  eval_env = make_eval_env(
202
  config,
203
  EnvHyperparams(**config.env_hyperparams),
204
- override_n_envs=study_args.n_eval_envs,
205
  )
206
- callback = OptimizeCallback(
207
  policy,
208
  eval_env,
209
  trial,
210
  tb_writer,
211
  step_freq=config.n_timesteps // study_args.n_evaluations,
212
  n_episodes=study_args.n_eval_episodes,
213
- deterministic=config.eval_params.get("deterministic", True),
214
  )
 
 
 
 
 
 
215
  try:
216
- algo.learn(config.n_timesteps, callback=callback)
217
 
218
- if not callback.is_pruned:
219
- callback.evaluate()
220
- if not callback.is_pruned:
221
  policy.save(config.model_dir_path(best=False))
222
 
223
  eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore
@@ -230,8 +246,8 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
230
  "hparam/last_result": eval_stat.score.mean - eval_stat.score.std,
231
  "hparam/train_mean": train_stat.score.mean,
232
  "hparam/train_result": train_stat.score.mean - train_stat.score.std,
233
- "hparam/score": callback.last_score,
234
- "hparam/is_pruned": callback.is_pruned,
235
  },
236
  None,
237
  config.run_name(),
@@ -239,13 +255,15 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
239
  tb_writer.close()
240
 
241
  if wandb_enabled:
242
- wandb.run.summary["state"] = "Pruned" if callback.is_pruned else "Complete"
 
 
243
  wandb.finish(quiet=True)
244
 
245
- if callback.is_pruned:
246
  raise optuna.exceptions.TrialPruned()
247
 
248
- return callback.last_score
249
  except AssertionError as e:
250
  logging.warning(e)
251
  return np.nan
@@ -299,7 +317,10 @@ def stepwise_optimize(
299
  tb_writer=tb_writer,
300
  )
301
  device = get_device(config, env)
302
- policy = make_policy(arg.algo, env, device, **config.policy_hyperparams)
 
 
 
303
  if i > 0:
304
  policy.load(config.model_dir_path())
305
  algo = ALGOS[arg.algo](
@@ -310,7 +331,7 @@ def stepwise_optimize(
310
  config,
311
  EnvHyperparams(**config.env_hyperparams),
312
  normalize_load_path=config.model_dir_path() if i > 0 else None,
313
- override_n_envs=study_args.n_eval_envs,
314
  )
315
 
316
  start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
@@ -319,10 +340,22 @@ def stepwise_optimize(
319
  - start_timesteps
320
  )
321
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  try:
323
  algo.learn(
324
  train_timesteps,
325
- callback=None,
326
  total_timesteps=config.n_timesteps,
327
  start_timesteps=start_timesteps,
328
  )
@@ -333,7 +366,7 @@ def stepwise_optimize(
333
  eval_env,
334
  tb_writer,
335
  study_args.n_eval_episodes,
336
- config.eval_params.get("deterministic", True),
337
  start_timesteps + train_timesteps,
338
  )
339
  )
@@ -379,7 +412,7 @@ def stepwise_optimize(
379
 
380
 
381
  def wandb_finish(state: str) -> None:
382
- wandb.run.summary["state"] = state
383
  wandb.finish(quiet=True)
384
 
385
 
 
2
  import gc
3
  import inspect
4
  import logging
5
+ import os
6
+ from dataclasses import asdict, dataclass
7
+ from typing import Callable, List, NamedTuple, Optional, Sequence, Union
8
+
9
  import numpy as np
10
  import optuna
 
11
  import torch
 
 
 
12
  from optuna.pruners import HyperbandPruner
13
  from optuna.samplers import TPESampler
14
  from optuna.visualization import plot_optimization_history, plot_param_importances
15
  from torch.utils.tensorboard.writer import SummaryWriter
 
16
 
17
+ import wandb
18
  from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
19
  from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
 
20
  from rl_algo_impls.runner.running_utils import (
21
+ ALGOS,
22
  base_parser,
 
 
23
  get_device,
 
 
24
  hparam_dict,
25
+ load_hyperparams,
26
+ make_policy,
27
+ set_seeds,
28
+ )
29
+ from rl_algo_impls.shared.callbacks import Callback
30
+ from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
31
+ MicrortsRewardDecayCallback,
32
  )
33
  from rl_algo_impls.shared.callbacks.optimize_callback import (
34
  Evaluation,
35
  OptimizeCallback,
36
  evaluation,
37
  )
38
+ from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
39
  from rl_algo_impls.shared.stats import EpisodesStats
40
+ from rl_algo_impls.shared.vec_env import make_env, make_eval_env
41
+ from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
42
+ from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper
43
 
44
 
45
  @dataclass
 
202
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
203
  )
204
  device = get_device(config, env)
205
+ policy_factory = lambda: make_policy(
206
+ args.algo, env, device, **config.policy_hyperparams
207
+ )
208
+ policy = policy_factory()
209
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
210
 
211
  eval_env = make_eval_env(
212
  config,
213
  EnvHyperparams(**config.env_hyperparams),
214
+ override_hparams={"n_envs": study_args.n_eval_envs},
215
  )
216
+ optimize_callback = OptimizeCallback(
217
  policy,
218
  eval_env,
219
  trial,
220
  tb_writer,
221
  step_freq=config.n_timesteps // study_args.n_evaluations,
222
  n_episodes=study_args.n_eval_episodes,
223
+ deterministic=config.eval_hyperparams.get("deterministic", True),
224
  )
225
+ callbacks: List[Callback] = [optimize_callback]
226
+ if config.hyperparams.microrts_reward_decay_callback:
227
+ callbacks.append(MicrortsRewardDecayCallback(config, env))
228
+ selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
229
+ if selfPlayWrapper:
230
+ callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
231
  try:
232
+ algo.learn(config.n_timesteps, callbacks=callbacks)
233
 
234
+ if not optimize_callback.is_pruned:
235
+ optimize_callback.evaluate()
236
+ if not optimize_callback.is_pruned:
237
  policy.save(config.model_dir_path(best=False))
238
 
239
  eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore
 
246
  "hparam/last_result": eval_stat.score.mean - eval_stat.score.std,
247
  "hparam/train_mean": train_stat.score.mean,
248
  "hparam/train_result": train_stat.score.mean - train_stat.score.std,
249
+ "hparam/score": optimize_callback.last_score,
250
+ "hparam/is_pruned": optimize_callback.is_pruned,
251
  },
252
  None,
253
  config.run_name(),
 
255
  tb_writer.close()
256
 
257
  if wandb_enabled:
258
+ wandb.run.summary["state"] = ( # type: ignore
259
+ "Pruned" if optimize_callback.is_pruned else "Complete"
260
+ )
261
  wandb.finish(quiet=True)
262
 
263
+ if optimize_callback.is_pruned:
264
  raise optuna.exceptions.TrialPruned()
265
 
266
+ return optimize_callback.last_score
267
  except AssertionError as e:
268
  logging.warning(e)
269
  return np.nan
 
317
  tb_writer=tb_writer,
318
  )
319
  device = get_device(config, env)
320
+ policy_factory = lambda: make_policy(
321
+ arg.algo, env, device, **config.policy_hyperparams
322
+ )
323
+ policy = policy_factory()
324
  if i > 0:
325
  policy.load(config.model_dir_path())
326
  algo = ALGOS[arg.algo](
 
331
  config,
332
  EnvHyperparams(**config.env_hyperparams),
333
  normalize_load_path=config.model_dir_path() if i > 0 else None,
334
+ override_hparams={"n_envs": study_args.n_eval_envs},
335
  )
336
 
337
  start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
 
340
  - start_timesteps
341
  )
342
 
343
+ callbacks = []
344
+ if config.hyperparams.microrts_reward_decay_callback:
345
+ callbacks.append(
346
+ MicrortsRewardDecayCallback(
347
+ config, env, start_timesteps=start_timesteps
348
+ )
349
+ )
350
+ selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
351
+ if selfPlayWrapper:
352
+ callbacks.append(
353
+ SelfPlayCallback(policy, policy_factory, selfPlayWrapper)
354
+ )
355
  try:
356
  algo.learn(
357
  train_timesteps,
358
+ callbacks=callbacks,
359
  total_timesteps=config.n_timesteps,
360
  start_timesteps=start_timesteps,
361
  )
 
366
  eval_env,
367
  tb_writer,
368
  study_args.n_eval_episodes,
369
+ config.eval_hyperparams.get("deterministic", True),
370
  start_timesteps + train_timesteps,
371
  )
372
  )
 
412
 
413
 
414
  def wandb_finish(state: str) -> None:
415
+ wandb.run.summary["state"] = state # type: ignore
416
  wandb.finish(quiet=True)
417
 
418
 
rl_algo_impls/ppo/ppo.py CHANGED
@@ -10,12 +10,16 @@ from torch.optim import Adam
10
  from torch.utils.tensorboard.writer import SummaryWriter
11
 
12
  from rl_algo_impls.shared.algorithm import Algorithm
13
- from rl_algo_impls.shared.callbacks.callback import Callback
14
  from rl_algo_impls.shared.gae import compute_advantages
15
- from rl_algo_impls.shared.policy.on_policy import ActorCritic
16
- from rl_algo_impls.shared.schedule import schedule, update_learning_rate
 
 
 
 
 
17
  from rl_algo_impls.shared.stats import log_scalars
18
- from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
19
  from rl_algo_impls.wrappers.vectorable_wrapper import (
20
  VecEnv,
21
  single_action_space,
@@ -102,12 +106,17 @@ class PPO(Algorithm):
102
  sde_sample_freq: int = -1,
103
  update_advantage_between_epochs: bool = True,
104
  update_returns_between_epochs: bool = False,
 
105
  ) -> None:
106
  super().__init__(policy, env, device, tb_writer)
107
  self.policy = policy
108
- self.action_masker = find_action_masker(env)
109
 
110
- self.gamma = gamma
 
 
 
 
111
  self.gae_lambda = gae_lambda
112
  self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
113
  self.lr_schedule = schedule(learning_rate_decay, learning_rate)
@@ -138,7 +147,7 @@ class PPO(Algorithm):
138
  def learn(
139
  self: PPOSelf,
140
  train_timesteps: int,
141
- callback: Optional[Callback] = None,
142
  total_timesteps: Optional[int] = None,
143
  start_timesteps: int = 0,
144
  ) -> PPOSelf:
@@ -153,15 +162,13 @@ class PPO(Algorithm):
153
  act_shape = self.policy.action_shape
154
 
155
  next_obs = self.env.reset()
156
- next_action_masks = (
157
- self.action_masker.action_masks() if self.action_masker else None
158
- )
159
- next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
160
 
161
  obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
162
  actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore
163
  rewards = np.zeros(epoch_dim, dtype=np.float32)
164
- episode_starts = np.zeros(epoch_dim, dtype=np.bool8)
165
  values = np.zeros(epoch_dim, dtype=np.float32)
166
  logprobs = np.zeros(epoch_dim, dtype=np.float32)
167
  action_masks = (
@@ -181,10 +188,12 @@ class PPO(Algorithm):
181
  learning_rate = self.lr_schedule(progress)
182
  update_learning_rate(self.optimizer, learning_rate)
183
  pi_clip = self.clip_range_schedule(progress)
 
184
  chart_scalars = {
185
  "learning_rate": self.optimizer.param_groups[0]["lr"],
186
  "ent_coef": ent_coef,
187
  "pi_clip": pi_clip,
 
188
  }
189
  if self.clip_range_vf_schedule:
190
  v_clip = self.clip_range_vf_schedule(progress)
@@ -215,7 +224,7 @@ class PPO(Algorithm):
215
  clamped_action
216
  )
217
  next_action_masks = (
218
- self.action_masker.action_masks() if self.action_masker else None
219
  )
220
 
221
  self.policy.train()
@@ -251,7 +260,7 @@ class PPO(Algorithm):
251
  next_episode_starts,
252
  next_obs,
253
  self.policy,
254
- self.gamma,
255
  self.gae_lambda,
256
  )
257
  b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
@@ -364,8 +373,10 @@ class PPO(Algorithm):
364
  timesteps_elapsed,
365
  )
366
 
367
- if callback:
368
- if not callback.on_step(timesteps_elapsed=rollout_steps):
 
 
369
  logging.info(
370
  f"Callback terminated training at {timesteps_elapsed} timesteps"
371
  )
 
10
  from torch.utils.tensorboard.writer import SummaryWriter
11
 
12
  from rl_algo_impls.shared.algorithm import Algorithm
13
+ from rl_algo_impls.shared.callbacks import Callback
14
  from rl_algo_impls.shared.gae import compute_advantages
15
+ from rl_algo_impls.shared.policy.actor_critic import ActorCritic
16
+ from rl_algo_impls.shared.schedule import (
17
+ constant_schedule,
18
+ linear_schedule,
19
+ schedule,
20
+ update_learning_rate,
21
+ )
22
  from rl_algo_impls.shared.stats import log_scalars
 
23
  from rl_algo_impls.wrappers.vectorable_wrapper import (
24
  VecEnv,
25
  single_action_space,
 
106
  sde_sample_freq: int = -1,
107
  update_advantage_between_epochs: bool = True,
108
  update_returns_between_epochs: bool = False,
109
+ gamma_end: Optional[float] = None,
110
  ) -> None:
111
  super().__init__(policy, env, device, tb_writer)
112
  self.policy = policy
113
+ self.get_action_mask = getattr(env, "get_action_mask", None)
114
 
115
+ self.gamma_schedule = (
116
+ linear_schedule(gamma, gamma_end)
117
+ if gamma_end is not None
118
+ else constant_schedule(gamma)
119
+ )
120
  self.gae_lambda = gae_lambda
121
  self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
122
  self.lr_schedule = schedule(learning_rate_decay, learning_rate)
 
147
  def learn(
148
  self: PPOSelf,
149
  train_timesteps: int,
150
+ callbacks: Optional[List[Callback]] = None,
151
  total_timesteps: Optional[int] = None,
152
  start_timesteps: int = 0,
153
  ) -> PPOSelf:
 
162
  act_shape = self.policy.action_shape
163
 
164
  next_obs = self.env.reset()
165
+ next_action_masks = self.get_action_mask() if self.get_action_mask else None
166
+ next_episode_starts = np.full(step_dim, True, dtype=np.bool_)
 
 
167
 
168
  obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
169
  actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore
170
  rewards = np.zeros(epoch_dim, dtype=np.float32)
171
+ episode_starts = np.zeros(epoch_dim, dtype=np.bool_)
172
  values = np.zeros(epoch_dim, dtype=np.float32)
173
  logprobs = np.zeros(epoch_dim, dtype=np.float32)
174
  action_masks = (
 
188
  learning_rate = self.lr_schedule(progress)
189
  update_learning_rate(self.optimizer, learning_rate)
190
  pi_clip = self.clip_range_schedule(progress)
191
+ gamma = self.gamma_schedule(progress)
192
  chart_scalars = {
193
  "learning_rate": self.optimizer.param_groups[0]["lr"],
194
  "ent_coef": ent_coef,
195
  "pi_clip": pi_clip,
196
+ "gamma": gamma,
197
  }
198
  if self.clip_range_vf_schedule:
199
  v_clip = self.clip_range_vf_schedule(progress)
 
224
  clamped_action
225
  )
226
  next_action_masks = (
227
+ self.get_action_mask() if self.get_action_mask else None
228
  )
229
 
230
  self.policy.train()
 
260
  next_episode_starts,
261
  next_obs,
262
  self.policy,
263
+ gamma,
264
  self.gae_lambda,
265
  )
266
  b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
 
373
  timesteps_elapsed,
374
  )
375
 
376
+ if callbacks:
377
+ if not all(
378
+ c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
379
+ ):
380
  logging.info(
381
  f"Callback terminated training at {timesteps_elapsed} timesteps"
382
  )
rl_algo_impls/runner/config.py CHANGED
@@ -51,6 +51,8 @@ class EnvHyperparams:
51
  normalize_type: Optional[str] = None
52
  mask_actions: bool = False
53
  bots: Optional[Dict[str, int]] = None
 
 
54
 
55
 
56
  HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
@@ -63,9 +65,10 @@ class Hyperparams:
63
  env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
64
  policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
65
  algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
66
- eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict)
67
  env_id: Optional[str] = None
68
  additional_keys_to_log: List[str] = dataclasses.field(default_factory=list)
 
69
 
70
  @classmethod
71
  def from_dict_with_extra_fields(
@@ -110,8 +113,14 @@ class Config:
110
  return self.hyperparams.algo_hyperparams
111
 
112
  @property
113
- def eval_params(self) -> Dict[str, Any]:
114
- return self.hyperparams.eval_params
 
 
 
 
 
 
115
 
116
  @property
117
  def algo(self) -> str:
 
51
  normalize_type: Optional[str] = None
52
  mask_actions: bool = False
53
  bots: Optional[Dict[str, int]] = None
54
+ self_play_kwargs: Optional[Dict[str, Any]] = None
55
+ selfplay_bots: Optional[Dict[str, int]] = None
56
 
57
 
58
  HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
 
65
  env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
66
  policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
67
  algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
68
+ eval_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
69
  env_id: Optional[str] = None
70
  additional_keys_to_log: List[str] = dataclasses.field(default_factory=list)
71
+ microrts_reward_decay_callback: bool = False
72
 
73
  @classmethod
74
  def from_dict_with_extra_fields(
 
113
  return self.hyperparams.algo_hyperparams
114
 
115
  @property
116
+ def eval_hyperparams(self) -> Dict[str, Any]:
117
+ return self.hyperparams.eval_hyperparams
118
+
119
+ def eval_callback_params(self) -> Dict[str, Any]:
120
+ eval_hyperparams = self.eval_hyperparams.copy()
121
+ if "env_overrides" in eval_hyperparams:
122
+ del eval_hyperparams["env_overrides"]
123
+ return eval_hyperparams
124
 
125
  @property
126
  def algo(self) -> str:
rl_algo_impls/runner/evaluate.py CHANGED
@@ -1,20 +1,19 @@
1
  import os
2
  import shutil
3
-
4
  from dataclasses import dataclass
5
  from typing import NamedTuple, Optional
6
 
7
- from rl_algo_impls.shared.vec_env import make_eval_env
8
  from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
9
  from rl_algo_impls.runner.running_utils import (
10
- load_hyperparams,
11
- set_seeds,
12
  get_device,
 
13
  make_policy,
 
14
  )
15
  from rl_algo_impls.shared.callbacks.eval_callback import evaluate
16
  from rl_algo_impls.shared.policy.policy import Policy
17
  from rl_algo_impls.shared.stats import EpisodesStats
 
18
 
19
 
20
  @dataclass
@@ -71,7 +70,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
71
  env = make_eval_env(
72
  config,
73
  EnvHyperparams(**config.env_hyperparams),
74
- override_n_envs=args.n_envs,
75
  render=args.render,
76
  normalize_load_path=model_path,
77
  )
@@ -87,7 +86,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
87
  deterministic = (
88
  args.deterministic_eval
89
  if args.deterministic_eval is not None
90
- else config.eval_params.get("deterministic", True)
91
  )
92
  return Evaluation(
93
  policy,
 
1
  import os
2
  import shutil
 
3
  from dataclasses import dataclass
4
  from typing import NamedTuple, Optional
5
 
 
6
  from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
7
  from rl_algo_impls.runner.running_utils import (
 
 
8
  get_device,
9
+ load_hyperparams,
10
  make_policy,
11
+ set_seeds,
12
  )
13
  from rl_algo_impls.shared.callbacks.eval_callback import evaluate
14
  from rl_algo_impls.shared.policy.policy import Policy
15
  from rl_algo_impls.shared.stats import EpisodesStats
16
+ from rl_algo_impls.shared.vec_env import make_eval_env
17
 
18
 
19
  @dataclass
 
70
  env = make_eval_env(
71
  config,
72
  EnvHyperparams(**config.env_hyperparams),
73
+ override_hparams={"n_envs": args.n_envs} if args.n_envs else None,
74
  render=args.render,
75
  normalize_load_path=model_path,
76
  )
 
86
  deterministic = (
87
  args.deterministic_eval
88
  if args.deterministic_eval is not None
89
+ else config.eval_hyperparams.get("deterministic", True)
90
  )
91
  return Evaluation(
92
  policy,
rl_algo_impls/runner/running_utils.py CHANGED
@@ -22,7 +22,7 @@ from rl_algo_impls.ppo.ppo import PPO
22
  from rl_algo_impls.runner.config import Config, Hyperparams
23
  from rl_algo_impls.shared.algorithm import Algorithm
24
  from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
25
- from rl_algo_impls.shared.policy.on_policy import ActorCritic
26
  from rl_algo_impls.shared.policy.policy import Policy
27
  from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts
28
  from rl_algo_impls.vpg.policy import VPGActorCritic
@@ -97,29 +97,21 @@ def get_device(config: Config, env: VecEnv) -> torch.device:
97
  # cuda by default
98
  if device == "auto":
99
  device = "cuda"
100
- # Apple MPS is a second choice (sometimes)
101
- if device == "cuda" and not torch.cuda.is_available():
102
- device = "mps"
103
- # If no MPS, fallback to cpu
104
- if device == "mps" and not torch.backends.mps.is_available():
105
- device = "cpu"
106
- # Simple environments like Discreet and 1-D Boxes might also be better
107
- # served with the CPU.
108
- if device == "mps":
109
- obs_space = single_observation_space(env)
110
- if isinstance(obs_space, Discrete):
111
  device = "cpu"
112
- elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
113
- device = "cpu"
114
- if is_microrts(config):
115
- try:
116
- from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
117
-
118
- # Models that move more than one unit at a time should use mps
119
- if not isinstance(env.unwrapped, MicroRTSGridModeVecEnv):
120
- device = "cpu"
121
- except ModuleNotFoundError:
122
- # Likely on gym_microrts v0.0.2 to match ppo-implementation-details
123
  device = "cpu"
124
  print(f"Device: {device}")
125
  return torch.device(device)
 
22
  from rl_algo_impls.runner.config import Config, Hyperparams
23
  from rl_algo_impls.shared.algorithm import Algorithm
24
  from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
25
+ from rl_algo_impls.shared.policy.actor_critic import ActorCritic
26
  from rl_algo_impls.shared.policy.policy import Policy
27
  from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts
28
  from rl_algo_impls.vpg.policy import VPGActorCritic
 
97
  # cuda by default
98
  if device == "auto":
99
  device = "cuda"
100
+ # Apple MPS is a second choice (sometimes)
101
+ if device == "cuda" and not torch.cuda.is_available():
102
+ device = "mps"
103
+ # If no MPS, fallback to cpu
104
+ if device == "mps" and not torch.backends.mps.is_available():
 
 
 
 
 
 
105
  device = "cpu"
106
+ # Simple environments like Discreet and 1-D Boxes might also be better
107
+ # served with the CPU.
108
+ if device == "mps":
109
+ obs_space = single_observation_space(env)
110
+ if isinstance(obs_space, Discrete):
111
+ device = "cpu"
112
+ elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
113
+ device = "cpu"
114
+ if is_microrts(config):
 
 
115
  device = "cpu"
116
  print(f"Device: {device}")
117
  return torch.device(device)
rl_algo_impls/runner/selfplay_evaluate.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import dataclasses
3
+ import os
4
+ import shutil
5
+ from dataclasses import dataclass
6
+ from typing import List, NamedTuple, Optional
7
+
8
+ import numpy as np
9
+
10
+ import wandb
11
+ from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
12
+ from rl_algo_impls.runner.evaluate import Evaluation
13
+ from rl_algo_impls.runner.running_utils import (
14
+ get_device,
15
+ load_hyperparams,
16
+ make_policy,
17
+ set_seeds,
18
+ )
19
+ from rl_algo_impls.shared.callbacks.eval_callback import evaluate
20
+ from rl_algo_impls.shared.vec_env import make_eval_env
21
+ from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
22
+
23
+
24
+ @dataclass
25
+ class SelfplayEvalArgs(RunArgs):
26
+ # Either wandb_run_paths or model_file_paths must have 2 elements in it.
27
+ wandb_run_paths: List[str] = dataclasses.field(default_factory=list)
28
+ model_file_paths: List[str] = dataclasses.field(default_factory=list)
29
+ render: bool = False
30
+ best: bool = True
31
+ n_envs: int = 1
32
+ n_episodes: int = 1
33
+ deterministic_eval: Optional[bool] = None
34
+ no_print_returns: bool = False
35
+ video_path: Optional[str] = None
36
+
37
+
38
+ def selfplay_evaluate(args: SelfplayEvalArgs, root_dir: str) -> Evaluation:
39
+ if args.wandb_run_paths:
40
+ api = wandb.Api()
41
+ args, config, player_1_model_path = load_player(
42
+ api, args.wandb_run_paths[0], args, root_dir
43
+ )
44
+ _, _, player_2_model_path = load_player(
45
+ api, args.wandb_run_paths[1], args, root_dir
46
+ )
47
+ elif args.model_file_paths:
48
+ hyperparams = load_hyperparams(args.algo, args.env)
49
+
50
+ config = Config(args, hyperparams, root_dir)
51
+ player_1_model_path, player_2_model_path = args.model_file_paths
52
+ else:
53
+ raise ValueError("Must specify 2 wandb_run_paths or 2 model_file_paths")
54
+
55
+ print(args)
56
+
57
+ set_seeds(args.seed, args.use_deterministic_algorithms)
58
+
59
+ env_make_kwargs = (
60
+ config.eval_hyperparams.get("env_overrides", {}).get("make_kwargs", {}).copy()
61
+ )
62
+ env_make_kwargs["num_selfplay_envs"] = args.n_envs * 2
63
+ env = make_eval_env(
64
+ config,
65
+ EnvHyperparams(**config.env_hyperparams),
66
+ override_hparams={
67
+ "n_envs": args.n_envs,
68
+ "selfplay_bots": {
69
+ player_2_model_path: args.n_envs,
70
+ },
71
+ "self_play_kwargs": {
72
+ "num_old_policies": 0,
73
+ "save_steps": np.inf,
74
+ "swap_steps": np.inf,
75
+ "bot_always_player_2": True,
76
+ },
77
+ "bots": None,
78
+ "make_kwargs": env_make_kwargs,
79
+ },
80
+ render=args.render,
81
+ normalize_load_path=player_1_model_path,
82
+ )
83
+ if args.video_path:
84
+ env = VecEpisodeRecorder(
85
+ env, args.video_path, max_video_length=18000, num_episodes=args.n_episodes
86
+ )
87
+ device = get_device(config, env)
88
+ policy = make_policy(
89
+ args.algo,
90
+ env,
91
+ device,
92
+ load_path=player_1_model_path,
93
+ **config.policy_hyperparams,
94
+ ).eval()
95
+
96
+ deterministic = (
97
+ args.deterministic_eval
98
+ if args.deterministic_eval is not None
99
+ else config.eval_hyperparams.get("deterministic", True)
100
+ )
101
+ return Evaluation(
102
+ policy,
103
+ evaluate(
104
+ env,
105
+ policy,
106
+ args.n_episodes,
107
+ render=args.render,
108
+ deterministic=deterministic,
109
+ print_returns=not args.no_print_returns,
110
+ ),
111
+ config,
112
+ )
113
+
114
+
115
+ class PlayerData(NamedTuple):
116
+ args: SelfplayEvalArgs
117
+ config: Config
118
+ model_path: str
119
+
120
+
121
+ def load_player(
122
+ api: wandb.Api, run_path: str, args: SelfplayEvalArgs, root_dir: str
123
+ ) -> PlayerData:
124
+ args = copy.copy(args)
125
+
126
+ run = api.run(run_path)
127
+ params = run.config
128
+ args.algo = params["algo"]
129
+ args.env = params["env"]
130
+ args.seed = params.get("seed", None)
131
+ args.use_deterministic_algorithms = params.get("use_deterministic_algorithms", True)
132
+ config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
133
+ model_path = config.model_dir_path(best=args.best, downloaded=True)
134
+
135
+ model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
136
+ run.file(model_archive_name).download()
137
+ if os.path.isdir(model_path):
138
+ shutil.rmtree(model_path)
139
+ shutil.unpack_archive(model_archive_name, model_path)
140
+ os.remove(model_archive_name)
141
+
142
+ return PlayerData(args, config, model_path)
rl_algo_impls/runner/train.py CHANGED
@@ -1,12 +1,17 @@
1
  # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
  import os
3
 
 
 
 
 
 
4
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
 
6
  import dataclasses
7
  import shutil
8
  from dataclasses import asdict, dataclass
9
- from typing import Any, Dict, Optional, Sequence
10
 
11
  import yaml
12
  from torch.utils.tensorboard.writer import SummaryWriter
@@ -23,6 +28,9 @@ from rl_algo_impls.runner.running_utils import (
23
  set_seeds,
24
  )
25
  from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
 
 
 
26
  from rl_algo_impls.shared.stats import EpisodesStats
27
  from rl_algo_impls.shared.vec_env import make_env, make_eval_env
28
 
@@ -41,7 +49,7 @@ def train(args: TrainArgs):
41
  print(hyperparams)
42
  config = Config(args, hyperparams, os.getcwd())
43
 
44
- wandb_enabled = args.wandb_project_name
45
  if wandb_enabled:
46
  wandb.tensorboard.patch(
47
  root_logdir=config.tensorboard_summary_path, pytorch=True
@@ -66,14 +74,17 @@ def train(args: TrainArgs):
66
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
67
  )
68
  device = get_device(config, env)
69
- policy = make_policy(args.algo, env, device, **config.policy_hyperparams)
 
 
 
70
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
71
 
72
  num_parameters = policy.num_parameters()
73
  num_trainable_parameters = policy.num_trainable_parameters()
74
  if wandb_enabled:
75
- wandb.run.summary["num_parameters"] = num_parameters
76
- wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters
77
  else:
78
  print(
79
  f"num_parameters = {num_parameters} ; "
@@ -81,40 +92,49 @@ def train(args: TrainArgs):
81
  )
82
 
83
  eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
84
- record_best_videos = config.eval_params.get("record_best_videos", True)
85
- callback = EvalCallback(
86
  policy,
87
  eval_env,
88
  tb_writer,
89
  best_model_path=config.model_dir_path(best=True),
90
- **config.eval_params,
91
  video_env=make_eval_env(
92
- config, EnvHyperparams(**config.env_hyperparams), override_n_envs=1
 
 
93
  )
94
  if record_best_videos
95
  else None,
96
  best_video_dir=config.best_videos_dir,
97
  additional_keys_to_log=config.additional_keys_to_log,
 
98
  )
99
- algo.learn(config.n_timesteps, callback=callback)
 
 
 
 
 
 
100
 
101
  policy.save(config.model_dir_path(best=False))
102
 
103
- eval_stats = callback.evaluate(n_episodes=10, print_returns=True)
104
 
105
- plot_eval_callback(callback, tb_writer, config.run_name())
106
 
107
  log_dict: Dict[str, Any] = {
108
  "eval": eval_stats._asdict(),
109
  }
110
- if callback.best:
111
- log_dict["best_eval"] = callback.best._asdict()
112
  log_dict.update(asdict(hyperparams))
113
  log_dict.update(vars(args))
114
  with open(config.logs_path, "a") as f:
115
  yaml.dump({config.run_name(): log_dict}, f)
116
 
117
- best_eval_stats: EpisodesStats = callback.best # type: ignore
118
  tb_writer.add_hparams(
119
  hparam_dict(hyperparams, vars(args)),
120
  {
@@ -132,13 +152,8 @@ def train(args: TrainArgs):
132
 
133
  if wandb_enabled:
134
  shutil.make_archive(
135
- os.path.join(wandb.run.dir, config.model_dir_name()),
136
  "zip",
137
  config.model_dir_path(),
138
  )
139
- shutil.make_archive(
140
- os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
141
- "zip",
142
- config.model_dir_path(best=True),
143
- )
144
  wandb.finish()
 
1
  # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
  import os
3
 
4
+ from rl_algo_impls.shared.callbacks import Callback
5
+ from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
6
+ from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
7
+ from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper
8
+
9
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
10
 
11
  import dataclasses
12
  import shutil
13
  from dataclasses import asdict, dataclass
14
+ from typing import Any, Dict, List, Optional, Sequence
15
 
16
  import yaml
17
  from torch.utils.tensorboard.writer import SummaryWriter
 
28
  set_seeds,
29
  )
30
  from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
31
+ from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
32
+ MicrortsRewardDecayCallback,
33
+ )
34
  from rl_algo_impls.shared.stats import EpisodesStats
35
  from rl_algo_impls.shared.vec_env import make_env, make_eval_env
36
 
 
49
  print(hyperparams)
50
  config = Config(args, hyperparams, os.getcwd())
51
 
52
+ wandb_enabled = bool(args.wandb_project_name)
53
  if wandb_enabled:
54
  wandb.tensorboard.patch(
55
  root_logdir=config.tensorboard_summary_path, pytorch=True
 
74
  config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
75
  )
76
  device = get_device(config, env)
77
+ policy_factory = lambda: make_policy(
78
+ args.algo, env, device, **config.policy_hyperparams
79
+ )
80
+ policy = policy_factory()
81
  algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
82
 
83
  num_parameters = policy.num_parameters()
84
  num_trainable_parameters = policy.num_trainable_parameters()
85
  if wandb_enabled:
86
+ wandb.run.summary["num_parameters"] = num_parameters # type: ignore
87
+ wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters # type: ignore
88
  else:
89
  print(
90
  f"num_parameters = {num_parameters} ; "
 
92
  )
93
 
94
  eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
95
+ record_best_videos = config.eval_hyperparams.get("record_best_videos", True)
96
+ eval_callback = EvalCallback(
97
  policy,
98
  eval_env,
99
  tb_writer,
100
  best_model_path=config.model_dir_path(best=True),
101
+ **config.eval_callback_params(),
102
  video_env=make_eval_env(
103
+ config,
104
+ EnvHyperparams(**config.env_hyperparams),
105
+ override_hparams={"n_envs": 1},
106
  )
107
  if record_best_videos
108
  else None,
109
  best_video_dir=config.best_videos_dir,
110
  additional_keys_to_log=config.additional_keys_to_log,
111
+ wandb_enabled=wandb_enabled,
112
  )
113
+ callbacks: List[Callback] = [eval_callback]
114
+ if config.hyperparams.microrts_reward_decay_callback:
115
+ callbacks.append(MicrortsRewardDecayCallback(config, env))
116
+ selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
117
+ if selfPlayWrapper:
118
+ callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
119
+ algo.learn(config.n_timesteps, callbacks=callbacks)
120
 
121
  policy.save(config.model_dir_path(best=False))
122
 
123
+ eval_stats = eval_callback.evaluate(n_episodes=10, print_returns=True)
124
 
125
+ plot_eval_callback(eval_callback, tb_writer, config.run_name())
126
 
127
  log_dict: Dict[str, Any] = {
128
  "eval": eval_stats._asdict(),
129
  }
130
+ if eval_callback.best:
131
+ log_dict["best_eval"] = eval_callback.best._asdict()
132
  log_dict.update(asdict(hyperparams))
133
  log_dict.update(vars(args))
134
  with open(config.logs_path, "a") as f:
135
  yaml.dump({config.run_name(): log_dict}, f)
136
 
137
+ best_eval_stats: EpisodesStats = eval_callback.best # type: ignore
138
  tb_writer.add_hparams(
139
  hparam_dict(hyperparams, vars(args)),
140
  {
 
152
 
153
  if wandb_enabled:
154
  shutil.make_archive(
155
+ os.path.join(wandb.run.dir, config.model_dir_name()), # type: ignore
156
  "zip",
157
  config.model_dir_path(),
158
  )
 
 
 
 
 
159
  wandb.finish()
rl_algo_impls/selfplay_enjoy.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
2
+ import os
3
+
4
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
+
6
+ from rl_algo_impls.runner.running_utils import base_parser
7
+ from rl_algo_impls.runner.selfplay_evaluate import SelfplayEvalArgs, selfplay_evaluate
8
+
9
+
10
+ def selfplay_enjoy() -> None:
11
+ parser = base_parser(multiple=False)
12
+ parser.add_argument(
13
+ "--wandb-run-paths",
14
+ type=str,
15
+ nargs="*",
16
+ help="WandB run paths to load players from. Must be 0 or 2",
17
+ )
18
+ parser.add_argument(
19
+ "--model-file-paths",
20
+ type=str,
21
+ help="File paths to load players from. Must be 0 or 2",
22
+ )
23
+ parser.add_argument("--render", action="store_true")
24
+ parser.add_argument("--n-envs", default=1, type=int)
25
+ parser.add_argument("--n-episodes", default=1, type=int)
26
+ parser.add_argument("--deterministic-eval", default=None, type=bool)
27
+ parser.add_argument(
28
+ "--no-print-returns", action="store_true", help="Limit printing"
29
+ )
30
+ parser.add_argument(
31
+ "--video-path", type=str, help="Path to save video of all plays"
32
+ )
33
+ # parser.set_defaults(
34
+ # algo=["ppo"],
35
+ # env=["Microrts-selfplay-unet-decay"],
36
+ # n_episodes=10,
37
+ # model_file_paths=[
38
+ # "downloaded_models/ppo-Microrts-selfplay-unet-decay-S3-best",
39
+ # "downloaded_models/ppo-Microrts-selfplay-unet-decay-S2-best",
40
+ # ],
41
+ # video_path="/Users/sgoodfriend/Desktop/decay3-vs-decay2",
42
+ # )
43
+ args = parser.parse_args()
44
+ args.algo = args.algo[0]
45
+ args.env = args.env[0]
46
+ args.seed = args.seed[0]
47
+ args = SelfplayEvalArgs(**vars(args))
48
+
49
+ selfplay_evaluate(args, os.getcwd())
50
+
51
+
52
+ if __name__ == "__main__":
53
+ selfplay_enjoy()
rl_algo_impls/shared/actor/__init__.py CHANGED
@@ -1,2 +1,2 @@
1
- from rl_algo_impls.shared.actor.actor import Actor, PiForward
2
  from rl_algo_impls.shared.actor.make_actor import actor_head
 
1
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
2
  from rl_algo_impls.shared.actor.make_actor import actor_head
rl_algo_impls/shared/actor/actor.py CHANGED
@@ -31,12 +31,13 @@ class Actor(nn.Module, ABC):
31
  def action_shape(self) -> Tuple[int, ...]:
32
  ...
33
 
34
- def pi_forward(
35
- self, distribution: Distribution, actions: Optional[torch.Tensor] = None
36
- ) -> PiForward:
37
- logp_a = None
38
- entropy = None
39
- if actions is not None:
40
- logp_a = distribution.log_prob(actions)
41
- entropy = distribution.entropy()
42
- return PiForward(distribution, logp_a, entropy)
 
 
31
  def action_shape(self) -> Tuple[int, ...]:
32
  ...
33
 
34
+
35
+ def pi_forward(
36
+ distribution: Distribution, actions: Optional[torch.Tensor] = None
37
+ ) -> PiForward:
38
+ logp_a = None
39
+ entropy = None
40
+ if actions is not None:
41
+ logp_a = distribution.log_prob(actions)
42
+ entropy = distribution.entropy()
43
+ return PiForward(distribution, logp_a, entropy)
rl_algo_impls/shared/actor/categorical.py CHANGED
@@ -4,8 +4,8 @@ import torch
4
  import torch.nn as nn
5
  from torch.distributions import Categorical
6
 
7
- from rl_algo_impls.shared.actor import Actor, PiForward
8
- from rl_algo_impls.shared.module.module import mlp
9
 
10
 
11
  class MaskedCategorical(Categorical):
@@ -57,7 +57,7 @@ class CategoricalActorHead(Actor):
57
  ) -> PiForward:
58
  logits = self._fc(obs)
59
  pi = MaskedCategorical(logits=logits, mask=action_masks)
60
- return self.pi_forward(pi, actions)
61
 
62
  @property
63
  def action_shape(self) -> Tuple[int, ...]:
 
4
  import torch.nn as nn
5
  from torch.distributions import Categorical
6
 
7
+ from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
8
+ from rl_algo_impls.shared.module.utils import mlp
9
 
10
 
11
  class MaskedCategorical(Categorical):
 
57
  ) -> PiForward:
58
  logits = self._fc(obs)
59
  pi = MaskedCategorical(logits=logits, mask=action_masks)
60
+ return pi_forward(pi, actions)
61
 
62
  @property
63
  def action_shape(self) -> Tuple[int, ...]:
rl_algo_impls/shared/actor/gaussian.py CHANGED
@@ -4,8 +4,8 @@ import torch
4
  import torch.nn as nn
5
  from torch.distributions import Distribution, Normal
6
 
7
- from rl_algo_impls.shared.actor.actor import Actor, PiForward
8
- from rl_algo_impls.shared.module.module import mlp
9
 
10
 
11
  class GaussianDistribution(Normal):
@@ -54,7 +54,7 @@ class GaussianActorHead(Actor):
54
  not action_masks
55
  ), f"{self.__class__.__name__} does not support action_masks"
56
  pi = self._distribution(obs)
57
- return self.pi_forward(pi, actions)
58
 
59
  @property
60
  def action_shape(self) -> Tuple[int, ...]:
 
4
  import torch.nn as nn
5
  from torch.distributions import Distribution, Normal
6
 
7
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
8
+ from rl_algo_impls.shared.module.utils import mlp
9
 
10
 
11
  class GaussianDistribution(Normal):
 
54
  not action_masks
55
  ), f"{self.__class__.__name__} does not support action_masks"
56
  pi = self._distribution(obs)
57
+ return pi_forward(pi, actions)
58
 
59
  @property
60
  def action_shape(self) -> Tuple[int, ...]:
rl_algo_impls/shared/actor/gridnet.py CHANGED
@@ -6,10 +6,10 @@ import torch.nn as nn
6
  from numpy.typing import NDArray
7
  from torch.distributions import Distribution, constraints
8
 
9
- from rl_algo_impls.shared.actor import Actor, PiForward
10
  from rl_algo_impls.shared.actor.categorical import MaskedCategorical
11
  from rl_algo_impls.shared.encoder import EncoderOutDim
12
- from rl_algo_impls.shared.module.module import mlp
13
 
14
 
15
  class GridnetDistribution(Distribution):
@@ -25,7 +25,7 @@ class GridnetDistribution(Distribution):
25
  self.action_vec = action_vec
26
 
27
  masks = masks.view(-1, masks.shape[-1])
28
- split_masks = torch.split(masks[:, 1:], action_vec.tolist(), dim=1)
29
 
30
  grid_logits = logits.reshape(-1, action_vec.sum())
31
  split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
@@ -101,7 +101,7 @@ class GridnetActorHead(Actor):
101
  ), f"No mask case unhandled in {self.__class__.__name__}"
102
  logits = self._fc(obs)
103
  pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
104
- return self.pi_forward(pi, actions)
105
 
106
  @property
107
  def action_shape(self) -> Tuple[int, ...]:
 
6
  from numpy.typing import NDArray
7
  from torch.distributions import Distribution, constraints
8
 
9
+ from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
10
  from rl_algo_impls.shared.actor.categorical import MaskedCategorical
11
  from rl_algo_impls.shared.encoder import EncoderOutDim
12
+ from rl_algo_impls.shared.module.utils import mlp
13
 
14
 
15
  class GridnetDistribution(Distribution):
 
25
  self.action_vec = action_vec
26
 
27
  masks = masks.view(-1, masks.shape[-1])
28
+ split_masks = torch.split(masks, action_vec.tolist(), dim=1)
29
 
30
  grid_logits = logits.reshape(-1, action_vec.sum())
31
  split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
 
101
  ), f"No mask case unhandled in {self.__class__.__name__}"
102
  logits = self._fc(obs)
103
  pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
104
+ return pi_forward(pi, actions)
105
 
106
  @property
107
  def action_shape(self) -> Tuple[int, ...]:
rl_algo_impls/shared/actor/gridnet_decoder.py CHANGED
@@ -5,11 +5,10 @@ import torch
5
  import torch.nn as nn
6
  from numpy.typing import NDArray
7
 
8
- from rl_algo_impls.shared.actor import Actor, PiForward
9
- from rl_algo_impls.shared.actor.categorical import MaskedCategorical
10
  from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
11
  from rl_algo_impls.shared.encoder import EncoderOutDim
12
- from rl_algo_impls.shared.module.module import layer_init
13
 
14
 
15
  class Transpose(nn.Module):
@@ -73,7 +72,7 @@ class GridnetDecoder(Actor):
73
  ), f"No mask case unhandled in {self.__class__.__name__}"
74
  logits = self.deconv(obs)
75
  pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
76
- return self.pi_forward(pi, actions)
77
 
78
  @property
79
  def action_shape(self) -> Tuple[int, ...]:
 
5
  import torch.nn as nn
6
  from numpy.typing import NDArray
7
 
8
+ from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
 
9
  from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
10
  from rl_algo_impls.shared.encoder import EncoderOutDim
11
+ from rl_algo_impls.shared.module.utils import layer_init
12
 
13
 
14
  class Transpose(nn.Module):
 
72
  ), f"No mask case unhandled in {self.__class__.__name__}"
73
  logits = self.deconv(obs)
74
  pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
75
+ return pi_forward(pi, actions)
76
 
77
  @property
78
  def action_shape(self) -> Tuple[int, ...]:
rl_algo_impls/shared/actor/make_actor.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Tuple, Type
2
 
3
  import gym
4
  import torch.nn as nn
@@ -27,6 +27,7 @@ def actor_head(
27
  full_std: bool = True,
28
  squash_output: bool = False,
29
  actor_head_style: str = "single",
 
30
  ) -> Actor:
31
  assert not use_sde or isinstance(
32
  action_space, Box
@@ -73,18 +74,20 @@ def actor_head(
73
  init_layers_orthogonal=init_layers_orthogonal,
74
  )
75
  elif actor_head_style == "gridnet":
 
76
  return GridnetActorHead(
77
- action_space.nvec[0], # type: ignore
78
- action_space.nvec[1:], # type: ignore
79
  in_dim=in_dim,
80
  hidden_sizes=hidden_sizes,
81
  activation=activation,
82
  init_layers_orthogonal=init_layers_orthogonal,
83
  )
84
  elif actor_head_style == "gridnet_decoder":
 
85
  return GridnetDecoder(
86
- action_space.nvec[0], # type: ignore
87
- action_space.nvec[1:], # type: ignore
88
  in_dim=in_dim,
89
  activation=activation,
90
  init_layers_orthogonal=init_layers_orthogonal,
 
1
+ from typing import Optional, Tuple, Type
2
 
3
  import gym
4
  import torch.nn as nn
 
27
  full_std: bool = True,
28
  squash_output: bool = False,
29
  actor_head_style: str = "single",
30
+ action_plane_space: Optional[bool] = None,
31
  ) -> Actor:
32
  assert not use_sde or isinstance(
33
  action_space, Box
 
74
  init_layers_orthogonal=init_layers_orthogonal,
75
  )
76
  elif actor_head_style == "gridnet":
77
+ assert isinstance(action_plane_space, MultiDiscrete)
78
  return GridnetActorHead(
79
+ len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
80
+ action_plane_space.nvec, # type: ignore
81
  in_dim=in_dim,
82
  hidden_sizes=hidden_sizes,
83
  activation=activation,
84
  init_layers_orthogonal=init_layers_orthogonal,
85
  )
86
  elif actor_head_style == "gridnet_decoder":
87
+ assert isinstance(action_plane_space, MultiDiscrete)
88
  return GridnetDecoder(
89
+ len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
90
+ action_plane_space.nvec, # type: ignore
91
  in_dim=in_dim,
92
  activation=activation,
93
  init_layers_orthogonal=init_layers_orthogonal,
rl_algo_impls/shared/actor/multi_discrete.py CHANGED
@@ -6,10 +6,10 @@ import torch.nn as nn
6
  from numpy.typing import NDArray
7
  from torch.distributions import Distribution, constraints
8
 
9
- from rl_algo_impls.shared.actor.actor import Actor, PiForward
10
  from rl_algo_impls.shared.actor.categorical import MaskedCategorical
11
  from rl_algo_impls.shared.encoder import EncoderOutDim
12
- from rl_algo_impls.shared.module.module import mlp
13
 
14
 
15
  class MultiCategorical(Distribution):
@@ -94,7 +94,7 @@ class MultiDiscreteActorHead(Actor):
94
  ) -> PiForward:
95
  logits = self._fc(obs)
96
  pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
97
- return self.pi_forward(pi, actions)
98
 
99
  @property
100
  def action_shape(self) -> Tuple[int, ...]:
 
6
  from numpy.typing import NDArray
7
  from torch.distributions import Distribution, constraints
8
 
9
+ from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
10
  from rl_algo_impls.shared.actor.categorical import MaskedCategorical
11
  from rl_algo_impls.shared.encoder import EncoderOutDim
12
+ from rl_algo_impls.shared.module.utils import mlp
13
 
14
 
15
  class MultiCategorical(Distribution):
 
94
  ) -> PiForward:
95
  logits = self._fc(obs)
96
  pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
97
+ return pi_forward(pi, actions)
98
 
99
  @property
100
  def action_shape(self) -> Tuple[int, ...]:
rl_algo_impls/shared/actor/state_dependent_noise.py CHANGED
@@ -5,7 +5,7 @@ import torch.nn as nn
5
  from torch.distributions import Distribution, Normal
6
 
7
  from rl_algo_impls.shared.actor.actor import Actor, PiForward
8
- from rl_algo_impls.shared.module.module import mlp
9
 
10
 
11
  class TanhBijector:
@@ -172,7 +172,7 @@ class StateDependentNoiseActorHead(Actor):
172
  not action_masks
173
  ), f"{self.__class__.__name__} does not support action_masks"
174
  pi = self._distribution(obs)
175
- return self.pi_forward(pi, actions)
176
 
177
  def sample_weights(self, batch_size: int = 1) -> None:
178
  std = self._get_std()
@@ -185,16 +185,15 @@ class StateDependentNoiseActorHead(Actor):
185
  def action_shape(self) -> Tuple[int, ...]:
186
  return (self.act_dim,)
187
 
188
- def pi_forward(
189
- self, distribution: Distribution, actions: Optional[torch.Tensor] = None
190
- ) -> PiForward:
191
- logp_a = None
192
- entropy = None
193
- if actions is not None:
194
- logp_a = distribution.log_prob(actions)
195
- entropy = (
196
- -logp_a
197
- if self.bijector
198
- else sum_independent_dims(distribution.entropy())
199
- )
200
- return PiForward(distribution, logp_a, entropy)
 
5
  from torch.distributions import Distribution, Normal
6
 
7
  from rl_algo_impls.shared.actor.actor import Actor, PiForward
8
+ from rl_algo_impls.shared.module.utils import mlp
9
 
10
 
11
  class TanhBijector:
 
172
  not action_masks
173
  ), f"{self.__class__.__name__} does not support action_masks"
174
  pi = self._distribution(obs)
175
+ return pi_forward(pi, actions, self.bijector)
176
 
177
  def sample_weights(self, batch_size: int = 1) -> None:
178
  std = self._get_std()
 
185
  def action_shape(self) -> Tuple[int, ...]:
186
  return (self.act_dim,)
187
 
188
+
189
+ def pi_forward(
190
+ distribution: Distribution,
191
+ actions: Optional[torch.Tensor] = None,
192
+ bijector: Optional[TanhBijector] = None,
193
+ ) -> PiForward:
194
+ logp_a = None
195
+ entropy = None
196
+ if actions is not None:
197
+ logp_a = distribution.log_prob(actions)
198
+ entropy = -logp_a if bijector else sum_independent_dims(distribution.entropy())
199
+ return PiForward(distribution, logp_a, entropy)
 
rl_algo_impls/shared/algorithm.py CHANGED
@@ -1,11 +1,11 @@
 
 
 
1
  import gym
2
  import torch
3
-
4
- from abc import ABC, abstractmethod
5
  from torch.utils.tensorboard.writer import SummaryWriter
6
- from typing import Optional, TypeVar
7
 
8
- from rl_algo_impls.shared.callbacks.callback import Callback
9
  from rl_algo_impls.shared.policy.policy import Policy
10
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
11
 
@@ -32,7 +32,7 @@ class Algorithm(ABC):
32
  def learn(
33
  self: AlgorithmSelf,
34
  train_timesteps: int,
35
- callback: Optional[Callback] = None,
36
  total_timesteps: Optional[int] = None,
37
  start_timesteps: int = 0,
38
  ) -> AlgorithmSelf:
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Optional, TypeVar
3
+
4
  import gym
5
  import torch
 
 
6
  from torch.utils.tensorboard.writer import SummaryWriter
 
7
 
8
+ from rl_algo_impls.shared.callbacks import Callback
9
  from rl_algo_impls.shared.policy.policy import Policy
10
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
11
 
 
32
  def learn(
33
  self: AlgorithmSelf,
34
  train_timesteps: int,
35
+ callbacks: Optional[List[Callback]] = None,
36
  total_timesteps: Optional[int] = None,
37
  start_timesteps: int = 0,
38
  ) -> AlgorithmSelf:
rl_algo_impls/shared/callbacks/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from rl_algo_impls.shared.callbacks.callback import Callback
rl_algo_impls/shared/callbacks/eval_callback.py CHANGED
@@ -1,12 +1,13 @@
1
  import itertools
2
  import os
 
3
  from time import perf_counter
4
  from typing import Dict, List, Optional, Union
5
 
6
  import numpy as np
7
  from torch.utils.tensorboard.writer import SummaryWriter
8
 
9
- from rl_algo_impls.shared.callbacks.callback import Callback
10
  from rl_algo_impls.shared.policy.policy import Policy
11
  from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
12
  from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
@@ -80,6 +81,7 @@ def evaluate(
80
  print_returns: bool = True,
81
  ignore_first_episode: bool = False,
82
  additional_keys_to_log: Optional[List[str]] = None,
 
83
  ) -> EpisodesStats:
84
  policy.sync_normalization(env)
85
  policy.eval()
@@ -93,18 +95,21 @@ def evaluate(
93
  )
94
 
95
  obs = env.reset()
96
- action_masker = find_action_masker(env)
97
  while not episodes.is_done():
98
  act = policy.act(
99
  obs,
100
  deterministic=deterministic,
101
- action_masks=action_masker.action_masks() if action_masker else None,
102
  )
103
  obs, rew, done, info = env.step(act)
104
  episodes.step(rew, done, info)
105
  if render:
106
  env.render()
107
- stats = EpisodesStats(episodes.episodes)
 
 
 
108
  if print_returns:
109
  print(stats)
110
  return stats
@@ -127,6 +132,8 @@ class EvalCallback(Callback):
127
  max_video_length: int = 3600,
128
  ignore_first_episode: bool = False,
129
  additional_keys_to_log: Optional[List[str]] = None,
 
 
130
  ) -> None:
131
  super().__init__()
132
  self.policy = policy
@@ -151,6 +158,8 @@ class EvalCallback(Callback):
151
  self.best_video_base_path = None
152
  self.ignore_first_episode = ignore_first_episode
153
  self.additional_keys_to_log = additional_keys_to_log
 
 
154
 
155
  def on_step(self, timesteps_elapsed: int = 1) -> bool:
156
  super().on_step(timesteps_elapsed)
@@ -170,6 +179,7 @@ class EvalCallback(Callback):
170
  print_returns=print_returns or False,
171
  ignore_first_episode=self.ignore_first_episode,
172
  additional_keys_to_log=self.additional_keys_to_log,
 
173
  )
174
  end_time = perf_counter()
175
  self.tb_writer.add_scalar(
@@ -189,6 +199,15 @@ class EvalCallback(Callback):
189
  assert self.best_model_path
190
  self.policy.save(self.best_model_path)
191
  print("Saved best model")
 
 
 
 
 
 
 
 
 
192
  self.best.write_to_tensorboard(
193
  self.tb_writer, "best_eval", self.timesteps_elapsed
194
  )
@@ -208,6 +227,7 @@ class EvalCallback(Callback):
208
  1,
209
  deterministic=self.deterministic,
210
  print_returns=False,
 
211
  )
212
  print(f"Saved best video: {video_stats}")
213
 
 
1
  import itertools
2
  import os
3
+ import shutil
4
  from time import perf_counter
5
  from typing import Dict, List, Optional, Union
6
 
7
  import numpy as np
8
  from torch.utils.tensorboard.writer import SummaryWriter
9
 
10
+ from rl_algo_impls.shared.callbacks import Callback
11
  from rl_algo_impls.shared.policy.policy import Policy
12
  from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
13
  from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
 
81
  print_returns: bool = True,
82
  ignore_first_episode: bool = False,
83
  additional_keys_to_log: Optional[List[str]] = None,
84
+ score_function: str = "mean-std",
85
  ) -> EpisodesStats:
86
  policy.sync_normalization(env)
87
  policy.eval()
 
95
  )
96
 
97
  obs = env.reset()
98
+ get_action_mask = getattr(env, "get_action_mask", None)
99
  while not episodes.is_done():
100
  act = policy.act(
101
  obs,
102
  deterministic=deterministic,
103
+ action_masks=get_action_mask() if get_action_mask else None,
104
  )
105
  obs, rew, done, info = env.step(act)
106
  episodes.step(rew, done, info)
107
  if render:
108
  env.render()
109
+ stats = EpisodesStats(
110
+ episodes.episodes,
111
+ score_function=score_function,
112
+ )
113
  if print_returns:
114
  print(stats)
115
  return stats
 
132
  max_video_length: int = 3600,
133
  ignore_first_episode: bool = False,
134
  additional_keys_to_log: Optional[List[str]] = None,
135
+ score_function: str = "mean-std",
136
+ wandb_enabled: bool = False,
137
  ) -> None:
138
  super().__init__()
139
  self.policy = policy
 
158
  self.best_video_base_path = None
159
  self.ignore_first_episode = ignore_first_episode
160
  self.additional_keys_to_log = additional_keys_to_log
161
+ self.score_function = score_function
162
+ self.wandb_enabled = wandb_enabled
163
 
164
  def on_step(self, timesteps_elapsed: int = 1) -> bool:
165
  super().on_step(timesteps_elapsed)
 
179
  print_returns=print_returns or False,
180
  ignore_first_episode=self.ignore_first_episode,
181
  additional_keys_to_log=self.additional_keys_to_log,
182
+ score_function=self.score_function,
183
  )
184
  end_time = perf_counter()
185
  self.tb_writer.add_scalar(
 
199
  assert self.best_model_path
200
  self.policy.save(self.best_model_path)
201
  print("Saved best model")
202
+ if self.wandb_enabled:
203
+ import wandb
204
+
205
+ best_model_name = os.path.split(self.best_model_path)[-1]
206
+ shutil.make_archive(
207
+ os.path.join(wandb.run.dir, best_model_name), # type: ignore
208
+ "zip",
209
+ self.best_model_path,
210
+ )
211
  self.best.write_to_tensorboard(
212
  self.tb_writer, "best_eval", self.timesteps_elapsed
213
  )
 
227
  1,
228
  deterministic=self.deterministic,
229
  print_returns=False,
230
+ score_function=self.score_function,
231
  )
232
  print(f"Saved best video: {video_stats}")
233
 
rl_algo_impls/shared/callbacks/microrts_reward_decay_callback.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from rl_algo_impls.runner.config import Config
4
+ from rl_algo_impls.shared.callbacks import Callback
5
+ from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
6
+
7
+
8
+ class MicrortsRewardDecayCallback(Callback):
9
+ def __init__(
10
+ self,
11
+ config: Config,
12
+ env: VecEnv,
13
+ start_timesteps: int = 0,
14
+ ) -> None:
15
+ super().__init__()
16
+ from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
17
+
18
+ unwrapped = env.unwrapped
19
+ assert isinstance(unwrapped, MicroRTSGridModeVecEnv)
20
+ self.microrts_env = unwrapped
21
+ self.base_reward_weights = self.microrts_env.reward_weight
22
+
23
+ self.total_train_timesteps = config.n_timesteps
24
+ self.timesteps_elapsed = start_timesteps
25
+
26
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
27
+ super().on_step(timesteps_elapsed)
28
+
29
+ progress = self.timesteps_elapsed / self.total_train_timesteps
30
+ # Decay all rewards except WinLoss
31
+ reward_weights = self.base_reward_weights * np.array(
32
+ [1] + [1 - progress] * (len(self.base_reward_weights) - 1)
33
+ )
34
+ self.microrts_env.reward_weight = reward_weights
35
+
36
+ return True
rl_algo_impls/shared/callbacks/optimize_callback.py CHANGED
@@ -5,7 +5,7 @@ from time import perf_counter
5
  from torch.utils.tensorboard.writer import SummaryWriter
6
  from typing import NamedTuple, Union
7
 
8
- from rl_algo_impls.shared.callbacks.callback import Callback
9
  from rl_algo_impls.shared.callbacks.eval_callback import evaluate
10
  from rl_algo_impls.shared.policy.policy import Policy
11
  from rl_algo_impls.shared.stats import EpisodesStats
 
5
  from torch.utils.tensorboard.writer import SummaryWriter
6
  from typing import NamedTuple, Union
7
 
8
+ from rl_algo_impls.shared.callbacks import Callback
9
  from rl_algo_impls.shared.callbacks.eval_callback import evaluate
10
  from rl_algo_impls.shared.policy.policy import Policy
11
  from rl_algo_impls.shared.stats import EpisodesStats
rl_algo_impls/shared/callbacks/self_play_callback.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ from rl_algo_impls.shared.callbacks import Callback
4
+ from rl_algo_impls.shared.policy.policy import Policy
5
+ from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
6
+
7
+
8
+ class SelfPlayCallback(Callback):
9
+ def __init__(
10
+ self,
11
+ policy: Policy,
12
+ policy_factory: Callable[[], Policy],
13
+ selfPlayWrapper: SelfPlayWrapper,
14
+ ) -> None:
15
+ super().__init__()
16
+ self.policy = policy
17
+ self.policy_factory = policy_factory
18
+ self.selfPlayWrapper = selfPlayWrapper
19
+ self.checkpoint_policy()
20
+
21
+ def on_step(self, timesteps_elapsed: int = 1) -> bool:
22
+ super().on_step(timesteps_elapsed)
23
+ if (
24
+ self.timesteps_elapsed
25
+ >= self.last_checkpoint_step + self.selfPlayWrapper.save_steps
26
+ ):
27
+ self.checkpoint_policy()
28
+ return True
29
+
30
+ def checkpoint_policy(self):
31
+ self.selfPlayWrapper.checkpoint_policy(
32
+ self.policy_factory().load_from(self.policy)
33
+ )
34
+ self.last_checkpoint_step = self.timesteps_elapsed
rl_algo_impls/shared/encoder/cnn.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  import torch
7
  import torch.nn as nn
8
 
9
- from rl_algo_impls.shared.module.module import layer_init
10
 
11
  EncoderOutDim = Union[int, Tuple[int, ...]]
12
 
 
6
  import torch
7
  import torch.nn as nn
8
 
9
+ from rl_algo_impls.shared.module.utils import layer_init
10
 
11
  EncoderOutDim = Union[int, Tuple[int, ...]]
12
 
rl_algo_impls/shared/encoder/encoder.py CHANGED
@@ -12,7 +12,7 @@ from rl_algo_impls.shared.encoder.gridnet_encoder import GridnetEncoder
12
  from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn
13
  from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn
14
  from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn
15
- from rl_algo_impls.shared.module.module import layer_init
16
 
17
  CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = {
18
  "nature": NatureCnn,
 
12
  from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn
13
  from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn
14
  from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn
15
+ from rl_algo_impls.shared.module.utils import layer_init
16
 
17
  CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = {
18
  "nature": NatureCnn,
rl_algo_impls/shared/encoder/gridnet_encoder.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torch.nn as nn
6
 
7
  from rl_algo_impls.shared.encoder.cnn import CnnEncoder, EncoderOutDim
8
- from rl_algo_impls.shared.module.module import layer_init
9
 
10
 
11
  class GridnetEncoder(CnnEncoder):
 
5
  import torch.nn as nn
6
 
7
  from rl_algo_impls.shared.encoder.cnn import CnnEncoder, EncoderOutDim
8
+ from rl_algo_impls.shared.module.utils import layer_init
9
 
10
 
11
  class GridnetEncoder(CnnEncoder):
rl_algo_impls/shared/encoder/impala_cnn.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torch.nn as nn
6
 
7
  from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
8
- from rl_algo_impls.shared.module.module import layer_init
9
 
10
 
11
  class ResidualBlock(nn.Module):
 
5
  import torch.nn as nn
6
 
7
  from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
8
+ from rl_algo_impls.shared.module.utils import layer_init
9
 
10
 
11
  class ResidualBlock(nn.Module):
rl_algo_impls/shared/encoder/microrts_cnn.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import torch.nn as nn
6
 
7
  from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
8
- from rl_algo_impls.shared.module.module import layer_init
9
 
10
 
11
  class MicrortsCnn(FlattenedCnnEncoder):
 
5
  import torch.nn as nn
6
 
7
  from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
8
+ from rl_algo_impls.shared.module.utils import layer_init
9
 
10
 
11
  class MicrortsCnn(FlattenedCnnEncoder):
rl_algo_impls/shared/encoder/nature_cnn.py CHANGED
@@ -4,7 +4,7 @@ import gym
4
  import torch.nn as nn
5
 
6
  from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
7
- from rl_algo_impls.shared.module.module import layer_init
8
 
9
 
10
  class NatureCnn(FlattenedCnnEncoder):
 
4
  import torch.nn as nn
5
 
6
  from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
7
+ from rl_algo_impls.shared.module.utils import layer_init
8
 
9
 
10
  class NatureCnn(FlattenedCnnEncoder):
rl_algo_impls/shared/gae.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
 
4
  from typing import NamedTuple, Sequence
5
 
6
- from rl_algo_impls.shared.policy.on_policy import OnPolicy
7
  from rl_algo_impls.shared.trajectory import Trajectory
8
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs
9
 
 
3
 
4
  from typing import NamedTuple, Sequence
5
 
6
+ from rl_algo_impls.shared.policy.actor_critic import OnPolicy
7
  from rl_algo_impls.shared.trajectory import Trajectory
8
  from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs
9
 
rl_algo_impls/shared/module/{module.py → utils.py} RENAMED
File without changes
rl_algo_impls/shared/policy/{on_policy.py → actor_critic.py} RENAMED
@@ -4,12 +4,14 @@ from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
4
  import gym
5
  import numpy as np
6
  import torch
7
- from gym.spaces import Box, Discrete, Space
8
 
9
- from rl_algo_impls.shared.actor import PiForward, actor_head
10
- from rl_algo_impls.shared.encoder import Encoder
11
- from rl_algo_impls.shared.policy.critic import CriticHead
12
- from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy
 
 
13
  from rl_algo_impls.wrappers.vectorable_wrapper import (
14
  VecEnv,
15
  VecEnvObs,
@@ -52,21 +54,6 @@ def clamp_actions(
52
  return actions
53
 
54
 
55
- def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
56
- if isinstance(obs_space, Box):
57
- if len(obs_space.shape) == 3:
58
- # By default feature extractor to output has no hidden layers
59
- return []
60
- elif len(obs_space.shape) == 1:
61
- return [64, 64]
62
- else:
63
- raise ValueError(f"Unsupported observation space: {obs_space}")
64
- elif isinstance(obs_space, Discrete):
65
- return [64]
66
- else:
67
- raise ValueError(f"Unsupported observation space: {obs_space}")
68
-
69
-
70
  class OnPolicy(Policy):
71
  @abstractmethod
72
  def value(self, obs: VecEnvObs) -> np.ndarray:
@@ -106,78 +93,59 @@ class ActorCritic(OnPolicy):
106
 
107
  observation_space = single_observation_space(env)
108
  action_space = single_action_space(env)
 
109
 
110
- pi_hidden_sizes = (
111
- pi_hidden_sizes
112
- if pi_hidden_sizes is not None
113
- else default_hidden_sizes(observation_space)
114
- )
115
- v_hidden_sizes = (
116
- v_hidden_sizes
117
- if v_hidden_sizes is not None
118
- else default_hidden_sizes(observation_space)
119
- )
120
-
121
- activation = ACTIVATION[activation_fn]
122
  self.action_space = action_space
123
  self.squash_output = squash_output
124
- self.share_features_extractor = share_features_extractor
125
- self._feature_extractor = Encoder(
126
- observation_space,
127
- activation,
128
- init_layers_orthogonal=init_layers_orthogonal,
129
- cnn_flatten_dim=cnn_flatten_dim,
130
- cnn_style=cnn_style,
131
- cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
132
- impala_channels=impala_channels,
133
- )
134
- self._pi = actor_head(
135
- self.action_space,
136
- self._feature_extractor.out_dim,
137
- tuple(pi_hidden_sizes),
138
- init_layers_orthogonal,
139
- activation,
140
- log_std_init=log_std_init,
141
- use_sde=use_sde,
142
- full_std=full_std,
143
- squash_output=squash_output,
144
- actor_head_style=actor_head_style,
145
- )
146
 
147
- if not share_features_extractor:
148
- self._v_feature_extractor = Encoder(
149
  observation_space,
150
- activation,
 
 
151
  init_layers_orthogonal=init_layers_orthogonal,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  cnn_flatten_dim=cnn_flatten_dim,
153
  cnn_style=cnn_style,
154
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
 
 
 
155
  )
156
- critic_in_dim = self._v_feature_extractor.out_dim
157
  else:
158
- self._v_feature_extractor = None
159
- critic_in_dim = self._feature_extractor.out_dim
160
- self._v = CriticHead(
161
- in_dim=critic_in_dim,
162
- hidden_sizes=v_hidden_sizes,
163
- activation=activation,
164
- init_layers_orthogonal=init_layers_orthogonal,
165
- )
166
-
167
- def _pi_forward(
168
- self,
169
- obs: torch.Tensor,
170
- action_masks: Optional[torch.Tensor],
171
- action: Optional[torch.Tensor] = None,
172
- ) -> Tuple[PiForward, torch.Tensor]:
173
- p_fe = self._feature_extractor(obs)
174
- pi_forward = self._pi(p_fe, actions=action, action_masks=action_masks)
175
-
176
- return pi_forward, p_fe
177
-
178
- def _v_forward(self, obs: torch.Tensor, p_fc: torch.Tensor) -> torch.Tensor:
179
- v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
180
- return self._v(v_fe)
181
 
182
  def forward(
183
  self,
@@ -185,8 +153,7 @@ class ActorCritic(OnPolicy):
185
  action: torch.Tensor,
186
  action_masks: Optional[torch.Tensor] = None,
187
  ) -> ACForward:
188
- (_, logp_a, entropy), p_fc = self._pi_forward(obs, action_masks, action=action)
189
- v = self._v_forward(obs, p_fc)
190
 
191
  assert logp_a is not None
192
  assert entropy is not None
@@ -195,24 +162,17 @@ class ActorCritic(OnPolicy):
195
  def value(self, obs: VecEnvObs) -> np.ndarray:
196
  o = self._as_tensor(obs)
197
  with torch.no_grad():
198
- fe = (
199
- self._v_feature_extractor(o)
200
- if self._v_feature_extractor
201
- else self._feature_extractor(o)
202
- )
203
- v = self._v(fe)
204
  return v.cpu().numpy()
205
 
206
  def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
207
  o = self._as_tensor(obs)
208
  a_masks = self._as_tensor(action_masks) if action_masks is not None else None
209
  with torch.no_grad():
210
- (pi, _, _), p_fc = self._pi_forward(o, action_masks=a_masks)
211
  a = pi.sample()
212
  logp_a = pi.log_prob(a)
213
 
214
- v = self._v_forward(o, p_fc)
215
-
216
  a_np = a.cpu().numpy()
217
  clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
218
  return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
@@ -231,7 +191,9 @@ class ActorCritic(OnPolicy):
231
  self._as_tensor(action_masks) if action_masks is not None else None
232
  )
233
  with torch.no_grad():
234
- (pi, _, _), _ = self._pi_forward(o, action_masks=a_masks)
 
 
235
  a = pi.mode
236
  return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
237
 
@@ -239,11 +201,16 @@ class ActorCritic(OnPolicy):
239
  super().load(path)
240
  self.reset_noise()
241
 
 
 
 
 
 
242
  def reset_noise(self, batch_size: Optional[int] = None) -> None:
243
- self._pi.sample_weights(
244
  batch_size=batch_size if batch_size else self.env.num_envs
245
  )
246
 
247
  @property
248
  def action_shape(self) -> Tuple[int, ...]:
249
- return self._pi.action_shape
 
4
  import gym
5
  import numpy as np
6
  import torch
7
+ from gym.spaces import Box, Space
8
 
9
+ from rl_algo_impls.shared.policy.actor_critic_network import (
10
+ ConnectedTrioActorCriticNetwork,
11
+ SeparateActorCriticNetwork,
12
+ UNetActorCriticNetwork,
13
+ )
14
+ from rl_algo_impls.shared.policy.policy import Policy
15
  from rl_algo_impls.wrappers.vectorable_wrapper import (
16
  VecEnv,
17
  VecEnvObs,
 
54
  return actions
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  class OnPolicy(Policy):
58
  @abstractmethod
59
  def value(self, obs: VecEnvObs) -> np.ndarray:
 
93
 
94
  observation_space = single_observation_space(env)
95
  action_space = single_action_space(env)
96
+ action_plane_space = getattr(env, "action_plane_space", None)
97
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  self.action_space = action_space
99
  self.squash_output = squash_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ if actor_head_style == "unet":
102
+ self.network = UNetActorCriticNetwork(
103
  observation_space,
104
+ action_space,
105
+ action_plane_space,
106
+ v_hidden_sizes=v_hidden_sizes,
107
  init_layers_orthogonal=init_layers_orthogonal,
108
+ activation_fn=activation_fn,
109
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
110
+ )
111
+ elif share_features_extractor:
112
+ self.network = ConnectedTrioActorCriticNetwork(
113
+ observation_space,
114
+ action_space,
115
+ pi_hidden_sizes=pi_hidden_sizes,
116
+ v_hidden_sizes=v_hidden_sizes,
117
+ init_layers_orthogonal=init_layers_orthogonal,
118
+ activation_fn=activation_fn,
119
+ log_std_init=log_std_init,
120
+ use_sde=use_sde,
121
+ full_std=full_std,
122
+ squash_output=squash_output,
123
  cnn_flatten_dim=cnn_flatten_dim,
124
  cnn_style=cnn_style,
125
  cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
126
+ impala_channels=impala_channels,
127
+ actor_head_style=actor_head_style,
128
+ action_plane_space=action_plane_space,
129
  )
 
130
  else:
131
+ self.network = SeparateActorCriticNetwork(
132
+ observation_space,
133
+ action_space,
134
+ pi_hidden_sizes=pi_hidden_sizes,
135
+ v_hidden_sizes=v_hidden_sizes,
136
+ init_layers_orthogonal=init_layers_orthogonal,
137
+ activation_fn=activation_fn,
138
+ log_std_init=log_std_init,
139
+ use_sde=use_sde,
140
+ full_std=full_std,
141
+ squash_output=squash_output,
142
+ cnn_flatten_dim=cnn_flatten_dim,
143
+ cnn_style=cnn_style,
144
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
145
+ impala_channels=impala_channels,
146
+ actor_head_style=actor_head_style,
147
+ action_plane_space=action_plane_space,
148
+ )
 
 
 
 
 
149
 
150
  def forward(
151
  self,
 
153
  action: torch.Tensor,
154
  action_masks: Optional[torch.Tensor] = None,
155
  ) -> ACForward:
156
+ (_, logp_a, entropy), v = self.network(obs, action, action_masks=action_masks)
 
157
 
158
  assert logp_a is not None
159
  assert entropy is not None
 
162
  def value(self, obs: VecEnvObs) -> np.ndarray:
163
  o = self._as_tensor(obs)
164
  with torch.no_grad():
165
+ v = self.network.value(o)
 
 
 
 
 
166
  return v.cpu().numpy()
167
 
168
  def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
169
  o = self._as_tensor(obs)
170
  a_masks = self._as_tensor(action_masks) if action_masks is not None else None
171
  with torch.no_grad():
172
+ (pi, _, _), v = self.network.distribution_and_value(o, action_masks=a_masks)
173
  a = pi.sample()
174
  logp_a = pi.log_prob(a)
175
 
 
 
176
  a_np = a.cpu().numpy()
177
  clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
178
  return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
 
191
  self._as_tensor(action_masks) if action_masks is not None else None
192
  )
193
  with torch.no_grad():
194
+ (pi, _, _), _ = self.network.distribution_and_value(
195
+ o, action_masks=a_masks
196
+ )
197
  a = pi.mode
198
  return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
199
 
 
201
  super().load(path)
202
  self.reset_noise()
203
 
204
+ def load_from(self: ActorCriticSelf, policy: ActorCriticSelf) -> ActorCriticSelf:
205
+ super().load_from(policy)
206
+ self.reset_noise()
207
+ return self
208
+
209
  def reset_noise(self, batch_size: Optional[int] = None) -> None:
210
+ self.network.reset_noise(
211
  batch_size=batch_size if batch_size else self.env.num_envs
212
  )
213
 
214
  @property
215
  def action_shape(self) -> Tuple[int, ...]:
216
+ return self.network.action_shape
rl_algo_impls/shared/policy/actor_critic_network/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rl_algo_impls.shared.policy.actor_critic_network.connected_trio import (
2
+ ConnectedTrioActorCriticNetwork,
3
+ )
4
+ from rl_algo_impls.shared.policy.actor_critic_network.network import (
5
+ ActorCriticNetwork,
6
+ default_hidden_sizes,
7
+ )
8
+ from rl_algo_impls.shared.policy.actor_critic_network.separate_actor_critic import (
9
+ SeparateActorCriticNetwork,
10
+ )
11
+ from rl_algo_impls.shared.policy.actor_critic_network.unet import UNetActorCriticNetwork
rl_algo_impls/shared/policy/actor_critic_network/connected_trio.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import torch
4
+ from gym.spaces import Space
5
+
6
+ from rl_algo_impls.shared.actor import actor_head
7
+ from rl_algo_impls.shared.encoder import Encoder
8
+ from rl_algo_impls.shared.policy.actor_critic_network.network import (
9
+ ACNForward,
10
+ ActorCriticNetwork,
11
+ default_hidden_sizes,
12
+ )
13
+ from rl_algo_impls.shared.policy.critic import CriticHead
14
+ from rl_algo_impls.shared.policy.policy import ACTIVATION
15
+
16
+
17
+ class ConnectedTrioActorCriticNetwork(ActorCriticNetwork):
18
+ """Encode (feature extractor), decoder (actor head), critic head networks"""
19
+
20
+ def __init__(
21
+ self,
22
+ observation_space: Space,
23
+ action_space: Space,
24
+ pi_hidden_sizes: Optional[Sequence[int]] = None,
25
+ v_hidden_sizes: Optional[Sequence[int]] = None,
26
+ init_layers_orthogonal: bool = True,
27
+ activation_fn: str = "tanh",
28
+ log_std_init: float = -0.5,
29
+ use_sde: bool = False,
30
+ full_std: bool = True,
31
+ squash_output: bool = False,
32
+ cnn_flatten_dim: int = 512,
33
+ cnn_style: str = "nature",
34
+ cnn_layers_init_orthogonal: Optional[bool] = None,
35
+ impala_channels: Sequence[int] = (16, 32, 32),
36
+ actor_head_style: str = "single",
37
+ action_plane_space: Optional[Space] = None,
38
+ ) -> None:
39
+ super().__init__()
40
+
41
+ pi_hidden_sizes = (
42
+ pi_hidden_sizes
43
+ if pi_hidden_sizes is not None
44
+ else default_hidden_sizes(observation_space)
45
+ )
46
+ v_hidden_sizes = (
47
+ v_hidden_sizes
48
+ if v_hidden_sizes is not None
49
+ else default_hidden_sizes(observation_space)
50
+ )
51
+
52
+ activation = ACTIVATION[activation_fn]
53
+ self._feature_extractor = Encoder(
54
+ observation_space,
55
+ activation,
56
+ init_layers_orthogonal=init_layers_orthogonal,
57
+ cnn_flatten_dim=cnn_flatten_dim,
58
+ cnn_style=cnn_style,
59
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
60
+ impala_channels=impala_channels,
61
+ )
62
+ self._pi = actor_head(
63
+ action_space,
64
+ self._feature_extractor.out_dim,
65
+ tuple(pi_hidden_sizes),
66
+ init_layers_orthogonal,
67
+ activation,
68
+ log_std_init=log_std_init,
69
+ use_sde=use_sde,
70
+ full_std=full_std,
71
+ squash_output=squash_output,
72
+ actor_head_style=actor_head_style,
73
+ action_plane_space=action_plane_space,
74
+ )
75
+
76
+ self._v = CriticHead(
77
+ in_dim=self._feature_extractor.out_dim,
78
+ hidden_sizes=v_hidden_sizes,
79
+ activation=activation,
80
+ init_layers_orthogonal=init_layers_orthogonal,
81
+ )
82
+
83
+ def forward(
84
+ self,
85
+ obs: torch.Tensor,
86
+ action: torch.Tensor,
87
+ action_masks: Optional[torch.Tensor] = None,
88
+ ) -> ACNForward:
89
+ return self._distribution_and_value(
90
+ obs, action=action, action_masks=action_masks
91
+ )
92
+
93
+ def distribution_and_value(
94
+ self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
95
+ ) -> ACNForward:
96
+ return self._distribution_and_value(obs, action_masks=action_masks)
97
+
98
+ def _distribution_and_value(
99
+ self,
100
+ obs: torch.Tensor,
101
+ action: Optional[torch.Tensor] = None,
102
+ action_masks: Optional[torch.Tensor] = None,
103
+ ) -> ACNForward:
104
+ encoded = self._feature_extractor(obs)
105
+ pi_forward = self._pi(encoded, actions=action, action_masks=action_masks)
106
+ v = self._v(encoded)
107
+ return ACNForward(pi_forward, v)
108
+
109
+ def value(self, obs: torch.Tensor) -> torch.Tensor:
110
+ encoded = self._feature_extractor(obs)
111
+ return self._v(encoded)
112
+
113
+ def reset_noise(self, batch_size: int) -> None:
114
+ self._pi.sample_weights(batch_size=batch_size)
115
+
116
+ @property
117
+ def action_shape(self) -> Tuple[int, ...]:
118
+ return self._pi.action_shape
rl_algo_impls/shared/policy/actor_critic_network/network.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import NamedTuple, Optional, Sequence, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from gym.spaces import Box, Discrete, Space
7
+
8
+ from rl_algo_impls.shared.actor import PiForward
9
+
10
+
11
+ class ACNForward(NamedTuple):
12
+ pi_forward: PiForward
13
+ v: torch.Tensor
14
+
15
+
16
+ class ActorCriticNetwork(nn.Module, ABC):
17
+ @abstractmethod
18
+ def forward(
19
+ self,
20
+ obs: torch.Tensor,
21
+ action: torch.Tensor,
22
+ action_masks: Optional[torch.Tensor] = None,
23
+ ) -> ACNForward:
24
+ ...
25
+
26
+ @abstractmethod
27
+ def distribution_and_value(
28
+ self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
29
+ ) -> ACNForward:
30
+ ...
31
+
32
+ @abstractmethod
33
+ def value(self, obs: torch.Tensor) -> torch.Tensor:
34
+ ...
35
+
36
+ @abstractmethod
37
+ def reset_noise(self, batch_size: Optional[int] = None) -> None:
38
+ ...
39
+
40
+ @property
41
+ def action_shape(self) -> Tuple[int, ...]:
42
+ ...
43
+
44
+
45
+ def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
46
+ if isinstance(obs_space, Box):
47
+ if len(obs_space.shape) == 3: # type: ignore
48
+ # By default feature extractor to output has no hidden layers
49
+ return []
50
+ elif len(obs_space.shape) == 1: # type: ignore
51
+ return [64, 64]
52
+ else:
53
+ raise ValueError(f"Unsupported observation space: {obs_space}")
54
+ elif isinstance(obs_space, Discrete):
55
+ return [64]
56
+ else:
57
+ raise ValueError(f"Unsupported observation space: {obs_space}")
rl_algo_impls/shared/policy/actor_critic_network/separate_actor_critic.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from gym.spaces import Space
6
+
7
+ from rl_algo_impls.shared.actor import actor_head
8
+ from rl_algo_impls.shared.encoder import Encoder
9
+ from rl_algo_impls.shared.policy.actor_critic_network.network import (
10
+ ACNForward,
11
+ ActorCriticNetwork,
12
+ default_hidden_sizes,
13
+ )
14
+ from rl_algo_impls.shared.policy.critic import CriticHead
15
+ from rl_algo_impls.shared.policy.policy import ACTIVATION
16
+
17
+
18
+ class SeparateActorCriticNetwork(ActorCriticNetwork):
19
+ def __init__(
20
+ self,
21
+ observation_space: Space,
22
+ action_space: Space,
23
+ pi_hidden_sizes: Optional[Sequence[int]] = None,
24
+ v_hidden_sizes: Optional[Sequence[int]] = None,
25
+ init_layers_orthogonal: bool = True,
26
+ activation_fn: str = "tanh",
27
+ log_std_init: float = -0.5,
28
+ use_sde: bool = False,
29
+ full_std: bool = True,
30
+ squash_output: bool = False,
31
+ cnn_flatten_dim: int = 512,
32
+ cnn_style: str = "nature",
33
+ cnn_layers_init_orthogonal: Optional[bool] = None,
34
+ impala_channels: Sequence[int] = (16, 32, 32),
35
+ actor_head_style: str = "single",
36
+ action_plane_space: Optional[Space] = None,
37
+ ) -> None:
38
+ super().__init__()
39
+
40
+ pi_hidden_sizes = (
41
+ pi_hidden_sizes
42
+ if pi_hidden_sizes is not None
43
+ else default_hidden_sizes(observation_space)
44
+ )
45
+ v_hidden_sizes = (
46
+ v_hidden_sizes
47
+ if v_hidden_sizes is not None
48
+ else default_hidden_sizes(observation_space)
49
+ )
50
+
51
+ activation = ACTIVATION[activation_fn]
52
+ self._feature_extractor = Encoder(
53
+ observation_space,
54
+ activation,
55
+ init_layers_orthogonal=init_layers_orthogonal,
56
+ cnn_flatten_dim=cnn_flatten_dim,
57
+ cnn_style=cnn_style,
58
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
59
+ impala_channels=impala_channels,
60
+ )
61
+ self._pi = actor_head(
62
+ action_space,
63
+ self._feature_extractor.out_dim,
64
+ tuple(pi_hidden_sizes),
65
+ init_layers_orthogonal,
66
+ activation,
67
+ log_std_init=log_std_init,
68
+ use_sde=use_sde,
69
+ full_std=full_std,
70
+ squash_output=squash_output,
71
+ actor_head_style=actor_head_style,
72
+ action_plane_space=action_plane_space,
73
+ )
74
+
75
+ v_encoder = Encoder(
76
+ observation_space,
77
+ activation,
78
+ init_layers_orthogonal=init_layers_orthogonal,
79
+ cnn_flatten_dim=cnn_flatten_dim,
80
+ cnn_style=cnn_style,
81
+ cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
82
+ )
83
+ self._v = nn.Sequential(
84
+ v_encoder,
85
+ CriticHead(
86
+ in_dim=v_encoder.out_dim,
87
+ hidden_sizes=v_hidden_sizes,
88
+ activation=activation,
89
+ init_layers_orthogonal=init_layers_orthogonal,
90
+ ),
91
+ )
92
+
93
+ def forward(
94
+ self,
95
+ obs: torch.Tensor,
96
+ action: torch.Tensor,
97
+ action_masks: Optional[torch.Tensor] = None,
98
+ ) -> ACNForward:
99
+ return self._distribution_and_value(
100
+ obs, action=action, action_masks=action_masks
101
+ )
102
+
103
+ def distribution_and_value(
104
+ self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
105
+ ) -> ACNForward:
106
+ return self._distribution_and_value(obs, action_masks=action_masks)
107
+
108
+ def _distribution_and_value(
109
+ self,
110
+ obs: torch.Tensor,
111
+ action: Optional[torch.Tensor] = None,
112
+ action_masks: Optional[torch.Tensor] = None,
113
+ ) -> ACNForward:
114
+ pi_forward = self._pi(
115
+ self._feature_extractor(obs), actions=action, action_masks=action_masks
116
+ )
117
+ v = self._v(obs)
118
+ return ACNForward(pi_forward, v)
119
+
120
+ def value(self, obs: torch.Tensor) -> torch.Tensor:
121
+ return self._v(obs)
122
+
123
+ def reset_noise(self, batch_size: int) -> None:
124
+ self._pi.sample_weights(batch_size=batch_size)
125
+
126
+ @property
127
+ def action_shape(self) -> Tuple[int, ...]:
128
+ return self._pi.action_shape