|
from transformers import BertForMaskedLM |
|
import torch.nn as nn |
|
from RBFLayer import RBFLayer |
|
|
|
class CustomBertForMaskedLM(BertForMaskedLM): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
|
|
for i, layer in enumerate(self.bert.encoder.layer): |
|
in_features = 768 |
|
intermediate_features = 3072 |
|
|
|
|
|
layer.intermediate.dense = RBFLayer( |
|
in_features_dim=in_features, |
|
num_kernels=2, |
|
out_features_dim=intermediate_features, |
|
radial_function=gaussian_rbf, |
|
norm_function=euclidean_norm |
|
) |
|
|
|
|
|
layer.output.dense = RBFLayer( |
|
in_features_dim=intermediate_features, |
|
num_kernels=2, |
|
out_features_dim=in_features, |
|
radial_function=gaussian_rbf, |
|
norm_function=euclidean_norm |
|
) |
|
|
|
|
|
def gaussian_rbf(x): |
|
return torch.exp(-x**2) |
|
|
|
def euclidean_norm(x): |
|
return torch.norm(x, p=2, dim=-1) |
|
|