DQN playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
2a8bf2e
from abc import ABC, abstractmethod | |
from typing import NamedTuple, Optional, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.distributions import Distribution | |
class PiForward(NamedTuple): | |
pi: Distribution | |
logp_a: Optional[torch.Tensor] | |
entropy: Optional[torch.Tensor] | |
class Actor(nn.Module, ABC): | |
def forward( | |
self, | |
obs: torch.Tensor, | |
actions: Optional[torch.Tensor] = None, | |
action_masks: Optional[torch.Tensor] = None, | |
) -> PiForward: | |
... | |
def sample_weights(self, batch_size: int = 1) -> None: | |
pass | |
def action_shape(self) -> Tuple[int, ...]: | |
... | |
def pi_forward( | |
distribution: Distribution, actions: Optional[torch.Tensor] = None | |
) -> PiForward: | |
logp_a = None | |
entropy = None | |
if actions is not None: | |
logp_a = distribution.log_prob(actions) | |
entropy = distribution.entropy() | |
return PiForward(distribution, logp_a, entropy) | |