Spaces:
Runtime error
Runtime error
import pandas as pd | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from tqdm.notebook import tqdm | |
import torch | |
from torch.autograd import Variable | |
import torchvision | |
import pickle | |
from PIL import Image | |
import torch.nn as nn | |
import math | |
import random | |
import gradio as gr | |
device = "cpu" | |
max_seq_len=67 | |
with open('index_to_word.pkl', 'rb') as handle: | |
index_to_word = pickle.load(handle) | |
with open('word_to_index.pkl', 'rb') as handle: | |
word_to_index = pickle.load(handle) | |
resnet18 = torchvision.models.resnet18(pretrained=True).to(device) | |
resnet18.eval() | |
resNet18Layer4 = resnet18._modules.get('layer4').to(device) | |
def create_df(img): | |
df = pd.DataFrame({"image": [img]}) | |
return df | |
def get_vector(t_img): | |
t_img = Variable(t_img) | |
my_embedding = torch.zeros(1, 512, 7, 7) | |
def copy_data(m, i, o): | |
my_embedding.copy_(o.data) | |
h = resNet18Layer4.register_forward_hook(copy_data) | |
resnet18(t_img) | |
h.remove() | |
return my_embedding | |
class extractImageFeatureResNetDataSet(): | |
from PIL import Image | |
def __init__(self, data): | |
self.data = data | |
self.scaler = transforms.Resize([224, 224]) | |
self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
self.to_tensor = transforms.ToTensor() | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
image_name = self.data.iloc[idx]['image'] | |
img_loc = str(image_name) #os.getcwd()+'/imput_img/'+str(image_name) | |
img = Image.open(img_loc) | |
t_img = self.normalize(self.to_tensor(self.scaler(img))) | |
return image_name, t_img | |
def feature_exctractor(df): | |
extract_imgFtr_ResNet_input = {} | |
input_ImageDataset_ResNet = extractImageFeatureResNetDataSet(df[['image']]) | |
input_ImageDataloader_ResNet = DataLoader(input_ImageDataset_ResNet, batch_size = 1, shuffle=False) | |
for image_name, t_img in tqdm(input_ImageDataloader_ResNet): | |
t_img = t_img.to("cpu") | |
embdg = get_vector(t_img) | |
extract_imgFtr_ResNet_input[image_name[0]] = embdg | |
return extract_imgFtr_ResNet_input | |
class PositionalEncoding(nn.Module): | |
def __init__(self, d_model, dropout=0.1, max_len=max_seq_len): | |
super(PositionalEncoding, self).__init__() | |
self.dropout = nn.Dropout(p=dropout) | |
pe = torch.zeros(max_len, d_model) | |
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
pe[:, 0::2] = torch.sin(position * div_term) | |
pe[:, 1::2] = torch.cos(position * div_term) | |
pe = pe.unsqueeze(0) | |
self.register_buffer('pe', pe) | |
def forward(self, x): | |
if self.pe.size(0) < x.size(0): | |
self.pe = self.pe.repeat(x.size(0), 1, 1).to(device) | |
self.pe = self.pe[:x.size(0), : , : ] | |
x = x + self.pe | |
return self.dropout(x) | |
class ImageCaptionModel(nn.Module): | |
def __init__(self, n_head, n_decoder_layer, vocab_size, embedding_size): | |
super(ImageCaptionModel, self).__init__() | |
self.pos_encoder = PositionalEncoding(embedding_size, 0.1) | |
self.TransformerDecoderLayer = nn.TransformerDecoderLayer(d_model = embedding_size, nhead = n_head) | |
self.TransformerDecoder = nn.TransformerDecoder(decoder_layer = self.TransformerDecoderLayer, num_layers = n_decoder_layer) | |
self.embedding_size = embedding_size | |
self.embedding = nn.Embedding(vocab_size , embedding_size) | |
self.last_linear_layer = nn.Linear(embedding_size, vocab_size) | |
self.init_weights() | |
def init_weights(self): | |
initrange = 0.1 | |
self.embedding.weight.data.uniform_(-initrange, initrange) | |
self.last_linear_layer.bias.data.zero_() | |
self.last_linear_layer.weight.data.uniform_(-initrange, initrange) | |
def generate_Mask(self, size, decoder_inp): | |
decoder_input_mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1) | |
decoder_input_mask = decoder_input_mask.float().masked_fill(decoder_input_mask == 0, float('-inf')).masked_fill(decoder_input_mask == 1, float(0.0)) | |
decoder_input_pad_mask = decoder_inp.float().masked_fill(decoder_inp == 0, float(0.0)).masked_fill(decoder_inp > 0, float(1.0)) | |
decoder_input_pad_mask_bool = decoder_inp == 0 | |
return decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool | |
def forward(self, encoded_image, decoder_inp): | |
encoded_image = encoded_image.permute(1,0,2) | |
decoder_inp_embed = self.embedding(decoder_inp)* math.sqrt(self.embedding_size) | |
decoder_inp_embed = self.pos_encoder(decoder_inp_embed) | |
decoder_inp_embed = decoder_inp_embed.permute(1,0,2) | |
decoder_input_mask, decoder_input_pad_mask, decoder_input_pad_mask_bool = self.generate_Mask(decoder_inp.size(1), decoder_inp) | |
decoder_input_mask = decoder_input_mask.to(device) | |
decoder_input_pad_mask = decoder_input_pad_mask.to(device) | |
decoder_input_pad_mask_bool = decoder_input_pad_mask_bool.to(device) | |
decoder_output = self.TransformerDecoder(tgt = decoder_inp_embed, memory = encoded_image, tgt_mask = decoder_input_mask, tgt_key_padding_mask = decoder_input_pad_mask_bool) | |
final_output = self.last_linear_layer(decoder_output) | |
return final_output, decoder_input_pad_mask | |
def generate_caption(K, img_nm, extract_imgFtr_ResNet_input): | |
from PIL import Image | |
img_loc = str(img_nm)#os.getcwd()+'/imput_img/'+ | |
image = Image.open(img_loc).convert("RGB") | |
#plt.imshow(image) | |
model.eval() | |
img_embed = extract_imgFtr_ResNet_input[img_nm].to(device) | |
img_embed = img_embed.permute(0,2,3,1) | |
img_embed = img_embed.view(img_embed.size(0), -1, img_embed.size(3)) | |
input_seq = [pad_token]*max_seq_len | |
input_seq[0] = start_token | |
input_seq = torch.tensor(input_seq).unsqueeze(0).to(device) | |
predicted_sentence = [] | |
with torch.no_grad(): | |
for eval_iter in range(0, max_seq_len): | |
output, padding_mask = model.forward(img_embed, input_seq) | |
output = output[eval_iter, 0, :] | |
values = torch.topk(output, K).values.tolist() | |
indices = torch.topk(output, K).indices.tolist() | |
next_word_index = random.choices(indices, values, k = 1)[0] | |
next_word = index_to_word[next_word_index] | |
input_seq[:, eval_iter+1] = next_word_index | |
if next_word == '<end>' : | |
break | |
predicted_sentence.append(next_word) | |
return " ".join(predicted_sentence + ["."]) | |
device = torch.device('cpu') | |
model = torch.load('./BestModel_20000_Datos', map_location=device) | |
start_token = word_to_index['<start>'] | |
end_token = word_to_index['<end>'] | |
pad_token = word_to_index['<pad>'] | |
def predict(inp): | |
device = "cpu" | |
max_seq_len=67 | |
with open('index_to_word.pkl', 'rb') as handle: | |
index_to_word = pickle.load(handle) | |
with open('word_to_index.pkl', 'rb') as handle: | |
word_to_index = pickle.load(handle) | |
resnet18 = torchvision.models.resnet18(pretrained=True).to(device) | |
resnet18.eval() | |
resNet18Layer4 = resnet18._modules.get('layer4').to(device) | |
df = create_df(inp) | |
extract_imgFtr_ResNet_input = feature_exctractor(df) | |
prediction = generate_caption(1, inp, extract_imgFtr_ResNet_input) | |
return prediction | |
gr.Interface(fn=predict, | |
inputs=gr.Image(type="filepath"), | |
outputs=gr.Text(), | |
title = "Clothe captioning model", | |
description = "A clothe image captioning model to get descriptions of your code.\n Take your phone, make a picture of your clothes, upload it and you are ready to go").launch() |