import torch import torch.nn as nn import torchvision.models as models from typing import Tuple class Encoder(nn.Module): """ Image encoder to obtain features from images using a pretrained ResNet-50 model. The last layer of ResNet-50 is removed, and a linear layer is added to transform the output to the desired feature dimension. Args: image_emb_dim (int): Final output dimension of image features. device (torch.device): Device to run the model on (CPU or GPU). """ def __init__(self, image_emb_dim: int, device: torch.device): super(Encoder, self).__init__() self.image_emb_dim = image_emb_dim self.device = device # Load pretrained ResNet-50 model and freeze its parameters resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT) for param in resnet.parameters(): param.requires_grad_(False) # Remove the last layer of ResNet-50 modules = list(resnet.children())[:-1] self.resnet = nn.Sequential(*modules) # Define a final classifier self.fc = nn.Linear(resnet.fc.in_features, self.image_emb_dim) def forward(self, images: torch.Tensor) -> torch.Tensor: """ Forward pass through the encoder. Args: images (torch.Tensor): Input images of shape (BATCH, 3, 224, 224). Returns: torch.Tensor: Image features of shape (BATCH, IMAGE_EMB_DIM). """ features = self.resnet(images) # Reshape features to (BATCH, 2048) features = features.reshape(features.size(0), -1).to(self.device) # Pass features through final linear layer features = self.fc(features).to(self.device) return features class Decoder(nn.Module): """ Decoder that uses an LSTM to generate captions from embedded words and encoded image features. The hidden and cell states of the LSTM are initialized using the encoded image features. Args: word_emb_dim (int): Dimension of word embeddings. hidden_dim (int): Dimension of the LSTM hidden state. num_layers (int): Number of LSTM layers. vocab_size (int): Size of the vocabulary (output dimension of the final linear layer). device (torch.device): Device to run the model on (CPU or GPU). """ def __init__(self, word_emb_dim: int, hidden_dim: int, num_layers: int, vocab_size: int, device: torch.device): super(Decoder, self).__init__() self.word_emb_dim = word_emb_dim self.hidden_dim = hidden_dim self.num_layers = num_layers self.vocab_size = vocab_size self.device = device # Initialize hidden and cell states self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layers, 1, self.hidden_dim))) self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layers, 1, self.hidden_dim))) # Define LSTM layer self.lstm = nn.LSTM(self.word_emb_dim, self.hidden_dim, num_layers=self.num_layers, bidirectional=False) # Define final linear layer with LogSoftmax activation self.fc = nn.Sequential( nn.Linear(self.hidden_dim, self.vocab_size), nn.LogSoftmax(dim=2) ) def forward(self, embedded_captions: torch.Tensor, hidden: torch.Tensor, cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Forward pass through the decoder. Args: embedded_captions (torch.Tensor): Embedded captions of shape (SEQ_LEN, BATCH, WORD_EMB_DIM). hidden (torch.Tensor): LSTM hidden state of shape (NUM_LAYER, BATCH, HIDDEN_DIM). cell (torch.Tensor): LSTM cell state of shape (NUM_LAYER, BATCH, HIDDEN_DIM). Returns: Tuple: - output (torch.Tensor): Output logits of shape (SEQ_LEN, BATCH, VOCAB_SIZE). - (hidden, cell) (Tuple[torch.Tensor, torch.Tensor]): Updated hidden and cell states. """ # Pass through LSTM output, (hidden, cell) = self.lstm(embedded_captions, (hidden, cell)) # Pass through final linear layer output = self.fc(output) return output, (hidden, cell)