sgoodfriend's picture
PPO playing QbertNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
a861765
from typing import Callable
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
class SelfPlayCallback(Callback):
def __init__(
self,
policy: Policy,
policy_factory: Callable[[], Policy],
selfPlayWrapper: SelfPlayWrapper,
) -> None:
super().__init__()
self.policy = policy
self.policy_factory = policy_factory
self.selfPlayWrapper = selfPlayWrapper
self.checkpoint_policy()
def on_step(self, timesteps_elapsed: int = 1) -> bool:
super().on_step(timesteps_elapsed)
if (
self.timesteps_elapsed
>= self.last_checkpoint_step + self.selfPlayWrapper.save_steps
):
self.checkpoint_policy()
return True
def checkpoint_policy(self):
self.selfPlayWrapper.checkpoint_policy(
self.policy_factory().load_from(self.policy)
)
self.last_checkpoint_step = self.timesteps_elapsed