import torch; torch.manual_seed(0) import torch.nn as nn import torch.nn.functional as F import torch.utils import torch.distributions import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_activation(activation): if activation == 'tanh': activ = F.tanh elif activation == 'relu': activ = F.relu elif activation == 'mish': activ = F.mish elif activation == 'sigmoid': activ = torch.sigmoid elif activation == 'leakyrelu': activ = F.leaky_relu elif activation == 'exp': activ = torch.exp else: raise ValueError return activ class SimpleNet(nn.Module): def __init__(self, input_dim, hidden_dims, output_dim, activation, dropout, final_activ=None): super(SimpleNet, self).__init__() self.linears = nn.ModuleList() self.dropouts = nn.ModuleList() self.output_dim = output_dim dims = [input_dim] + hidden_dims + [output_dim] for d_in, d_out in zip(dims[:-1], dims[1:]): self.linears.append(nn.Linear(d_in, d_out)) self.dropouts.append(nn.Dropout(dropout)) self.activation = get_activation(activation) self.n_layers = len(self.linears) self.layer_range = range(self.n_layers) if final_activ != None: self.final_activ = get_activation(final_activ) self.use_final_activ = True else: self.use_final_activ = False def forward(self, x): for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): x = layer(x) if i_layer != self.n_layers - 1: x = self.activation(dropout(x)) if self.use_final_activ: x = self.final_activ(x) return x