File size: 1,020 Bytes
9747ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from sentence_transformers import SentenceTransformer, util

class Similarity:
    # Class variable to store the loaded model
    embedder = None

    def __init__(self, model_path="all-MiniLM-L6-v2", device="cuda:0"):
        self.device = device
        # Load the model if not already loaded
        if Similarity.embedder is None:
            Similarity.embedder = SentenceTransformer(model_path).to(device)

    def infer(self, queries, corpus, top=1):
        corpus_embeddings = Similarity.embedder.encode(corpus, convert_to_tensor=True).to(self.device)
        top_k = min(top, len(corpus))
        results=[]
        for query in queries:
            query_embedding = Similarity.embedder.encode(query, convert_to_tensor=True).to(self.device)
            cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
            top_results = torch.topk(cos_scores, k=top_k)
            for score, idx in zip(top_results[0], top_results[1]):
                results.append(corpus[idx])
        return results