|
from typing import List, Optional |
|
|
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig |
|
import torch |
|
from haystack.nodes.base import BaseComponent |
|
from haystack.modeling.utils import initialize_device_settings |
|
from haystack.schema import Document |
|
|
|
|
|
class EntailmentChecker(BaseComponent): |
|
""" |
|
This node checks the entailment between every document content and the query. |
|
It enrichs the documents metadata with entailment informations. |
|
It also returns aggregate entailment information. |
|
""" |
|
|
|
outgoing_edges = 1 |
|
|
|
def __init__( |
|
self, |
|
model_name_or_path: str = "roberta-large-mnli", |
|
model_version: Optional[str] = None, |
|
tokenizer: Optional[str] = None, |
|
use_gpu: bool = True, |
|
batch_size: int = 16, |
|
entailment_contradiction_threshold: float = 0.5, |
|
): |
|
""" |
|
Load a Natural Language Inference model from Transformers. |
|
|
|
:param model_name_or_path: Directory of a saved model or the name of a public model. |
|
See https://huggingface.co/models for full list of available models. |
|
:param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. |
|
:param tokenizer: Name of the tokenizer (usually the same as model) |
|
:param use_gpu: Whether to use GPU (if available). |
|
:param batch_size: Number of Documents to be processed at a time. |
|
:param entailment_contradiction_threshold: if in the first N documents there is a strong evidence of entailment/contradiction |
|
(aggregate entailment or contradiction are greater than the threshold), the less relevant documents are not taken into account |
|
""" |
|
super().__init__() |
|
|
|
self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) |
|
|
|
tokenizer = tokenizer or model_name_or_path |
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) |
|
self.model = AutoModelForSequenceClassification.from_pretrained( |
|
pretrained_model_name_or_path=model_name_or_path, revision=model_version |
|
) |
|
self.batch_size = batch_size |
|
self.entailment_contradiction_threshold = entailment_contradiction_threshold |
|
self.model.to(str(self.devices[0])) |
|
|
|
id2label = AutoConfig.from_pretrained(model_name_or_path).id2label |
|
self.labels = [id2label[k].lower() for k in sorted(id2label)] |
|
if "entailment" not in self.labels: |
|
raise ValueError( |
|
"The model config must contain entailment value in the id2label dict." |
|
) |
|
|
|
def run(self, query: str, documents: List[Document]): |
|
|
|
scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0 |
|
for i, doc in enumerate(documents): |
|
entailment_info = self.get_entailment(premise=doc.content, hypotesis=query) |
|
doc.meta["entailment_info"] = entailment_info |
|
|
|
scores += doc.score |
|
con, neu, ent = ( |
|
entailment_info["contradiction"], |
|
entailment_info["neutral"], |
|
entailment_info["entailment"], |
|
) |
|
agg_con += con * doc.score |
|
agg_neu += neu * doc.score |
|
agg_ent += ent * doc.score |
|
|
|
|
|
|
|
if max(agg_con, agg_ent) / scores > self.entailment_contradiction_threshold: |
|
break |
|
|
|
aggregate_entailment_info = { |
|
"contradiction": round(agg_con / scores, 2), |
|
"neutral": round(agg_neu / scores, 2), |
|
"entailment": round(agg_ent / scores, 2), |
|
} |
|
|
|
entailment_checker_result = { |
|
"documents": documents[: i + 1], |
|
"aggregate_entailment_info": aggregate_entailment_info, |
|
} |
|
|
|
return entailment_checker_result, "output_1" |
|
|
|
def run_batch(self, queries: List[str], documents: List[Document]): |
|
pass |
|
|
|
def get_entailment(self, premise, hypotesis): |
|
with torch.inference_mode(): |
|
inputs = self.tokenizer( |
|
f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt" |
|
).to(self.devices[0]) |
|
out = self.model(**inputs) |
|
logits = out.logits |
|
probs = ( |
|
torch.nn.functional.softmax(logits, dim=-1)[0, :].detach().cpu().numpy() |
|
) |
|
entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)} |
|
return entailment_dict |
|
|