cuhksz-text2image / discriminatorModel.py
bryandts's picture
Update discriminatorModel.py
98688ba verified
raw
history blame
2.24 kB
import torch
import torch.nn as nn
class Embedding(nn.Module):
def __init__(self, size_in, size_out):
super(Embedding, self).__init__()
self.text_embedding = nn.Sequential(
nn.Linear(size_in, size_out),
nn.BatchNorm1d(1),
nn.LeakyReLU(0.2, inplace=True)
)
def forward(self, x, text):
embed_out = self.text_embedding(text)
embed_out_resize = embed_out.repeat(4, 1, 4, 1).permute(1, 3, 0, 2) # Resize to match the discriminator input size
out = torch.cat([x, embed_out_resize], 1) # Concatenate text embedding with the input feature map
return out
# The Discriminator model
class Discriminator(nn.Module):
def __init__(self, channels, embed_dim=1024, embed_out_dim=128):
super(Discriminator, self).__init__()
self.channels = channels
self.embed_dim = embed_dim
self.embed_out_dim = embed_out_dim
# Discriminator architecture
self.model = nn.Sequential(
*self._create_layer(self.channels, 32, 4, 2, 1, normalize=False),
*self._create_layer(32, 64, 4, 2, 1),
*self._create_layer(64, 128, 4, 2, 1),
*self._create_layer(128, 256, 4, 2, 1),
*self._create_layer(256, 512, 4, 2, 1)
)
self.text_embedding = Embedding(self.embed_dim, self.embed_out_dim) # Text embedding module
self.output = nn.Sequential(
nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False), nn.Sigmoid()
)
def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True):
layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)]
if normalize:
layers.append(nn.BatchNorm2d(size_out))
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
def forward(self, x, text):
x_out = self.model(x) # Extract features from the input using the discriminator architecture
out = self.text_embedding(x_out, text) # Apply text embedding and concatenate with the input features
out = self.output(out) # Final discriminator output
return out.squeeze(), x_out