DQN playing SpaceInvadersNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
271ebe6
import copy | |
import random | |
from collections import deque | |
from typing import Any, Deque, Dict, List, Optional | |
import numpy as np | |
from rl_algo_impls.runner.config import Config | |
from rl_algo_impls.shared.policy.policy import Policy | |
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker | |
from rl_algo_impls.wrappers.vectorable_wrapper import ( | |
VecEnvObs, | |
VecEnvStepReturn, | |
VecotarableWrapper, | |
) | |
class SelfPlayWrapper(VecotarableWrapper): | |
next_obs: VecEnvObs | |
next_action_masks: Optional[np.ndarray] | |
def __init__( | |
self, | |
env, | |
config: Config, | |
num_old_policies: int = 0, | |
save_steps: int = 20_000, | |
swap_steps: int = 10_000, | |
window: int = 10, | |
swap_window_size: int = 2, | |
selfplay_bots: Optional[Dict[str, Any]] = None, | |
bot_always_player_2: bool = False, | |
) -> None: | |
super().__init__(env) | |
assert num_old_policies % 2 == 0, f"num_old_policies must be even" | |
assert ( | |
num_old_policies % swap_window_size == 0 | |
), f"num_old_policies must be a multiple of swap_window_size" | |
self.config = config | |
self.num_old_policies = num_old_policies | |
self.save_steps = save_steps | |
self.swap_steps = swap_steps | |
self.swap_window_size = swap_window_size | |
self.selfplay_bots = selfplay_bots | |
self.bot_always_player_2 = bot_always_player_2 | |
self.policies: Deque[Policy] = deque(maxlen=window) | |
self.policy_assignments: List[Optional[Policy]] = [None] * env.num_envs | |
self.steps_since_swap = np.zeros(env.num_envs) | |
self.selfplay_policies: Dict[str, Policy] = {} | |
self.num_envs = env.num_envs - num_old_policies | |
if self.selfplay_bots: | |
self.num_envs -= sum(self.selfplay_bots.values()) | |
self.initialize_selfplay_bots() | |
def get_action_mask(self) -> Optional[np.ndarray]: | |
return self.env.get_action_mask()[self.learner_indexes()] | |
def learner_indexes(self) -> List[int]: | |
return [p is None for p in self.policy_assignments] | |
def checkpoint_policy(self, copied_policy: Policy) -> None: | |
copied_policy.train(False) | |
self.policies.append(copied_policy) | |
if all(p is None for p in self.policy_assignments[: 2 * self.num_old_policies]): | |
for i in range(self.num_old_policies): | |
# Switch between player 1 and 2 | |
self.policy_assignments[ | |
2 * i + (i % 2 if not self.bot_always_player_2 else 1) | |
] = copied_policy | |
def swap_policy(self, idx: int, swap_window_size: int = 1) -> None: | |
policy = random.choice(self.policies) | |
idx = idx // 2 * 2 | |
for j in range(swap_window_size * 2): | |
if self.policy_assignments[idx + j]: | |
self.policy_assignments[idx + j] = policy | |
self.steps_since_swap[idx : idx + swap_window_size * 2] = np.zeros( | |
swap_window_size * 2 | |
) | |
def initialize_selfplay_bots(self) -> None: | |
if not self.selfplay_bots: | |
return | |
from rl_algo_impls.runner.running_utils import get_device, make_policy | |
env = self.env # Type: ignore | |
device = get_device(self.config, env) | |
start_idx = 2 * self.num_old_policies | |
for model_path, n in self.selfplay_bots.items(): | |
policy = make_policy( | |
self.config.algo, | |
env, | |
device, | |
load_path=model_path, | |
**self.config.policy_hyperparams, | |
).eval() | |
self.selfplay_policies["model_path"] = policy | |
for idx in range(start_idx, start_idx + 2 * n, 2): | |
bot_idx = ( | |
(idx + 1) if self.bot_always_player_2 else (idx + idx // 2 % 2) | |
) | |
self.policy_assignments[bot_idx] = policy | |
start_idx += 2 * n | |
def step(self, actions: np.ndarray) -> VecEnvStepReturn: | |
env = self.env # type: ignore | |
all_actions = np.zeros((env.num_envs,) + actions.shape[1:], dtype=actions.dtype) | |
orig_learner_indexes = self.learner_indexes() | |
all_actions[orig_learner_indexes] = actions | |
for policy in set(p for p in self.policy_assignments if p): | |
policy_indexes = [policy == p for p in self.policy_assignments] | |
if any(policy_indexes): | |
all_actions[policy_indexes] = policy.act( | |
self.next_obs[policy_indexes], | |
deterministic=False, | |
action_masks=self.next_action_masks[policy_indexes] | |
if self.next_action_masks is not None | |
else None, | |
) | |
self.next_obs, rew, done, info = env.step(all_actions) | |
self.next_action_masks = self.env.get_action_mask() | |
rew = rew[orig_learner_indexes] | |
info = [i for i, b in zip(info, orig_learner_indexes) if b] | |
self.steps_since_swap += 1 | |
for idx in range(0, self.num_old_policies * 2, 2 * self.swap_window_size): | |
if self.steps_since_swap[idx] > self.swap_steps: | |
self.swap_policy(idx, self.swap_window_size) | |
new_learner_indexes = self.learner_indexes() | |
return self.next_obs[new_learner_indexes], rew, done[new_learner_indexes], info | |
def reset(self) -> VecEnvObs: | |
self.next_obs = super().reset() | |
self.next_action_masks = self.env.get_action_mask() | |
return self.next_obs[self.learner_indexes()] | |