lmzjms's picture
Upload 16 files
62bee25
raw
history blame
998 Bytes
import torch
import torch.nn as nn
class Film(nn.Module):
def __init__(self, channels, cond_embedding_dim):
super(Film, self).__init__()
self.linear = nn.Sequential(
nn.Linear(cond_embedding_dim, channels * 2),
nn.ReLU(inplace=True),
nn.Linear(channels * 2, channels),
nn.ReLU(inplace=True)
)
def forward(self, data, cond_vec):
"""
:param data: [batchsize, channels, samples] or [batchsize, channels, T, F] or [batchsize, channels, F, T]
:param cond_vec: [batchsize, cond_embedding_dim]
:return:
"""
bias = self.linear(cond_vec) # [batchsize, channels]
if len(list(data.size())) == 3:
data = data + bias[..., None]
elif len(list(data.size())) == 4:
data = data + bias[..., None, None]
else:
print("Warning: The size of input tensor,", data.size(), "is not correct. Film is not working.")
return data