|
from sentence_transformers import SentenceTransformer, util |
|
|
|
class Similarity: |
|
|
|
embedder = None |
|
|
|
def __init__(self, model_path="all-MiniLM-L6-v2", device="cuda:0"): |
|
self.device = device |
|
|
|
if Similarity.embedder is None: |
|
Similarity.embedder = SentenceTransformer(model_path).to(device) |
|
|
|
def infer(self, queries, corpus, top=1): |
|
corpus_embeddings = Similarity.embedder.encode(corpus, convert_to_tensor=True).to(self.device) |
|
top_k = min(top, len(corpus)) |
|
results=[] |
|
for query in queries: |
|
query_embedding = Similarity.embedder.encode(query, convert_to_tensor=True).to(self.device) |
|
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0] |
|
top_results = torch.topk(cos_scores, k=top_k) |
|
for score, idx in zip(top_results[0], top_results[1]): |
|
results.append(corpus[idx]) |
|
return results |