vpg-CartPole-v1 / ppo /policy.py
sgoodfriend's picture
VPG playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
9dc837c
raw
history blame
916 Bytes
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
from typing import Optional, Sequence
from gym.spaces import Box, Discrete
from shared.policy.on_policy import ActorCritic, default_hidden_sizes
class PPOActorCritic(ActorCritic):
def __init__(
self,
env: VecEnv,
pi_hidden_sizes: Optional[Sequence[int]] = None,
v_hidden_sizes: Optional[Sequence[int]] = None,
**kwargs,
) -> None:
pi_hidden_sizes = (
pi_hidden_sizes
if pi_hidden_sizes is not None
else default_hidden_sizes(env.observation_space)
)
v_hidden_sizes = (
v_hidden_sizes
if v_hidden_sizes is not None
else default_hidden_sizes(env.observation_space)
)
super().__init__(
env,
pi_hidden_sizes,
v_hidden_sizes,
**kwargs,
)