YaTharThShaRma999 commited on
Commit
9747ab6
1 Parent(s): 75ccb56

Create similar_search.py

Browse files
Files changed (1) hide show
  1. modules/similar_search.py +23 -0
modules/similar_search.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer, util
2
+
3
+ class Similarity:
4
+ # Class variable to store the loaded model
5
+ embedder = None
6
+
7
+ def __init__(self, model_path="all-MiniLM-L6-v2", device="cuda:0"):
8
+ self.device = device
9
+ # Load the model if not already loaded
10
+ if Similarity.embedder is None:
11
+ Similarity.embedder = SentenceTransformer(model_path).to(device)
12
+
13
+ def infer(self, queries, corpus, top=1):
14
+ corpus_embeddings = Similarity.embedder.encode(corpus, convert_to_tensor=True).to(self.device)
15
+ top_k = min(top, len(corpus))
16
+ results=[]
17
+ for query in queries:
18
+ query_embedding = Similarity.embedder.encode(query, convert_to_tensor=True).to(self.device)
19
+ cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
20
+ top_results = torch.topk(cos_scores, k=top_k)
21
+ for score, idx in zip(top_results[0], top_results[1]):
22
+ results.append(corpus[idx])
23
+ return results