Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pickle, os | |
import pandas as pd | |
import numpy as np | |
import os | |
from transformers import CLIPProcessor, CLIPModel | |
from datasets import load_dataset | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") | |
processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip") | |
# hf_token = os.environ.get("HF_API_TOKEN") | |
# dataset = load_dataset('pratyush19/cyborg', use_auth_token=hf_token, split='train') | |
# dir_path = "train/" | |
# print (dataset) | |
# print (dataset[0].keys()) | |
with open('valid_images_sample.pkl','rb') as f: | |
valid_images = pickle.load(f) | |
with open('image_encodings_sample.pkl','rb') as f: | |
image_encodings = pickle.load(f) | |
valid_images = np.array(valid_images) | |
with open('PIL_images.pkl','rb') as f: | |
PIL_images = pickle.load(f) | |
def softmax(x): | |
e_x = np.exp(x - np.max(x, axis=1, keepdims=True)) | |
return e_x / e_x.sum(axis=1, keepdims=True) | |
def find_similar_images(caption, image_encodings): | |
inputs = processor(text=[caption], return_tensors="pt") | |
text_features = model.get_text_features(**inputs) | |
text_features = text_features.detach().numpy() | |
logits_per_image = softmax(np.dot(text_features, image_encodings.T)) | |
return logits_per_image | |
def find_relevant_images(caption): | |
similarity_scores = find_similar_images(caption, image_encodings)[0] | |
top_indices = np.argsort(similarity_scores)[::-1][:16] | |
# top_path = valid_images[top_indices] | |
images = [] | |
for idx in top_indices: | |
images.append(PIL_images[idx]) | |
return images | |
def gradio_interface(input_text): | |
# with open("user_inputs.txt", "a") as file: | |
# file.write(input_text + "\n") | |
images = find_relevant_images(input_text) | |
return images | |
def clear_inputs(): | |
return [None, None, None, None, None, None, None, None, None, | |
None, None, None, None, None, None, None, None, None] | |
outputs = [None]*16 | |
with gr.Blocks(title="MirrAI") as demo: | |
gr.Markdown("<h1 style='text-align: center;'>MirrAI: GenAI-based Fashion Search</h1>") | |
gr.Markdown("Enter a text to find the most relevant images from our dataset.") | |
text_input = gr.Textbox(lines=1, label="Input Text", placeholder="Enter your text here...") | |
with gr.Row(): | |
cancel_button = gr.Button("Cancel") | |
submit_button = gr.Button("Submit") | |
examples = gr.Examples(["high-rise flare jean", | |
"a-line dress with floral", | |
"men colorful blazers", | |
"jumpsuit with puffed sleeve", | |
"sleeveless sweater", | |
"floral shirt", | |
"blue asymmetrical wedding dress with one sleeve", | |
"women long coat", | |
"cardigan sweater"], inputs=[text_input]) | |
with gr.Row(): | |
outputs[0] = gr.Image() | |
outputs[1] = gr.Image() | |
outputs[2] = gr.Image() | |
outputs[3] = gr.Image() | |
with gr.Row(): | |
outputs[4] = gr.Image() | |
outputs[5] = gr.Image() | |
outputs[6] = gr.Image() | |
outputs[7] = gr.Image() | |
with gr.Row(): | |
outputs[8] = gr.Image() | |
outputs[9] = gr.Image() | |
outputs[10] = gr.Image() | |
outputs[11] = gr.Image() | |
with gr.Row(): | |
outputs[12] = gr.Image() | |
outputs[13] = gr.Image() | |
outputs[14] = gr.Image() | |
outputs[15] = gr.Image() | |
submit_button.click( | |
fn=gradio_interface, | |
inputs=text_input, | |
outputs=outputs | |
) | |
cancel_button.click( | |
fn=clear_inputs, | |
inputs=None, | |
outputs=[text_input] + outputs | |
) | |
demo.launch(share=True) | |