Spaces:
Sleeping
Sleeping
File size: 649 Bytes
8610a1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
import torch
import torch.nn as nn
class Embedding(nn.Module):
def __init__(self, size_in, size_out):
super(Embedding, self).__init__()
self.text_embedding = nn.Sequential(
nn.Linear(size_in, size_out),
nn.BatchNorm1d(1),
nn.LeakyReLU(0.2, inplace=True)
)
def forward(self, x, text):
embed_out = self.text_embedding(text)
embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
return out |