Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
try: | |
nn.init.xavier_uniform_(m.weight.data) | |
m.bias.data.fill_(0) | |
except AttributeError: | |
print("Skipping initialization of ", classname) | |
class GatedActivation(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, x): | |
x, y = x.chunk(2, dim=1) | |
return F.tanh(x) * F.sigmoid(y) | |
class GatedMaskedConv1d(nn.Module): | |
def __init__(self, mask_type, dim, kernel, residual, n_classes=10): | |
super().__init__() | |
assert kernel % 2 == 1, print("Kernel size must be odd") | |
self.mask_type = mask_type | |
self.residual = residual | |
self.class_cond_embedding = nn.Embedding( | |
n_classes, 2 * dim | |
) | |
kernel_shp = (kernel // 2 + 1) # (ceil(n/2), n) | |
padding_shp = (kernel // 2) | |
self.vert_stack = nn.Conv1d( | |
dim, dim * 2, | |
kernel_shp, 1, padding_shp | |
) | |
self.gate = GatedActivation() | |
if self.residual: | |
self.res = nn.Conv1d(dim, dim, 1) | |
def make_causal(self): | |
self.vert_stack.weight.data[:, :, -1].zero_() # Mask final row | |
def forward(self, x, h): | |
if self.mask_type == 'A': | |
self.make_causal() | |
h = self.class_cond_embedding(h) | |
h_vert = self.vert_stack(x) | |
h_vert = h_vert[:, :, :x.size(-2), :] | |
out = self.gate(h_vert + h[:, :, None, None]) | |
if self.residual: | |
out = self.res(out) + x | |
return out | |
class GatedPixelCNN(nn.Module): | |
def __init__(self, input_dim=256, dim=64, n_layers=15, n_classes=10): | |
super().__init__() | |
self.dim = dim | |
self.embedding_aud_mo = nn.Conv1d(512, dim, 1, 1, padding=0) | |
self.fusion = nn.Conv1d(dim * 2, dim, 1, 1, padding=0) | |
# Create embedding layer to embed input | |
self.embedding = nn.Embedding(input_dim, dim) | |
# Building the PixelCNN layer by layer | |
self.layers = nn.ModuleList() | |
# Initial block with Mask-A convolution | |
# Rest with Mask-B convolutions | |
for i in range(n_layers): | |
mask_type = 'A' if i == 0 else 'B' | |
kernel = 7 if i == 0 else 3 | |
residual = False if i == 0 else True | |
self.layers.append( | |
GatedMaskedConv1d(mask_type, dim, kernel, residual, n_classes) | |
) | |
# Add the output layer | |
self.output_conv = nn.Sequential( | |
nn.Conv1d(dim, 512, 1), | |
nn.ReLU(True), | |
nn.Conv1d(512, input_dim, 1) | |
) | |
self.apply(weights_init) | |
self.dp = nn.Dropout(0.1) | |
def forward(self, x, label, c): | |
x = x # (B, C, W) | |
for i, layer in enumerate(self.layers): | |
if i == 1: | |
c = self.embedding(c) | |
x = self.fusion(torch.cat([x, c], dim=1)) | |
x = layer(x, label) | |
return self.output_conv(x) | |
def generate(self, label, shape=(8, 8), batch_size=64, aud_feat=None, pre_latents=None, pre_audio=None): | |
param = next(self.parameters()) | |
x = torch.zeros( | |
(batch_size, *shape), | |
dtype=torch.int64, device=param.device | |
) | |
if pre_latents is not None: | |
x = torch.cat([pre_latents, x], dim=1) | |
aud_feat = torch.cat([pre_audio, aud_feat], dim=2) | |
h0 = pre_latents.shape[1] | |
h = h0 + shape[0] | |
else: | |
h0 = 0 | |
h = shape[0] | |
for i in range(h0, h): | |
for j in range(shape[1]): | |
if self.audio: | |
logits = self.forward(x, label, aud_feat) | |
else: | |
logits = self.forward(x, label) | |
probs = F.softmax(logits[:, :, i, j], -1) | |
x.data[:, i, j].copy_( | |
probs.multinomial(1).squeeze().data | |
) | |
return x[:, h0:h] | |