VPG playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
9dc837c
from stable_baselines3.common.vec_env.base_vec_env import VecEnv | |
from typing import Optional, Sequence | |
from gym.spaces import Box, Discrete | |
from shared.policy.on_policy import ActorCritic, default_hidden_sizes | |
class PPOActorCritic(ActorCritic): | |
def __init__( | |
self, | |
env: VecEnv, | |
pi_hidden_sizes: Optional[Sequence[int]] = None, | |
v_hidden_sizes: Optional[Sequence[int]] = None, | |
**kwargs, | |
) -> None: | |
pi_hidden_sizes = ( | |
pi_hidden_sizes | |
if pi_hidden_sizes is not None | |
else default_hidden_sizes(env.observation_space) | |
) | |
v_hidden_sizes = ( | |
v_hidden_sizes | |
if v_hidden_sizes is not None | |
else default_hidden_sizes(env.observation_space) | |
) | |
super().__init__( | |
env, | |
pi_hidden_sizes, | |
v_hidden_sizes, | |
**kwargs, | |
) | |