File size: 4,610 Bytes
6923ebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from typing import List, Optional

from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig
import torch
from haystack.nodes.base import BaseComponent
from haystack.modeling.utils import initialize_device_settings
from haystack.schema import Document


class EntailmentChecker(BaseComponent):
    """
    This node checks the entailment between every document content and the query.
    It enrichs the documents metadata with entailment informations.
    It also returns aggregate entailment information.
    """

    outgoing_edges = 1

    def __init__(
        self,
        model_name_or_path: str = "roberta-large-mnli",
        model_version: Optional[str] = None,
        tokenizer: Optional[str] = None,
        use_gpu: bool = True,
        batch_size: int = 16,
        entailment_contradiction_threshold: float = 0.5,
    ):
        """
        Load a Natural Language Inference model from Transformers.

        :param model_name_or_path: Directory of a saved model or the name of a public model.
        See https://huggingface.co/models for full list of available models.
        :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash.
        :param tokenizer: Name of the tokenizer (usually the same as model)
        :param use_gpu: Whether to use GPU (if available).
        :param batch_size: Number of Documents to be processed at a time.
        :param entailment_contradiction_threshold: if in the first N documents there is a strong evidence of entailment/contradiction
        (aggregate entailment or contradiction are greater than the threshold), the less relevant documents are not taken into account
        """
        super().__init__()

        self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False)

        tokenizer = tokenizer or model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        self.model = AutoModelForSequenceClassification.from_pretrained(
            pretrained_model_name_or_path=model_name_or_path, revision=model_version
        )
        self.batch_size = batch_size
        self.entailment_contradiction_threshold = entailment_contradiction_threshold
        self.model.to(str(self.devices[0]))

        id2label = AutoConfig.from_pretrained(model_name_or_path).id2label
        self.labels = [id2label[k].lower() for k in sorted(id2label)]
        if "entailment" not in self.labels:
            raise ValueError(
                "The model config must contain entailment value in the id2label dict."
            )

    def run(self, query: str, documents: List[Document]):

        scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0
        for i, doc in enumerate(documents):
            entailment_info = self.get_entailment(premise=doc.content, hypotesis=query)
            doc.meta["entailment_info"] = entailment_info

            scores += doc.score
            con, neu, ent = (
                entailment_info["contradiction"],
                entailment_info["neutral"],
                entailment_info["entailment"],
            )
            agg_con += con * doc.score
            agg_neu += neu * doc.score
            agg_ent += ent * doc.score

            # if in the first documents there is a strong evidence of entailment/contradiction,
            # there is no need to consider less relevant documents
            if max(agg_con, agg_ent) / scores > self.entailment_contradiction_threshold:
                break

        aggregate_entailment_info = {
            "contradiction": round(agg_con / scores, 2),
            "neutral": round(agg_neu / scores, 2),
            "entailment": round(agg_ent / scores, 2),
        }

        entailment_checker_result = {
            "documents": documents[: i + 1],
            "aggregate_entailment_info": aggregate_entailment_info,
        }

        return entailment_checker_result, "output_1"

    def run_batch(self, queries: List[str], documents: List[Document]):
        pass

    def get_entailment(self, premise, hypotesis):
        with torch.inference_mode():
            inputs = self.tokenizer(
                f"{premise}{self.tokenizer.sep_token}{hypotesis}", return_tensors="pt"
            ).to(self.devices[0])
            out = self.model(**inputs)
            logits = out.logits
            probs = (
                torch.nn.functional.softmax(logits, dim=-1)[0, :].detach().cpu().numpy()
            )
        entailment_dict = {k.lower(): v for k, v in zip(self.labels, probs)}
        return entailment_dict