newer_project / modules /similar_search.py
YaTharThShaRma999's picture
Create similar_search.py
9747ab6 verified
raw
history blame
No virus
1.02 kB
from sentence_transformers import SentenceTransformer, util
class Similarity:
# Class variable to store the loaded model
embedder = None
def __init__(self, model_path="all-MiniLM-L6-v2", device="cuda:0"):
self.device = device
# Load the model if not already loaded
if Similarity.embedder is None:
Similarity.embedder = SentenceTransformer(model_path).to(device)
def infer(self, queries, corpus, top=1):
corpus_embeddings = Similarity.embedder.encode(corpus, convert_to_tensor=True).to(self.device)
top_k = min(top, len(corpus))
results=[]
for query in queries:
query_embedding = Similarity.embedder.encode(query, convert_to_tensor=True).to(self.device)
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=top_k)
for score, idx in zip(top_results[0], top_results[1]):
results.append(corpus[idx])
return results