File size: 2,057 Bytes
1da5a00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import cv2
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from diffusers import StableDiffusionPipeline
import torch


sentence_model = SentenceTransformer('all-MiniLM-L6-v2')
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipeline.to("cuda")

# Step 3: Function to get the embedding of the input sentence
def get_sentence_embedding(sentence):
    return sentence_model.encode(sentence)
# Step 4: Generate image using Stable Diffusion if needed
def generate_image(prompt):
    global pipeline
    pipeline.to("cuda" if torch.cuda.is_available() else "cpu")
    generated_image = pipeline(prompt).images[0]
    generated_image_path = "generated_image.png"
    generated_image.save(generated_image_path)
    return generated_image_path

# Step 5: Find the most reliable image
def find_most_reliable_image(folder_path, input_sentence, threshold=0.5):
    image_files = [f for f in os.listdir(folder_path) if f.endswith(('jpg', 'jpeg', 'png'))]
    sentence_embedding = get_sentence_embedding(input_sentence)
    
    max_similarity = -1
    most_reliable_image = None
    
    for image_file in image_files:
        filename_without_extension = os.path.splitext(image_file)[0]
        filename_embedding = get_sentence_embedding(filename_without_extension)
        similarity = cosine_similarity([sentence_embedding], [filename_embedding])[0][0]
        
        if similarity > max_similarity:
            max_similarity = similarity
            most_reliable_image = os.path.join(folder_path, image_file)
    
    if max_similarity < threshold:
        most_reliable_image = generate_image(input_sentence)
    
    return most_reliable_image

def findImg(input_sentence):
    folder_path = 'images_collection'
    threshold = 0.5
    most_reliable_image = find_most_reliable_image(folder_path, input_sentence, threshold)
    return most_reliable_image