Spaces:
Sleeping
Sleeping
File size: 4,361 Bytes
cb7427c fa5cb64 cb7427c 58b663c cb7427c 58b663c cb7427c 58b663c cb7427c 58b663c cb7427c 58b663c cb7427c 58b663c cb7427c 58b663c cb7427c dc4fff3 cb7427c 9a90e40 cb7427c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from torchvision import transforms
import torch
import torch.utils.data
from PIL import Image
from source.vocab import Vocab
from source.model import Decoder, Encoder
from source.config import Config
def generate_caption(image: torch.Tensor,
image_encoder: Encoder,
emb_layer: torch.nn.Embedding,
image_decoder: Decoder,
vocab: Vocab,
device: torch.device) -> list[str]:
"""
Generate caption of a single image of size (3, 224, 224).
Generating of caption starts with <sos>, and each next predicted word ID
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>.
Returns:
list[str]: caption for given image
"""
image = image.to(device)
# image: (3, 224, 224)
image = image.unsqueeze(0)
# image: (1, 3, 224, 224)
hidden = image_decoder.hidden_state_0
cell = image_decoder.cell_state_0
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM)
sentence = []
# initialize LSTM input to SOS token = 1
input_words = [vocab.SOS]
MAX_LENGTH = 20
for i in range(MAX_LENGTH):
features = image_encoder.forward(image)
# features: (1, IMAGE_EMB_DIM)
features = features.to(device)
features = features.unsqueeze(0)
# features: (1, 1, IMAGE_EMB_DIM)
input_words_tensor = torch.tensor([input_words])
# input_word_tensor : (B=1, SEQ_LENGTH)
input_words_tensor = input_words_tensor.to(device)
lstm_input = emb_layer.forward(input_words_tensor)
# lstm_input : (B=1, SEQ_LENGTH, WORD_EMB_DIM)
lstm_input = lstm_input.permute(1, 0, 2)
# lstm_input : (SEQ_LENGTH, B=1, WORD_EMB_DIM)
SEQ_LENGTH = lstm_input.shape[0]
features = features.repeat(SEQ_LENGTH, 1, 1)
# features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM)
next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell)
# next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE)
next_id_pred = next_id_pred[-1, 0, :]
# next_id_pred : (VOCAB_SIZE)
next_id_pred = torch.argmax(next_id_pred)
# append it to input_words which will be again as input for LSTM
input_words.append(next_id_pred.item())
# id --> word
next_word_pred = vocab.index_to_word(int(next_id_pred.item()))
if next_word_pred == vocab.index2word[vocab.EOS]:
break
sentence.append(next_word_pred)
return sentence
def main_caption(image):
config = Config()
vocab = Vocab()
vocab.load_vocab(config.VOCAB_FILE)
image = Image.fromarray(image.astype('uint8'), 'RGB')
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM,
device=config.DEVICE)
emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE,
embedding_dim=config.WORD_EMB_DIM,
padding_idx=vocab.PADDING_INDEX)
image_decoder = Decoder(word_emb_dim=config.WORD_EMB_DIM,
hidden_dim=config.HIDDEN_DIM,
num_layers=config.NUM_LAYER,
vocab_size=config.VOCAB_SIZE,
device=config.DEVICE)
emb_layer.eval()
image_encoder.eval()
image_decoder.eval()
emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE))
image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE))
image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE))
emb_layer = emb_layer.to(config.DEVICE)
image_encoder = image_encoder.to(config.DEVICE)
image_decoder = image_decoder.to(config.DEVICE)
image = image.to(config.DEVICE)
sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE)
description = ' '.join(word for word in sentence)
return description
|