Spaces:
Running
Running
import torch | |
import torch._utils | |
import torch.nn as nn | |
import torchvision.models as models | |
from typing import Tuple | |
from source.config import Config | |
class Encoder(nn.Module): | |
def __init__(self, image_emb_dim: int, device: torch.device): | |
""" Image encoder to obtain features from images. Contains pretrained Resnet50 with last layer removed | |
and a linear layer with the output dimension of (BATCH, image_emb_dim) | |
""" | |
super(Encoder, self).__init__() | |
self.image_emb_dim = image_emb_dim | |
self.device = device | |
# pretrained Resnet50 model with freezed parameters | |
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) | |
for param in resnet.parameters(): | |
param.requires_grad_(False) | |
# remove last layer | |
modules = list(resnet.children())[:-1] | |
self.resnet = nn.Sequential(*modules) | |
# define a final classifier | |
self.fc = nn.Linear(in_features=resnet.fc.in_features, out_features=self.image_emb_dim) | |
def forward(self, images: torch.Tensor) -> torch.Tensor: | |
""" Forward operation of encoder, passing images through resnet and then linear layer. | |
Args: | |
> images (torch.Tensor): (BATCH, 3, 224, 224) | |
Returns: | |
> features (torch.Tensor): (BATCH, IMAGE_EMB_DIM) | |
""" | |
features = self.resnet(images) | |
# features: (BATCH, 2048, 1, 1) | |
features = features.reshape(features.size(0), -1).to(self.device) | |
# features: (BATCH, 2048) | |
features = self.fc(features).to(self.device) | |
# features: (BATCH, IMAGE_EMB_DIM) | |
return features | |
class Decoder(nn.Module): | |
def __init__(self, | |
image_emb_dim: int, | |
word_emb_dim: int, | |
hidden_dim: int, | |
num_layers: int, | |
vocab_size: int, | |
device: torch.device): | |
""" | |
Decoder taking as input for the LSTM layer the concatenation of features obtained from the encoder | |
and embedded captions obtained from the embedding layer. Hidden and cell states are randomly initialized. | |
Final classifier is a linear layer with output dimension of the size of a vocabulary. | |
""" | |
super(Decoder, self).__init__() | |
self.config = Config() | |
self.image_emd_dim = image_emb_dim | |
self.word_emb_dim = word_emb_dim | |
self.hidden_dim = hidden_dim | |
self.num_layer = num_layers | |
self.vocab_size = vocab_size | |
self.device = device | |
self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim))) | |
self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim))) | |
self.lstm = nn.LSTM(input_size=self.image_emd_dim + self.word_emb_dim, | |
hidden_size=self.hidden_dim, | |
num_layers=self.num_layer, | |
bidirectional=False) | |
self.fc = nn.Sequential( | |
nn.Linear(in_features=self.hidden_dim, out_features=self.vocab_size), | |
nn.LogSoftmax(dim=2) | |
) | |
def forward(self, | |
embedded_captions: torch.Tensor, | |
features: torch.Tensor, | |
hidden: torch.Tensor, | |
cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | |
""" | |
Forward operation of (word-by-word) decoder. | |
The LSTM input (concatenation of embedded_captions and features) is passed through LSTM and then linear layer. | |
Args: | |
> embedded_captions(torch.Tensor): (SEQ_LENGTH = 1, BATCH, WORD_EMB_DIM) | |
> features (torch.Tensor): (1, BATCH, IMAGE_EMB_DIM) | |
> hidden (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM) | |
> cell (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM) | |
Returns: | |
> output (torch.Tensor): (1, BATCH, VOCAB_SIZE) | |
> (hidden, cell) (torch.Tensor, torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM) | |
""" | |
lstm_input = torch.cat((embedded_captions, features), dim=2) | |
output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell)) | |
# output : (SEQ_LENGTH, BATCH, HIDDEN_DIM) | |
# hidden : (NUM_LAYER, BATCH, HIDDEN_DIM) | |
output = output.to(self.device) | |
output = self.fc(output) | |
# output : (SEQ_LENGTH, BATCH, VOCAB_SIZE) | |
return output, (hidden, cell) | |