Spaces:
Running
Running
from torchvision import transforms | |
import torch | |
import torch.utils.data | |
from PIL import Image | |
from source.vocab import Vocab | |
from source.model import Decoder, Encoder | |
from source.config import Config | |
def generate_caption(image: torch.Tensor, | |
image_encoder: Encoder, | |
emb_layer: torch.nn.Embedding, | |
image_decoder: Decoder, | |
vocab: Vocab, | |
device: torch.device) -> list[str]: | |
""" | |
Generate caption of a single image of size (3, 224, 224). | |
Generating of caption starts with <sos>, and each next predicted word ID | |
is appended for the next LSTM input until the sentence reaches MAX_LENGTH or <eos>. | |
Returns: | |
list[str]: caption for given image | |
""" | |
image = image.to(device) | |
# image: (3, 224, 224) | |
image = image.unsqueeze(0) | |
# image: (1, 3, 224, 224) | |
hidden = image_decoder.hidden_state_0 | |
cell = image_decoder.cell_state_0 | |
# hidden, cell : (NUM_LAYER, 1, HIDDEN_DIM) | |
sentence = [] | |
# initialize LSTM input to SOS token = 1 | |
input_words = [vocab.SOS] | |
MAX_LENGTH = 20 | |
for i in range(MAX_LENGTH): | |
features = image_encoder.forward(image) | |
# features: (1, IMAGE_EMB_DIM) | |
features = features.to(device) | |
features = features.unsqueeze(0) | |
# features: (1, 1, IMAGE_EMB_DIM) | |
input_words_tensor = torch.tensor([input_words]) | |
# input_word_tensor : (B=1, SEQ_LENGTH) | |
input_words_tensor = input_words_tensor.to(device) | |
lstm_input = emb_layer.forward(input_words_tensor) | |
# lstm_input : (B=1, SEQ_LENGTH, WORD_EMB_DIM) | |
lstm_input = lstm_input.permute(1, 0, 2) | |
# lstm_input : (SEQ_LENGTH, B=1, WORD_EMB_DIM) | |
SEQ_LENGTH = lstm_input.shape[0] | |
features = features.repeat(SEQ_LENGTH, 1, 1) | |
# features : (SEQ_LENGTH, B=1, IMAGE_EMB_DIM) | |
next_id_pred, (hidden, cell) = image_decoder.forward(lstm_input, features, hidden, cell) | |
# next_id_pred : (SEQ_LENGTH, 1, VOCAB_SIZE) | |
next_id_pred = next_id_pred[-1, 0, :] | |
# next_id_pred : (VOCAB_SIZE) | |
next_id_pred = torch.argmax(next_id_pred) | |
# append it to input_words which will be again as input for LSTM | |
input_words.append(next_id_pred.item()) | |
# id --> word | |
next_word_pred = vocab.index_to_word(int(next_id_pred.item())) | |
if next_word_pred == vocab.index2word[vocab.EOS]: | |
break | |
sentence.append(next_word_pred) | |
return sentence | |
def main_caption(image): | |
config = Config() | |
vocab = Vocab() | |
vocab.load_vocab(config.VOCAB_FILE) | |
image = Image.fromarray(image.astype('uint8'), 'RGB') | |
transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
image = transform(image) | |
image_encoder = Encoder(image_emb_dim=config.IMAGE_EMB_DIM, | |
device=config.DEVICE) | |
emb_layer = torch.nn.Embedding(num_embeddings=config.VOCAB_SIZE, | |
embedding_dim=config.WORD_EMB_DIM, | |
padding_idx=vocab.PADDING_INDEX) | |
image_decoder = Decoder(word_emb_dim=config.WORD_EMB_DIM, | |
hidden_dim=config.HIDDEN_DIM, | |
num_layers=config.NUM_LAYER, | |
vocab_size=config.VOCAB_SIZE, | |
device=config.DEVICE) | |
emb_layer.eval() | |
image_encoder.eval() | |
image_decoder.eval() | |
emb_layer.load_state_dict(torch.load(f=config.EMBEDDING_WEIGHT_FILE, map_location=config.DEVICE)) | |
image_encoder.load_state_dict(torch.load(f=config.ENCODER_WEIGHT_FILE, map_location=config.DEVICE)) | |
image_decoder.load_state_dict(torch.load(f=config.DECODER_WEIGHT_FILE, map_location=config.DEVICE)) | |
emb_layer = emb_layer.to(config.DEVICE) | |
image_encoder = image_encoder.to(config.DEVICE) | |
image_decoder = image_decoder.to(config.DEVICE) | |
image = image.to(config.DEVICE) | |
sentence = generate_caption(image, image_encoder, emb_layer, image_decoder, vocab, device=config.DEVICE) | |
description = ' '.join(word for word in sentence) | |
return description | |