nssharmaofficial commited on
Commit
9a90e40
β€’
1 Parent(s): e8a30b2

Update code and weights

Browse files
source/config.py CHANGED
@@ -12,8 +12,8 @@ class Config(object):
12
  self.VOCAB_SIZE = 5000
13
 
14
  self.NUM_LAYER = 1
15
- self.IMAGE_EMB_DIM = 256
16
- self.WORD_EMB_DIM = 256
17
  self.HIDDEN_DIM = 512
18
 
19
  self.EMBEDDING_WEIGHT_FILE = 'source/weights/embeddings-32B-512H-1L-e5.pt'
 
12
  self.VOCAB_SIZE = 5000
13
 
14
  self.NUM_LAYER = 1
15
+ self.IMAGE_EMB_DIM = 512
16
+ self.WORD_EMB_DIM = 5121
17
  self.HIDDEN_DIM = 512
18
 
19
  self.EMBEDDING_WEIGHT_FILE = 'source/weights/embeddings-32B-512H-1L-e5.pt'
source/model.py CHANGED
@@ -1,124 +1,117 @@
1
  import torch
2
- import torch._utils
3
  import torch.nn as nn
4
  import torchvision.models as models
5
  from typing import Tuple
6
- from source.config import Config
7
 
8
 
9
  class Encoder(nn.Module):
10
- def __init__(self, image_emb_dim: int, device: torch.device):
11
- """ Image encoder to obtain features from images. Contains pretrained Resnet50 with last layer removed
12
- and a linear layer with the output dimension of (BATCH, image_emb_dim)
13
- """
 
 
 
 
 
14
 
 
15
  super(Encoder, self).__init__()
16
  self.image_emb_dim = image_emb_dim
17
  self.device = device
18
 
19
- # pretrained Resnet50 model with freezed parameters
20
  resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
21
  for param in resnet.parameters():
22
  param.requires_grad_(False)
23
 
24
- # remove last layer
25
  modules = list(resnet.children())[:-1]
26
  self.resnet = nn.Sequential(*modules)
27
 
28
- # define a final classifier
29
- self.fc = nn.Linear(in_features=resnet.fc.in_features, out_features=self.image_emb_dim)
30
 
31
  def forward(self, images: torch.Tensor) -> torch.Tensor:
32
- """ Forward operation of encoder, passing images through resnet and then linear layer.
 
33
 
34
  Args:
35
- > images (torch.Tensor): (BATCH, 3, 224, 224)
36
 
37
  Returns:
38
- > features (torch.Tensor): (BATCH, IMAGE_EMB_DIM)
39
  """
40
-
41
  features = self.resnet(images)
42
- # features: (BATCH, 2048, 1, 1)
43
-
44
  features = features.reshape(features.size(0), -1).to(self.device)
45
- # features: (BATCH, 2048)
46
-
47
  features = self.fc(features).to(self.device)
48
- # features: (BATCH, IMAGE_EMB_DIM)
49
-
50
  return features
51
 
52
 
53
  class Decoder(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
54
  def __init__(self,
55
- image_emb_dim: int,
56
  word_emb_dim: int,
57
  hidden_dim: int,
58
  num_layers: int,
59
  vocab_size: int,
60
  device: torch.device):
61
- """
62
- Decoder taking as input for the LSTM layer the concatenation of features obtained from the encoder
63
- and embedded captions obtained from the embedding layer. Hidden and cell states are randomly initialized.
64
- Final classifier is a linear layer with output dimension of the size of a vocabulary.
65
- """
66
-
67
  super(Decoder, self).__init__()
68
 
69
- self.config = Config()
70
-
71
- self.image_emd_dim = image_emb_dim
72
  self.word_emb_dim = word_emb_dim
73
  self.hidden_dim = hidden_dim
74
- self.num_layer = num_layers
75
  self.vocab_size = vocab_size
76
  self.device = device
77
 
78
- self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim)))
79
- self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layer, 1, self.hidden_dim)))
 
80
 
81
- self.lstm = nn.LSTM(input_size=self.image_emd_dim + self.word_emb_dim,
82
- hidden_size=self.hidden_dim,
83
- num_layers=self.num_layer,
 
84
  bidirectional=False)
85
 
 
86
  self.fc = nn.Sequential(
87
- nn.Linear(in_features=self.hidden_dim, out_features=self.vocab_size),
88
  nn.LogSoftmax(dim=2)
89
  )
90
 
