sgoodfriend
commited on
Commit
•
1cde088
1
Parent(s):
946448b
PPO playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +16 -15
- environment.yml +1 -1
- pyproject.toml +6 -3
- replay.meta.json +1 -1
- replay.mp4 +0 -0
- rl_algo_impls/a2c/a2c.py +11 -9
- rl_algo_impls/a2c/optimize.py +9 -5
- rl_algo_impls/dqn/dqn.py +15 -8
- rl_algo_impls/dqn/q_net.py +1 -1
- rl_algo_impls/huggingface_publish.py +6 -7
- rl_algo_impls/hyperparams/a2c.yml +13 -12
- rl_algo_impls/hyperparams/dqn.yml +3 -3
- rl_algo_impls/hyperparams/ppo.yml +123 -10
- rl_algo_impls/hyperparams/vpg.yml +6 -6
- rl_algo_impls/optimize.py +61 -28
- rl_algo_impls/ppo/ppo.py +27 -16
- rl_algo_impls/runner/config.py +12 -3
- rl_algo_impls/runner/evaluate.py +5 -6
- rl_algo_impls/runner/running_utils.py +15 -23
- rl_algo_impls/runner/selfplay_evaluate.py +142 -0
- rl_algo_impls/runner/train.py +36 -21
- rl_algo_impls/selfplay_enjoy.py +53 -0
- rl_algo_impls/shared/actor/__init__.py +1 -1
- rl_algo_impls/shared/actor/actor.py +10 -9
- rl_algo_impls/shared/actor/categorical.py +3 -3
- rl_algo_impls/shared/actor/gaussian.py +3 -3
- rl_algo_impls/shared/actor/gridnet.py +4 -4
- rl_algo_impls/shared/actor/gridnet_decoder.py +3 -4
- rl_algo_impls/shared/actor/make_actor.py +8 -5
- rl_algo_impls/shared/actor/multi_discrete.py +3 -3
- rl_algo_impls/shared/actor/state_dependent_noise.py +14 -15
- rl_algo_impls/shared/algorithm.py +5 -5
- rl_algo_impls/shared/callbacks/__init__.py +1 -0
- rl_algo_impls/shared/callbacks/eval_callback.py +24 -4
- rl_algo_impls/shared/callbacks/microrts_reward_decay_callback.py +36 -0
- rl_algo_impls/shared/callbacks/optimize_callback.py +1 -1
- rl_algo_impls/shared/callbacks/self_play_callback.py +34 -0
- rl_algo_impls/shared/encoder/cnn.py +1 -1
- rl_algo_impls/shared/encoder/encoder.py +1 -1
- rl_algo_impls/shared/encoder/gridnet_encoder.py +1 -1
- rl_algo_impls/shared/encoder/impala_cnn.py +1 -1
- rl_algo_impls/shared/encoder/microrts_cnn.py +1 -1
- rl_algo_impls/shared/encoder/nature_cnn.py +1 -1
- rl_algo_impls/shared/gae.py +1 -1
- rl_algo_impls/shared/module/{module.py → utils.py} +0 -0
- rl_algo_impls/shared/policy/{on_policy.py → actor_critic.py} +62 -95
- rl_algo_impls/shared/policy/actor_critic_network/__init__.py +11 -0
- rl_algo_impls/shared/policy/actor_critic_network/connected_trio.py +118 -0
- rl_algo_impls/shared/policy/actor_critic_network/network.py +57 -0
- rl_algo_impls/shared/policy/actor_critic_network/separate_actor_critic.py +128 -0
README.md
CHANGED
@@ -23,17 +23,17 @@ model-index:
|
|
23 |
|
24 |
This is a trained model of a **PPO** agent playing **CartPole-v1** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
|
25 |
|
26 |
-
All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/
|
27 |
|
28 |
## Training Results
|
29 |
|
30 |
-
This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [
|
31 |
|
32 |
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|
33 |
|:-------|:------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
|
34 |
-
| ppo | CartPole-v1 | 1 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/
|
35 |
-
| ppo | CartPole-v1 | 2 | 500 | 0 | 16 |
|
36 |
-
| ppo | CartPole-v1 | 3 | 500 | 0 | 16 |
|
37 |
|
38 |
|
39 |
### Prerequisites: Weights & Biases (WandB)
|
@@ -53,10 +53,10 @@ login`.
|
|
53 |
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
-
[
|
57 |
```
|
58 |
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
59 |
-
python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/
|
60 |
```
|
61 |
|
62 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
@@ -68,11 +68,11 @@ notebook.
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
-
commit the agent was trained on: [
|
72 |
training is deterministic, different hardware will give different results.
|
73 |
|
74 |
```
|
75 |
-
python train.py --algo ppo --env CartPole-v1 --seed
|
76 |
```
|
77 |
|
78 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
@@ -83,7 +83,7 @@ notebook.
|
|
83 |
|
84 |
|
85 |
## Benchmarking (with Lambda Labs instance)
|
86 |
-
This and other models from https://api.wandb.ai/links/sgoodfriend/
|
87 |
Labs instance. In a Lambda Labs instance terminal:
|
88 |
```
|
89 |
git clone [email protected]:sgoodfriend/rl-algo-impls.git
|
@@ -123,19 +123,20 @@ env: CartPole-v1
|
|
123 |
env_hyperparams:
|
124 |
n_envs: 8
|
125 |
env_id: null
|
126 |
-
|
127 |
step_freq: 25000
|
|
|
128 |
n_timesteps: 100000
|
129 |
policy_hyperparams: {}
|
130 |
-
seed:
|
131 |
use_deterministic_algorithms: true
|
132 |
wandb_entity: null
|
133 |
wandb_group: null
|
134 |
wandb_project_name: rl-algo-impls-benchmarks
|
135 |
wandb_tags:
|
136 |
-
-
|
137 |
-
-
|
138 |
- branch_main
|
139 |
-
- v0.0.
|
140 |
|
141 |
```
|
|
|
23 |
|
24 |
This is a trained model of a **PPO** agent playing **CartPole-v1** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
|
25 |
|
26 |
+
All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/ysd5gj7p.
|
27 |
|
28 |
## Training Results
|
29 |
|
30 |
+
This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [983cb75](https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
|
31 |
|
32 |
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|
33 |
|:-------|:------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
|
34 |
+
| ppo | CartPole-v1 | 1 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/h6q6ybro) |
|
35 |
+
| ppo | CartPole-v1 | 2 | 500 | 0 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/t43i90su) |
|
36 |
+
| ppo | CartPole-v1 | 3 | 500 | 0 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/bf8ho2cx) |
|
37 |
|
38 |
|
39 |
### Prerequisites: Weights & Biases (WandB)
|
|
|
53 |
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
+
[983cb75](https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b).
|
57 |
```
|
58 |
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
59 |
+
python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/t43i90su
|
60 |
```
|
61 |
|
62 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
+
commit the agent was trained on: [983cb75](https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b). While
|
72 |
training is deterministic, different hardware will give different results.
|
73 |
|
74 |
```
|
75 |
+
python train.py --algo ppo --env CartPole-v1 --seed 2
|
76 |
```
|
77 |
|
78 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
|
|
83 |
|
84 |
|
85 |
## Benchmarking (with Lambda Labs instance)
|
86 |
+
This and other models from https://api.wandb.ai/links/sgoodfriend/ysd5gj7p were generated by running a script on a Lambda
|
87 |
Labs instance. In a Lambda Labs instance terminal:
|
88 |
```
|
89 |
git clone [email protected]:sgoodfriend/rl-algo-impls.git
|
|
|
123 |
env_hyperparams:
|
124 |
n_envs: 8
|
125 |
env_id: null
|
126 |
+
eval_hyperparams:
|
127 |
step_freq: 25000
|
128 |
+
microrts_reward_decay_callback: false
|
129 |
n_timesteps: 100000
|
130 |
policy_hyperparams: {}
|
131 |
+
seed: 2
|
132 |
use_deterministic_algorithms: true
|
133 |
wandb_entity: null
|
134 |
wandb_group: null
|
135 |
wandb_project_name: rl-algo-impls-benchmarks
|
136 |
wandb_tags:
|
137 |
+
- benchmark_983cb75
|
138 |
+
- host_129-159-43-75
|
139 |
- branch_main
|
140 |
+
- v0.0.9
|
141 |
|
142 |
```
|
environment.yml
CHANGED
@@ -4,7 +4,7 @@ channels:
|
|
4 |
- conda-forge
|
5 |
- nodefaults
|
6 |
dependencies:
|
7 |
-
- python>=3.8, <3.
|
8 |
- mamba
|
9 |
- pip
|
10 |
- pytorch
|
|
|
4 |
- conda-forge
|
5 |
- nodefaults
|
6 |
dependencies:
|
7 |
+
- python>=3.8, <3.10
|
8 |
- mamba
|
9 |
- pip
|
10 |
- pytorch
|
pyproject.toml
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
[project]
|
2 |
name = "rl_algo_impls"
|
3 |
-
version = "0.0.
|
4 |
description = "Implementations of reinforcement learning algorithms"
|
5 |
authors = [
|
6 |
{name = "Scott Goodfriend", email = "[email protected]"},
|
@@ -56,14 +56,17 @@ procgen = [
|
|
56 |
"glfw >= 1.12.0, < 1.13",
|
57 |
"procgen; platform_machine=='x86_64'",
|
58 |
]
|
59 |
-
microrts-
|
60 |
"numpy < 1.24.0", # Support for gym-microrts < 0.6.0
|
61 |
"gym-microrts == 0.2.0", # Match ppo-implementation-details
|
62 |
]
|
63 |
-
microrts = [
|
64 |
"numpy < 1.24.0", # Support for gym-microrts < 0.6.0
|
65 |
"gym-microrts == 0.3.2",
|
66 |
]
|
|
|
|
|
|
|
67 |
jupyter = [
|
68 |
"jupyter",
|
69 |
"notebook"
|
|
|
1 |
[project]
|
2 |
name = "rl_algo_impls"
|
3 |
+
version = "0.0.9"
|
4 |
description = "Implementations of reinforcement learning algorithms"
|
5 |
authors = [
|
6 |
{name = "Scott Goodfriend", email = "[email protected]"},
|
|
|
56 |
"glfw >= 1.12.0, < 1.13",
|
57 |
"procgen; platform_machine=='x86_64'",
|
58 |
]
|
59 |
+
microrts-ppo = [
|
60 |
"numpy < 1.24.0", # Support for gym-microrts < 0.6.0
|
61 |
"gym-microrts == 0.2.0", # Match ppo-implementation-details
|
62 |
]
|
63 |
+
microrts-paper = [
|
64 |
"numpy < 1.24.0", # Support for gym-microrts < 0.6.0
|
65 |
"gym-microrts == 0.3.2",
|
66 |
]
|
67 |
+
microrts = [
|
68 |
+
"gym-microrts",
|
69 |
+
]
|
70 |
jupyter = [
|
71 |
"jupyter",
|
72 |
"notebook"
|
replay.meta.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/tmp/
|
|
|
1 |
+
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/tmp/tmp6t0jvdwb/ppo-CartPole-v1/replay.mp4"]}, "episodes": [{"r": 500.0, "l": 500, "t": 3.30735}]}
|
replay.mp4
CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
|
|
rl_algo_impls/a2c/a2c.py
CHANGED
@@ -1,23 +1,23 @@
|
|
1 |
import logging
|
|
|
|
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
import torch.nn.functional as F
|
6 |
-
|
7 |
-
from time import perf_counter
|
8 |
from torch.utils.tensorboard.writer import SummaryWriter
|
9 |
-
from typing import Optional, TypeVar
|
10 |
|
11 |
from rl_algo_impls.shared.algorithm import Algorithm
|
12 |
-
from rl_algo_impls.shared.callbacks
|
13 |
from rl_algo_impls.shared.gae import compute_advantages
|
14 |
-
from rl_algo_impls.shared.policy.
|
15 |
from rl_algo_impls.shared.schedule import schedule, update_learning_rate
|
16 |
from rl_algo_impls.shared.stats import log_scalars
|
17 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
18 |
VecEnv,
|
19 |
-
single_observation_space,
|
20 |
single_action_space,
|
|
|
21 |
)
|
22 |
|
23 |
A2CSelf = TypeVar("A2CSelf", bound="A2C")
|
@@ -70,7 +70,7 @@ class A2C(Algorithm):
|
|
70 |
def learn(
|
71 |
self: A2CSelf,
|
72 |
train_timesteps: int,
|
73 |
-
|
74 |
total_timesteps: Optional[int] = None,
|
75 |
start_timesteps: int = 0,
|
76 |
) -> A2CSelf:
|
@@ -193,8 +193,10 @@ class A2C(Algorithm):
|
|
193 |
timesteps_elapsed,
|
194 |
)
|
195 |
|
196 |
-
if
|
197 |
-
if not
|
|
|
|
|
198 |
logging.info(
|
199 |
f"Callback terminated training at {timesteps_elapsed} timesteps"
|
200 |
)
|
|
|
1 |
import logging
|
2 |
+
from time import perf_counter
|
3 |
+
from typing import List, Optional, TypeVar
|
4 |
+
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torch.nn.functional as F
|
|
|
|
|
9 |
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
10 |
|
11 |
from rl_algo_impls.shared.algorithm import Algorithm
|
12 |
+
from rl_algo_impls.shared.callbacks import Callback
|
13 |
from rl_algo_impls.shared.gae import compute_advantages
|
14 |
+
from rl_algo_impls.shared.policy.actor_critic import ActorCritic
|
15 |
from rl_algo_impls.shared.schedule import schedule, update_learning_rate
|
16 |
from rl_algo_impls.shared.stats import log_scalars
|
17 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
18 |
VecEnv,
|
|
|
19 |
single_action_space,
|
20 |
+
single_observation_space,
|
21 |
)
|
22 |
|
23 |
A2CSelf = TypeVar("A2CSelf", bound="A2C")
|
|
|
70 |
def learn(
|
71 |
self: A2CSelf,
|
72 |
train_timesteps: int,
|
73 |
+
callbacks: Optional[List[Callback]] = None,
|
74 |
total_timesteps: Optional[int] = None,
|
75 |
start_timesteps: int = 0,
|
76 |
) -> A2CSelf:
|
|
|
193 |
timesteps_elapsed,
|
194 |
)
|
195 |
|
196 |
+
if callbacks:
|
197 |
+
if not all(
|
198 |
+
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
|
199 |
+
):
|
200 |
logging.info(
|
201 |
f"Callback terminated training at {timesteps_elapsed} timesteps"
|
202 |
)
|
rl_algo_impls/a2c/optimize.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
-
import optuna
|
2 |
-
|
3 |
from copy import deepcopy
|
4 |
|
5 |
-
|
6 |
-
|
|
|
7 |
from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
|
|
|
8 |
from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
|
9 |
|
10 |
|
@@ -16,7 +16,11 @@ def sample_params(
|
|
16 |
hyperparams = deepcopy(base_hyperparams)
|
17 |
|
18 |
base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
|
19 |
-
env = make_eval_env(
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# env_hyperparams
|
22 |
env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
|
|
|
|
|
|
|
1 |
from copy import deepcopy
|
2 |
|
3 |
+
import optuna
|
4 |
+
|
5 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams
|
6 |
from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams
|
7 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
8 |
from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams
|
9 |
|
10 |
|
|
|
16 |
hyperparams = deepcopy(base_hyperparams)
|
17 |
|
18 |
base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams)
|
19 |
+
env = make_eval_env(
|
20 |
+
base_config,
|
21 |
+
base_env_hyperparams,
|
22 |
+
override_hparams={"n_envs": 1},
|
23 |
+
)
|
24 |
|
25 |
# env_hyperparams
|
26 |
env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env)
|
rl_algo_impls/dqn/dqn.py
CHANGED
@@ -1,18 +1,19 @@
|
|
1 |
import copy
|
2 |
-
import
|
3 |
import random
|
|
|
|
|
|
|
|
|
4 |
import torch
|
5 |
import torch.nn as nn
|
6 |
import torch.nn.functional as F
|
7 |
-
|
8 |
-
from collections import deque
|
9 |
from torch.optim import Adam
|
10 |
from torch.utils.tensorboard.writer import SummaryWriter
|
11 |
-
from typing import NamedTuple, Optional, TypeVar
|
12 |
|
13 |
from rl_algo_impls.dqn.policy import DQNPolicy
|
14 |
from rl_algo_impls.shared.algorithm import Algorithm
|
15 |
-
from rl_algo_impls.shared.callbacks
|
16 |
from rl_algo_impls.shared.schedule import linear_schedule
|
17 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
|
18 |
|
@@ -118,7 +119,7 @@ class DQN(Algorithm):
|
|
118 |
self.max_grad_norm = max_grad_norm
|
119 |
|
120 |
def learn(
|
121 |
-
self: DQNSelf, total_timesteps: int,
|
122 |
) -> DQNSelf:
|
123 |
self.policy.train(True)
|
124 |
obs = self.env.reset()
|
@@ -140,8 +141,14 @@ class DQN(Algorithm):
|
|
140 |
if steps_since_target_update >= self.target_update_interval:
|
141 |
self._update_target()
|
142 |
steps_since_target_update = 0
|
143 |
-
if
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
return self
|
146 |
|
147 |
def train(self) -> None:
|
|
|
1 |
import copy
|
2 |
+
import logging
|
3 |
import random
|
4 |
+
from collections import deque
|
5 |
+
from typing import List, NamedTuple, Optional, TypeVar
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
import torch
|
9 |
import torch.nn as nn
|
10 |
import torch.nn.functional as F
|
|
|
|
|
11 |
from torch.optim import Adam
|
12 |
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
13 |
|
14 |
from rl_algo_impls.dqn.policy import DQNPolicy
|
15 |
from rl_algo_impls.shared.algorithm import Algorithm
|
16 |
+
from rl_algo_impls.shared.callbacks import Callback
|
17 |
from rl_algo_impls.shared.schedule import linear_schedule
|
18 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs
|
19 |
|
|
|
119 |
self.max_grad_norm = max_grad_norm
|
120 |
|
121 |
def learn(
|
122 |
+
self: DQNSelf, total_timesteps: int, callbacks: Optional[List[Callback]] = None
|
123 |
) -> DQNSelf:
|
124 |
self.policy.train(True)
|
125 |
obs = self.env.reset()
|
|
|
141 |
if steps_since_target_update >= self.target_update_interval:
|
142 |
self._update_target()
|
143 |
steps_since_target_update = 0
|
144 |
+
if callbacks:
|
145 |
+
if not all(
|
146 |
+
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
|
147 |
+
):
|
148 |
+
logging.info(
|
149 |
+
f"Callback terminated training at {timesteps_elapsed} timesteps"
|
150 |
+
)
|
151 |
+
break
|
152 |
return self
|
153 |
|
154 |
def train(self) -> None:
|
rl_algo_impls/dqn/q_net.py
CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
|
|
6 |
from gym.spaces import Discrete
|
7 |
|
8 |
from rl_algo_impls.shared.encoder import Encoder
|
9 |
-
from rl_algo_impls.shared.module.
|
10 |
|
11 |
|
12 |
class QNetwork(nn.Module):
|
|
|
6 |
from gym.spaces import Discrete
|
7 |
|
8 |
from rl_algo_impls.shared.encoder import Encoder
|
9 |
+
from rl_algo_impls.shared.module.utils import mlp
|
10 |
|
11 |
|
12 |
class QNetwork(nn.Module):
|
rl_algo_impls/huggingface_publish.py
CHANGED
@@ -3,24 +3,23 @@ import os
|
|
3 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
4 |
|
5 |
import argparse
|
6 |
-
import requests
|
7 |
import shutil
|
8 |
import subprocess
|
9 |
import tempfile
|
10 |
-
import wandb
|
11 |
-
import wandb.apis.public
|
12 |
-
|
13 |
from typing import List, Optional
|
14 |
|
|
|
|
|
15 |
from huggingface_hub.hf_api import HfApi, upload_folder
|
16 |
from huggingface_hub.repocard import metadata_save
|
17 |
from pyvirtualdisplay.display import Display
|
18 |
|
|
|
19 |
from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
|
20 |
from rl_algo_impls.runner.config import EnvHyperparams
|
21 |
from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
|
22 |
-
from rl_algo_impls.shared.vec_env import make_eval_env
|
23 |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
|
|
24 |
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
|
25 |
|
26 |
|
@@ -134,7 +133,7 @@ def publish(
|
|
134 |
make_eval_env(
|
135 |
config,
|
136 |
EnvHyperparams(**config.env_hyperparams),
|
137 |
-
|
138 |
normalize_load_path=model_path,
|
139 |
),
|
140 |
os.path.join(repo_dir_path, "replay"),
|
@@ -144,7 +143,7 @@ def publish(
|
|
144 |
video_env,
|
145 |
policy,
|
146 |
1,
|
147 |
-
deterministic=config.
|
148 |
)
|
149 |
|
150 |
api = HfApi()
|
|
|
3 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
4 |
|
5 |
import argparse
|
|
|
6 |
import shutil
|
7 |
import subprocess
|
8 |
import tempfile
|
|
|
|
|
|
|
9 |
from typing import List, Optional
|
10 |
|
11 |
+
import requests
|
12 |
+
import wandb.apis.public
|
13 |
from huggingface_hub.hf_api import HfApi, upload_folder
|
14 |
from huggingface_hub.repocard import metadata_save
|
15 |
from pyvirtualdisplay.display import Display
|
16 |
|
17 |
+
import wandb
|
18 |
from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text
|
19 |
from rl_algo_impls.runner.config import EnvHyperparams
|
20 |
from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model
|
|
|
21 |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
22 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
23 |
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
|
24 |
|
25 |
|
|
|
133 |
make_eval_env(
|
134 |
config,
|
135 |
EnvHyperparams(**config.env_hyperparams),
|
136 |
+
override_hparams={"n_envs": 1},
|
137 |
normalize_load_path=model_path,
|
138 |
),
|
139 |
os.path.join(repo_dir_path, "replay"),
|
|
|
143 |
video_env,
|
144 |
policy,
|
145 |
1,
|
146 |
+
deterministic=config.eval_hyperparams.get("deterministic", True),
|
147 |
)
|
148 |
|
149 |
api = HfApi()
|
rl_algo_impls/hyperparams/a2c.yml
CHANGED
@@ -101,31 +101,32 @@ HopperBulletEnv-v0:
|
|
101 |
CarRacing-v0:
|
102 |
n_timesteps: !!float 4e6
|
103 |
env_hyperparams:
|
104 |
-
n_envs:
|
105 |
frame_stack: 4
|
106 |
normalize: true
|
107 |
normalize_kwargs:
|
108 |
norm_obs: false
|
109 |
norm_reward: true
|
110 |
policy_hyperparams:
|
111 |
-
use_sde:
|
112 |
-
log_std_init: -
|
113 |
init_layers_orthogonal: true
|
114 |
activation_fn: tanh
|
115 |
share_features_extractor: false
|
116 |
cnn_flatten_dim: 256
|
117 |
hidden_sizes: [256]
|
118 |
algo_hyperparams:
|
119 |
-
n_steps:
|
120 |
-
learning_rate: 0.
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
vf_coef: 0.
|
126 |
-
max_grad_norm:
|
127 |
-
normalize_advantage:
|
128 |
use_rms_prop: false
|
|
|
129 |
|
130 |
_atari: &atari-defaults
|
131 |
n_timesteps: !!float 1e7
|
|
|
101 |
CarRacing-v0:
|
102 |
n_timesteps: !!float 4e6
|
103 |
env_hyperparams:
|
104 |
+
n_envs: 4
|
105 |
frame_stack: 4
|
106 |
normalize: true
|
107 |
normalize_kwargs:
|
108 |
norm_obs: false
|
109 |
norm_reward: true
|
110 |
policy_hyperparams:
|
111 |
+
use_sde: true
|
112 |
+
log_std_init: -4.839609092563
|
113 |
init_layers_orthogonal: true
|
114 |
activation_fn: tanh
|
115 |
share_features_extractor: false
|
116 |
cnn_flatten_dim: 256
|
117 |
hidden_sizes: [256]
|
118 |
algo_hyperparams:
|
119 |
+
n_steps: 64
|
120 |
+
learning_rate: 0.000018971962220405576
|
121 |
+
gamma: 0.9942776405534832
|
122 |
+
gae_lambda: 0.9549244758833236
|
123 |
+
ent_coef: 0.0000015666550584860516
|
124 |
+
ent_coef_decay: linear
|
125 |
+
vf_coef: 0.12164696385898476
|
126 |
+
max_grad_norm: 2.2574480552177127
|
127 |
+
normalize_advantage: false
|
128 |
use_rms_prop: false
|
129 |
+
sde_sample_freq: 16
|
130 |
|
131 |
_atari: &atari-defaults
|
132 |
n_timesteps: !!float 1e7
|
rl_algo_impls/hyperparams/dqn.yml
CHANGED
@@ -15,7 +15,7 @@ CartPole-v1: &cartpole-defaults
|
|
15 |
gradient_steps: 128
|
16 |
exploration_fraction: 0.16
|
17 |
exploration_final_eps: 0.04
|
18 |
-
|
19 |
step_freq: !!float 1e4
|
20 |
|
21 |
CartPole-v0:
|
@@ -76,7 +76,7 @@ LunarLander-v2:
|
|
76 |
exploration_fraction: 0.12
|
77 |
exploration_final_eps: 0.1
|
78 |
max_grad_norm: 0.5
|
79 |
-
|
80 |
step_freq: 25_000
|
81 |
|
82 |
_atari: &atari-defaults
|
@@ -97,7 +97,7 @@ _atari: &atari-defaults
|
|
97 |
gradient_steps: 2
|
98 |
exploration_fraction: 0.1
|
99 |
exploration_final_eps: 0.01
|
100 |
-
|
101 |
deterministic: false
|
102 |
|
103 |
PongNoFrameskip-v4:
|
|
|
15 |
gradient_steps: 128
|
16 |
exploration_fraction: 0.16
|
17 |
exploration_final_eps: 0.04
|
18 |
+
eval_hyperparams:
|
19 |
step_freq: !!float 1e4
|
20 |
|
21 |
CartPole-v0:
|
|
|
76 |
exploration_fraction: 0.12
|
77 |
exploration_final_eps: 0.1
|
78 |
max_grad_norm: 0.5
|
79 |
+
eval_hyperparams:
|
80 |
step_freq: 25_000
|
81 |
|
82 |
_atari: &atari-defaults
|
|
|
97 |
gradient_steps: 2
|
98 |
exploration_fraction: 0.1
|
99 |
exploration_final_eps: 0.01
|
100 |
+
eval_hyperparams:
|
101 |
deterministic: false
|
102 |
|
103 |
PongNoFrameskip-v4:
|
rl_algo_impls/hyperparams/ppo.yml
CHANGED
@@ -13,7 +13,7 @@ CartPole-v1: &cartpole-defaults
|
|
13 |
learning_rate_decay: linear
|
14 |
clip_range: 0.2
|
15 |
clip_range_decay: linear
|
16 |
-
|
17 |
step_freq: !!float 2.5e4
|
18 |
|
19 |
CartPole-v0:
|
@@ -52,7 +52,7 @@ MountainCarContinuous-v0:
|
|
52 |
gae_lambda: 0.9
|
53 |
max_grad_norm: 5
|
54 |
vf_coef: 0.19
|
55 |
-
|
56 |
step_freq: 5000
|
57 |
|
58 |
Acrobot-v1:
|
@@ -162,7 +162,7 @@ _atari: &atari-defaults
|
|
162 |
clip_range_decay: linear
|
163 |
vf_coef: 0.5
|
164 |
ent_coef: 0.01
|
165 |
-
|
166 |
deterministic: false
|
167 |
|
168 |
_norm-rewards-atari: &norm-rewards-atari-default
|
@@ -228,7 +228,7 @@ _microrts: µrts-defaults
|
|
228 |
clip_range_decay: none
|
229 |
clip_range_vf: 0.1
|
230 |
ppo2_vf_coef_halving: true
|
231 |
-
|
232 |
deterministic: false # Good idea because MultiCategorical mode isn't great
|
233 |
|
234 |
_no-mask-microrts: &no-mask-microrts-defaults
|
@@ -252,15 +252,15 @@ MicrortsRandomEnemyShapedReward3-v1-NoMask:
|
|
252 |
_microrts_ai: µrts-ai-defaults
|
253 |
<<: *microrts-defaults
|
254 |
n_timesteps: !!float 100e6
|
255 |
-
additional_keys_to_log: ["microrts_stats"]
|
256 |
env_hyperparams: µrts-ai-env-defaults
|
257 |
n_envs: 24
|
258 |
env_type: microrts
|
259 |
-
make_kwargs:
|
260 |
num_selfplay_envs: 0
|
261 |
-
max_steps:
|
262 |
render_theme: 2
|
263 |
-
|
264 |
reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
|
265 |
policy_hyperparams: µrts-ai-policy-defaults
|
266 |
<<: *microrts-policy-defaults
|
@@ -278,6 +278,15 @@ _microrts_ai: µrts-ai-defaults
|
|
278 |
max_grad_norm: 0.5
|
279 |
clip_range: 0.1
|
280 |
clip_range_vf: 0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
|
282 |
MicrortsAttackPassiveEnemySparseReward-v3:
|
283 |
<<: *microrts-ai-defaults
|
@@ -305,6 +314,18 @@ enc-dec-MicrortsDefeatRandomEnemySparseReward-v3:
|
|
305 |
actor_head_style: gridnet_decoder
|
306 |
v_hidden_sizes: [128]
|
307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
MicrortsDefeatCoacAIShaped-v3: µrts-coacai-defaults
|
309 |
<<: *microrts-ai-defaults
|
310 |
env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple
|
@@ -313,6 +334,27 @@ MicrortsDefeatCoacAIShaped-v3: µrts-coacai-defaults
|
|
313 |
<<: *microrts-ai-env-defaults
|
314 |
bots:
|
315 |
coacAI: 24
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-diverse-defaults
|
318 |
<<: *microrts-coacai-defaults
|
@@ -325,6 +367,7 @@ MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-diverse-defaults
|
|
325 |
workerRushAI: 2
|
326 |
|
327 |
enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
|
|
|
328 |
<<: *microrts-diverse-defaults
|
329 |
policy_hyperparams:
|
330 |
<<: *microrts-ai-policy-defaults
|
@@ -332,6 +375,76 @@ enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
|
|
332 |
actor_head_style: gridnet_decoder
|
333 |
v_hidden_sizes: [128]
|
334 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
336 |
n_timesteps: !!float 2e6
|
337 |
env_hyperparams: &pybullet-env-defaults
|
@@ -418,7 +531,7 @@ _procgen: &procgen-defaults
|
|
418 |
learning_rate: !!float 5e-4
|
419 |
# learning_rate_decay: linear
|
420 |
vf_coef: 0.5
|
421 |
-
|
422 |
ignore_first_episode: true
|
423 |
# deterministic: false
|
424 |
step_freq: !!float 1e5
|
@@ -466,7 +579,7 @@ _procgen-hard: &procgen-hard-defaults
|
|
466 |
batch_size: 8192
|
467 |
clip_range_decay: linear
|
468 |
learning_rate_decay: linear
|
469 |
-
|
470 |
<<: *procgen-eval-defaults
|
471 |
step_freq: !!float 5e5
|
472 |
|
|
|
13 |
learning_rate_decay: linear
|
14 |
clip_range: 0.2
|
15 |
clip_range_decay: linear
|
16 |
+
eval_hyperparams:
|
17 |
step_freq: !!float 2.5e4
|
18 |
|
19 |
CartPole-v0:
|
|
|
52 |
gae_lambda: 0.9
|
53 |
max_grad_norm: 5
|
54 |
vf_coef: 0.19
|
55 |
+
eval_hyperparams:
|
56 |
step_freq: 5000
|
57 |
|
58 |
Acrobot-v1:
|
|
|
162 |
clip_range_decay: linear
|
163 |
vf_coef: 0.5
|
164 |
ent_coef: 0.01
|
165 |
+
eval_hyperparams:
|
166 |
deterministic: false
|
167 |
|
168 |
_norm-rewards-atari: &norm-rewards-atari-default
|
|
|
228 |
clip_range_decay: none
|
229 |
clip_range_vf: 0.1
|
230 |
ppo2_vf_coef_halving: true
|
231 |
+
eval_hyperparams: µrts-eval-defaults
|
232 |
deterministic: false # Good idea because MultiCategorical mode isn't great
|
233 |
|
234 |
_no-mask-microrts: &no-mask-microrts-defaults
|
|
|
252 |
_microrts_ai: µrts-ai-defaults
|
253 |
<<: *microrts-defaults
|
254 |
n_timesteps: !!float 100e6
|
255 |
+
additional_keys_to_log: ["microrts_stats", "microrts_results"]
|
256 |
env_hyperparams: µrts-ai-env-defaults
|
257 |
n_envs: 24
|
258 |
env_type: microrts
|
259 |
+
make_kwargs: µrts-ai-env-make-kwargs-defaults
|
260 |
num_selfplay_envs: 0
|
261 |
+
max_steps: 4000
|
262 |
render_theme: 2
|
263 |
+
map_paths: [maps/16x16/basesWorkers16x16.xml]
|
264 |
reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0]
|
265 |
policy_hyperparams: µrts-ai-policy-defaults
|
266 |
<<: *microrts-policy-defaults
|
|
|
278 |
max_grad_norm: 0.5
|
279 |
clip_range: 0.1
|
280 |
clip_range_vf: 0.1
|
281 |
+
eval_hyperparams: µrts-ai-eval-defaults
|
282 |
+
<<: *microrts-eval-defaults
|
283 |
+
score_function: mean
|
284 |
+
max_video_length: 4000
|
285 |
+
env_overrides: µrts-ai-eval-env-overrides
|
286 |
+
make_kwargs:
|
287 |
+
<<: *microrts-ai-env-make-kwargs-defaults
|
288 |
+
max_steps: 4000
|
289 |
+
reward_weight: [1.0, 0, 0, 0, 0, 0]
|
290 |
|
291 |
MicrortsAttackPassiveEnemySparseReward-v3:
|
292 |
<<: *microrts-ai-defaults
|
|
|
314 |
actor_head_style: gridnet_decoder
|
315 |
v_hidden_sizes: [128]
|
316 |
|
317 |
+
unet-MicrortsDefeatRandomEnemySparseReward-v3:
|
318 |
+
<<: *microrts-random-ai-defaults
|
319 |
+
# device: cpu
|
320 |
+
policy_hyperparams:
|
321 |
+
<<: *microrts-ai-policy-defaults
|
322 |
+
actor_head_style: unet
|
323 |
+
v_hidden_sizes: [256, 128]
|
324 |
+
algo_hyperparams:
|
325 |
+
<<: *microrts-ai-algo-defaults
|
326 |
+
learning_rate: !!float 2.5e-4
|
327 |
+
learning_rate_decay: spike
|
328 |
+
|
329 |
MicrortsDefeatCoacAIShaped-v3: µrts-coacai-defaults
|
330 |
<<: *microrts-ai-defaults
|
331 |
env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple
|
|
|
334 |
<<: *microrts-ai-env-defaults
|
335 |
bots:
|
336 |
coacAI: 24
|
337 |
+
eval_hyperparams: µrts-coacai-eval-defaults
|
338 |
+
<<: *microrts-ai-eval-defaults
|
339 |
+
step_freq: !!float 1e6
|
340 |
+
n_episodes: 26
|
341 |
+
env_overrides: µrts-coacai-eval-env-overrides
|
342 |
+
<<: *microrts-ai-eval-env-overrides
|
343 |
+
n_envs: 26
|
344 |
+
bots:
|
345 |
+
coacAI: 2
|
346 |
+
randomBiasedAI: 2
|
347 |
+
randomAI: 2
|
348 |
+
passiveAI: 2
|
349 |
+
workerRushAI: 2
|
350 |
+
lightRushAI: 2
|
351 |
+
naiveMCTSAI: 2
|
352 |
+
mixedBot: 2
|
353 |
+
rojo: 2
|
354 |
+
izanagi: 2
|
355 |
+
tiamat: 2
|
356 |
+
droplet: 2
|
357 |
+
guidedRojoA3N: 2
|
358 |
|
359 |
MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-diverse-defaults
|
360 |
<<: *microrts-coacai-defaults
|
|
|
367 |
workerRushAI: 2
|
368 |
|
369 |
enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
|
370 |
+
µrts-env-dec-diverse-defaults
|
371 |
<<: *microrts-diverse-defaults
|
372 |
policy_hyperparams:
|
373 |
<<: *microrts-ai-policy-defaults
|
|
|
375 |
actor_head_style: gridnet_decoder
|
376 |
v_hidden_sizes: [128]
|
377 |
|
378 |
+
debug-enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots:
|
379 |
+
<<: *microrts-env-dec-diverse-defaults
|
380 |
+
n_timesteps: !!float 1e6
|
381 |
+
|
382 |
+
unet-MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-unet-defaults
|
383 |
+
<<: *microrts-diverse-defaults
|
384 |
+
policy_hyperparams:
|
385 |
+
<<: *microrts-ai-policy-defaults
|
386 |
+
actor_head_style: unet
|
387 |
+
v_hidden_sizes: [256, 128]
|
388 |
+
algo_hyperparams: µrts-unet-algo-defaults
|
389 |
+
<<: *microrts-ai-algo-defaults
|
390 |
+
learning_rate: !!float 2.5e-4
|
391 |
+
learning_rate_decay: spike
|
392 |
+
|
393 |
+
Microrts-selfplay-unet: µrts-selfplay-defaults
|
394 |
+
<<: *microrts-unet-defaults
|
395 |
+
env_hyperparams: µrts-selfplay-env-defaults
|
396 |
+
<<: *microrts-ai-env-defaults
|
397 |
+
make_kwargs: µrts-selfplay-env-make-kwargs-defaults
|
398 |
+
<<: *microrts-ai-env-make-kwargs-defaults
|
399 |
+
num_selfplay_envs: 36
|
400 |
+
self_play_kwargs:
|
401 |
+
num_old_policies: 12
|
402 |
+
save_steps: 300000
|
403 |
+
swap_steps: 6000
|
404 |
+
swap_window_size: 4
|
405 |
+
window: 33
|
406 |
+
eval_hyperparams: µrts-selfplay-eval-defaults
|
407 |
+
<<: *microrts-coacai-eval-defaults
|
408 |
+
env_overrides: µrts-selfplay-eval-env-overrides
|
409 |
+
<<: *microrts-coacai-eval-env-overrides
|
410 |
+
self_play_kwargs: {}
|
411 |
+
|
412 |
+
Microrts-selfplay-unet-winloss: µrts-selfplay-winloss-defaults
|
413 |
+
<<: *microrts-selfplay-defaults
|
414 |
+
env_hyperparams:
|
415 |
+
<<: *microrts-selfplay-env-defaults
|
416 |
+
make_kwargs:
|
417 |
+
<<: *microrts-selfplay-env-make-kwargs-defaults
|
418 |
+
reward_weight: [1.0, 0, 0, 0, 0, 0]
|
419 |
+
algo_hyperparams: µrts-selfplay-winloss-algo-defaults
|
420 |
+
<<: *microrts-unet-algo-defaults
|
421 |
+
gamma: 0.999
|
422 |
+
|
423 |
+
Microrts-selfplay-unet-decay: µrts-selfplay-decay-defaults
|
424 |
+
<<: *microrts-selfplay-defaults
|
425 |
+
microrts_reward_decay_callback: true
|
426 |
+
algo_hyperparams:
|
427 |
+
<<: *microrts-unet-algo-defaults
|
428 |
+
gamma_end: 0.999
|
429 |
+
|
430 |
+
Microrts-selfplay-unet-debug: µrts-selfplay-debug-defaults
|
431 |
+
<<: *microrts-selfplay-decay-defaults
|
432 |
+
eval_hyperparams:
|
433 |
+
<<: *microrts-selfplay-eval-defaults
|
434 |
+
step_freq: !!float 1e5
|
435 |
+
env_overrides:
|
436 |
+
<<: *microrts-selfplay-eval-env-overrides
|
437 |
+
n_envs: 24
|
438 |
+
bots:
|
439 |
+
coacAI: 12
|
440 |
+
randomBiasedAI: 4
|
441 |
+
workerRushAI: 4
|
442 |
+
lightRushAI: 4
|
443 |
+
|
444 |
+
Microrts-selfplay-unet-debug-mps:
|
445 |
+
<<: *microrts-selfplay-debug-defaults
|
446 |
+
device: mps
|
447 |
+
|
448 |
HalfCheetahBulletEnv-v0: &pybullet-defaults
|
449 |
n_timesteps: !!float 2e6
|
450 |
env_hyperparams: &pybullet-env-defaults
|
|
|
531 |
learning_rate: !!float 5e-4
|
532 |
# learning_rate_decay: linear
|
533 |
vf_coef: 0.5
|
534 |
+
eval_hyperparams: &procgen-eval-defaults
|
535 |
ignore_first_episode: true
|
536 |
# deterministic: false
|
537 |
step_freq: !!float 1e5
|
|
|
579 |
batch_size: 8192
|
580 |
clip_range_decay: linear
|
581 |
learning_rate_decay: linear
|
582 |
+
eval_hyperparams:
|
583 |
<<: *procgen-eval-defaults
|
584 |
step_freq: !!float 5e5
|
585 |
|
rl_algo_impls/hyperparams/vpg.yml
CHANGED
@@ -7,7 +7,7 @@ CartPole-v1: &cartpole-defaults
|
|
7 |
gae_lambda: 1
|
8 |
val_lr: 0.01
|
9 |
train_v_iters: 80
|
10 |
-
|
11 |
step_freq: !!float 2.5e4
|
12 |
|
13 |
CartPole-v0:
|
@@ -52,7 +52,7 @@ MountainCarContinuous-v0:
|
|
52 |
val_lr: !!float 1e-3
|
53 |
train_v_iters: 80
|
54 |
max_grad_norm: 5
|
55 |
-
|
56 |
step_freq: 5000
|
57 |
|
58 |
Acrobot-v1:
|
@@ -78,7 +78,7 @@ LunarLander-v2:
|
|
78 |
val_lr: 0.0001
|
79 |
train_v_iters: 80
|
80 |
max_grad_norm: 0.5
|
81 |
-
|
82 |
deterministic: false
|
83 |
|
84 |
BipedalWalker-v3:
|
@@ -96,7 +96,7 @@ BipedalWalker-v3:
|
|
96 |
val_lr: !!float 1e-4
|
97 |
train_v_iters: 80
|
98 |
max_grad_norm: 0.5
|
99 |
-
|
100 |
deterministic: false
|
101 |
|
102 |
CarRacing-v0:
|
@@ -169,7 +169,7 @@ FrozenLake-v1:
|
|
169 |
val_lr: 0.01
|
170 |
train_v_iters: 80
|
171 |
max_grad_norm: 0.5
|
172 |
-
|
173 |
step_freq: !!float 5e4
|
174 |
n_episodes: 10
|
175 |
save_best: true
|
@@ -193,5 +193,5 @@ _atari: &atari-defaults
|
|
193 |
train_v_iters: 80
|
194 |
max_grad_norm: 0.5
|
195 |
ent_coef: 0.01
|
196 |
-
|
197 |
deterministic: false
|
|
|
7 |
gae_lambda: 1
|
8 |
val_lr: 0.01
|
9 |
train_v_iters: 80
|
10 |
+
eval_hyperparams:
|
11 |
step_freq: !!float 2.5e4
|
12 |
|
13 |
CartPole-v0:
|
|
|
52 |
val_lr: !!float 1e-3
|
53 |
train_v_iters: 80
|
54 |
max_grad_norm: 5
|
55 |
+
eval_hyperparams:
|
56 |
step_freq: 5000
|
57 |
|
58 |
Acrobot-v1:
|
|
|
78 |
val_lr: 0.0001
|
79 |
train_v_iters: 80
|
80 |
max_grad_norm: 0.5
|
81 |
+
eval_hyperparams:
|
82 |
deterministic: false
|
83 |
|
84 |
BipedalWalker-v3:
|
|
|
96 |
val_lr: !!float 1e-4
|
97 |
train_v_iters: 80
|
98 |
max_grad_norm: 0.5
|
99 |
+
eval_hyperparams:
|
100 |
deterministic: false
|
101 |
|
102 |
CarRacing-v0:
|
|
|
169 |
val_lr: 0.01
|
170 |
train_v_iters: 80
|
171 |
max_grad_norm: 0.5
|
172 |
+
eval_hyperparams:
|
173 |
step_freq: !!float 5e4
|
174 |
n_episodes: 10
|
175 |
save_best: true
|
|
|
193 |
train_v_iters: 80
|
194 |
max_grad_norm: 0.5
|
195 |
ent_coef: 0.01
|
196 |
+
eval_hyperparams:
|
197 |
deterministic: false
|
rl_algo_impls/optimize.py
CHANGED
@@ -2,37 +2,44 @@ import dataclasses
|
|
2 |
import gc
|
3 |
import inspect
|
4 |
import logging
|
|
|
|
|
|
|
|
|
5 |
import numpy as np
|
6 |
import optuna
|
7 |
-
import os
|
8 |
import torch
|
9 |
-
import wandb
|
10 |
-
|
11 |
-
from dataclasses import asdict, dataclass
|
12 |
from optuna.pruners import HyperbandPruner
|
13 |
from optuna.samplers import TPESampler
|
14 |
from optuna.visualization import plot_optimization_history, plot_param_importances
|
15 |
from torch.utils.tensorboard.writer import SummaryWriter
|
16 |
-
from typing import Callable, List, NamedTuple, Optional, Sequence, Union
|
17 |
|
|
|
18 |
from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
|
19 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
|
20 |
-
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
|
21 |
from rl_algo_impls.runner.running_utils import (
|
|
|
22 |
base_parser,
|
23 |
-
load_hyperparams,
|
24 |
-
set_seeds,
|
25 |
get_device,
|
26 |
-
make_policy,
|
27 |
-
ALGOS,
|
28 |
hparam_dict,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
)
|
30 |
from rl_algo_impls.shared.callbacks.optimize_callback import (
|
31 |
Evaluation,
|
32 |
OptimizeCallback,
|
33 |
evaluation,
|
34 |
)
|
|
|
35 |
from rl_algo_impls.shared.stats import EpisodesStats
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
@dataclass
|
@@ -195,29 +202,38 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
|
|
195 |
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
196 |
)
|
197 |
device = get_device(config, env)
|
198 |
-
|
|
|
|
|
|
|
199 |
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
200 |
|
201 |
eval_env = make_eval_env(
|
202 |
config,
|
203 |
EnvHyperparams(**config.env_hyperparams),
|
204 |
-
|
205 |
)
|
206 |
-
|
207 |
policy,
|
208 |
eval_env,
|
209 |
trial,
|
210 |
tb_writer,
|
211 |
step_freq=config.n_timesteps // study_args.n_evaluations,
|
212 |
n_episodes=study_args.n_eval_episodes,
|
213 |
-
deterministic=config.
|
214 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
try:
|
216 |
-
algo.learn(config.n_timesteps,
|
217 |
|
218 |
-
if not
|
219 |
-
|
220 |
-
if not
|
221 |
policy.save(config.model_dir_path(best=False))
|
222 |
|
223 |
eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore
|
@@ -230,8 +246,8 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
|
|
230 |
"hparam/last_result": eval_stat.score.mean - eval_stat.score.std,
|
231 |
"hparam/train_mean": train_stat.score.mean,
|
232 |
"hparam/train_result": train_stat.score.mean - train_stat.score.std,
|
233 |
-
"hparam/score":
|
234 |
-
"hparam/is_pruned":
|
235 |
},
|
236 |
None,
|
237 |
config.run_name(),
|
@@ -239,13 +255,15 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) -
|
|
239 |
tb_writer.close()
|
240 |
|
241 |
if wandb_enabled:
|
242 |
-
wandb.run.summary["state"] =
|
|
|
|
|
243 |
wandb.finish(quiet=True)
|
244 |
|
245 |
-
if
|
246 |
raise optuna.exceptions.TrialPruned()
|
247 |
|
248 |
-
return
|
249 |
except AssertionError as e:
|
250 |
logging.warning(e)
|
251 |
return np.nan
|
@@ -299,7 +317,10 @@ def stepwise_optimize(
|
|
299 |
tb_writer=tb_writer,
|
300 |
)
|
301 |
device = get_device(config, env)
|
302 |
-
|
|
|
|
|
|
|
303 |
if i > 0:
|
304 |
policy.load(config.model_dir_path())
|
305 |
algo = ALGOS[arg.algo](
|
@@ -310,7 +331,7 @@ def stepwise_optimize(
|
|
310 |
config,
|
311 |
EnvHyperparams(**config.env_hyperparams),
|
312 |
normalize_load_path=config.model_dir_path() if i > 0 else None,
|
313 |
-
|
314 |
)
|
315 |
|
316 |
start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
|
@@ -319,10 +340,22 @@ def stepwise_optimize(
|
|
319 |
- start_timesteps
|
320 |
)
|
321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
try:
|
323 |
algo.learn(
|
324 |
train_timesteps,
|
325 |
-
|
326 |
total_timesteps=config.n_timesteps,
|
327 |
start_timesteps=start_timesteps,
|
328 |
)
|
@@ -333,7 +366,7 @@ def stepwise_optimize(
|
|
333 |
eval_env,
|
334 |
tb_writer,
|
335 |
study_args.n_eval_episodes,
|
336 |
-
config.
|
337 |
start_timesteps + train_timesteps,
|
338 |
)
|
339 |
)
|
@@ -379,7 +412,7 @@ def stepwise_optimize(
|
|
379 |
|
380 |
|
381 |
def wandb_finish(state: str) -> None:
|
382 |
-
wandb.run.summary["state"] = state
|
383 |
wandb.finish(quiet=True)
|
384 |
|
385 |
|
|
|
2 |
import gc
|
3 |
import inspect
|
4 |
import logging
|
5 |
+
import os
|
6 |
+
from dataclasses import asdict, dataclass
|
7 |
+
from typing import Callable, List, NamedTuple, Optional, Sequence, Union
|
8 |
+
|
9 |
import numpy as np
|
10 |
import optuna
|
|
|
11 |
import torch
|
|
|
|
|
|
|
12 |
from optuna.pruners import HyperbandPruner
|
13 |
from optuna.samplers import TPESampler
|
14 |
from optuna.visualization import plot_optimization_history, plot_param_importances
|
15 |
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
16 |
|
17 |
+
import wandb
|
18 |
from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params
|
19 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs
|
|
|
20 |
from rl_algo_impls.runner.running_utils import (
|
21 |
+
ALGOS,
|
22 |
base_parser,
|
|
|
|
|
23 |
get_device,
|
|
|
|
|
24 |
hparam_dict,
|
25 |
+
load_hyperparams,
|
26 |
+
make_policy,
|
27 |
+
set_seeds,
|
28 |
+
)
|
29 |
+
from rl_algo_impls.shared.callbacks import Callback
|
30 |
+
from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
|
31 |
+
MicrortsRewardDecayCallback,
|
32 |
)
|
33 |
from rl_algo_impls.shared.callbacks.optimize_callback import (
|
34 |
Evaluation,
|
35 |
OptimizeCallback,
|
36 |
evaluation,
|
37 |
)
|
38 |
+
from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
|
39 |
from rl_algo_impls.shared.stats import EpisodesStats
|
40 |
+
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
|
41 |
+
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
|
42 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper
|
43 |
|
44 |
|
45 |
@dataclass
|
|
|
202 |
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
203 |
)
|
204 |
device = get_device(config, env)
|
205 |
+
policy_factory = lambda: make_policy(
|
206 |
+
args.algo, env, device, **config.policy_hyperparams
|
207 |
+
)
|
208 |
+
policy = policy_factory()
|
209 |
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
210 |
|
211 |
eval_env = make_eval_env(
|
212 |
config,
|
213 |
EnvHyperparams(**config.env_hyperparams),
|
214 |
+
override_hparams={"n_envs": study_args.n_eval_envs},
|
215 |
)
|
216 |
+
optimize_callback = OptimizeCallback(
|
217 |
policy,
|
218 |
eval_env,
|
219 |
trial,
|
220 |
tb_writer,
|
221 |
step_freq=config.n_timesteps // study_args.n_evaluations,
|
222 |
n_episodes=study_args.n_eval_episodes,
|
223 |
+
deterministic=config.eval_hyperparams.get("deterministic", True),
|
224 |
)
|
225 |
+
callbacks: List[Callback] = [optimize_callback]
|
226 |
+
if config.hyperparams.microrts_reward_decay_callback:
|
227 |
+
callbacks.append(MicrortsRewardDecayCallback(config, env))
|
228 |
+
selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
|
229 |
+
if selfPlayWrapper:
|
230 |
+
callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
|
231 |
try:
|
232 |
+
algo.learn(config.n_timesteps, callbacks=callbacks)
|
233 |
|
234 |
+
if not optimize_callback.is_pruned:
|
235 |
+
optimize_callback.evaluate()
|
236 |
+
if not optimize_callback.is_pruned:
|
237 |
policy.save(config.model_dir_path(best=False))
|
238 |
|
239 |
eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore
|
|
|
246 |
"hparam/last_result": eval_stat.score.mean - eval_stat.score.std,
|
247 |
"hparam/train_mean": train_stat.score.mean,
|
248 |
"hparam/train_result": train_stat.score.mean - train_stat.score.std,
|
249 |
+
"hparam/score": optimize_callback.last_score,
|
250 |
+
"hparam/is_pruned": optimize_callback.is_pruned,
|
251 |
},
|
252 |
None,
|
253 |
config.run_name(),
|
|
|
255 |
tb_writer.close()
|
256 |
|
257 |
if wandb_enabled:
|
258 |
+
wandb.run.summary["state"] = ( # type: ignore
|
259 |
+
"Pruned" if optimize_callback.is_pruned else "Complete"
|
260 |
+
)
|
261 |
wandb.finish(quiet=True)
|
262 |
|
263 |
+
if optimize_callback.is_pruned:
|
264 |
raise optuna.exceptions.TrialPruned()
|
265 |
|
266 |
+
return optimize_callback.last_score
|
267 |
except AssertionError as e:
|
268 |
logging.warning(e)
|
269 |
return np.nan
|
|
|
317 |
tb_writer=tb_writer,
|
318 |
)
|
319 |
device = get_device(config, env)
|
320 |
+
policy_factory = lambda: make_policy(
|
321 |
+
arg.algo, env, device, **config.policy_hyperparams
|
322 |
+
)
|
323 |
+
policy = policy_factory()
|
324 |
if i > 0:
|
325 |
policy.load(config.model_dir_path())
|
326 |
algo = ALGOS[arg.algo](
|
|
|
331 |
config,
|
332 |
EnvHyperparams(**config.env_hyperparams),
|
333 |
normalize_load_path=config.model_dir_path() if i > 0 else None,
|
334 |
+
override_hparams={"n_envs": study_args.n_eval_envs},
|
335 |
)
|
336 |
|
337 |
start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations)
|
|
|
340 |
- start_timesteps
|
341 |
)
|
342 |
|
343 |
+
callbacks = []
|
344 |
+
if config.hyperparams.microrts_reward_decay_callback:
|
345 |
+
callbacks.append(
|
346 |
+
MicrortsRewardDecayCallback(
|
347 |
+
config, env, start_timesteps=start_timesteps
|
348 |
+
)
|
349 |
+
)
|
350 |
+
selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
|
351 |
+
if selfPlayWrapper:
|
352 |
+
callbacks.append(
|
353 |
+
SelfPlayCallback(policy, policy_factory, selfPlayWrapper)
|
354 |
+
)
|
355 |
try:
|
356 |
algo.learn(
|
357 |
train_timesteps,
|
358 |
+
callbacks=callbacks,
|
359 |
total_timesteps=config.n_timesteps,
|
360 |
start_timesteps=start_timesteps,
|
361 |
)
|
|
|
366 |
eval_env,
|
367 |
tb_writer,
|
368 |
study_args.n_eval_episodes,
|
369 |
+
config.eval_hyperparams.get("deterministic", True),
|
370 |
start_timesteps + train_timesteps,
|
371 |
)
|
372 |
)
|
|
|
412 |
|
413 |
|
414 |
def wandb_finish(state: str) -> None:
|
415 |
+
wandb.run.summary["state"] = state # type: ignore
|
416 |
wandb.finish(quiet=True)
|
417 |
|
418 |
|
rl_algo_impls/ppo/ppo.py
CHANGED
@@ -10,12 +10,16 @@ from torch.optim import Adam
|
|
10 |
from torch.utils.tensorboard.writer import SummaryWriter
|
11 |
|
12 |
from rl_algo_impls.shared.algorithm import Algorithm
|
13 |
-
from rl_algo_impls.shared.callbacks
|
14 |
from rl_algo_impls.shared.gae import compute_advantages
|
15 |
-
from rl_algo_impls.shared.policy.
|
16 |
-
from rl_algo_impls.shared.schedule import
|
|
|
|
|
|
|
|
|
|
|
17 |
from rl_algo_impls.shared.stats import log_scalars
|
18 |
-
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
|
19 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
20 |
VecEnv,
|
21 |
single_action_space,
|
@@ -102,12 +106,17 @@ class PPO(Algorithm):
|
|
102 |
sde_sample_freq: int = -1,
|
103 |
update_advantage_between_epochs: bool = True,
|
104 |
update_returns_between_epochs: bool = False,
|
|
|
105 |
) -> None:
|
106 |
super().__init__(policy, env, device, tb_writer)
|
107 |
self.policy = policy
|
108 |
-
self.
|
109 |
|
110 |
-
self.
|
|
|
|
|
|
|
|
|
111 |
self.gae_lambda = gae_lambda
|
112 |
self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
|
113 |
self.lr_schedule = schedule(learning_rate_decay, learning_rate)
|
@@ -138,7 +147,7 @@ class PPO(Algorithm):
|
|
138 |
def learn(
|
139 |
self: PPOSelf,
|
140 |
train_timesteps: int,
|
141 |
-
|
142 |
total_timesteps: Optional[int] = None,
|
143 |
start_timesteps: int = 0,
|
144 |
) -> PPOSelf:
|
@@ -153,15 +162,13 @@ class PPO(Algorithm):
|
|
153 |
act_shape = self.policy.action_shape
|
154 |
|
155 |
next_obs = self.env.reset()
|
156 |
-
next_action_masks = (
|
157 |
-
|
158 |
-
)
|
159 |
-
next_episode_starts = np.full(step_dim, True, dtype=np.bool8)
|
160 |
|
161 |
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
|
162 |
actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore
|
163 |
rewards = np.zeros(epoch_dim, dtype=np.float32)
|
164 |
-
episode_starts = np.zeros(epoch_dim, dtype=np.
|
165 |
values = np.zeros(epoch_dim, dtype=np.float32)
|
166 |
logprobs = np.zeros(epoch_dim, dtype=np.float32)
|
167 |
action_masks = (
|
@@ -181,10 +188,12 @@ class PPO(Algorithm):
|
|
181 |
learning_rate = self.lr_schedule(progress)
|
182 |
update_learning_rate(self.optimizer, learning_rate)
|
183 |
pi_clip = self.clip_range_schedule(progress)
|
|
|
184 |
chart_scalars = {
|
185 |
"learning_rate": self.optimizer.param_groups[0]["lr"],
|
186 |
"ent_coef": ent_coef,
|
187 |
"pi_clip": pi_clip,
|
|
|
188 |
}
|
189 |
if self.clip_range_vf_schedule:
|
190 |
v_clip = self.clip_range_vf_schedule(progress)
|
@@ -215,7 +224,7 @@ class PPO(Algorithm):
|
|
215 |
clamped_action
|
216 |
)
|
217 |
next_action_masks = (
|
218 |
-
self.
|
219 |
)
|
220 |
|
221 |
self.policy.train()
|
@@ -251,7 +260,7 @@ class PPO(Algorithm):
|
|
251 |
next_episode_starts,
|
252 |
next_obs,
|
253 |
self.policy,
|
254 |
-
|
255 |
self.gae_lambda,
|
256 |
)
|
257 |
b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
|
@@ -364,8 +373,10 @@ class PPO(Algorithm):
|
|
364 |
timesteps_elapsed,
|
365 |
)
|
366 |
|
367 |
-
if
|
368 |
-
if not
|
|
|
|
|
369 |
logging.info(
|
370 |
f"Callback terminated training at {timesteps_elapsed} timesteps"
|
371 |
)
|
|
|
10 |
from torch.utils.tensorboard.writer import SummaryWriter
|
11 |
|
12 |
from rl_algo_impls.shared.algorithm import Algorithm
|
13 |
+
from rl_algo_impls.shared.callbacks import Callback
|
14 |
from rl_algo_impls.shared.gae import compute_advantages
|
15 |
+
from rl_algo_impls.shared.policy.actor_critic import ActorCritic
|
16 |
+
from rl_algo_impls.shared.schedule import (
|
17 |
+
constant_schedule,
|
18 |
+
linear_schedule,
|
19 |
+
schedule,
|
20 |
+
update_learning_rate,
|
21 |
+
)
|
22 |
from rl_algo_impls.shared.stats import log_scalars
|
|
|
23 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
24 |
VecEnv,
|
25 |
single_action_space,
|
|
|
106 |
sde_sample_freq: int = -1,
|
107 |
update_advantage_between_epochs: bool = True,
|
108 |
update_returns_between_epochs: bool = False,
|
109 |
+
gamma_end: Optional[float] = None,
|
110 |
) -> None:
|
111 |
super().__init__(policy, env, device, tb_writer)
|
112 |
self.policy = policy
|
113 |
+
self.get_action_mask = getattr(env, "get_action_mask", None)
|
114 |
|
115 |
+
self.gamma_schedule = (
|
116 |
+
linear_schedule(gamma, gamma_end)
|
117 |
+
if gamma_end is not None
|
118 |
+
else constant_schedule(gamma)
|
119 |
+
)
|
120 |
self.gae_lambda = gae_lambda
|
121 |
self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7)
|
122 |
self.lr_schedule = schedule(learning_rate_decay, learning_rate)
|
|
|
147 |
def learn(
|
148 |
self: PPOSelf,
|
149 |
train_timesteps: int,
|
150 |
+
callbacks: Optional[List[Callback]] = None,
|
151 |
total_timesteps: Optional[int] = None,
|
152 |
start_timesteps: int = 0,
|
153 |
) -> PPOSelf:
|
|
|
162 |
act_shape = self.policy.action_shape
|
163 |
|
164 |
next_obs = self.env.reset()
|
165 |
+
next_action_masks = self.get_action_mask() if self.get_action_mask else None
|
166 |
+
next_episode_starts = np.full(step_dim, True, dtype=np.bool_)
|
|
|
|
|
167 |
|
168 |
obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore
|
169 |
actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore
|
170 |
rewards = np.zeros(epoch_dim, dtype=np.float32)
|
171 |
+
episode_starts = np.zeros(epoch_dim, dtype=np.bool_)
|
172 |
values = np.zeros(epoch_dim, dtype=np.float32)
|
173 |
logprobs = np.zeros(epoch_dim, dtype=np.float32)
|
174 |
action_masks = (
|
|
|
188 |
learning_rate = self.lr_schedule(progress)
|
189 |
update_learning_rate(self.optimizer, learning_rate)
|
190 |
pi_clip = self.clip_range_schedule(progress)
|
191 |
+
gamma = self.gamma_schedule(progress)
|
192 |
chart_scalars = {
|
193 |
"learning_rate": self.optimizer.param_groups[0]["lr"],
|
194 |
"ent_coef": ent_coef,
|
195 |
"pi_clip": pi_clip,
|
196 |
+
"gamma": gamma,
|
197 |
}
|
198 |
if self.clip_range_vf_schedule:
|
199 |
v_clip = self.clip_range_vf_schedule(progress)
|
|
|
224 |
clamped_action
|
225 |
)
|
226 |
next_action_masks = (
|
227 |
+
self.get_action_mask() if self.get_action_mask else None
|
228 |
)
|
229 |
|
230 |
self.policy.train()
|
|
|
260 |
next_episode_starts,
|
261 |
next_obs,
|
262 |
self.policy,
|
263 |
+
gamma,
|
264 |
self.gae_lambda,
|
265 |
)
|
266 |
b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device)
|
|
|
373 |
timesteps_elapsed,
|
374 |
)
|
375 |
|
376 |
+
if callbacks:
|
377 |
+
if not all(
|
378 |
+
c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks
|
379 |
+
):
|
380 |
logging.info(
|
381 |
f"Callback terminated training at {timesteps_elapsed} timesteps"
|
382 |
)
|
rl_algo_impls/runner/config.py
CHANGED
@@ -51,6 +51,8 @@ class EnvHyperparams:
|
|
51 |
normalize_type: Optional[str] = None
|
52 |
mask_actions: bool = False
|
53 |
bots: Optional[Dict[str, int]] = None
|
|
|
|
|
54 |
|
55 |
|
56 |
HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
|
@@ -63,9 +65,10 @@ class Hyperparams:
|
|
63 |
env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
64 |
policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
65 |
algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
66 |
-
|
67 |
env_id: Optional[str] = None
|
68 |
additional_keys_to_log: List[str] = dataclasses.field(default_factory=list)
|
|
|
69 |
|
70 |
@classmethod
|
71 |
def from_dict_with_extra_fields(
|
@@ -110,8 +113,14 @@ class Config:
|
|
110 |
return self.hyperparams.algo_hyperparams
|
111 |
|
112 |
@property
|
113 |
-
def
|
114 |
-
return self.hyperparams.
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
@property
|
117 |
def algo(self) -> str:
|
|
|
51 |
normalize_type: Optional[str] = None
|
52 |
mask_actions: bool = False
|
53 |
bots: Optional[Dict[str, int]] = None
|
54 |
+
self_play_kwargs: Optional[Dict[str, Any]] = None
|
55 |
+
selfplay_bots: Optional[Dict[str, int]] = None
|
56 |
|
57 |
|
58 |
HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams")
|
|
|
65 |
env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
66 |
policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
67 |
algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
68 |
+
eval_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
69 |
env_id: Optional[str] = None
|
70 |
additional_keys_to_log: List[str] = dataclasses.field(default_factory=list)
|
71 |
+
microrts_reward_decay_callback: bool = False
|
72 |
|
73 |
@classmethod
|
74 |
def from_dict_with_extra_fields(
|
|
|
113 |
return self.hyperparams.algo_hyperparams
|
114 |
|
115 |
@property
|
116 |
+
def eval_hyperparams(self) -> Dict[str, Any]:
|
117 |
+
return self.hyperparams.eval_hyperparams
|
118 |
+
|
119 |
+
def eval_callback_params(self) -> Dict[str, Any]:
|
120 |
+
eval_hyperparams = self.eval_hyperparams.copy()
|
121 |
+
if "env_overrides" in eval_hyperparams:
|
122 |
+
del eval_hyperparams["env_overrides"]
|
123 |
+
return eval_hyperparams
|
124 |
|
125 |
@property
|
126 |
def algo(self) -> str:
|
rl_algo_impls/runner/evaluate.py
CHANGED
@@ -1,20 +1,19 @@
|
|
1 |
import os
|
2 |
import shutil
|
3 |
-
|
4 |
from dataclasses import dataclass
|
5 |
from typing import NamedTuple, Optional
|
6 |
|
7 |
-
from rl_algo_impls.shared.vec_env import make_eval_env
|
8 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
|
9 |
from rl_algo_impls.runner.running_utils import (
|
10 |
-
load_hyperparams,
|
11 |
-
set_seeds,
|
12 |
get_device,
|
|
|
13 |
make_policy,
|
|
|
14 |
)
|
15 |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
16 |
from rl_algo_impls.shared.policy.policy import Policy
|
17 |
from rl_algo_impls.shared.stats import EpisodesStats
|
|
|
18 |
|
19 |
|
20 |
@dataclass
|
@@ -71,7 +70,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
|
|
71 |
env = make_eval_env(
|
72 |
config,
|
73 |
EnvHyperparams(**config.env_hyperparams),
|
74 |
-
|
75 |
render=args.render,
|
76 |
normalize_load_path=model_path,
|
77 |
)
|
@@ -87,7 +86,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
|
|
87 |
deterministic = (
|
88 |
args.deterministic_eval
|
89 |
if args.deterministic_eval is not None
|
90 |
-
else config.
|
91 |
)
|
92 |
return Evaluation(
|
93 |
policy,
|
|
|
1 |
import os
|
2 |
import shutil
|
|
|
3 |
from dataclasses import dataclass
|
4 |
from typing import NamedTuple, Optional
|
5 |
|
|
|
6 |
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
|
7 |
from rl_algo_impls.runner.running_utils import (
|
|
|
|
|
8 |
get_device,
|
9 |
+
load_hyperparams,
|
10 |
make_policy,
|
11 |
+
set_seeds,
|
12 |
)
|
13 |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
14 |
from rl_algo_impls.shared.policy.policy import Policy
|
15 |
from rl_algo_impls.shared.stats import EpisodesStats
|
16 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
17 |
|
18 |
|
19 |
@dataclass
|
|
|
70 |
env = make_eval_env(
|
71 |
config,
|
72 |
EnvHyperparams(**config.env_hyperparams),
|
73 |
+
override_hparams={"n_envs": args.n_envs} if args.n_envs else None,
|
74 |
render=args.render,
|
75 |
normalize_load_path=model_path,
|
76 |
)
|
|
|
86 |
deterministic = (
|
87 |
args.deterministic_eval
|
88 |
if args.deterministic_eval is not None
|
89 |
+
else config.eval_hyperparams.get("deterministic", True)
|
90 |
)
|
91 |
return Evaluation(
|
92 |
policy,
|
rl_algo_impls/runner/running_utils.py
CHANGED
@@ -22,7 +22,7 @@ from rl_algo_impls.ppo.ppo import PPO
|
|
22 |
from rl_algo_impls.runner.config import Config, Hyperparams
|
23 |
from rl_algo_impls.shared.algorithm import Algorithm
|
24 |
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
|
25 |
-
from rl_algo_impls.shared.policy.
|
26 |
from rl_algo_impls.shared.policy.policy import Policy
|
27 |
from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts
|
28 |
from rl_algo_impls.vpg.policy import VPGActorCritic
|
@@ -97,29 +97,21 @@ def get_device(config: Config, env: VecEnv) -> torch.device:
|
|
97 |
# cuda by default
|
98 |
if device == "auto":
|
99 |
device = "cuda"
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
device = "cpu"
|
106 |
-
# Simple environments like Discreet and 1-D Boxes might also be better
|
107 |
-
# served with the CPU.
|
108 |
-
if device == "mps":
|
109 |
-
obs_space = single_observation_space(env)
|
110 |
-
if isinstance(obs_space, Discrete):
|
111 |
device = "cpu"
|
112 |
-
|
113 |
-
|
114 |
-
if
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
except ModuleNotFoundError:
|
122 |
-
# Likely on gym_microrts v0.0.2 to match ppo-implementation-details
|
123 |
device = "cpu"
|
124 |
print(f"Device: {device}")
|
125 |
return torch.device(device)
|
|
|
22 |
from rl_algo_impls.runner.config import Config, Hyperparams
|
23 |
from rl_algo_impls.shared.algorithm import Algorithm
|
24 |
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
|
25 |
+
from rl_algo_impls.shared.policy.actor_critic import ActorCritic
|
26 |
from rl_algo_impls.shared.policy.policy import Policy
|
27 |
from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts
|
28 |
from rl_algo_impls.vpg.policy import VPGActorCritic
|
|
|
97 |
# cuda by default
|
98 |
if device == "auto":
|
99 |
device = "cuda"
|
100 |
+
# Apple MPS is a second choice (sometimes)
|
101 |
+
if device == "cuda" and not torch.cuda.is_available():
|
102 |
+
device = "mps"
|
103 |
+
# If no MPS, fallback to cpu
|
104 |
+
if device == "mps" and not torch.backends.mps.is_available():
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
device = "cpu"
|
106 |
+
# Simple environments like Discreet and 1-D Boxes might also be better
|
107 |
+
# served with the CPU.
|
108 |
+
if device == "mps":
|
109 |
+
obs_space = single_observation_space(env)
|
110 |
+
if isinstance(obs_space, Discrete):
|
111 |
+
device = "cpu"
|
112 |
+
elif isinstance(obs_space, Box) and len(obs_space.shape) == 1:
|
113 |
+
device = "cpu"
|
114 |
+
if is_microrts(config):
|
|
|
|
|
115 |
device = "cpu"
|
116 |
print(f"Device: {device}")
|
117 |
return torch.device(device)
|
rl_algo_impls/runner/selfplay_evaluate.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import dataclasses
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import List, NamedTuple, Optional
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import wandb
|
11 |
+
from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs
|
12 |
+
from rl_algo_impls.runner.evaluate import Evaluation
|
13 |
+
from rl_algo_impls.runner.running_utils import (
|
14 |
+
get_device,
|
15 |
+
load_hyperparams,
|
16 |
+
make_policy,
|
17 |
+
set_seeds,
|
18 |
+
)
|
19 |
+
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
20 |
+
from rl_algo_impls.shared.vec_env import make_eval_env
|
21 |
+
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
|
22 |
+
|
23 |
+
|
24 |
+
@dataclass
|
25 |
+
class SelfplayEvalArgs(RunArgs):
|
26 |
+
# Either wandb_run_paths or model_file_paths must have 2 elements in it.
|
27 |
+
wandb_run_paths: List[str] = dataclasses.field(default_factory=list)
|
28 |
+
model_file_paths: List[str] = dataclasses.field(default_factory=list)
|
29 |
+
render: bool = False
|
30 |
+
best: bool = True
|
31 |
+
n_envs: int = 1
|
32 |
+
n_episodes: int = 1
|
33 |
+
deterministic_eval: Optional[bool] = None
|
34 |
+
no_print_returns: bool = False
|
35 |
+
video_path: Optional[str] = None
|
36 |
+
|
37 |
+
|
38 |
+
def selfplay_evaluate(args: SelfplayEvalArgs, root_dir: str) -> Evaluation:
|
39 |
+
if args.wandb_run_paths:
|
40 |
+
api = wandb.Api()
|
41 |
+
args, config, player_1_model_path = load_player(
|
42 |
+
api, args.wandb_run_paths[0], args, root_dir
|
43 |
+
)
|
44 |
+
_, _, player_2_model_path = load_player(
|
45 |
+
api, args.wandb_run_paths[1], args, root_dir
|
46 |
+
)
|
47 |
+
elif args.model_file_paths:
|
48 |
+
hyperparams = load_hyperparams(args.algo, args.env)
|
49 |
+
|
50 |
+
config = Config(args, hyperparams, root_dir)
|
51 |
+
player_1_model_path, player_2_model_path = args.model_file_paths
|
52 |
+
else:
|
53 |
+
raise ValueError("Must specify 2 wandb_run_paths or 2 model_file_paths")
|
54 |
+
|
55 |
+
print(args)
|
56 |
+
|
57 |
+
set_seeds(args.seed, args.use_deterministic_algorithms)
|
58 |
+
|
59 |
+
env_make_kwargs = (
|
60 |
+
config.eval_hyperparams.get("env_overrides", {}).get("make_kwargs", {}).copy()
|
61 |
+
)
|
62 |
+
env_make_kwargs["num_selfplay_envs"] = args.n_envs * 2
|
63 |
+
env = make_eval_env(
|
64 |
+
config,
|
65 |
+
EnvHyperparams(**config.env_hyperparams),
|
66 |
+
override_hparams={
|
67 |
+
"n_envs": args.n_envs,
|
68 |
+
"selfplay_bots": {
|
69 |
+
player_2_model_path: args.n_envs,
|
70 |
+
},
|
71 |
+
"self_play_kwargs": {
|
72 |
+
"num_old_policies": 0,
|
73 |
+
"save_steps": np.inf,
|
74 |
+
"swap_steps": np.inf,
|
75 |
+
"bot_always_player_2": True,
|
76 |
+
},
|
77 |
+
"bots": None,
|
78 |
+
"make_kwargs": env_make_kwargs,
|
79 |
+
},
|
80 |
+
render=args.render,
|
81 |
+
normalize_load_path=player_1_model_path,
|
82 |
+
)
|
83 |
+
if args.video_path:
|
84 |
+
env = VecEpisodeRecorder(
|
85 |
+
env, args.video_path, max_video_length=18000, num_episodes=args.n_episodes
|
86 |
+
)
|
87 |
+
device = get_device(config, env)
|
88 |
+
policy = make_policy(
|
89 |
+
args.algo,
|
90 |
+
env,
|
91 |
+
device,
|
92 |
+
load_path=player_1_model_path,
|
93 |
+
**config.policy_hyperparams,
|
94 |
+
).eval()
|
95 |
+
|
96 |
+
deterministic = (
|
97 |
+
args.deterministic_eval
|
98 |
+
if args.deterministic_eval is not None
|
99 |
+
else config.eval_hyperparams.get("deterministic", True)
|
100 |
+
)
|
101 |
+
return Evaluation(
|
102 |
+
policy,
|
103 |
+
evaluate(
|
104 |
+
env,
|
105 |
+
policy,
|
106 |
+
args.n_episodes,
|
107 |
+
render=args.render,
|
108 |
+
deterministic=deterministic,
|
109 |
+
print_returns=not args.no_print_returns,
|
110 |
+
),
|
111 |
+
config,
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
+
class PlayerData(NamedTuple):
|
116 |
+
args: SelfplayEvalArgs
|
117 |
+
config: Config
|
118 |
+
model_path: str
|
119 |
+
|
120 |
+
|
121 |
+
def load_player(
|
122 |
+
api: wandb.Api, run_path: str, args: SelfplayEvalArgs, root_dir: str
|
123 |
+
) -> PlayerData:
|
124 |
+
args = copy.copy(args)
|
125 |
+
|
126 |
+
run = api.run(run_path)
|
127 |
+
params = run.config
|
128 |
+
args.algo = params["algo"]
|
129 |
+
args.env = params["env"]
|
130 |
+
args.seed = params.get("seed", None)
|
131 |
+
args.use_deterministic_algorithms = params.get("use_deterministic_algorithms", True)
|
132 |
+
config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir)
|
133 |
+
model_path = config.model_dir_path(best=args.best, downloaded=True)
|
134 |
+
|
135 |
+
model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
|
136 |
+
run.file(model_archive_name).download()
|
137 |
+
if os.path.isdir(model_path):
|
138 |
+
shutil.rmtree(model_path)
|
139 |
+
shutil.unpack_archive(model_archive_name, model_path)
|
140 |
+
os.remove(model_archive_name)
|
141 |
+
|
142 |
+
return PlayerData(args, config, model_path)
|
rl_algo_impls/runner/train.py
CHANGED
@@ -1,12 +1,17 @@
|
|
1 |
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
|
2 |
import os
|
3 |
|
|
|
|
|
|
|
|
|
|
|
4 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
|
6 |
import dataclasses
|
7 |
import shutil
|
8 |
from dataclasses import asdict, dataclass
|
9 |
-
from typing import Any, Dict, Optional, Sequence
|
10 |
|
11 |
import yaml
|
12 |
from torch.utils.tensorboard.writer import SummaryWriter
|
@@ -23,6 +28,9 @@ from rl_algo_impls.runner.running_utils import (
|
|
23 |
set_seeds,
|
24 |
)
|
25 |
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
|
|
|
|
|
|
|
26 |
from rl_algo_impls.shared.stats import EpisodesStats
|
27 |
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
|
28 |
|
@@ -41,7 +49,7 @@ def train(args: TrainArgs):
|
|
41 |
print(hyperparams)
|
42 |
config = Config(args, hyperparams, os.getcwd())
|
43 |
|
44 |
-
wandb_enabled = args.wandb_project_name
|
45 |
if wandb_enabled:
|
46 |
wandb.tensorboard.patch(
|
47 |
root_logdir=config.tensorboard_summary_path, pytorch=True
|
@@ -66,14 +74,17 @@ def train(args: TrainArgs):
|
|
66 |
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
67 |
)
|
68 |
device = get_device(config, env)
|
69 |
-
|
|
|
|
|
|
|
70 |
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
71 |
|
72 |
num_parameters = policy.num_parameters()
|
73 |
num_trainable_parameters = policy.num_trainable_parameters()
|
74 |
if wandb_enabled:
|
75 |
-
wandb.run.summary["num_parameters"] = num_parameters
|
76 |
-
wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters
|
77 |
else:
|
78 |
print(
|
79 |
f"num_parameters = {num_parameters} ; "
|
@@ -81,40 +92,49 @@ def train(args: TrainArgs):
|
|
81 |
)
|
82 |
|
83 |
eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
|
84 |
-
record_best_videos = config.
|
85 |
-
|
86 |
policy,
|
87 |
eval_env,
|
88 |
tb_writer,
|
89 |
best_model_path=config.model_dir_path(best=True),
|
90 |
-
**config.
|
91 |
video_env=make_eval_env(
|
92 |
-
config,
|
|
|
|
|
93 |
)
|
94 |
if record_best_videos
|
95 |
else None,
|
96 |
best_video_dir=config.best_videos_dir,
|
97 |
additional_keys_to_log=config.additional_keys_to_log,
|
|
|
98 |
)
|
99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
policy.save(config.model_dir_path(best=False))
|
102 |
|
103 |
-
eval_stats =
|
104 |
|
105 |
-
plot_eval_callback(
|
106 |
|
107 |
log_dict: Dict[str, Any] = {
|
108 |
"eval": eval_stats._asdict(),
|
109 |
}
|
110 |
-
if
|
111 |
-
log_dict["best_eval"] =
|
112 |
log_dict.update(asdict(hyperparams))
|
113 |
log_dict.update(vars(args))
|
114 |
with open(config.logs_path, "a") as f:
|
115 |
yaml.dump({config.run_name(): log_dict}, f)
|
116 |
|
117 |
-
best_eval_stats: EpisodesStats =
|
118 |
tb_writer.add_hparams(
|
119 |
hparam_dict(hyperparams, vars(args)),
|
120 |
{
|
@@ -132,13 +152,8 @@ def train(args: TrainArgs):
|
|
132 |
|
133 |
if wandb_enabled:
|
134 |
shutil.make_archive(
|
135 |
-
os.path.join(wandb.run.dir, config.model_dir_name()),
|
136 |
"zip",
|
137 |
config.model_dir_path(),
|
138 |
)
|
139 |
-
shutil.make_archive(
|
140 |
-
os.path.join(wandb.run.dir, config.model_dir_name(best=True)),
|
141 |
-
"zip",
|
142 |
-
config.model_dir_path(best=True),
|
143 |
-
)
|
144 |
wandb.finish()
|
|
|
1 |
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
|
2 |
import os
|
3 |
|
4 |
+
from rl_algo_impls.shared.callbacks import Callback
|
5 |
+
from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback
|
6 |
+
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
|
7 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper
|
8 |
+
|
9 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
10 |
|
11 |
import dataclasses
|
12 |
import shutil
|
13 |
from dataclasses import asdict, dataclass
|
14 |
+
from typing import Any, Dict, List, Optional, Sequence
|
15 |
|
16 |
import yaml
|
17 |
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
28 |
set_seeds,
|
29 |
)
|
30 |
from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback
|
31 |
+
from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import (
|
32 |
+
MicrortsRewardDecayCallback,
|
33 |
+
)
|
34 |
from rl_algo_impls.shared.stats import EpisodesStats
|
35 |
from rl_algo_impls.shared.vec_env import make_env, make_eval_env
|
36 |
|
|
|
49 |
print(hyperparams)
|
50 |
config = Config(args, hyperparams, os.getcwd())
|
51 |
|
52 |
+
wandb_enabled = bool(args.wandb_project_name)
|
53 |
if wandb_enabled:
|
54 |
wandb.tensorboard.patch(
|
55 |
root_logdir=config.tensorboard_summary_path, pytorch=True
|
|
|
74 |
config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer
|
75 |
)
|
76 |
device = get_device(config, env)
|
77 |
+
policy_factory = lambda: make_policy(
|
78 |
+
args.algo, env, device, **config.policy_hyperparams
|
79 |
+
)
|
80 |
+
policy = policy_factory()
|
81 |
algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams)
|
82 |
|
83 |
num_parameters = policy.num_parameters()
|
84 |
num_trainable_parameters = policy.num_trainable_parameters()
|
85 |
if wandb_enabled:
|
86 |
+
wandb.run.summary["num_parameters"] = num_parameters # type: ignore
|
87 |
+
wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters # type: ignore
|
88 |
else:
|
89 |
print(
|
90 |
f"num_parameters = {num_parameters} ; "
|
|
|
92 |
)
|
93 |
|
94 |
eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams))
|
95 |
+
record_best_videos = config.eval_hyperparams.get("record_best_videos", True)
|
96 |
+
eval_callback = EvalCallback(
|
97 |
policy,
|
98 |
eval_env,
|
99 |
tb_writer,
|
100 |
best_model_path=config.model_dir_path(best=True),
|
101 |
+
**config.eval_callback_params(),
|
102 |
video_env=make_eval_env(
|
103 |
+
config,
|
104 |
+
EnvHyperparams(**config.env_hyperparams),
|
105 |
+
override_hparams={"n_envs": 1},
|
106 |
)
|
107 |
if record_best_videos
|
108 |
else None,
|
109 |
best_video_dir=config.best_videos_dir,
|
110 |
additional_keys_to_log=config.additional_keys_to_log,
|
111 |
+
wandb_enabled=wandb_enabled,
|
112 |
)
|
113 |
+
callbacks: List[Callback] = [eval_callback]
|
114 |
+
if config.hyperparams.microrts_reward_decay_callback:
|
115 |
+
callbacks.append(MicrortsRewardDecayCallback(config, env))
|
116 |
+
selfPlayWrapper = find_wrapper(env, SelfPlayWrapper)
|
117 |
+
if selfPlayWrapper:
|
118 |
+
callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper))
|
119 |
+
algo.learn(config.n_timesteps, callbacks=callbacks)
|
120 |
|
121 |
policy.save(config.model_dir_path(best=False))
|
122 |
|
123 |
+
eval_stats = eval_callback.evaluate(n_episodes=10, print_returns=True)
|
124 |
|
125 |
+
plot_eval_callback(eval_callback, tb_writer, config.run_name())
|
126 |
|
127 |
log_dict: Dict[str, Any] = {
|
128 |
"eval": eval_stats._asdict(),
|
129 |
}
|
130 |
+
if eval_callback.best:
|
131 |
+
log_dict["best_eval"] = eval_callback.best._asdict()
|
132 |
log_dict.update(asdict(hyperparams))
|
133 |
log_dict.update(vars(args))
|
134 |
with open(config.logs_path, "a") as f:
|
135 |
yaml.dump({config.run_name(): log_dict}, f)
|
136 |
|
137 |
+
best_eval_stats: EpisodesStats = eval_callback.best # type: ignore
|
138 |
tb_writer.add_hparams(
|
139 |
hparam_dict(hyperparams, vars(args)),
|
140 |
{
|
|
|
152 |
|
153 |
if wandb_enabled:
|
154 |
shutil.make_archive(
|
155 |
+
os.path.join(wandb.run.dir, config.model_dir_name()), # type: ignore
|
156 |
"zip",
|
157 |
config.model_dir_path(),
|
158 |
)
|
|
|
|
|
|
|
|
|
|
|
159 |
wandb.finish()
|
rl_algo_impls/selfplay_enjoy.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
|
2 |
+
import os
|
3 |
+
|
4 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
5 |
+
|
6 |
+
from rl_algo_impls.runner.running_utils import base_parser
|
7 |
+
from rl_algo_impls.runner.selfplay_evaluate import SelfplayEvalArgs, selfplay_evaluate
|
8 |
+
|
9 |
+
|
10 |
+
def selfplay_enjoy() -> None:
|
11 |
+
parser = base_parser(multiple=False)
|
12 |
+
parser.add_argument(
|
13 |
+
"--wandb-run-paths",
|
14 |
+
type=str,
|
15 |
+
nargs="*",
|
16 |
+
help="WandB run paths to load players from. Must be 0 or 2",
|
17 |
+
)
|
18 |
+
parser.add_argument(
|
19 |
+
"--model-file-paths",
|
20 |
+
type=str,
|
21 |
+
help="File paths to load players from. Must be 0 or 2",
|
22 |
+
)
|
23 |
+
parser.add_argument("--render", action="store_true")
|
24 |
+
parser.add_argument("--n-envs", default=1, type=int)
|
25 |
+
parser.add_argument("--n-episodes", default=1, type=int)
|
26 |
+
parser.add_argument("--deterministic-eval", default=None, type=bool)
|
27 |
+
parser.add_argument(
|
28 |
+
"--no-print-returns", action="store_true", help="Limit printing"
|
29 |
+
)
|
30 |
+
parser.add_argument(
|
31 |
+
"--video-path", type=str, help="Path to save video of all plays"
|
32 |
+
)
|
33 |
+
# parser.set_defaults(
|
34 |
+
# algo=["ppo"],
|
35 |
+
# env=["Microrts-selfplay-unet-decay"],
|
36 |
+
# n_episodes=10,
|
37 |
+
# model_file_paths=[
|
38 |
+
# "downloaded_models/ppo-Microrts-selfplay-unet-decay-S3-best",
|
39 |
+
# "downloaded_models/ppo-Microrts-selfplay-unet-decay-S2-best",
|
40 |
+
# ],
|
41 |
+
# video_path="/Users/sgoodfriend/Desktop/decay3-vs-decay2",
|
42 |
+
# )
|
43 |
+
args = parser.parse_args()
|
44 |
+
args.algo = args.algo[0]
|
45 |
+
args.env = args.env[0]
|
46 |
+
args.seed = args.seed[0]
|
47 |
+
args = SelfplayEvalArgs(**vars(args))
|
48 |
+
|
49 |
+
selfplay_evaluate(args, os.getcwd())
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
selfplay_enjoy()
|
rl_algo_impls/shared/actor/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
2 |
from rl_algo_impls.shared.actor.make_actor import actor_head
|
|
|
1 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
|
2 |
from rl_algo_impls.shared.actor.make_actor import actor_head
|
rl_algo_impls/shared/actor/actor.py
CHANGED
@@ -31,12 +31,13 @@ class Actor(nn.Module, ABC):
|
|
31 |
def action_shape(self) -> Tuple[int, ...]:
|
32 |
...
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
31 |
def action_shape(self) -> Tuple[int, ...]:
|
32 |
...
|
33 |
|
34 |
+
|
35 |
+
def pi_forward(
|
36 |
+
distribution: Distribution, actions: Optional[torch.Tensor] = None
|
37 |
+
) -> PiForward:
|
38 |
+
logp_a = None
|
39 |
+
entropy = None
|
40 |
+
if actions is not None:
|
41 |
+
logp_a = distribution.log_prob(actions)
|
42 |
+
entropy = distribution.entropy()
|
43 |
+
return PiForward(distribution, logp_a, entropy)
|
rl_algo_impls/shared/actor/categorical.py
CHANGED
@@ -4,8 +4,8 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
from torch.distributions import Categorical
|
6 |
|
7 |
-
from rl_algo_impls.shared.actor import Actor, PiForward
|
8 |
-
from rl_algo_impls.shared.module.
|
9 |
|
10 |
|
11 |
class MaskedCategorical(Categorical):
|
@@ -57,7 +57,7 @@ class CategoricalActorHead(Actor):
|
|
57 |
) -> PiForward:
|
58 |
logits = self._fc(obs)
|
59 |
pi = MaskedCategorical(logits=logits, mask=action_masks)
|
60 |
-
return
|
61 |
|
62 |
@property
|
63 |
def action_shape(self) -> Tuple[int, ...]:
|
|
|
4 |
import torch.nn as nn
|
5 |
from torch.distributions import Categorical
|
6 |
|
7 |
+
from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
|
8 |
+
from rl_algo_impls.shared.module.utils import mlp
|
9 |
|
10 |
|
11 |
class MaskedCategorical(Categorical):
|
|
|
57 |
) -> PiForward:
|
58 |
logits = self._fc(obs)
|
59 |
pi = MaskedCategorical(logits=logits, mask=action_masks)
|
60 |
+
return pi_forward(pi, actions)
|
61 |
|
62 |
@property
|
63 |
def action_shape(self) -> Tuple[int, ...]:
|
rl_algo_impls/shared/actor/gaussian.py
CHANGED
@@ -4,8 +4,8 @@ import torch
|
|
4 |
import torch.nn as nn
|
5 |
from torch.distributions import Distribution, Normal
|
6 |
|
7 |
-
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
8 |
-
from rl_algo_impls.shared.module.
|
9 |
|
10 |
|
11 |
class GaussianDistribution(Normal):
|
@@ -54,7 +54,7 @@ class GaussianActorHead(Actor):
|
|
54 |
not action_masks
|
55 |
), f"{self.__class__.__name__} does not support action_masks"
|
56 |
pi = self._distribution(obs)
|
57 |
-
return
|
58 |
|
59 |
@property
|
60 |
def action_shape(self) -> Tuple[int, ...]:
|
|
|
4 |
import torch.nn as nn
|
5 |
from torch.distributions import Distribution, Normal
|
6 |
|
7 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
|
8 |
+
from rl_algo_impls.shared.module.utils import mlp
|
9 |
|
10 |
|
11 |
class GaussianDistribution(Normal):
|
|
|
54 |
not action_masks
|
55 |
), f"{self.__class__.__name__} does not support action_masks"
|
56 |
pi = self._distribution(obs)
|
57 |
+
return pi_forward(pi, actions)
|
58 |
|
59 |
@property
|
60 |
def action_shape(self) -> Tuple[int, ...]:
|
rl_algo_impls/shared/actor/gridnet.py
CHANGED
@@ -6,10 +6,10 @@ import torch.nn as nn
|
|
6 |
from numpy.typing import NDArray
|
7 |
from torch.distributions import Distribution, constraints
|
8 |
|
9 |
-
from rl_algo_impls.shared.actor import Actor, PiForward
|
10 |
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
11 |
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
-
from rl_algo_impls.shared.module.
|
13 |
|
14 |
|
15 |
class GridnetDistribution(Distribution):
|
@@ -25,7 +25,7 @@ class GridnetDistribution(Distribution):
|
|
25 |
self.action_vec = action_vec
|
26 |
|
27 |
masks = masks.view(-1, masks.shape[-1])
|
28 |
-
split_masks = torch.split(masks
|
29 |
|
30 |
grid_logits = logits.reshape(-1, action_vec.sum())
|
31 |
split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
|
@@ -101,7 +101,7 @@ class GridnetActorHead(Actor):
|
|
101 |
), f"No mask case unhandled in {self.__class__.__name__}"
|
102 |
logits = self._fc(obs)
|
103 |
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
|
104 |
-
return
|
105 |
|
106 |
@property
|
107 |
def action_shape(self) -> Tuple[int, ...]:
|
|
|
6 |
from numpy.typing import NDArray
|
7 |
from torch.distributions import Distribution, constraints
|
8 |
|
9 |
+
from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
|
10 |
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
11 |
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
+
from rl_algo_impls.shared.module.utils import mlp
|
13 |
|
14 |
|
15 |
class GridnetDistribution(Distribution):
|
|
|
25 |
self.action_vec = action_vec
|
26 |
|
27 |
masks = masks.view(-1, masks.shape[-1])
|
28 |
+
split_masks = torch.split(masks, action_vec.tolist(), dim=1)
|
29 |
|
30 |
grid_logits = logits.reshape(-1, action_vec.sum())
|
31 |
split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1)
|
|
|
101 |
), f"No mask case unhandled in {self.__class__.__name__}"
|
102 |
logits = self._fc(obs)
|
103 |
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
|
104 |
+
return pi_forward(pi, actions)
|
105 |
|
106 |
@property
|
107 |
def action_shape(self) -> Tuple[int, ...]:
|
rl_algo_impls/shared/actor/gridnet_decoder.py
CHANGED
@@ -5,11 +5,10 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
from numpy.typing import NDArray
|
7 |
|
8 |
-
from rl_algo_impls.shared.actor import Actor, PiForward
|
9 |
-
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
10 |
from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
|
11 |
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
-
from rl_algo_impls.shared.module.
|
13 |
|
14 |
|
15 |
class Transpose(nn.Module):
|
@@ -73,7 +72,7 @@ class GridnetDecoder(Actor):
|
|
73 |
), f"No mask case unhandled in {self.__class__.__name__}"
|
74 |
logits = self.deconv(obs)
|
75 |
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
|
76 |
-
return
|
77 |
|
78 |
@property
|
79 |
def action_shape(self) -> Tuple[int, ...]:
|
|
|
5 |
import torch.nn as nn
|
6 |
from numpy.typing import NDArray
|
7 |
|
8 |
+
from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward
|
|
|
9 |
from rl_algo_impls.shared.actor.gridnet import GridnetDistribution
|
10 |
from rl_algo_impls.shared.encoder import EncoderOutDim
|
11 |
+
from rl_algo_impls.shared.module.utils import layer_init
|
12 |
|
13 |
|
14 |
class Transpose(nn.Module):
|
|
|
72 |
), f"No mask case unhandled in {self.__class__.__name__}"
|
73 |
logits = self.deconv(obs)
|
74 |
pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks)
|
75 |
+
return pi_forward(pi, actions)
|
76 |
|
77 |
@property
|
78 |
def action_shape(self) -> Tuple[int, ...]:
|
rl_algo_impls/shared/actor/make_actor.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import Tuple, Type
|
2 |
|
3 |
import gym
|
4 |
import torch.nn as nn
|
@@ -27,6 +27,7 @@ def actor_head(
|
|
27 |
full_std: bool = True,
|
28 |
squash_output: bool = False,
|
29 |
actor_head_style: str = "single",
|
|
|
30 |
) -> Actor:
|
31 |
assert not use_sde or isinstance(
|
32 |
action_space, Box
|
@@ -73,18 +74,20 @@ def actor_head(
|
|
73 |
init_layers_orthogonal=init_layers_orthogonal,
|
74 |
)
|
75 |
elif actor_head_style == "gridnet":
|
|
|
76 |
return GridnetActorHead(
|
77 |
-
action_space.nvec
|
78 |
-
|
79 |
in_dim=in_dim,
|
80 |
hidden_sizes=hidden_sizes,
|
81 |
activation=activation,
|
82 |
init_layers_orthogonal=init_layers_orthogonal,
|
83 |
)
|
84 |
elif actor_head_style == "gridnet_decoder":
|
|
|
85 |
return GridnetDecoder(
|
86 |
-
action_space.nvec
|
87 |
-
|
88 |
in_dim=in_dim,
|
89 |
activation=activation,
|
90 |
init_layers_orthogonal=init_layers_orthogonal,
|
|
|
1 |
+
from typing import Optional, Tuple, Type
|
2 |
|
3 |
import gym
|
4 |
import torch.nn as nn
|
|
|
27 |
full_std: bool = True,
|
28 |
squash_output: bool = False,
|
29 |
actor_head_style: str = "single",
|
30 |
+
action_plane_space: Optional[bool] = None,
|
31 |
) -> Actor:
|
32 |
assert not use_sde or isinstance(
|
33 |
action_space, Box
|
|
|
74 |
init_layers_orthogonal=init_layers_orthogonal,
|
75 |
)
|
76 |
elif actor_head_style == "gridnet":
|
77 |
+
assert isinstance(action_plane_space, MultiDiscrete)
|
78 |
return GridnetActorHead(
|
79 |
+
len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
|
80 |
+
action_plane_space.nvec, # type: ignore
|
81 |
in_dim=in_dim,
|
82 |
hidden_sizes=hidden_sizes,
|
83 |
activation=activation,
|
84 |
init_layers_orthogonal=init_layers_orthogonal,
|
85 |
)
|
86 |
elif actor_head_style == "gridnet_decoder":
|
87 |
+
assert isinstance(action_plane_space, MultiDiscrete)
|
88 |
return GridnetDecoder(
|
89 |
+
len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore
|
90 |
+
action_plane_space.nvec, # type: ignore
|
91 |
in_dim=in_dim,
|
92 |
activation=activation,
|
93 |
init_layers_orthogonal=init_layers_orthogonal,
|
rl_algo_impls/shared/actor/multi_discrete.py
CHANGED
@@ -6,10 +6,10 @@ import torch.nn as nn
|
|
6 |
from numpy.typing import NDArray
|
7 |
from torch.distributions import Distribution, constraints
|
8 |
|
9 |
-
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
10 |
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
11 |
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
-
from rl_algo_impls.shared.module.
|
13 |
|
14 |
|
15 |
class MultiCategorical(Distribution):
|
@@ -94,7 +94,7 @@ class MultiDiscreteActorHead(Actor):
|
|
94 |
) -> PiForward:
|
95 |
logits = self._fc(obs)
|
96 |
pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
|
97 |
-
return
|
98 |
|
99 |
@property
|
100 |
def action_shape(self) -> Tuple[int, ...]:
|
|
|
6 |
from numpy.typing import NDArray
|
7 |
from torch.distributions import Distribution, constraints
|
8 |
|
9 |
+
from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward
|
10 |
from rl_algo_impls.shared.actor.categorical import MaskedCategorical
|
11 |
from rl_algo_impls.shared.encoder import EncoderOutDim
|
12 |
+
from rl_algo_impls.shared.module.utils import mlp
|
13 |
|
14 |
|
15 |
class MultiCategorical(Distribution):
|
|
|
94 |
) -> PiForward:
|
95 |
logits = self._fc(obs)
|
96 |
pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks)
|
97 |
+
return pi_forward(pi, actions)
|
98 |
|
99 |
@property
|
100 |
def action_shape(self) -> Tuple[int, ...]:
|
rl_algo_impls/shared/actor/state_dependent_noise.py
CHANGED
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|
5 |
from torch.distributions import Distribution, Normal
|
6 |
|
7 |
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
8 |
-
from rl_algo_impls.shared.module.
|
9 |
|
10 |
|
11 |
class TanhBijector:
|
@@ -172,7 +172,7 @@ class StateDependentNoiseActorHead(Actor):
|
|
172 |
not action_masks
|
173 |
), f"{self.__class__.__name__} does not support action_masks"
|
174 |
pi = self._distribution(obs)
|
175 |
-
return
|
176 |
|
177 |
def sample_weights(self, batch_size: int = 1) -> None:
|
178 |
std = self._get_std()
|
@@ -185,16 +185,15 @@ class StateDependentNoiseActorHead(Actor):
|
|
185 |
def action_shape(self) -> Tuple[int, ...]:
|
186 |
return (self.act_dim,)
|
187 |
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
return PiForward(distribution, logp_a, entropy)
|
|
|
5 |
from torch.distributions import Distribution, Normal
|
6 |
|
7 |
from rl_algo_impls.shared.actor.actor import Actor, PiForward
|
8 |
+
from rl_algo_impls.shared.module.utils import mlp
|
9 |
|
10 |
|
11 |
class TanhBijector:
|
|
|
172 |
not action_masks
|
173 |
), f"{self.__class__.__name__} does not support action_masks"
|
174 |
pi = self._distribution(obs)
|
175 |
+
return pi_forward(pi, actions, self.bijector)
|
176 |
|
177 |
def sample_weights(self, batch_size: int = 1) -> None:
|
178 |
std = self._get_std()
|
|
|
185 |
def action_shape(self) -> Tuple[int, ...]:
|
186 |
return (self.act_dim,)
|
187 |
|
188 |
+
|
189 |
+
def pi_forward(
|
190 |
+
distribution: Distribution,
|
191 |
+
actions: Optional[torch.Tensor] = None,
|
192 |
+
bijector: Optional[TanhBijector] = None,
|
193 |
+
) -> PiForward:
|
194 |
+
logp_a = None
|
195 |
+
entropy = None
|
196 |
+
if actions is not None:
|
197 |
+
logp_a = distribution.log_prob(actions)
|
198 |
+
entropy = -logp_a if bijector else sum_independent_dims(distribution.entropy())
|
199 |
+
return PiForward(distribution, logp_a, entropy)
|
|
rl_algo_impls/shared/algorithm.py
CHANGED
@@ -1,11 +1,11 @@
|
|
|
|
|
|
|
|
1 |
import gym
|
2 |
import torch
|
3 |
-
|
4 |
-
from abc import ABC, abstractmethod
|
5 |
from torch.utils.tensorboard.writer import SummaryWriter
|
6 |
-
from typing import Optional, TypeVar
|
7 |
|
8 |
-
from rl_algo_impls.shared.callbacks
|
9 |
from rl_algo_impls.shared.policy.policy import Policy
|
10 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
|
11 |
|
@@ -32,7 +32,7 @@ class Algorithm(ABC):
|
|
32 |
def learn(
|
33 |
self: AlgorithmSelf,
|
34 |
train_timesteps: int,
|
35 |
-
|
36 |
total_timesteps: Optional[int] = None,
|
37 |
start_timesteps: int = 0,
|
38 |
) -> AlgorithmSelf:
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import List, Optional, TypeVar
|
3 |
+
|
4 |
import gym
|
5 |
import torch
|
|
|
|
|
6 |
from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
7 |
|
8 |
+
from rl_algo_impls.shared.callbacks import Callback
|
9 |
from rl_algo_impls.shared.policy.policy import Policy
|
10 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
|
11 |
|
|
|
32 |
def learn(
|
33 |
self: AlgorithmSelf,
|
34 |
train_timesteps: int,
|
35 |
+
callbacks: Optional[List[Callback]] = None,
|
36 |
total_timesteps: Optional[int] = None,
|
37 |
start_timesteps: int = 0,
|
38 |
) -> AlgorithmSelf:
|
rl_algo_impls/shared/callbacks/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from rl_algo_impls.shared.callbacks.callback import Callback
|
rl_algo_impls/shared/callbacks/eval_callback.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
import itertools
|
2 |
import os
|
|
|
3 |
from time import perf_counter
|
4 |
from typing import Dict, List, Optional, Union
|
5 |
|
6 |
import numpy as np
|
7 |
from torch.utils.tensorboard.writer import SummaryWriter
|
8 |
|
9 |
-
from rl_algo_impls.shared.callbacks
|
10 |
from rl_algo_impls.shared.policy.policy import Policy
|
11 |
from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
|
12 |
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
|
@@ -80,6 +81,7 @@ def evaluate(
|
|
80 |
print_returns: bool = True,
|
81 |
ignore_first_episode: bool = False,
|
82 |
additional_keys_to_log: Optional[List[str]] = None,
|
|
|
83 |
) -> EpisodesStats:
|
84 |
policy.sync_normalization(env)
|
85 |
policy.eval()
|
@@ -93,18 +95,21 @@ def evaluate(
|
|
93 |
)
|
94 |
|
95 |
obs = env.reset()
|
96 |
-
|
97 |
while not episodes.is_done():
|
98 |
act = policy.act(
|
99 |
obs,
|
100 |
deterministic=deterministic,
|
101 |
-
action_masks=
|
102 |
)
|
103 |
obs, rew, done, info = env.step(act)
|
104 |
episodes.step(rew, done, info)
|
105 |
if render:
|
106 |
env.render()
|
107 |
-
stats = EpisodesStats(
|
|
|
|
|
|
|
108 |
if print_returns:
|
109 |
print(stats)
|
110 |
return stats
|
@@ -127,6 +132,8 @@ class EvalCallback(Callback):
|
|
127 |
max_video_length: int = 3600,
|
128 |
ignore_first_episode: bool = False,
|
129 |
additional_keys_to_log: Optional[List[str]] = None,
|
|
|
|
|
130 |
) -> None:
|
131 |
super().__init__()
|
132 |
self.policy = policy
|
@@ -151,6 +158,8 @@ class EvalCallback(Callback):
|
|
151 |
self.best_video_base_path = None
|
152 |
self.ignore_first_episode = ignore_first_episode
|
153 |
self.additional_keys_to_log = additional_keys_to_log
|
|
|
|
|
154 |
|
155 |
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
156 |
super().on_step(timesteps_elapsed)
|
@@ -170,6 +179,7 @@ class EvalCallback(Callback):
|
|
170 |
print_returns=print_returns or False,
|
171 |
ignore_first_episode=self.ignore_first_episode,
|
172 |
additional_keys_to_log=self.additional_keys_to_log,
|
|
|
173 |
)
|
174 |
end_time = perf_counter()
|
175 |
self.tb_writer.add_scalar(
|
@@ -189,6 +199,15 @@ class EvalCallback(Callback):
|
|
189 |
assert self.best_model_path
|
190 |
self.policy.save(self.best_model_path)
|
191 |
print("Saved best model")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
self.best.write_to_tensorboard(
|
193 |
self.tb_writer, "best_eval", self.timesteps_elapsed
|
194 |
)
|
@@ -208,6 +227,7 @@ class EvalCallback(Callback):
|
|
208 |
1,
|
209 |
deterministic=self.deterministic,
|
210 |
print_returns=False,
|
|
|
211 |
)
|
212 |
print(f"Saved best video: {video_stats}")
|
213 |
|
|
|
1 |
import itertools
|
2 |
import os
|
3 |
+
import shutil
|
4 |
from time import perf_counter
|
5 |
from typing import Dict, List, Optional, Union
|
6 |
|
7 |
import numpy as np
|
8 |
from torch.utils.tensorboard.writer import SummaryWriter
|
9 |
|
10 |
+
from rl_algo_impls.shared.callbacks import Callback
|
11 |
from rl_algo_impls.shared.policy.policy import Policy
|
12 |
from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
|
13 |
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
|
|
|
81 |
print_returns: bool = True,
|
82 |
ignore_first_episode: bool = False,
|
83 |
additional_keys_to_log: Optional[List[str]] = None,
|
84 |
+
score_function: str = "mean-std",
|
85 |
) -> EpisodesStats:
|
86 |
policy.sync_normalization(env)
|
87 |
policy.eval()
|
|
|
95 |
)
|
96 |
|
97 |
obs = env.reset()
|
98 |
+
get_action_mask = getattr(env, "get_action_mask", None)
|
99 |
while not episodes.is_done():
|
100 |
act = policy.act(
|
101 |
obs,
|
102 |
deterministic=deterministic,
|
103 |
+
action_masks=get_action_mask() if get_action_mask else None,
|
104 |
)
|
105 |
obs, rew, done, info = env.step(act)
|
106 |
episodes.step(rew, done, info)
|
107 |
if render:
|
108 |
env.render()
|
109 |
+
stats = EpisodesStats(
|
110 |
+
episodes.episodes,
|
111 |
+
score_function=score_function,
|
112 |
+
)
|
113 |
if print_returns:
|
114 |
print(stats)
|
115 |
return stats
|
|
|
132 |
max_video_length: int = 3600,
|
133 |
ignore_first_episode: bool = False,
|
134 |
additional_keys_to_log: Optional[List[str]] = None,
|
135 |
+
score_function: str = "mean-std",
|
136 |
+
wandb_enabled: bool = False,
|
137 |
) -> None:
|
138 |
super().__init__()
|
139 |
self.policy = policy
|
|
|
158 |
self.best_video_base_path = None
|
159 |
self.ignore_first_episode = ignore_first_episode
|
160 |
self.additional_keys_to_log = additional_keys_to_log
|
161 |
+
self.score_function = score_function
|
162 |
+
self.wandb_enabled = wandb_enabled
|
163 |
|
164 |
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
165 |
super().on_step(timesteps_elapsed)
|
|
|
179 |
print_returns=print_returns or False,
|
180 |
ignore_first_episode=self.ignore_first_episode,
|
181 |
additional_keys_to_log=self.additional_keys_to_log,
|
182 |
+
score_function=self.score_function,
|
183 |
)
|
184 |
end_time = perf_counter()
|
185 |
self.tb_writer.add_scalar(
|
|
|
199 |
assert self.best_model_path
|
200 |
self.policy.save(self.best_model_path)
|
201 |
print("Saved best model")
|
202 |
+
if self.wandb_enabled:
|
203 |
+
import wandb
|
204 |
+
|
205 |
+
best_model_name = os.path.split(self.best_model_path)[-1]
|
206 |
+
shutil.make_archive(
|
207 |
+
os.path.join(wandb.run.dir, best_model_name), # type: ignore
|
208 |
+
"zip",
|
209 |
+
self.best_model_path,
|
210 |
+
)
|
211 |
self.best.write_to_tensorboard(
|
212 |
self.tb_writer, "best_eval", self.timesteps_elapsed
|
213 |
)
|
|
|
227 |
1,
|
228 |
deterministic=self.deterministic,
|
229 |
print_returns=False,
|
230 |
+
score_function=self.score_function,
|
231 |
)
|
232 |
print(f"Saved best video: {video_stats}")
|
233 |
|
rl_algo_impls/shared/callbacks/microrts_reward_decay_callback.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from rl_algo_impls.runner.config import Config
|
4 |
+
from rl_algo_impls.shared.callbacks import Callback
|
5 |
+
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
|
6 |
+
|
7 |
+
|
8 |
+
class MicrortsRewardDecayCallback(Callback):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
config: Config,
|
12 |
+
env: VecEnv,
|
13 |
+
start_timesteps: int = 0,
|
14 |
+
) -> None:
|
15 |
+
super().__init__()
|
16 |
+
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
|
17 |
+
|
18 |
+
unwrapped = env.unwrapped
|
19 |
+
assert isinstance(unwrapped, MicroRTSGridModeVecEnv)
|
20 |
+
self.microrts_env = unwrapped
|
21 |
+
self.base_reward_weights = self.microrts_env.reward_weight
|
22 |
+
|
23 |
+
self.total_train_timesteps = config.n_timesteps
|
24 |
+
self.timesteps_elapsed = start_timesteps
|
25 |
+
|
26 |
+
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
27 |
+
super().on_step(timesteps_elapsed)
|
28 |
+
|
29 |
+
progress = self.timesteps_elapsed / self.total_train_timesteps
|
30 |
+
# Decay all rewards except WinLoss
|
31 |
+
reward_weights = self.base_reward_weights * np.array(
|
32 |
+
[1] + [1 - progress] * (len(self.base_reward_weights) - 1)
|
33 |
+
)
|
34 |
+
self.microrts_env.reward_weight = reward_weights
|
35 |
+
|
36 |
+
return True
|
rl_algo_impls/shared/callbacks/optimize_callback.py
CHANGED
@@ -5,7 +5,7 @@ from time import perf_counter
|
|
5 |
from torch.utils.tensorboard.writer import SummaryWriter
|
6 |
from typing import NamedTuple, Union
|
7 |
|
8 |
-
from rl_algo_impls.shared.callbacks
|
9 |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
10 |
from rl_algo_impls.shared.policy.policy import Policy
|
11 |
from rl_algo_impls.shared.stats import EpisodesStats
|
|
|
5 |
from torch.utils.tensorboard.writer import SummaryWriter
|
6 |
from typing import NamedTuple, Union
|
7 |
|
8 |
+
from rl_algo_impls.shared.callbacks import Callback
|
9 |
from rl_algo_impls.shared.callbacks.eval_callback import evaluate
|
10 |
from rl_algo_impls.shared.policy.policy import Policy
|
11 |
from rl_algo_impls.shared.stats import EpisodesStats
|
rl_algo_impls/shared/callbacks/self_play_callback.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
|
3 |
+
from rl_algo_impls.shared.callbacks import Callback
|
4 |
+
from rl_algo_impls.shared.policy.policy import Policy
|
5 |
+
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
|
6 |
+
|
7 |
+
|
8 |
+
class SelfPlayCallback(Callback):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
policy: Policy,
|
12 |
+
policy_factory: Callable[[], Policy],
|
13 |
+
selfPlayWrapper: SelfPlayWrapper,
|
14 |
+
) -> None:
|
15 |
+
super().__init__()
|
16 |
+
self.policy = policy
|
17 |
+
self.policy_factory = policy_factory
|
18 |
+
self.selfPlayWrapper = selfPlayWrapper
|
19 |
+
self.checkpoint_policy()
|
20 |
+
|
21 |
+
def on_step(self, timesteps_elapsed: int = 1) -> bool:
|
22 |
+
super().on_step(timesteps_elapsed)
|
23 |
+
if (
|
24 |
+
self.timesteps_elapsed
|
25 |
+
>= self.last_checkpoint_step + self.selfPlayWrapper.save_steps
|
26 |
+
):
|
27 |
+
self.checkpoint_policy()
|
28 |
+
return True
|
29 |
+
|
30 |
+
def checkpoint_policy(self):
|
31 |
+
self.selfPlayWrapper.checkpoint_policy(
|
32 |
+
self.policy_factory().load_from(self.policy)
|
33 |
+
)
|
34 |
+
self.last_checkpoint_step = self.timesteps_elapsed
|
rl_algo_impls/shared/encoder/cnn.py
CHANGED
@@ -6,7 +6,7 @@ import numpy as np
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
|
9 |
-
from rl_algo_impls.shared.module.
|
10 |
|
11 |
EncoderOutDim = Union[int, Tuple[int, ...]]
|
12 |
|
|
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
|
9 |
+
from rl_algo_impls.shared.module.utils import layer_init
|
10 |
|
11 |
EncoderOutDim = Union[int, Tuple[int, ...]]
|
12 |
|
rl_algo_impls/shared/encoder/encoder.py
CHANGED
@@ -12,7 +12,7 @@ from rl_algo_impls.shared.encoder.gridnet_encoder import GridnetEncoder
|
|
12 |
from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn
|
13 |
from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn
|
14 |
from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn
|
15 |
-
from rl_algo_impls.shared.module.
|
16 |
|
17 |
CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = {
|
18 |
"nature": NatureCnn,
|
|
|
12 |
from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn
|
13 |
from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn
|
14 |
from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn
|
15 |
+
from rl_algo_impls.shared.module.utils import layer_init
|
16 |
|
17 |
CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = {
|
18 |
"nature": NatureCnn,
|
rl_algo_impls/shared/encoder/gridnet_encoder.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from rl_algo_impls.shared.encoder.cnn import CnnEncoder, EncoderOutDim
|
8 |
-
from rl_algo_impls.shared.module.
|
9 |
|
10 |
|
11 |
class GridnetEncoder(CnnEncoder):
|
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from rl_algo_impls.shared.encoder.cnn import CnnEncoder, EncoderOutDim
|
8 |
+
from rl_algo_impls.shared.module.utils import layer_init
|
9 |
|
10 |
|
11 |
class GridnetEncoder(CnnEncoder):
|
rl_algo_impls/shared/encoder/impala_cnn.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
|
8 |
-
from rl_algo_impls.shared.module.
|
9 |
|
10 |
|
11 |
class ResidualBlock(nn.Module):
|
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
|
8 |
+
from rl_algo_impls.shared.module.utils import layer_init
|
9 |
|
10 |
|
11 |
class ResidualBlock(nn.Module):
|
rl_algo_impls/shared/encoder/microrts_cnn.py
CHANGED
@@ -5,7 +5,7 @@ import torch
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
|
8 |
-
from rl_algo_impls.shared.module.
|
9 |
|
10 |
|
11 |
class MicrortsCnn(FlattenedCnnEncoder):
|
|
|
5 |
import torch.nn as nn
|
6 |
|
7 |
from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
|
8 |
+
from rl_algo_impls.shared.module.utils import layer_init
|
9 |
|
10 |
|
11 |
class MicrortsCnn(FlattenedCnnEncoder):
|
rl_algo_impls/shared/encoder/nature_cnn.py
CHANGED
@@ -4,7 +4,7 @@ import gym
|
|
4 |
import torch.nn as nn
|
5 |
|
6 |
from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
|
7 |
-
from rl_algo_impls.shared.module.
|
8 |
|
9 |
|
10 |
class NatureCnn(FlattenedCnnEncoder):
|
|
|
4 |
import torch.nn as nn
|
5 |
|
6 |
from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder
|
7 |
+
from rl_algo_impls.shared.module.utils import layer_init
|
8 |
|
9 |
|
10 |
class NatureCnn(FlattenedCnnEncoder):
|
rl_algo_impls/shared/gae.py
CHANGED
@@ -3,7 +3,7 @@ import torch
|
|
3 |
|
4 |
from typing import NamedTuple, Sequence
|
5 |
|
6 |
-
from rl_algo_impls.shared.policy.
|
7 |
from rl_algo_impls.shared.trajectory import Trajectory
|
8 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs
|
9 |
|
|
|
3 |
|
4 |
from typing import NamedTuple, Sequence
|
5 |
|
6 |
+
from rl_algo_impls.shared.policy.actor_critic import OnPolicy
|
7 |
from rl_algo_impls.shared.trajectory import Trajectory
|
8 |
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs
|
9 |
|
rl_algo_impls/shared/module/{module.py → utils.py}
RENAMED
File without changes
|
rl_algo_impls/shared/policy/{on_policy.py → actor_critic.py}
RENAMED
@@ -4,12 +4,14 @@ from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
|
|
4 |
import gym
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
-
from gym.spaces import Box,
|
8 |
|
9 |
-
from rl_algo_impls.shared.
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
13 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
14 |
VecEnv,
|
15 |
VecEnvObs,
|
@@ -52,21 +54,6 @@ def clamp_actions(
|
|
52 |
return actions
|
53 |
|
54 |
|
55 |
-
def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
|
56 |
-
if isinstance(obs_space, Box):
|
57 |
-
if len(obs_space.shape) == 3:
|
58 |
-
# By default feature extractor to output has no hidden layers
|
59 |
-
return []
|
60 |
-
elif len(obs_space.shape) == 1:
|
61 |
-
return [64, 64]
|
62 |
-
else:
|
63 |
-
raise ValueError(f"Unsupported observation space: {obs_space}")
|
64 |
-
elif isinstance(obs_space, Discrete):
|
65 |
-
return [64]
|
66 |
-
else:
|
67 |
-
raise ValueError(f"Unsupported observation space: {obs_space}")
|
68 |
-
|
69 |
-
|
70 |
class OnPolicy(Policy):
|
71 |
@abstractmethod
|
72 |
def value(self, obs: VecEnvObs) -> np.ndarray:
|
@@ -106,78 +93,59 @@ class ActorCritic(OnPolicy):
|
|
106 |
|
107 |
observation_space = single_observation_space(env)
|
108 |
action_space = single_action_space(env)
|
|
|
109 |
|
110 |
-
pi_hidden_sizes = (
|
111 |
-
pi_hidden_sizes
|
112 |
-
if pi_hidden_sizes is not None
|
113 |
-
else default_hidden_sizes(observation_space)
|
114 |
-
)
|
115 |
-
v_hidden_sizes = (
|
116 |
-
v_hidden_sizes
|
117 |
-
if v_hidden_sizes is not None
|
118 |
-
else default_hidden_sizes(observation_space)
|
119 |
-
)
|
120 |
-
|
121 |
-
activation = ACTIVATION[activation_fn]
|
122 |
self.action_space = action_space
|
123 |
self.squash_output = squash_output
|
124 |
-
self.share_features_extractor = share_features_extractor
|
125 |
-
self._feature_extractor = Encoder(
|
126 |
-
observation_space,
|
127 |
-
activation,
|
128 |
-
init_layers_orthogonal=init_layers_orthogonal,
|
129 |
-
cnn_flatten_dim=cnn_flatten_dim,
|
130 |
-
cnn_style=cnn_style,
|
131 |
-
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
132 |
-
impala_channels=impala_channels,
|
133 |
-
)
|
134 |
-
self._pi = actor_head(
|
135 |
-
self.action_space,
|
136 |
-
self._feature_extractor.out_dim,
|
137 |
-
tuple(pi_hidden_sizes),
|
138 |
-
init_layers_orthogonal,
|
139 |
-
activation,
|
140 |
-
log_std_init=log_std_init,
|
141 |
-
use_sde=use_sde,
|
142 |
-
full_std=full_std,
|
143 |
-
squash_output=squash_output,
|
144 |
-
actor_head_style=actor_head_style,
|
145 |
-
)
|
146 |
|
147 |
-
if
|
148 |
-
self.
|
149 |
observation_space,
|
150 |
-
|
|
|
|
|
151 |
init_layers_orthogonal=init_layers_orthogonal,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
cnn_flatten_dim=cnn_flatten_dim,
|
153 |
cnn_style=cnn_style,
|
154 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
|
|
|
|
|
|
155 |
)
|
156 |
-
critic_in_dim = self._v_feature_extractor.out_dim
|
157 |
else:
|
158 |
-
self.
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
return pi_forward, p_fe
|
177 |
-
|
178 |
-
def _v_forward(self, obs: torch.Tensor, p_fc: torch.Tensor) -> torch.Tensor:
|
179 |
-
v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc
|
180 |
-
return self._v(v_fe)
|
181 |
|
182 |
def forward(
|
183 |
self,
|
@@ -185,8 +153,7 @@ class ActorCritic(OnPolicy):
|
|
185 |
action: torch.Tensor,
|
186 |
action_masks: Optional[torch.Tensor] = None,
|
187 |
) -> ACForward:
|
188 |
-
(_, logp_a, entropy),
|
189 |
-
v = self._v_forward(obs, p_fc)
|
190 |
|
191 |
assert logp_a is not None
|
192 |
assert entropy is not None
|
@@ -195,24 +162,17 @@ class ActorCritic(OnPolicy):
|
|
195 |
def value(self, obs: VecEnvObs) -> np.ndarray:
|
196 |
o = self._as_tensor(obs)
|
197 |
with torch.no_grad():
|
198 |
-
|
199 |
-
self._v_feature_extractor(o)
|
200 |
-
if self._v_feature_extractor
|
201 |
-
else self._feature_extractor(o)
|
202 |
-
)
|
203 |
-
v = self._v(fe)
|
204 |
return v.cpu().numpy()
|
205 |
|
206 |
def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
|
207 |
o = self._as_tensor(obs)
|
208 |
a_masks = self._as_tensor(action_masks) if action_masks is not None else None
|
209 |
with torch.no_grad():
|
210 |
-
(pi, _, _),
|
211 |
a = pi.sample()
|
212 |
logp_a = pi.log_prob(a)
|
213 |
|
214 |
-
v = self._v_forward(o, p_fc)
|
215 |
-
|
216 |
a_np = a.cpu().numpy()
|
217 |
clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
|
218 |
return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
|
@@ -231,7 +191,9 @@ class ActorCritic(OnPolicy):
|
|
231 |
self._as_tensor(action_masks) if action_masks is not None else None
|
232 |
)
|
233 |
with torch.no_grad():
|
234 |
-
(pi, _, _), _ = self.
|
|
|
|
|
235 |
a = pi.mode
|
236 |
return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
|
237 |
|
@@ -239,11 +201,16 @@ class ActorCritic(OnPolicy):
|
|
239 |
super().load(path)
|
240 |
self.reset_noise()
|
241 |
|
|
|
|
|
|
|
|
|
|
|
242 |
def reset_noise(self, batch_size: Optional[int] = None) -> None:
|
243 |
-
self.
|
244 |
batch_size=batch_size if batch_size else self.env.num_envs
|
245 |
)
|
246 |
|
247 |
@property
|
248 |
def action_shape(self) -> Tuple[int, ...]:
|
249 |
-
return self.
|
|
|
4 |
import gym
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
+
from gym.spaces import Box, Space
|
8 |
|
9 |
+
from rl_algo_impls.shared.policy.actor_critic_network import (
|
10 |
+
ConnectedTrioActorCriticNetwork,
|
11 |
+
SeparateActorCriticNetwork,
|
12 |
+
UNetActorCriticNetwork,
|
13 |
+
)
|
14 |
+
from rl_algo_impls.shared.policy.policy import Policy
|
15 |
from rl_algo_impls.wrappers.vectorable_wrapper import (
|
16 |
VecEnv,
|
17 |
VecEnvObs,
|
|
|
54 |
return actions
|
55 |
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
class OnPolicy(Policy):
|
58 |
@abstractmethod
|
59 |
def value(self, obs: VecEnvObs) -> np.ndarray:
|
|
|
93 |
|
94 |
observation_space = single_observation_space(env)
|
95 |
action_space = single_action_space(env)
|
96 |
+
action_plane_space = getattr(env, "action_plane_space", None)
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
self.action_space = action_space
|
99 |
self.squash_output = squash_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
|
101 |
+
if actor_head_style == "unet":
|
102 |
+
self.network = UNetActorCriticNetwork(
|
103 |
observation_space,
|
104 |
+
action_space,
|
105 |
+
action_plane_space,
|
106 |
+
v_hidden_sizes=v_hidden_sizes,
|
107 |
init_layers_orthogonal=init_layers_orthogonal,
|
108 |
+
activation_fn=activation_fn,
|
109 |
+
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
110 |
+
)
|
111 |
+
elif share_features_extractor:
|
112 |
+
self.network = ConnectedTrioActorCriticNetwork(
|
113 |
+
observation_space,
|
114 |
+
action_space,
|
115 |
+
pi_hidden_sizes=pi_hidden_sizes,
|
116 |
+
v_hidden_sizes=v_hidden_sizes,
|
117 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
118 |
+
activation_fn=activation_fn,
|
119 |
+
log_std_init=log_std_init,
|
120 |
+
use_sde=use_sde,
|
121 |
+
full_std=full_std,
|
122 |
+
squash_output=squash_output,
|
123 |
cnn_flatten_dim=cnn_flatten_dim,
|
124 |
cnn_style=cnn_style,
|
125 |
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
126 |
+
impala_channels=impala_channels,
|
127 |
+
actor_head_style=actor_head_style,
|
128 |
+
action_plane_space=action_plane_space,
|
129 |
)
|
|
|
130 |
else:
|
131 |
+
self.network = SeparateActorCriticNetwork(
|
132 |
+
observation_space,
|
133 |
+
action_space,
|
134 |
+
pi_hidden_sizes=pi_hidden_sizes,
|
135 |
+
v_hidden_sizes=v_hidden_sizes,
|
136 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
137 |
+
activation_fn=activation_fn,
|
138 |
+
log_std_init=log_std_init,
|
139 |
+
use_sde=use_sde,
|
140 |
+
full_std=full_std,
|
141 |
+
squash_output=squash_output,
|
142 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
143 |
+
cnn_style=cnn_style,
|
144 |
+
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
145 |
+
impala_channels=impala_channels,
|
146 |
+
actor_head_style=actor_head_style,
|
147 |
+
action_plane_space=action_plane_space,
|
148 |
+
)
|
|
|
|
|
|
|
|
|
|
|
149 |
|
150 |
def forward(
|
151 |
self,
|
|
|
153 |
action: torch.Tensor,
|
154 |
action_masks: Optional[torch.Tensor] = None,
|
155 |
) -> ACForward:
|
156 |
+
(_, logp_a, entropy), v = self.network(obs, action, action_masks=action_masks)
|
|
|
157 |
|
158 |
assert logp_a is not None
|
159 |
assert entropy is not None
|
|
|
162 |
def value(self, obs: VecEnvObs) -> np.ndarray:
|
163 |
o = self._as_tensor(obs)
|
164 |
with torch.no_grad():
|
165 |
+
v = self.network.value(o)
|
|
|
|
|
|
|
|
|
|
|
166 |
return v.cpu().numpy()
|
167 |
|
168 |
def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step:
|
169 |
o = self._as_tensor(obs)
|
170 |
a_masks = self._as_tensor(action_masks) if action_masks is not None else None
|
171 |
with torch.no_grad():
|
172 |
+
(pi, _, _), v = self.network.distribution_and_value(o, action_masks=a_masks)
|
173 |
a = pi.sample()
|
174 |
logp_a = pi.log_prob(a)
|
175 |
|
|
|
|
|
176 |
a_np = a.cpu().numpy()
|
177 |
clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output)
|
178 |
return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np)
|
|
|
191 |
self._as_tensor(action_masks) if action_masks is not None else None
|
192 |
)
|
193 |
with torch.no_grad():
|
194 |
+
(pi, _, _), _ = self.network.distribution_and_value(
|
195 |
+
o, action_masks=a_masks
|
196 |
+
)
|
197 |
a = pi.mode
|
198 |
return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output)
|
199 |
|
|
|
201 |
super().load(path)
|
202 |
self.reset_noise()
|
203 |
|
204 |
+
def load_from(self: ActorCriticSelf, policy: ActorCriticSelf) -> ActorCriticSelf:
|
205 |
+
super().load_from(policy)
|
206 |
+
self.reset_noise()
|
207 |
+
return self
|
208 |
+
|
209 |
def reset_noise(self, batch_size: Optional[int] = None) -> None:
|
210 |
+
self.network.reset_noise(
|
211 |
batch_size=batch_size if batch_size else self.env.num_envs
|
212 |
)
|
213 |
|
214 |
@property
|
215 |
def action_shape(self) -> Tuple[int, ...]:
|
216 |
+
return self.network.action_shape
|
rl_algo_impls/shared/policy/actor_critic_network/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rl_algo_impls.shared.policy.actor_critic_network.connected_trio import (
|
2 |
+
ConnectedTrioActorCriticNetwork,
|
3 |
+
)
|
4 |
+
from rl_algo_impls.shared.policy.actor_critic_network.network import (
|
5 |
+
ActorCriticNetwork,
|
6 |
+
default_hidden_sizes,
|
7 |
+
)
|
8 |
+
from rl_algo_impls.shared.policy.actor_critic_network.separate_actor_critic import (
|
9 |
+
SeparateActorCriticNetwork,
|
10 |
+
)
|
11 |
+
from rl_algo_impls.shared.policy.actor_critic_network.unet import UNetActorCriticNetwork
|
rl_algo_impls/shared/policy/actor_critic_network/connected_trio.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Sequence, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from gym.spaces import Space
|
5 |
+
|
6 |
+
from rl_algo_impls.shared.actor import actor_head
|
7 |
+
from rl_algo_impls.shared.encoder import Encoder
|
8 |
+
from rl_algo_impls.shared.policy.actor_critic_network.network import (
|
9 |
+
ACNForward,
|
10 |
+
ActorCriticNetwork,
|
11 |
+
default_hidden_sizes,
|
12 |
+
)
|
13 |
+
from rl_algo_impls.shared.policy.critic import CriticHead
|
14 |
+
from rl_algo_impls.shared.policy.policy import ACTIVATION
|
15 |
+
|
16 |
+
|
17 |
+
class ConnectedTrioActorCriticNetwork(ActorCriticNetwork):
|
18 |
+
"""Encode (feature extractor), decoder (actor head), critic head networks"""
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
observation_space: Space,
|
23 |
+
action_space: Space,
|
24 |
+
pi_hidden_sizes: Optional[Sequence[int]] = None,
|
25 |
+
v_hidden_sizes: Optional[Sequence[int]] = None,
|
26 |
+
init_layers_orthogonal: bool = True,
|
27 |
+
activation_fn: str = "tanh",
|
28 |
+
log_std_init: float = -0.5,
|
29 |
+
use_sde: bool = False,
|
30 |
+
full_std: bool = True,
|
31 |
+
squash_output: bool = False,
|
32 |
+
cnn_flatten_dim: int = 512,
|
33 |
+
cnn_style: str = "nature",
|
34 |
+
cnn_layers_init_orthogonal: Optional[bool] = None,
|
35 |
+
impala_channels: Sequence[int] = (16, 32, 32),
|
36 |
+
actor_head_style: str = "single",
|
37 |
+
action_plane_space: Optional[Space] = None,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
pi_hidden_sizes = (
|
42 |
+
pi_hidden_sizes
|
43 |
+
if pi_hidden_sizes is not None
|
44 |
+
else default_hidden_sizes(observation_space)
|
45 |
+
)
|
46 |
+
v_hidden_sizes = (
|
47 |
+
v_hidden_sizes
|
48 |
+
if v_hidden_sizes is not None
|
49 |
+
else default_hidden_sizes(observation_space)
|
50 |
+
)
|
51 |
+
|
52 |
+
activation = ACTIVATION[activation_fn]
|
53 |
+
self._feature_extractor = Encoder(
|
54 |
+
observation_space,
|
55 |
+
activation,
|
56 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
57 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
58 |
+
cnn_style=cnn_style,
|
59 |
+
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
60 |
+
impala_channels=impala_channels,
|
61 |
+
)
|
62 |
+
self._pi = actor_head(
|
63 |
+
action_space,
|
64 |
+
self._feature_extractor.out_dim,
|
65 |
+
tuple(pi_hidden_sizes),
|
66 |
+
init_layers_orthogonal,
|
67 |
+
activation,
|
68 |
+
log_std_init=log_std_init,
|
69 |
+
use_sde=use_sde,
|
70 |
+
full_std=full_std,
|
71 |
+
squash_output=squash_output,
|
72 |
+
actor_head_style=actor_head_style,
|
73 |
+
action_plane_space=action_plane_space,
|
74 |
+
)
|
75 |
+
|
76 |
+
self._v = CriticHead(
|
77 |
+
in_dim=self._feature_extractor.out_dim,
|
78 |
+
hidden_sizes=v_hidden_sizes,
|
79 |
+
activation=activation,
|
80 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
81 |
+
)
|
82 |
+
|
83 |
+
def forward(
|
84 |
+
self,
|
85 |
+
obs: torch.Tensor,
|
86 |
+
action: torch.Tensor,
|
87 |
+
action_masks: Optional[torch.Tensor] = None,
|
88 |
+
) -> ACNForward:
|
89 |
+
return self._distribution_and_value(
|
90 |
+
obs, action=action, action_masks=action_masks
|
91 |
+
)
|
92 |
+
|
93 |
+
def distribution_and_value(
|
94 |
+
self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
|
95 |
+
) -> ACNForward:
|
96 |
+
return self._distribution_and_value(obs, action_masks=action_masks)
|
97 |
+
|
98 |
+
def _distribution_and_value(
|
99 |
+
self,
|
100 |
+
obs: torch.Tensor,
|
101 |
+
action: Optional[torch.Tensor] = None,
|
102 |
+
action_masks: Optional[torch.Tensor] = None,
|
103 |
+
) -> ACNForward:
|
104 |
+
encoded = self._feature_extractor(obs)
|
105 |
+
pi_forward = self._pi(encoded, actions=action, action_masks=action_masks)
|
106 |
+
v = self._v(encoded)
|
107 |
+
return ACNForward(pi_forward, v)
|
108 |
+
|
109 |
+
def value(self, obs: torch.Tensor) -> torch.Tensor:
|
110 |
+
encoded = self._feature_extractor(obs)
|
111 |
+
return self._v(encoded)
|
112 |
+
|
113 |
+
def reset_noise(self, batch_size: int) -> None:
|
114 |
+
self._pi.sample_weights(batch_size=batch_size)
|
115 |
+
|
116 |
+
@property
|
117 |
+
def action_shape(self) -> Tuple[int, ...]:
|
118 |
+
return self._pi.action_shape
|
rl_algo_impls/shared/policy/actor_critic_network/network.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import NamedTuple, Optional, Sequence, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from gym.spaces import Box, Discrete, Space
|
7 |
+
|
8 |
+
from rl_algo_impls.shared.actor import PiForward
|
9 |
+
|
10 |
+
|
11 |
+
class ACNForward(NamedTuple):
|
12 |
+
pi_forward: PiForward
|
13 |
+
v: torch.Tensor
|
14 |
+
|
15 |
+
|
16 |
+
class ActorCriticNetwork(nn.Module, ABC):
|
17 |
+
@abstractmethod
|
18 |
+
def forward(
|
19 |
+
self,
|
20 |
+
obs: torch.Tensor,
|
21 |
+
action: torch.Tensor,
|
22 |
+
action_masks: Optional[torch.Tensor] = None,
|
23 |
+
) -> ACNForward:
|
24 |
+
...
|
25 |
+
|
26 |
+
@abstractmethod
|
27 |
+
def distribution_and_value(
|
28 |
+
self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
|
29 |
+
) -> ACNForward:
|
30 |
+
...
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def value(self, obs: torch.Tensor) -> torch.Tensor:
|
34 |
+
...
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def reset_noise(self, batch_size: Optional[int] = None) -> None:
|
38 |
+
...
|
39 |
+
|
40 |
+
@property
|
41 |
+
def action_shape(self) -> Tuple[int, ...]:
|
42 |
+
...
|
43 |
+
|
44 |
+
|
45 |
+
def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
|
46 |
+
if isinstance(obs_space, Box):
|
47 |
+
if len(obs_space.shape) == 3: # type: ignore
|
48 |
+
# By default feature extractor to output has no hidden layers
|
49 |
+
return []
|
50 |
+
elif len(obs_space.shape) == 1: # type: ignore
|
51 |
+
return [64, 64]
|
52 |
+
else:
|
53 |
+
raise ValueError(f"Unsupported observation space: {obs_space}")
|
54 |
+
elif isinstance(obs_space, Discrete):
|
55 |
+
return [64]
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Unsupported observation space: {obs_space}")
|
rl_algo_impls/shared/policy/actor_critic_network/separate_actor_critic.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, Sequence, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from gym.spaces import Space
|
6 |
+
|
7 |
+
from rl_algo_impls.shared.actor import actor_head
|
8 |
+
from rl_algo_impls.shared.encoder import Encoder
|
9 |
+
from rl_algo_impls.shared.policy.actor_critic_network.network import (
|
10 |
+
ACNForward,
|
11 |
+
ActorCriticNetwork,
|
12 |
+
default_hidden_sizes,
|
13 |
+
)
|
14 |
+
from rl_algo_impls.shared.policy.critic import CriticHead
|
15 |
+
from rl_algo_impls.shared.policy.policy import ACTIVATION
|
16 |
+
|
17 |
+
|
18 |
+
class SeparateActorCriticNetwork(ActorCriticNetwork):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
observation_space: Space,
|
22 |
+
action_space: Space,
|
23 |
+
pi_hidden_sizes: Optional[Sequence[int]] = None,
|
24 |
+
v_hidden_sizes: Optional[Sequence[int]] = None,
|
25 |
+
init_layers_orthogonal: bool = True,
|
26 |
+
activation_fn: str = "tanh",
|
27 |
+
log_std_init: float = -0.5,
|
28 |
+
use_sde: bool = False,
|
29 |
+
full_std: bool = True,
|
30 |
+
squash_output: bool = False,
|
31 |
+
cnn_flatten_dim: int = 512,
|
32 |
+
cnn_style: str = "nature",
|
33 |
+
cnn_layers_init_orthogonal: Optional[bool] = None,
|
34 |
+
impala_channels: Sequence[int] = (16, 32, 32),
|
35 |
+
actor_head_style: str = "single",
|
36 |
+
action_plane_space: Optional[Space] = None,
|
37 |
+
) -> None:
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
pi_hidden_sizes = (
|
41 |
+
pi_hidden_sizes
|
42 |
+
if pi_hidden_sizes is not None
|
43 |
+
else default_hidden_sizes(observation_space)
|
44 |
+
)
|
45 |
+
v_hidden_sizes = (
|
46 |
+
v_hidden_sizes
|
47 |
+
if v_hidden_sizes is not None
|
48 |
+
else default_hidden_sizes(observation_space)
|
49 |
+
)
|
50 |
+
|
51 |
+
activation = ACTIVATION[activation_fn]
|
52 |
+
self._feature_extractor = Encoder(
|
53 |
+
observation_space,
|
54 |
+
activation,
|
55 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
56 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
57 |
+
cnn_style=cnn_style,
|
58 |
+
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
59 |
+
impala_channels=impala_channels,
|
60 |
+
)
|
61 |
+
self._pi = actor_head(
|
62 |
+
action_space,
|
63 |
+
self._feature_extractor.out_dim,
|
64 |
+
tuple(pi_hidden_sizes),
|
65 |
+
init_layers_orthogonal,
|
66 |
+
activation,
|
67 |
+
log_std_init=log_std_init,
|
68 |
+
use_sde=use_sde,
|
69 |
+
full_std=full_std,
|
70 |
+
squash_output=squash_output,
|
71 |
+
actor_head_style=actor_head_style,
|
72 |
+
action_plane_space=action_plane_space,
|
73 |
+
)
|
74 |
+
|
75 |
+
v_encoder = Encoder(
|
76 |
+
observation_space,
|
77 |
+
activation,
|
78 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
79 |
+
cnn_flatten_dim=cnn_flatten_dim,
|
80 |
+
cnn_style=cnn_style,
|
81 |
+
cnn_layers_init_orthogonal=cnn_layers_init_orthogonal,
|
82 |
+
)
|
83 |
+
self._v = nn.Sequential(
|
84 |
+
v_encoder,
|
85 |
+
CriticHead(
|
86 |
+
in_dim=v_encoder.out_dim,
|
87 |
+
hidden_sizes=v_hidden_sizes,
|
88 |
+
activation=activation,
|
89 |
+
init_layers_orthogonal=init_layers_orthogonal,
|
90 |
+
),
|
91 |
+
)
|
92 |
+
|
93 |
+
def forward(
|
94 |
+
self,
|
95 |
+
obs: torch.Tensor,
|
96 |
+
action: torch.Tensor,
|
97 |
+
action_masks: Optional[torch.Tensor] = None,
|
98 |
+
) -> ACNForward:
|
99 |
+
return self._distribution_and_value(
|
100 |
+
obs, action=action, action_masks=action_masks
|
101 |
+
)
|
102 |
+
|
103 |
+
def distribution_and_value(
|
104 |
+
self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None
|
105 |
+
) -> ACNForward:
|
106 |
+
return self._distribution_and_value(obs, action_masks=action_masks)
|
107 |
+
|
108 |
+
def _distribution_and_value(
|
109 |
+
self,
|
110 |
+
obs: torch.Tensor,
|
111 |
+
action: Optional[torch.Tensor] = None,
|
112 |
+
action_masks: Optional[torch.Tensor] = None,
|
113 |
+
) -> ACNForward:
|
114 |
+
pi_forward = self._pi(
|
115 |
+
self._feature_extractor(obs), actions=action, action_masks=action_masks
|
116 |
+
)
|
117 |
+
v = self._v(obs)
|
118 |
+
return ACNForward(pi_forward, v)
|
119 |
+
|
120 |
+
def value(self, obs: torch.Tensor) -> torch.Tensor:
|
121 |
+
return self._v(obs)
|
122 |
+
|
123 |
+
def reset_noise(self, batch_size: int) -> None:
|
124 |
+
self._pi.sample_weights(batch_size=batch_size)
|
125 |
+
|
126 |
+
@property
|
127 |
+
def action_shape(self) -> Tuple[int, ...]:
|
128 |
+
return self._pi.action_shape
|