Temp / app.py
pratyush19's picture
fix
8ed71e7
import gradio as gr
import pickle, os
import pandas as pd
import numpy as np
import os
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from PIL import Image
import requests
from io import BytesIO
model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
processor = CLIPProcessor.from_pretrained("patrickjohncyh/fashion-clip")
# hf_token = os.environ.get("HF_API_TOKEN")
# dataset = load_dataset('pratyush19/cyborg', use_auth_token=hf_token, split='train')
# dir_path = "train/"
# print (dataset)
# print (dataset[0].keys())
with open('valid_images_sample.pkl','rb') as f:
valid_images = pickle.load(f)
with open('image_encodings_sample.pkl','rb') as f:
image_encodings = pickle.load(f)
valid_images = np.array(valid_images)
with open('PIL_images.pkl','rb') as f:
PIL_images = pickle.load(f)
def softmax(x):
e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
return e_x / e_x.sum(axis=1, keepdims=True)
def find_similar_images(caption, image_encodings):
inputs = processor(text=[caption], return_tensors="pt")
text_features = model.get_text_features(**inputs)
text_features = text_features.detach().numpy()
logits_per_image = softmax(np.dot(text_features, image_encodings.T))
return logits_per_image
def find_relevant_images(caption):
similarity_scores = find_similar_images(caption, image_encodings)[0]
top_indices = np.argsort(similarity_scores)[::-1][:16]
# top_path = valid_images[top_indices]
images = []
for idx in top_indices:
images.append(PIL_images[idx])
return images
def gradio_interface(input_text):
# with open("user_inputs.txt", "a") as file:
# file.write(input_text + "\n")
images = find_relevant_images(input_text)
return images
def clear_inputs():
return [None, None, None, None, None, None, None, None, None,
None, None, None, None, None, None, None, None, None]
outputs = [None]*16
with gr.Blocks(title="MirrAI") as demo:
gr.Markdown("<h1 style='text-align: center;'>MirrAI: GenAI-based Fashion Search</h1>")
gr.Markdown("Enter a text to find the most relevant images from our dataset.")
text_input = gr.Textbox(lines=1, label="Input Text", placeholder="Enter your text here...")
with gr.Row():
cancel_button = gr.Button("Cancel")
submit_button = gr.Button("Submit")
examples = gr.Examples(["high-rise flare jean",
"a-line dress with floral",
"men colorful blazers",
"jumpsuit with puffed sleeve",
"sleeveless sweater",
"floral shirt",
"blue asymmetrical wedding dress with one sleeve",
"women long coat",
"cardigan sweater"], inputs=[text_input])
with gr.Row():
outputs[0] = gr.Image()
outputs[1] = gr.Image()
outputs[2] = gr.Image()
outputs[3] = gr.Image()
with gr.Row():
outputs[4] = gr.Image()
outputs[5] = gr.Image()
outputs[6] = gr.Image()
outputs[7] = gr.Image()
with gr.Row():
outputs[8] = gr.Image()
outputs[9] = gr.Image()
outputs[10] = gr.Image()
outputs[11] = gr.Image()
with gr.Row():
outputs[12] = gr.Image()
outputs[13] = gr.Image()
outputs[14] = gr.Image()
outputs[15] = gr.Image()
submit_button.click(
fn=gradio_interface,
inputs=text_input,
outputs=outputs
)
cancel_button.click(
fn=clear_inputs,
inputs=None,
outputs=[text_input] + outputs
)
demo.launch(share=True)