cuhksz-text2image / discriminatorEmbedding.py
bryandts's picture
Create discriminatorEmbedding.py
8610a1c verified
raw
history blame contribute delete
No virus
649 Bytes
import torch
import torch.nn as nn
class Embedding(nn.Module):
def __init__(self, size_in, size_out):
super(Embedding, self).__init__()
self.text_embedding = nn.Sequential(
nn.Linear(size_in, size_out),
nn.BatchNorm1d(1),
nn.LeakyReLU(0.2, inplace=True)
)
def forward(self, x, text):
embed_out = self.text_embedding(text)
embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
return out