vpg-CartPole-v1 / shared /algorithm.py
sgoodfriend's picture
VPG playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
9dc837c
raw
history blame
933 Bytes
import gym
import torch
from abc import ABC, abstractmethod
from stable_baselines3.common.vec_env.base_vec_env import VecEnv
from torch.utils.tensorboard.writer import SummaryWriter
from typing import List, Optional, TypeVar
from shared.callbacks.callback import Callback
from shared.policy.policy import Policy
from shared.stats import EpisodesStats
AlgorithmSelf = TypeVar("AlgorithmSelf", bound="Algorithm")
class Algorithm(ABC):
@abstractmethod
def __init__(
self,
policy: Policy,
env: VecEnv,
device: torch.device,
tb_writer: SummaryWriter,
**kwargs,
) -> None:
super().__init__()
self.policy = policy
self.env = env
self.device = device
self.tb_writer = tb_writer
@abstractmethod
def learn(
self: AlgorithmSelf, total_timesteps: int, callback: Optional[Callback] = None
) -> AlgorithmSelf:
...