File size: 5,546 Bytes
2a8bf2e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import copy
import random
from collections import deque
from typing import Any, Deque, Dict, List, Optional

import numpy as np

from rl_algo_impls.runner.config import Config
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
from rl_algo_impls.wrappers.vectorable_wrapper import (
    VecEnvObs,
    VecEnvStepReturn,
    VecotarableWrapper,
)


class SelfPlayWrapper(VecotarableWrapper):
    next_obs: VecEnvObs
    next_action_masks: Optional[np.ndarray]

    def __init__(
        self,
        env,
        config: Config,
        num_old_policies: int = 0,
        save_steps: int = 20_000,
        swap_steps: int = 10_000,
        window: int = 10,
        swap_window_size: int = 2,
        selfplay_bots: Optional[Dict[str, Any]] = None,
        bot_always_player_2: bool = False,
    ) -> None:
        super().__init__(env)
        assert num_old_policies % 2 == 0, f"num_old_policies must be even"
        assert (
            num_old_policies % swap_window_size == 0
        ), f"num_old_policies must be a multiple of swap_window_size"

        self.config = config
        self.num_old_policies = num_old_policies
        self.save_steps = save_steps
        self.swap_steps = swap_steps
        self.swap_window_size = swap_window_size
        self.selfplay_bots = selfplay_bots
        self.bot_always_player_2 = bot_always_player_2

        self.policies: Deque[Policy] = deque(maxlen=window)
        self.policy_assignments: List[Optional[Policy]] = [None] * env.num_envs
        self.steps_since_swap = np.zeros(env.num_envs)

        self.selfplay_policies: Dict[str, Policy] = {}

        self.num_envs = env.num_envs - num_old_policies

        if self.selfplay_bots:
            self.num_envs -= sum(self.selfplay_bots.values())
            self.initialize_selfplay_bots()

    def get_action_mask(self) -> Optional[np.ndarray]:
        return self.env.get_action_mask()[self.learner_indexes()]

    def learner_indexes(self) -> List[int]:
        return [p is None for p in self.policy_assignments]

    def checkpoint_policy(self, copied_policy: Policy) -> None:
        copied_policy.train(False)
        self.policies.append(copied_policy)

        if all(p is None for p in self.policy_assignments[: 2 * self.num_old_policies]):
            for i in range(self.num_old_policies):
                # Switch between player 1 and 2
                self.policy_assignments[
                    2 * i + (i % 2 if not self.bot_always_player_2 else 1)
                ] = copied_policy

    def swap_policy(self, idx: int, swap_window_size: int = 1) -> None:
        policy = random.choice(self.policies)
        idx = idx // 2 * 2
        for j in range(swap_window_size * 2):
            if self.policy_assignments[idx + j]:
                self.policy_assignments[idx + j] = policy
        self.steps_since_swap[idx : idx + swap_window_size * 2] = np.zeros(
            swap_window_size * 2
        )

    def initialize_selfplay_bots(self) -> None:
        if not self.selfplay_bots:
            return
        from rl_algo_impls.runner.running_utils import get_device, make_policy

        env = self.env  # Type: ignore
        device = get_device(self.config, env)
        start_idx = 2 * self.num_old_policies
        for model_path, n in self.selfplay_bots.items():
            policy = make_policy(
                self.config.algo,
                env,
                device,
                load_path=model_path,
                **self.config.policy_hyperparams,
            ).eval()
            self.selfplay_policies["model_path"] = policy
            for idx in range(start_idx, start_idx + 2 * n, 2):
                bot_idx = (
                    (idx + 1) if self.bot_always_player_2 else (idx + idx // 2 % 2)
                )
                self.policy_assignments[bot_idx] = policy
            start_idx += 2 * n

    def step(self, actions: np.ndarray) -> VecEnvStepReturn:
        env = self.env  # type: ignore
        all_actions = np.zeros((env.num_envs,) + actions.shape[1:], dtype=actions.dtype)
        orig_learner_indexes = self.learner_indexes()

        all_actions[orig_learner_indexes] = actions
        for policy in set(p for p in self.policy_assignments if p):
            policy_indexes = [policy == p for p in self.policy_assignments]
            if any(policy_indexes):
                all_actions[policy_indexes] = policy.act(
                    self.next_obs[policy_indexes],
                    deterministic=False,
                    action_masks=self.next_action_masks[policy_indexes]
                    if self.next_action_masks is not None
                    else None,
                )
        self.next_obs, rew, done, info = env.step(all_actions)
        self.next_action_masks = self.env.get_action_mask()

        rew = rew[orig_learner_indexes]
        info = [i for i, b in zip(info, orig_learner_indexes) if b]

        self.steps_since_swap += 1
        for idx in range(0, self.num_old_policies * 2, 2 * self.swap_window_size):
            if self.steps_since_swap[idx] > self.swap_steps:
                self.swap_policy(idx, self.swap_window_size)

        new_learner_indexes = self.learner_indexes()
        return self.next_obs[new_learner_indexes], rew, done[new_learner_indexes], info

    def reset(self) -> VecEnvObs:
        self.next_obs = super().reset()
        self.next_action_masks = self.env.get_action_mask()
        return self.next_obs[self.learner_indexes()]