Spaces:
Runtime error
Runtime error
### Imports | |
from sentence_transformers import SentenceTransformer, util | |
### Classes and functions | |
##========================================================================================================== | |
class SentTransfUtilities: | |
##========================================================================================================== | |
""" | |
Definition of attributes | |
""" | |
model = None | |
__model_name = None | |
##========================================================================================================== | |
""" | |
Function: __init__ | |
Arguments: | |
- model_name: | |
Options: | |
- 'all-MiniLM-L6-v2 | |
- 'nq-distilbert-base-v1' | |
- 'paraphrase-multilingual-MiniLM-L12-v2' | |
""" | |
def __init__(self, model_name="all-MiniLM-L6-v2"): | |
self.__model_name = model_name | |
if self.model == None: | |
print("Initializing the Sentence Transformer model") | |
self.model = SentenceTransformer(self.__model_name) | |
##========================================================================================================== | |
""" | |
Function: get_embeddings() | |
""" | |
def get_embeddings(self, src_data): | |
return self.model.encode(src_data, convert_to_tensor=True, device='cpu') | |
##========================================================================================================== | |
""" | |
Function: compute_cosine_similarity(query_embeddings, passage_embeddings) | |
""" | |
def compute_cosine_similarity(self, query_embeddings, passage_embeddings): | |
#Compute cosine-similarities | |
cosine_scores = util.cos_sim(query_embeddings, passage_embeddings) | |
return cosine_scores | |
##========================================================================================================== | |
""" | |
Function: compute_dot_similarity(query_embeddings, passage_embeddings) | |
Arguments: | |
- query_embeddings | |
- passage_embeddings | |
""" | |
def compute_dot_similarity(self, query_embeddings, passage_embeddings): | |
#Compute dot-similarities | |
dot_scores = util.dot_score(query_embeddings, passage_embeddings) | |
return dot_scores | |
##========================================================================================================== | |
""" | |
Function: compute_semantic_search(query_embeddings, corpus_embeddings) | |
Arguments: | |
- query_embeddings | |
- corpus_embeddings | |
""" | |
def compute_semantic_search(self, query_embeddings, corpus_embeddings): | |
#Compute dot-similarities | |
dot_scores = util.semantic_search(query_embeddings, corpus_embeddings) | |
return dot_scores | |
##========================================================================================================== | |
""" | |
Function: compute_sentences_similarity(sentence_1, sentence_2, sim_func) | |
Arguments: | |
- sentence_1 | |
- sentence_2 | |
- sim_func: { "cosine", "dot" } | |
""" | |
def compute_sentences_similarity(self, sentence_1, sentence_2, sim_func="cosine"): | |
embeddings_1 = self.get_embeddings(sentence_1) | |
embeddings_2 = self.get_embeddings(sentence_2) | |
scores = None | |
if sim_func == "cosine": | |
scores = self.compute_cosine_similarity(embeddings_1, embeddings_2) | |
elif sim_func == "dot": | |
scores = self.compute_dot_similarity(embeddings_1, embeddings_2) | |
return scores | |
##========================================================================================================== | |
##========================================================================================================== | |