DQN playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/1d4094fbcc9082de7f53f4348dd4c7c354152907
ff8c6a7
import gym | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from gym.spaces import Box, Discrete | |
from stable_baselines3.common.preprocessing import get_flattened_obs_dim | |
from typing import Sequence, Type | |
class FeatureExtractor(nn.Module): | |
def __init__( | |
self, | |
obs_space: gym.Space, | |
activation: Type[nn.Module], | |
init_layers_orthogonal: bool = False, | |
cnn_feature_dim: int = 512, | |
) -> None: | |
super().__init__() | |
if isinstance(obs_space, Box): | |
# Conv2D: (channels, height, width) | |
if len(obs_space.shape) == 3: | |
# CNN from DQN Nature paper: Mnih, Volodymyr, et al. | |
# "Human-level control through deep reinforcement learning." | |
# Nature 518.7540 (2015): 529-533. | |
cnn = nn.Sequential( | |
layer_init( | |
nn.Conv2d(obs_space.shape[0], 32, kernel_size=8, stride=4), | |
init_layers_orthogonal, | |
), | |
activation(), | |
layer_init( | |
nn.Conv2d(32, 64, kernel_size=4, stride=2), | |
init_layers_orthogonal, | |
), | |
activation(), | |
layer_init( | |
nn.Conv2d(64, 64, kernel_size=3, stride=1), | |
init_layers_orthogonal, | |
), | |
activation(), | |
nn.Flatten(), | |
) | |
def preprocess(obs: torch.Tensor) -> torch.Tensor: | |
if len(obs.shape) == 3: | |
obs = obs.unsqueeze(0) | |
return obs.float() / 255.0 | |
with torch.no_grad(): | |
cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample()))) | |
self.preprocess = preprocess | |
self.feature_extractor = nn.Sequential( | |
cnn, | |
layer_init( | |
nn.Linear(cnn_out.shape[1], cnn_feature_dim), | |
init_layers_orthogonal, | |
), | |
activation(), | |
) | |
self.out_dim = cnn_feature_dim | |
elif len(obs_space.shape) == 1: | |
def preprocess(obs: torch.Tensor) -> torch.Tensor: | |
if len(obs.shape) == 1: | |
obs = obs.unsqueeze(0) | |
return obs.float() | |
self.preprocess = preprocess | |
self.feature_extractor = nn.Flatten() | |
self.out_dim = get_flattened_obs_dim(obs_space) | |
else: | |
raise ValueError(f"Unsupported observation space: {obs_space}") | |
elif isinstance(obs_space, Discrete): | |
self.preprocess = lambda x: F.one_hot(x, obs_space.n).float() | |
self.feature_extractor = nn.Flatten() | |
self.out_dim = obs_space.n | |
else: | |
raise NotImplementedError | |
def forward(self, obs: torch.Tensor) -> torch.Tensor: | |
if self.preprocess: | |
obs = self.preprocess(obs) | |
return self.feature_extractor(obs) | |
def mlp( | |
layer_sizes: Sequence[int], | |
activation: Type[nn.Module], | |
output_activation: Type[nn.Module] = nn.Identity, | |
init_layers_orthogonal: bool = False, | |
final_layer_gain: float = np.sqrt(2), | |
) -> nn.Module: | |
layers = [] | |
for i in range(len(layer_sizes) - 2): | |
layers.append( | |
layer_init( | |
nn.Linear(layer_sizes[i], layer_sizes[i + 1]), init_layers_orthogonal | |
) | |
) | |
layers.append(activation()) | |
layers.append( | |
layer_init( | |
nn.Linear(layer_sizes[-2], layer_sizes[-1]), | |
init_layers_orthogonal, | |
std=final_layer_gain, | |
) | |
) | |
layers.append(output_activation()) | |
return nn.Sequential(*layers) | |
def layer_init( | |
layer: nn.Module, init_layers_orthogonal: bool, std: float = np.sqrt(2) | |
) -> nn.Module: | |
if not init_layers_orthogonal: | |
return layer | |
nn.init.orthogonal_(layer.weight, std) # type: ignore | |
nn.init.constant_(layer.bias, 0.0) # type: ignore | |
return layer | |