|
import torch |
|
|
|
|
|
from transformers import PreTrainedModel, AutoConfig, T5ForTokenClassification, AutoModel, AutoTokenizer, AutoModelForTokenClassification |
|
|
|
from .configuration_hhem_v2 import HHEMv2Config |
|
|
|
class HHEMv2Model(PreTrainedModel): |
|
config_class = HHEMv2Config |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HHEMv2ForSequenceClassification(PreTrainedModel): |
|
config_class = HHEMv2Config |
|
|
|
def __init__(self, config=HHEMv2Config()): |
|
super().__init__(config) |
|
self.t5 = T5ForTokenClassification( |
|
AutoConfig.from_pretrained(config.foundation) |
|
) |
|
self.prompt = config.prompt |
|
self.tokenzier = AutoTokenizer.from_pretrained(config.foundation) |
|
|
|
def populate(self, model: AutoModel): |
|
"""Initiate the model with the pretrained model |
|
|
|
This method should only be called by Vectara employee who prepares the model for publishing. Users do not need to call this method. |
|
|
|
""" |
|
self.t5 = model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, **kwargs): |
|
self.t5.eval() |
|
with torch.no_grad(): |
|
outputs = self.t5(**kwargs) |
|
logits = outputs.logits |
|
logits = logits[:, 0, :] |
|
outputs.logits = logits |
|
return outputs |
|
|
|
|
|
def predict(self, text_pairs): |
|
tokenizer = self.tokenzier |
|
pair_dict = [{'text1': pair[0], 'text2': pair[1]} for pair in text_pairs] |
|
inputs = tokenizer( |
|
[self.prompt.format(**pair) for pair in pair_dict], return_tensors='pt', padding=True).to(self.t5.device) |
|
self.t5.eval() |
|
with torch.no_grad(): |
|
outputs = self.t5(**inputs) |
|
logits = outputs.logits |
|
logits = logits[:, 0, :] |
|
transformed_probs = torch.softmax(logits, dim=-1) |
|
raw_scores = transformed_probs[:, 1] |
|
return raw_scores |
|
|
|
|
|
|