MO3ALIMI / pages /utils.py
mouadenna's picture
Upload 34 files
1da5a00 verified
raw
history blame
2.06 kB
import os
import cv2
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline
import torch
sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline.to("cuda")
# Step 3: Function to get the embedding of the input sentence
def get_sentence_embedding(sentence):
return sentence_model.encode(sentence)
# Step 4: Generate image using Stable Diffusion if needed
def generate_image(prompt):
global pipeline
pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
generated_image = pipeline(prompt).images[0]
generated_image_path = "generated_image.png"
generated_image.save(generated_image_path)
return generated_image_path
# Step 5: Find the most reliable image
def find_most_reliable_image(folder_path, input_sentence, threshold=0.5):
image_files = [f for f in os.listdir(folder_path) if f.endswith(('jpg', 'jpeg', 'png'))]
sentence_embedding = get_sentence_embedding(input_sentence)
max_similarity = -1
most_reliable_image = None
for image_file in image_files:
filename_without_extension = os.path.splitext(image_file)[0]
filename_embedding = get_sentence_embedding(filename_without_extension)
similarity = cosine_similarity([sentence_embedding], [filename_embedding])[0][0]
if similarity > max_similarity:
max_similarity = similarity
most_reliable_image = os.path.join(folder_path, image_file)
if max_similarity < threshold:
most_reliable_image = generate_image(input_sentence)
return most_reliable_image
def findImg(input_sentence):
folder_path = 'images_collection'
threshold = 0.5
most_reliable_image = find_most_reliable_image(folder_path, input_sentence, threshold)
return most_reliable_image