import torch import torch.nn.functional as F from sentence_transformers.util import cos_sim from transformers import AutoTokenizer, AutoModel class OutcomeSimilarity: """ similarity detector between outcomes statements""" ID2LABEL = ["different", "similar"] def __init__(self, model_path: str): self.tokenizer = AutoTokenizer.from_pretrained(model_path) self.model = AutoModel.from_pretrained(model_path) def _mean_pooling(self, model_output, attention_mask: torch.Tensor): """ Mean Pooling - Take attention mask into account for correct averaging""" # First element of model_output contains all token embeddings token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze( -1).expand(token_embeddings.size()).float() return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) def _encode(self, outcomes_lot: list[tuple[str,str]]): # Parse sentences sentences = [] if len(outcomes_lot) > 0: _, sentences = zip(*outcomes_lot) # Tokenize sentences encoded_input = self.tokenizer( sentences, padding=True, truncation=True, return_tensors='pt') # Compute token embeddings with torch.no_grad(): model_output = self.model(**encoded_input) # Perform pooling sentence_embeddings = self._mean_pooling( model_output, encoded_input['attention_mask']) # Normalize embeddings sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) return sentence_embeddings def get_similarity( self, registry_outcomes:list[tuple[str,str]], article_outcomes:list[tuple[str,str]] ) -> list[tuple[int,int,float]]: """For each outcome in true_dict, find the most similar outcome in compared_dict and return a mapping of all matchs , for each tuple : registry is the first index (at i=0); article is the second index (at i=1) and the third index (i=3) is the cosine similarity score""" connections = set() rembs = self._encode(registry_outcomes) aembs = self._encode(article_outcomes) cosines_scores = cos_sim(rembs, aembs) lines_max = torch.argmax(cosines_scores, dim=1) col_max = torch.argmax(cosines_scores, dim=0) remaining_cols = set(range(len(col_max))) for i in range(len(lines_max)): connection = (i, lines_max[i].item(), cosines_scores[i, lines_max[i]].item()) remaining_cols.discard(lines_max[i].item()) connections.add(connection) for j in remaining_cols: connection = (col_max[j].item(), j, cosines_scores[col_max[j], j].item()) connections.add(connection) return connections