File size: 1,396 Bytes
046d995
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from transformers import BertForMaskedLM
import torch.nn as nn
from RBFLayer import RBFLayer  # Assuming RBFLayer is your custom RBF implementation

class CustomBertForMaskedLM(BertForMaskedLM):
    def __init__(self, config):
        super().__init__(config)

        # Replace the feedforward MLP layers with RBF layers in BERT's encoder
        for i, layer in enumerate(self.bert.encoder.layer):
            in_features = 768
            intermediate_features = 3072
            
            # Replace the intermediate dense layer (768 -> 3072) with RBF
            layer.intermediate.dense = RBFLayer(
                in_features_dim=in_features,
                num_kernels=2,  # Number of kernels in the RBF layer
                out_features_dim=intermediate_features,
                radial_function=gaussian_rbf,
                norm_function=euclidean_norm
            )
            
            # Replace the output dense layer (3072 -> 768) with RBF
            layer.output.dense = RBFLayer(
                in_features_dim=intermediate_features,
                num_kernels=2,
                out_features_dim=in_features,
                radial_function=gaussian_rbf,
                norm_function=euclidean_norm
            )

# Define radial basis and norm functions
def gaussian_rbf(x):
    return torch.exp(-x**2)

def euclidean_norm(x):
    return torch.norm(x, p=2, dim=-1)