from transformers import DebertaV2Config | |
class MultiHeadDebertaV2Config(DebertaV2Config): | |
model_type = "multi-head-deberta-for-sequence-classification" | |
def __init__(self, num_heads=5, **kwargs): | |
self.num_heads = num_heads | |
super().__init__(**kwargs) |