File size: 4,361 Bytes
cb7427c
 
 
 
 
fa5cb64
 
 
cb7427c
 
 
 
 
 
 
 
58b663c
 
 
 
cb7427c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58b663c
 
cb7427c
 
 
 
 
58b663c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb7427c
58b663c
 
cb7427c
58b663c
 
cb7427c
58b663c
 
 
cb7427c
58b663c
 
 
 
 
cb7427c
 
 
 
dc4fff3
 
cb7427c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a90e40
cb7427c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image

from source.vocab import Vocab
from source.model import Decoder, Encoder
from source.config import Config


def generate_caption(image: torch.Tensor,
                     image_encoder: Encoder,
                     emb_layer: torch.nn.Embedding,
                     image_decoder: Decoder,
                     vocab: Vocab,
                     device: torch.device) -> list[str]:
    """
    Generate caption of a single image of size (3, 224, 224).
    Generating of caption starts with <sos>, and each next predicted word ID
    is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.

    Returns:
        list[str]: caption for given image
    """

    image = image.to(device)
    # image: (3, 224, 224)
    image = image.unsqueeze(0)
    # image: (1, 3, 224, 224)

    hidden = image_decoder.hidden_state_0
    cell = image_decoder.cell_state_0
    # hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)

    sentence = []

    # initialize LSTM input to SOS token = 1
    input_words = [vocab.SOS]

    MAX_LENGTH = 20

    for i in range(MAX_LENGTH):

        features = image_encoder.forward(image)
        # features: (1, IMAGE_EMB_DIM)
        features = features.to(device)
        features = features.unsqueeze(0)
        # features: (1, 1, IMAGE_EMB_DIM)

        input_words_tensor = torch.tensor([input_words])
        # input_word_tensor : (B=1, SEQ_LENGTH)
        input_words_tensor = input_words_tensor.to(device)

        lstm_input = emb_layer.forward(input_words_tensor)
        # lstm_input : (B=1, SEQ_LENGTH, WORD_EMB_DIM)

        lstm_input = lstm_input.permute(1, 0, 2)
        # lstm_input : (SEQ_LENGTH, B=1, WORD_EMB_DIM)
        SEQ_LENGTH = lstm_input.shape[0]

        features = features.repeat(SEQ_LENGTH, 1, 1)
        # features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)

        next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
        # next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)

        next_id_pred = next_id_pred[-1, 0, :]
        # next_id_pred : (VOCAB_SIZE)
        next_id_pred = torch.argmax(next_id_pred)

        # append it to input_words which will be again as input for LSTM
        input_words.append(next_id_pred.item())

        # id --> word
        next_word_pred = vocab.index_to_word(int(next_id_pred.item()))

        if next_word_pred == vocab.index2word[vocab.EOS]:
            break

        sentence.append(next_word_pred)

    return sentence


def main_caption(image):

    config = Config()

    vocab = Vocab()
    vocab.load_vocab(config.VOCAB_FILE)

    image = Image.fromarray(image.astype('uint8'), 'RGB')

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    image = transform(image)

    image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM,
                            device=config.DEVICE)
    emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
                                   embedding_dim=config.WORD_EMB_DIM,
                                   padding_idx=vocab.PADDING_INDEX)
    image_decoder = Decoder(word_emb_dim=config.WORD_EMB_DIM,
                            hidden_dim=config.HIDDEN_DIM,
                            num_layers=config.NUM_LAYER,
                            vocab_size=config.VOCAB_SIZE,
                            device=config.DEVICE)

    emb_layer.eval()
    image_encoder.eval()
    image_decoder.eval()

    emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE))
    image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE))
    image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE))

    emb_layer = emb_layer.to(config.DEVICE)
    image_encoder = image_encoder.to(config.DEVICE)
    image_decoder = image_decoder.to(config.DEVICE)
    image = image.to(config.DEVICE)

    sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE)
    description = ' '.join(word for word in sentence)

    return description