import torch from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer from transformers.tools import PipelineTool class TextPairClassificationTool(PipelineTool): default_checkpoint = "sgugger/bert-finetuned-mrpc" pre_processor_class = AutoTokenizer model_class = AutoModelForSequenceClassification inputs = ["text", "text"] outputs = ["text"] description = ( "This is a tool that classifies if two texts in English are similar or not using the labels 'equivalent' and " "'not_equivalent'. It takes two inputs named `text` and `second_text` which should be in English and returns " "the predicted label." ) def encode(self, text, second_text): return self.pre_processor(text, second_text, return_tensors="pt") def decode(self, outputs): logits = outputs.logits label_id = torch.argmax(logits[0]).item() return self.model.config.id2label[label_id]