cuhksz-text2image / generatorModel.py
bryandts's picture
Rename generator.py to generatorModel.py
c76f6d4 verified
raw
history blame contribute delete
No virus
1.95 kB
import torch
import torch.nn as nn
# The Generator model
class Generator(nn.Module):
def __init__(self, channels, noise_dim=100, embed_dim=1024, embed_out_dim=128):
super(Generator, self).__init__()
self.channels = channels
self.noise_dim = noise_dim
self.embed_dim = embed_dim
self.embed_out_dim = embed_out_dim
# Text embedding layers
self.text_embedding = nn.Sequential(
nn.Linear(self.embed_dim, self.embed_out_dim),
nn.BatchNorm1d(1),
nn.LeakyReLU(0.2, inplace=True)
)
# Generator architecture
model = []
model += self._create_layer(self.noise_dim + self.embed_out_dim, 512, 4, stride=1, padding=0)
model += self._create_layer(512, 256, 4, stride=2, padding=1)
model += self._create_layer(256, 128, 4, stride=2, padding=1)
model += self._create_layer(128, 64, 4, stride=2, padding=1)
model += self._create_layer(64, 32, 4, stride=2, padding=1)
model += self._create_layer(32, self.channels, 4, stride=2, padding=1, output=True)
self.model = nn.Sequential(*model)
def _create_layer(self, size_in, size_out, kernel_size=4, stride=2, padding=1, output=False):
layers = [nn.ConvTranspose2d(size_in, size_out, kernel_size, stride=stride, padding=padding, bias=False)]
if output:
layers.append(nn.Tanh()) # Tanh activation for the output layer
else:
layers += [nn.BatchNorm2d(size_out), nn.ReLU(True)] # Batch normalization and ReLU for other layers
return layers
def forward(self, noise, text):
# Apply text embedding to the input text
text = self.text_embedding(text)
text = text.view(text.shape[0], text.shape[2], 1, 1) # Reshape to match the generator input size
z = torch.cat([text, noise], 1) # Concatenate text embedding with noise
return self.model(z)