Spaces:
Sleeping
Sleeping
import torch | |
# usersim_path_shoes = "http://www.dcs.gla.ac.uk/~craigm/fcrs/model_checkpoints/caption_model_shoes" | |
# usersim_path_dresses = "http://www.dcs.gla.ac.uk/~craigm/fcrs/captioners/dresses_cap_caption_models" | |
drive_path = 'mmir_usersim_resources/' | |
data_type= ["shoes", "dresses", "shirts", "tops&tees"] | |
usersim_path_shoes = drive_path + "checkpoints_usersim/shoes" | |
usersim_path_dresses = drive_path + "checkpoints_usersim/dresses" | |
usersim_path_shirts = drive_path + "checkpoints_usersim/shirts" | |
usersim_path_topstees = drive_path + "checkpoints_usersim/topstees" | |
usersim_path = [usersim_path_shoes, usersim_path_dresses, usersim_path_shirts, usersim_path_topstees] | |
import captioning.captioner as captioner | |
image_feat_params = {'model':'resnet101','model_root':drive_path + 'imagenet_weights','att_size':7} | |
# image_feat_params = {'model':'resnet101','model_root':'','att_size':7} | |
captioner_relative_shoes = captioner.Captioner(is_relative= True, model_path= usersim_path[0], image_feat_params=image_feat_params, data_type=data_type[0], load_resnet=True) | |
captioner_relative_dresses = captioner.Captioner(is_relative= True, model_path= usersim_path[1], image_feat_params=image_feat_params, data_type=data_type[1], load_resnet=True) | |
captioner_relative_shirts = captioner.Captioner(is_relative= True, model_path= usersim_path[2], image_feat_params=image_feat_params, data_type=data_type[2], load_resnet=True) | |
captioner_relative_topstees = captioner.Captioner(is_relative= True, model_path= usersim_path[3], image_feat_params=image_feat_params, data_type=data_type[3], load_resnet=True) | |
def generate_sentence_shoes(image_path_1, image_path_2): | |
fc_feat, att_feat = captioner_relative_shoes.get_img_feat(image_path_1) | |
fc_feat_ref, att_feat_ref = captioner_relative_shoes.get_img_feat(image_path_2) | |
fc_feat = torch.unsqueeze(fc_feat, dim=0) | |
att_feat = torch.unsqueeze(att_feat, dim=0) | |
fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0) | |
att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0) | |
seq, sents = captioner_relative_shoes.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref)) | |
sentence = sents[0] | |
return sentence | |
def generate_sentence_dresses(image_path_1, image_path_2): | |
fc_feat, att_feat = captioner_relative_dresses.get_img_feat(image_path_1) | |
fc_feat_ref, att_feat_ref = captioner_relative_dresses.get_img_feat(image_path_2) | |
fc_feat = torch.unsqueeze(fc_feat, dim=0) | |
att_feat = torch.unsqueeze(att_feat, dim=0) | |
fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0) | |
att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0) | |
seq, sents = captioner_relative_dresses.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref)) | |
sentence = sents[0] | |
return sentence | |
def generate_sentence_shirts(image_path_1, image_path_2): | |
fc_feat, att_feat = captioner_relative_shirts.get_img_feat(image_path_1) | |
fc_feat_ref, att_feat_ref = captioner_relative_shirts.get_img_feat(image_path_2) | |
fc_feat = torch.unsqueeze(fc_feat, dim=0) | |
att_feat = torch.unsqueeze(att_feat, dim=0) | |
fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0) | |
att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0) | |
seq, sents = captioner_relative_shirts.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref)) | |
sentence = sents[0] | |
return sentence | |
def generate_sentence_topstees(image_path_1, image_path_2): | |
fc_feat, att_feat = captioner_relative_topstees.get_img_feat(image_path_1) | |
fc_feat_ref, att_feat_ref = captioner_relative_topstees.get_img_feat(image_path_2) | |
fc_feat = torch.unsqueeze(fc_feat, dim=0) | |
att_feat = torch.unsqueeze(att_feat, dim=0) | |
fc_feat_ref = torch.unsqueeze(fc_feat_ref, dim=0) | |
att_feat_ref = torch.unsqueeze(att_feat_ref, dim=0) | |
seq, sents = captioner_relative_topstees.gen_caption_from_feat((fc_feat,att_feat), (fc_feat_ref,att_feat_ref)) | |
sentence = sents[0] | |
return sentence | |
import numpy as np | |
import gradio as gr | |
examples_shoes = [["images/shoes/img_womens_athletic_shoes_1223.jpg", "images/shoes/img_womens_athletic_shoes_830.jpg"], | |
["images/shoes/img_womens_athletic_shoes_830.jpg", "images/shoes/img_womens_athletic_shoes_1223.jpg"], | |
["images/shoes/img_womens_high_heels_559.jpg", "images/shoes/img_womens_high_heels_690.jpg"], | |
["images/shoes/img_womens_high_heels_690.jpg", "images/shoes/img_womens_high_heels_559.jpg"]] | |
examples_dresses = [["images/dresses/B007UZSPC8.jpg", "images/dresses/B006MPVW4U.jpg"], | |
["images/dresses/B005KMQQFQ.jpg", "images/dresses/B005QYY5W4.jpg"], | |
["images/dresses/B005OBAGD6.jpg", "images/dresses/B006U07GW4.jpg"], | |
["images/dresses/B0047Y0K0U.jpg", "images/dresses/B006TAM4CW.jpg"]] | |
examples_shirts = [["images/shirts/B00305G9I4.jpg", "images/shirts/B005BLUUJY.jpg"], | |
["images/shirts/B004WSVYX8.jpg", "images/shirts/B008TP27PY.jpg"], | |
["images/shirts/B003INE0Q6.jpg", "images/shirts/B0051D0X2Q.jpg"], | |
["images/shirts/B00EZUKCCM.jpg", "images/shirts/B00B88ZKXA.jpg"]] | |
examples_topstees = [["images/topstees/B0082993AO.jpg", "images/topstees/B008293HO2.jpg"], | |
["images/topstees/B006YN4J2C.jpg", "images/topstees/B0035EPUBW.jpg"], | |
["images/topstees/B00B5SKOMU.jpg", "images/topstees/B004H3XMYM.jpg"], | |
["images/topstees/B008DVXGO0.jpg", "images/topstees/B008JYNN30.jpg"] | |
] | |
with gr.Blocks() as demo: | |
gr.Markdown("Relative Captioning for Fashion.") | |
with gr.Tab("Shoes"): | |
with gr.Row(): | |
target_shoes = gr.Image(source="upload", type="filepath", label="Target Image") | |
candidate_shoes = gr.Image(source="upload", type="filepath", label="Candidate Image") | |
output_text_shoes = gr.Textbox(label="Generated Sentence") | |
shoes_btn = gr.Button("Generate") | |
gr.Examples(examples_shoes, inputs=[target_shoes, candidate_shoes]) | |
with gr.Tab("Dresses"): | |
with gr.Row(): | |
target_dresses = gr.Image(source="upload", type="filepath", label="Target Image") | |
candidate_dresses = gr.Image(source="upload", type="filepath", label="Candidate Image") | |
output_text_dresses = gr.Textbox(label="Generated Sentence") | |
dresses_btn = gr.Button("Generate") | |
gr.Examples(examples_dresses, inputs=[target_dresses, candidate_dresses]) | |
with gr.Tab("Shirts"): | |
with gr.Row(): | |
target_shirts = gr.Image(source="upload", type="filepath", label="Target Image") | |
candidate_shirts = gr.Image(source="upload", type="filepath", label="Candidate Image") | |
output_text_shirts = gr.Textbox(label="Generated Sentence") | |
shirts_btn = gr.Button("Generate") | |
gr.Examples(examples_shirts, inputs=[target_shirts, candidate_shirts]) | |
with gr.Tab("Tops&Tees"): | |
with gr.Row(): | |
target_topstees = gr.Image(source="upload", type="filepath", label="Target Image") | |
candidate_topstees = gr.Image(source="upload", type="filepath", label="Candidate Image") | |
output_text_topstees = gr.Textbox(label="Generated Sentence") | |
topstees_btn = gr.Button("Generate") | |
gr.Examples(examples_topstees, inputs=[target_topstees, candidate_topstees]) | |
shoes_btn.click(generate_sentence_shoes, inputs=[target_shoes, candidate_shoes], outputs=output_text_shoes) | |
dresses_btn.click(generate_sentence_dresses, inputs=[target_dresses, candidate_dresses], outputs=output_text_dresses) | |
shirts_btn.click(generate_sentence_shirts, inputs=[target_shirts, candidate_shirts], outputs=output_text_shirts) | |
topstees_btn.click(generate_sentence_topstees, inputs=[target_topstees, candidate_topstees], outputs=output_text_topstees) | |
demo.queue(concurrency_count=3) | |
demo.launch() |