Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
# The Discriminator model | |
class Discriminator(nn.Module): | |
def __init__(self, channels, embed_dim=1024, embed_out_dim=128): | |
super(Discriminator, self).__init__() | |
self.channels = channels | |
self.embed_dim = embed_dim | |
self.embed_out_dim = embed_out_dim | |
# Discriminator architecture | |
self.model = nn.Sequential( | |
*self._create_layer(self.channels, 32, 4, 2, 1, normalize=False), | |
*self._create_layer(32, 64, 4, 2, 1), | |
*self._create_layer(64, 128, 4, 2, 1), | |
*self._create_layer(128, 256, 4, 2, 1), | |
*self._create_layer(256, 512, 4, 2, 1) | |
) | |
self.text_embedding = Embedding(self.embed_dim, self.embed_out_dim) # Text embedding module | |
self.output = nn.Sequential( | |
nn.Conv2d(512 + self.embed_out_dim, 1, 4, 1, 0, bias=False), nn.Sigmoid() | |
) | |
def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, normalize=True): | |
layers = [nn.Conv2d(size_in, size_out, kernel_size=kernel_size, stride=stride, padding=padding)] | |
if normalize: | |
layers.append(nn.BatchNorm2d(size_out)) | |
layers.append(nn.LeakyReLU(0.2, inplace=True)) | |
return layers | |
def forward(self, x, text): | |
x_out = self.model(x) # Extract features from the input using the discriminator architecture | |
out = self.text_embedding(x_out, text) # Apply text embedding and concatenate with the input features | |
out = self.output(out) # Final discriminator output | |
return out.squeeze(), x_out |