File size: 4,453 Bytes
cb7427c
 
 
 
 
 
 
9a90e40
 
 
 
 
 
 
 
 
cb7427c
9a90e40
cb7427c
 
 
 
9a90e40
cb7427c
 
 
 
9a90e40
cb7427c
 
 
9a90e40
 
cb7427c
 
9a90e40
 
cb7427c
 
9a90e40
cb7427c
 
9a90e40
cb7427c
 
9a90e40
cb7427c
9a90e40
cb7427c
 
 
 
 
9a90e40
 
 
 
 
 
 
 
 
 
 
 
cb7427c
 
 
 
 
 
 
 
 
 
9a90e40
cb7427c
 
 
9a90e40
 
 
cb7427c
9a90e40
 
 
 
cb7427c
 
9a90e40
cb7427c
9a90e40
cb7427c
 
 
 
 
 
 
 
9a90e40
cb7427c
 
9a90e40
 
 
cb7427c
 
9a90e40
 
 
cb7427c
9a90e40
 
 
cb7427c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
import torch.nn as nn
import torchvision.models as models
from typing import Tuple


class Encoder(nn.Module):
    """
    Image encoder to obtain features from images using a pretrained ResNet-50 model.
    The last layer of ResNet-50 is removed, and a linear layer is added to transform
    the output to the desired feature dimension.

    Args:
        image_emb_dim (int): Final output dimension of image features.
        device (torch.device): Device to run the model on (CPU or GPU).
    """

    def __init__(self, image_emb_dim: int, device: torch.device):
        super(Encoder, self).__init__()
        self.image_emb_dim = image_emb_dim
        self.device = device

        # Load pretrained ResNet-50 model and freeze its parameters
        resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
        for param in resnet.parameters():
            param.requires_grad_(False)

        # Remove the last layer of ResNet-50
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)

        # Define a final classifier
        self.fc = nn.Linear(resnet.fc.in_features, self.image_emb_dim)

    def forward(self, images: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the encoder.

        Args:
            images (torch.Tensor): Input images of shape (BATCH, 3, 224, 224).

        Returns:
            torch.Tensor: Image features of shape (BATCH, IMAGE_EMB_DIM).
        """
        features = self.resnet(images)
        # Reshape features to (BATCH, 2048)
        features = features.reshape(features.size(0), -1).to(self.device)
        # Pass features through final linear layer
        features = self.fc(features).to(self.device)
        return features


class Decoder(nn.Module):
    """
    Decoder that uses an LSTM to generate captions from embedded words and encoded image features.
    The hidden and cell states of the LSTM are initialized using the encoded image features.

    Args:
        word_emb_dim (int): Dimension of word embeddings.
        hidden_dim (int): Dimension of the LSTM hidden state.
        num_layers (int): Number of LSTM layers.
        vocab_size (int): Size of the vocabulary (output dimension of the final linear layer).
        device (torch.device): Device to run the model on (CPU or GPU).
    """

    def __init__(self,
                 word_emb_dim: int,
                 hidden_dim: int,
                 num_layers: int,
                 vocab_size: int,
                 device: torch.device):
        super(Decoder, self).__init__()

        self.word_emb_dim = word_emb_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.device = device

        # Initialize hidden and cell states
        self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layers, 1, self.hidden_dim)))
        self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layers, 1, self.hidden_dim)))

        # Define LSTM layer
        self.lstm = nn.LSTM(self.word_emb_dim,
                            self.hidden_dim,
                            num_layers=self.num_layers,
                            bidirectional=False)

        # Define final linear layer with LogSoftmax activation
        self.fc = nn.Sequential(
            nn.Linear(self.hidden_dim, self.vocab_size),
            nn.LogSoftmax(dim=2)
        )

    def forward(self,
                embedded_captions: torch.Tensor,
                hidden: torch.Tensor,
                cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass through the decoder.

        Args:
            embedded_captions (torch.Tensor): Embedded captions of shape (SEQ_LEN, BATCH, WORD_EMB_DIM).
            hidden (torch.Tensor): LSTM hidden state of shape (NUM_LAYER, BATCH, HIDDEN_DIM).
            cell (torch.Tensor): LSTM cell state of shape (NUM_LAYER, BATCH, HIDDEN_DIM).

        Returns:
            Tuple:
                - output (torch.Tensor): Output logits of shape (SEQ_LEN, BATCH, VOCAB_SIZE).
                - (hidden, cell) (Tuple[torch.Tensor, torch.Tensor]): Updated hidden and cell states.
        """
        # Pass through LSTM
        output, (hidden, cell) = self.lstm(embedded_captions, (hidden, cell))
        # Pass through final linear layer
        output = self.fc(output)
        return output, (hidden, cell)