91
  def forward(self,
92
  embedded_captions: torch.Tensor,
93
- features: torch.Tensor,
94
  hidden: torch.Tensor,
95
  cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
96
  """
97
- Forward operation of (word-by-word) decoder.
98
- The LSTM input (concatenation of embedded_captions and features) is passed through LSTM and then linear layer.
99
 
100
  Args:
101
-
102
- > embedded_captions(torch.Tensor): (SEQ_LENGTH = 1, BATCH, WORD_EMB_DIM)
103
- > features (torch.Tensor): (1, BATCH, IMAGE_EMB_DIM)
104
- > hidden (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
105
- > cell (torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
106
 
107
  Returns:
108
-
109
- > output (torch.Tensor): (1, BATCH, VOCAB_SIZE)
110
- > (hidden, cell) (torch.Tensor, torch.Tensor): (NUM_LAYER, BATCH, HIDDEN_DIM)
111
  """
112
-
113
- lstm_input = torch.cat((embedded_captions, features), dim=2)
114
-
115
- output, (hidden, cell) = self.lstm(lstm_input, (hidden, cell))
116
- # output : (SEQ_LENGTH, BATCH, HIDDEN_DIM)
117
- # hidden : (NUM_LAYER, BATCH, HIDDEN_DIM)
118
-
119
- output = output.to(self.device)
120
-
121
  output = self.fc(output)
122
- # output : (SEQ_LENGTH, BATCH, VOCAB_SIZE)
123
-
124
  return output, (hidden, cell)
 
1
  import torch
 
2
  import torch.nn as nn
3
  import torchvision.models as models
4
  from typing import Tuple
 
5
 
6
 
7
  class Encoder(nn.Module):
8
+ """
9
+ Image encoder to obtain features from images using a pretrained ResNet-50 model.
10
+ The last layer of ResNet-50 is removed, and a linear layer is added to transform
11
+ the output to the desired feature dimension.
12
+
13
+ Args:
14
+ image_emb_dim (int): Final output dimension of image features.
15
+ device (torch.device): Device to run the model on (CPU or GPU).
16
+ """
17
 
18
+ def __init__(self, image_emb_dim: int, device: torch.device):
19
  super(Encoder, self).__init__()
20
  self.image_emb_dim = image_emb_dim
21
  self.device = device
22
 
23
+ # Load pretrained ResNet-50 model and freeze its parameters
24
  resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
25
  for param in resnet.parameters():
26
  param.requires_grad_(False)
27
 
28
+ # Remove the last layer of ResNet-50
29
  modules = list(resnet.children())[:-1]
30
  self.resnet = nn.Sequential(*modules)
31
 
32
+ # Define a final classifier
33
+ self.fc = nn.Linear(resnet.fc.in_features, self.image_emb_dim)
34
 
35
  def forward(self, images: torch.Tensor) -> torch.Tensor:
36
+ """
37
+ Forward pass through the encoder.
38
 
39
  Args:
40
+ images (torch.Tensor): Input images of shape (BATCH, 3, 224, 224).
41
 
42
  Returns:
43
+ torch.Tensor: Image features of shape (BATCH, IMAGE_EMB_DIM).
44
  """
 
45
  features = self.resnet(images)
46
+ # Reshape features to (BATCH, 2048)
 
47
  features = features.reshape(features.size(0), -1).to(self.device)
48
+ # Pass features through final linear layer
 
49
  features = self.fc(features).to(self.device)
 
 
50
  return features
51
 
52
 
53
  class Decoder(nn.Module):
54
+ """
55
+ Decoder that uses an LSTM to generate captions from embedded words and encoded image features.
56
+ The hidden and cell states of the LSTM are initialized using the encoded image features.
57
+
58
+ Args:
59
+ word_emb_dim (int): Dimension of word embeddings.
60
+ hidden_dim (int): Dimension of the LSTM hidden state.
61
+ num_layers (int): Number of LSTM layers.
62
+ vocab_size (int): Size of the vocabulary (output dimension of the final linear layer).
63
+ device (torch.device): Device to run the model on (CPU or GPU).
64
+ """
65
+
66
  def __init__(self,
 
67
  word_emb_dim: int,
68
  hidden_dim: int,
69
  num_layers: int,
70
  vocab_size: int,
71
  device: torch.device):
 
 
 
 
 
 
72
  super(Decoder, self).__init__()
73
 
 
 
 
74
  self.word_emb_dim = word_emb_dim
75
  self.hidden_dim = hidden_dim
76
+ self.num_layers = num_layers
77
  self.vocab_size = vocab_size
78
  self.device = device
79
 
80
+ # Initialize hidden and cell states
81
+ self.hidden_state_0 = nn.Parameter(torch.zeros((self.num_layers, 1, self.hidden_dim)))
82
+ self.cell_state_0 = nn.Parameter(torch.zeros((self.num_layers, 1, self.hidden_dim)))
83
 
84
+ # Define LSTM layer
85
+ self.lstm = nn.LSTM(self.word_emb_dim,
86
+ self.hidden_dim,
87
+ num_layers=self.num_layers,
88
  bidirectional=False)
89
 
90
+ # Define final linear layer with LogSoftmax activation
91
  self.fc = nn.Sequential(
92
+ nn.Linear(self.hidden_dim, self.vocab_size),
93
  nn.LogSoftmax(dim=2)
94
  )
95
 
96
  def forward(self,
97
  embedded_captions: torch.Tensor,
 
98
  hidden: torch.Tensor,
99
  cell: torch.Tensor) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
100
  """
101
+ Forward pass through the decoder.
 
102
 
103
  Args:
104
+ embedded_captions (torch.Tensor): Embedded captions of shape (SEQ_LEN, BATCH, WORD_EMB_DIM).
105
+ hidden (torch.Tensor): LSTM hidden state of shape (NUM_LAYER, BATCH, HIDDEN_DIM).
106
+ cell (torch.Tensor): LSTM cell state of shape (NUM_LAYER, BATCH, HIDDEN_DIM).
 
 
107
 
108
  Returns:
109
+ Tuple:
110
+ - output (torch.Tensor): Output logits of shape (SEQ_LEN, BATCH, VOCAB_SIZE).
111
+ - (hidden, cell) (Tuple[torch.Tensor, torch.Tensor]): Updated hidden and cell states.
112
  """
113
+ # Pass through LSTM
114
+ output, (hidden, cell) = self.lstm(embedded_captions, (hidden, cell))
115
+ # Pass through final linear layer
 
 
 
 
 
 
116
  output = self.fc(output)
 
 
117
  return output, (hidden, cell)
source/predict_sample.py CHANGED
@@ -104,8 +104,7 @@ def main_caption(image):
104
  emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
105
  embedding_dim=config.WORD_EMB_DIM,
106
  padding_idx=vocab.PADDING_INDEX)
