import torch class StyleTransferController(torch.nn.Module): def __init__( self, num_control_params, edim, hidden_dim=256, agg_method="mlp", ): """Plugin parameter controller module to map from input to target style. Args: num_control_params (int): Number of plugin parameters to predicted. edim (int): Size of the encoder representations. hidden_dim (int, optional): Hidden size of the 3-layer parameter predictor MLP. Default: 256 agg_method (str, optional): Input/reference embed aggregation method ["conv" or "linear", "mlp"]. Default: "mlp" """ super().__init__() self.num_control_params = num_control_params self.edim = edim self.hidden_dim = hidden_dim self.agg_method = agg_method if agg_method == "conv": self.agg = torch.nn.Conv1d( 2, 1, kernel_size=129, stride=1, padding="same", bias=False, ) mlp_in_dim = edim elif agg_method == "linear": self.agg = torch.nn.Linear(edim * 2, edim) elif agg_method == "mlp": self.agg = None mlp_in_dim = edim * 2 else: raise ValueError(f"Invalid agg_method = {self.agg_method}.") self.mlp = torch.nn.Sequential( torch.nn.Linear(mlp_in_dim, hidden_dim), torch.nn.LeakyReLU(0.01), torch.nn.Linear(hidden_dim, hidden_dim), torch.nn.LeakyReLU(0.01), torch.nn.Linear(hidden_dim, num_control_params), torch.nn.Sigmoid(), # normalize between 0 and 1 ) def forward(self, e_x, e_y, z=None): """Forward pass to generate plugin parameters. Args: e_x (tensor): Input signal embedding of shape (batch, edim) e_y (tensor): Target signal embedding of shape (batch, edim) Returns: p (tensor): Estimated control parameters of shape (batch, num_control_params) """ # use learnable projection if self.agg_method == "conv": e_xy = torch.stack((e_x, e_y), dim=1) # concat on channel dim e_xy = self.agg(e_xy) elif self.agg_method == "linear": e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim e_xy = self.agg(e_xy) else: e_xy = torch.cat((e_x, e_y), dim=-1) # concat on embed dim # pass through MLP to project to control parametesr p = self.mlp(e_xy.squeeze(1)) return p