import os import pickle import gradio as gr import sentence_transformers from PIL import Image from sentence_transformers import SentenceTransformer, util from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer # (Pdb) query_emb.shape # torch.Size([1, 512]) # (Pdb) img_emb.shape # (24996, 512) ## Define model model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") # Open the precomputed embeddings emb_filename = "lv-handbags.pkl" # emb_filename = 'unsplash-25k-photos-embeddings.pkl' with open(emb_filename, "rb") as fIn: img_names, img_emb = pickle.load(fIn) # print(f'img_emb: {print(img_emb)}') # print(f'img_names: {print(img_names)}') def search_text(query, top_k=4): """ " Search an image based on the text query. Args: query ([string]): [query you want search for] top_k (int, optional): [Amount of images o return]. Defaults to 1. Returns: [list]: [list of images that are related to the query.] """ # First, we encode the query. inputs = tokenizer([query], padding=True, return_tensors="pt") query_emb = model.get_text_features(**inputs) # import pdb; pdb.set_trace() # Then, we use the util.semantic_search function, which computes the cosine-similarity # between the query embedding and all image embeddings. # It then returns the top_k highest ranked images, which we output hits = util.semantic_search(query_emb, img_emb, top_k=top_k)[0] image = [] for hit in hits: # print(img_names[hit['corpus_id']]) # object = Image.open(os.path.join("photos/", img_names[hit['corpus_id']])) object = Image.open(os.path.join("lvphotos/", img_names[hit["corpus_id"]])) image.append(object) # print(f'array length is: {len(image)}') return image iface = gr.Interface( title="Hushh Vibe Search Model on Louis Vuitton API", description="Quick demo of using text to perform vector search on an image collection", article="TBD", fn=search_text, inputs=[ gr.Textbox( lines=4, label="Write what you are looking for in an image...", placeholder="Text Here...", ) ], outputs=[ gr.Gallery( label="Generated images", show_label=False, elem_id="gallery", columns=2 ) ], examples=[ [("Vacation")], [("Rock Star")], [("Barbie")], [("Small Purse")], [("Big Bag")], [("Shoes that won't make me look fat")], ], ).launch(debug=True)