File size: 1,378 Bytes
aff5ec5 be023c1 aff5ec5 be023c1 aff5ec5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
from transformers import PretrainedConfig, BertConfig
from typing import List
class VGCNConfig(BertConfig):
model_type = "vgcn"
def __init__(
self,
bert_model='readerbench/RoBERT-base',
gcn_adj_matrix: str ='',
max_seq_len: int = 256,
npmi_threshold: float = 0.2,
tf_threshold: float = 0.0,
vocab_type: str = "all",
gcn_embedding_dim: int = 32,
**kwargs,
):
if vocab_type not in ["all", "pmi", "tf"]:
raise ValueError(f"`vocab_type` must be 'all', 'pmi' or 'tf', got {vocab_type}.")
if max_seq_len < 1 or max_seq_len > 512:
raise ValueError(f"`max_seq_len` must be between 1 and 512, got {max_seq_len}.")
if npmi_threshold < 0.0 or npmi_threshold > 1.0:
raise ValueError(f"`npmi_threshold` must be between 0.0 and 1.0, got {npmi_threshold}.")
if tf_threshold < 0.0 or tf_threshold > 1.0:
raise ValueError(f"`tf_threshold` must be between 0.0 and 1.0, got {tf_threshold}.")
self.gcn_adj_matrix = gcn_adj_matrix
self.max_seq_len = max_seq_len
self.npmi_threshold = npmi_threshold
self.tf_threshold = tf_threshold
self.vocab_type = vocab_type
self.gcn_embedding_dim = gcn_embedding_dim
self.bert_model = bert_model
super().__init__(**kwargs) |