VPG playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
9dc837c
import gym | |
import torch | |
import torch.nn as nn | |
from typing import Sequence, Type | |
from shared.module.feature_extractor import FeatureExtractor | |
from shared.module.module import mlp | |
class CriticHead(nn.Module): | |
def __init__( | |
self, | |
hidden_sizes: Sequence[int] = (32,), | |
activation: Type[nn.Module] = nn.Tanh, | |
init_layers_orthogonal: bool = True, | |
) -> None: | |
super().__init__() | |
layer_sizes = tuple(hidden_sizes) + (1,) | |
self._fc = mlp( | |
layer_sizes, | |
activation, | |
init_layers_orthogonal=init_layers_orthogonal, | |
final_layer_gain=1.0, | |
) | |
def forward(self, obs: torch.Tensor) -> torch.Tensor: | |
v = self._fc(obs) | |
return v.squeeze(-1) | |