Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from transformers import ElectraPreTrainedModel, ElectraModel, AutoTokenizer | |
class ElectraReranker(ElectraPreTrainedModel): | |
""" | |
Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level. | |
This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly. | |
""" | |
_keys_to_ignore_on_load_unexpected = [r"cls"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.electra = ElectraModel(config) | |
self.linear = nn.Linear(config.hidden_size, 1) | |
self.raw_tokenizer = AutoTokenizer.from_pretrained('google/electra-large-discriminator') | |
self.init_weights() | |
def forward(self, encoding): | |
outputs = self.electra(encoding.input_ids, | |
attention_mask=encoding.attention_mask, | |
token_type_ids=encoding.token_type_ids)[0] | |
scores = self.linear(outputs[:, 0]).squeeze(-1) | |
return scores | |
def save(self, path): | |
assert not path.endswith('.dnn'), f"{path}: We reserve *.dnn names for the deprecated checkpoint format." | |
self.save_pretrained(path) | |
self.raw_tokenizer.save_pretrained(path) |