107
- image_decoder = Decoder(image_emb_dim=config.IMAGE_EMB_DIM,
108
- word_emb_dim=config.WORD_EMB_DIM,
109
  hidden_dim=config.HIDDEN_DIM,
110
  num_layers=config.NUM_LAYER,
111
  vocab_size=config.VOCAB_SIZE,
 
104
  emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
105
  embedding_dim=config.WORD_EMB_DIM,
106
  padding_idx=vocab.PADDING_INDEX)
107
+ image_decoder = Decoder(word_emb_dim=config.WORD_EMB_DIM,
 
108
  hidden_dim=config.HIDDEN_DIM,
109
  num_layers=config.NUM_LAYER,
110
  vocab_size=config.VOCAB_SIZE,
source/weights/{decoder-32B-512H-1L-e2.pt β†’ decoder-32B-512H-1L-e10.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c14f313a1fea17eb147567a456f418355e666858a0f0fa4f5dfa8f8a561e076a
3
- size 18671955
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:435a74d3029be0e1bce2dd451cbb58ec84a2e9ee2e3d685fd9e151c5a2123139
3
+ size 18671964
source/weights/decoder-32B-512H-1L-e6.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1914f17b249f0819e2680740e1bed990e38cde2fd5db916e3f33b2e106f6c2fc
3
- size 18671955
 
 
 
 
source/weights/{decoder-32B-512H-1L-e4.pt β†’ embeddings-32B-512H-1L-e10.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:92a9289a063a101f4f3214cc7b67990d62b9054dfe917cb40492a7bde5440c60
3
- size 18671955
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e6b0a7b05ab93d06da4fcc93dff769d02fc3ff48963b6979d3faa00de6f62a9
3
+ size 10241467
source/weights/embeddings-32B-512H-1L-e2.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:efb8031895da44da642975ba1a1997a214437ca61113edbbfa31f30a26c2ad9e
3
- size 5121462
 
 
 
 
source/weights/embeddings-32B-512H-1L-e4.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a0faba2400080ae7acf50c38b214f389a763c95f2f587d1d664110b5d9978cf
3
- size 5121462
 
 
 
 
source/weights/embeddings-32B-512H-1L-e5.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3ad73c03e1547417874d7d154213a893ac38adb24d74386a2055fc4d1fd46884
3
- size 5121041
 
 
 
 
source/weights/embeddings-32B-512H-1L-e6.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:76efafb94073fa60b15cfce698f78072465067e428968a104d174b8a3adabd32
3
- size 5121462
 
 
 
 
source/weights/{decoder-32B-512H-1L-e5.pt β†’ encoder-32B-512H-1L-e10.pt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5dd4bb9bd858cc8518a4af612df8721ca67d40a3428f53d34c50baef4ee87371
3
- size 18671739
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ea482f42ec88705fef214bfa92acd4ee535e331110eaeda32198e63a8a9c108c
3
+ size 98552306
source/weights/encoder-32B-512H-1L-e2.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0dff5de9d9ad9ea43fc5f67798b610f0bb92224590eba264766921b418a0d7a6
3
- size 96453806
 
 
 
 
source/weights/encoder-32B-512H-1L-e4.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c51a0ffb6eccb3fc2163b7c3214bdb9e32972a14b12d2be210289865bec4d7f7
3
- size 96453806
 
 
 
 
source/weights/encoder-32B-512H-1L-e5.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:92c5362bce19b36b330c58455985136d546821404d31477947544af70dbeab83
3
- size 96458817
 
 
 
 
source/weights/encoder-32B-512H-1L-e6.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9d59e51481084ee51bd810e7b0b87fa89577cfdcc8cfd76d5495f45beaff9feb
3
- size 96453806