Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer | |
from colbert.utils.utils import torch_load_dnn | |
class HF_ColBERT(BertPreTrainedModel): | |
""" | |
Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level. | |
This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly. | |
""" | |
_keys_to_ignore_on_load_unexpected = [r"cls"] | |
def __init__(self, config, colbert_config): | |
super().__init__(config) | |
self.dim = colbert_config.dim | |
self.bert = BertModel(config) | |
self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False) | |
# if colbert_config.relu: | |
# self.score_scaler = nn.Linear(1, 1) | |
self.init_weights() | |
# if colbert_config.relu: | |
# self.score_scaler.weight.data.fill_(1.0) | |
# self.score_scaler.bias.data.fill_(-8.0) | |
def from_pretrained(cls, name_or_path, colbert_config): | |
if name_or_path.endswith('.dnn'): | |
dnn = torch_load_dnn(name_or_path) | |
base = dnn.get('arguments', {}).get('model', 'bert-base-uncased') | |
obj = super().from_pretrained(base, state_dict=dnn['model_state_dict'], colbert_config=colbert_config) | |
obj.base = base | |
return obj | |
obj = super().from_pretrained(name_or_path, colbert_config=colbert_config) | |
obj.base = name_or_path | |
return obj | |
def raw_tokenizer_from_pretrained(name_or_path): | |
if name_or_path.endswith('.dnn'): | |
dnn = torch_load_dnn(name_or_path) | |
base = dnn.get('arguments', {}).get('model', 'bert-base-uncased') | |
obj = AutoTokenizer.from_pretrained(base) | |
obj.base = base | |
return obj | |
obj = AutoTokenizer.from_pretrained(name_or_path) | |
obj.base = name_or_path | |
return obj | |
""" | |
TODO: It's easy to write a class generator that takes "name_or_path" and loads AutoConfig to check the Architecture's | |
name, finds that name's *PreTrainedModel and *Model in dir(transformers), and then basically repeats the above. | |
It's easy for the BaseColBERT class to instantiate things from there. | |
""" | |