test / configuration_tsp.py
RichardWang's picture
add model
09cc84d
raw
history blame contribute delete
No virus
1.66 kB
from transformers import PretrainedConfig
class TSPConfig(PretrainedConfig):
model_type = "tsp"
# Manually set mapping to auto models instead of using `register_for_auto_class`
# because it can only register the model class that used to execute `push_to_hub`
auto_map = {
"AutoModel": "modeling_tsp.TSPModel",
"AutoModelForPreTraining": "modeling_tsp.TSPModelForPreTraining",
"AutoModelForTokenClassification": "modeling_tsp.TSPModelForTokenClassification",
"AutoModelForSequenceClassification": "modeling_tsp.TSPModelForSequenceClassification",
"AutoModelForQuestionAnswering": "modeling_tsp.TSPModelForQuestionAnswering",
}
def __init__(
self,
embedding_size=128,
hidden_size=256,
num_hidden_layers=12,
num_attention_heads=4,
intermediate_size=1024,
dropout_prob=0.1,
max_sequence_length=128,
position_embedding_type="absolute",
pad_token_id=0,
vocab_size=30522,
**kwargs
):
assert hidden_size % num_attention_heads == 0
assert position_embedding_type in ["absolute", "rotary"]
self.vocab_size = vocab_size
self.embedding_size = embedding_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout_prob = dropout_prob
self.max_sequence_length = max_sequence_length
self.position_embedding_type = position_embedding_type
super().__init__(pad_token_id=pad_token_id, **kwargs)