File size: 1,133 Bytes
946448b 1cde088 946448b 1cde088 946448b 1cde088 946448b 1cde088 946448b 1cde088 946448b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from typing import Optional, Union
import numpy as np
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnv,
VecotarableWrapper,
find_wrapper,
)
class IncompleteArrayError(Exception):
pass
class SingleActionMaskWrapper(VecotarableWrapper):
def get_action_mask(self) -> Optional[np.ndarray]:
envs = getattr(self.env.unwrapped, "envs", None) # type: ignore
assert (
envs
), f"{self.__class__.__name__} expects to wrap synchronous vectorized env"
masks = [getattr(e.unwrapped, "action_mask", None) for e in envs]
assert all(m is not None for m in masks)
return np.array(masks, dtype=np.bool_)
class MicrortsMaskWrapper(VecotarableWrapper):
def get_action_mask(self) -> np.ndarray:
return self.env.get_action_mask().astype(bool) # type: ignore
def find_action_masker(
env: VecEnv,
) -> Optional[Union[SingleActionMaskWrapper, MicrortsMaskWrapper]]:
return find_wrapper(env, SingleActionMaskWrapper) or find_wrapper(
env, MicrortsMaskWrapper
)
|