Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
from . import utils | |
from torch import nn | |
class CustomModule(nn.Module): | |
"""A simple two layer type I MLP structure. | |
""" | |
def __init__(self, w1_weight=None, w2_bias=None, w2_weight=None, act='gelu'): | |
super().__init__() | |
self.linear1 = nn.Linear(w1_weight.shape[1], w1_weight.shape[0]) | |
self.linear2 = nn.Linear(w1_weight.shape[0], w1_weight.shape[1]) | |
self.act = utils.load_activation(act) | |
self.linear1.weight = nn.Parameter(w1_weight.float()) | |
self.linear1.bias = nn.Parameter(w2_bias.float()) | |
self.linear2.weight = nn.Parameter(w2_weight.T.float()) | |
self.linear2.bias = nn.Parameter(torch.zeros_like(self.linear2.bias)) | |
def forward(self, x): | |
return self.linear2(self.act(self.linear1(x))) | |
class CustomNormModule(nn.Module): | |
"""A simple two layer type I MLP structure. | |
""" | |
def __init__(self, | |
w1_weight=None, | |
w1_bias = None, | |
w2_weight=None, | |
centroid=None, | |
norm_weight=None, | |
norm_bias=None, | |
add_norm = True, | |
return_w1 = False, | |
act='relu' | |
): | |
super().__init__() | |
self.linear1 = nn.Linear(w1_weight.shape[1], w1_weight.shape[0]) | |
self.linear2 = nn.Linear(w1_weight.shape[0], w1_weight.shape[1]) | |
self.act = utils.load_activation(act) | |
self.centroid = centroid | |
self.norm_weight = norm_weight | |
self.norm_bias = norm_bias | |
if self.norm_bias is None: self.norm_bias = 0 | |
self.add_norm = add_norm | |
self.return_w1 = return_w1 | |
self.linear1.weight = nn.Parameter(w1_weight) | |
if w1_bias is not None: self.linear1.bias = nn.Parameter(w1_bias) | |
self.linear2.weight = nn.Parameter(w2_weight.T) | |
self.linear2.bias = nn.Parameter(torch.zeros_like(self.linear2.bias).to(w1_weight.dtype).cuda()) | |
def forward(self, x): | |
# normalisation (part I) | |
x = (x - self.norm_bias) / self.norm_weight / np.sqrt(self.centroid.shape[0]) | |
x = x - self.centroid | |
if self.add_norm: | |
x = x / torch.norm(x, dim=-1)[:,:,None] | |
w1_output = self.act(self.linear1(x)) | |
if self.return_w1: | |
return w1_output | |
w2_output = self.linear2(w1_output) | |
return w2_output | |
class ModifiedMLP(nn.Module): | |
"""Modifed MLP structure | |
""" | |
def __init__(self, original_mlp, custom_module): | |
super(ModifiedMLP, self).__init__() | |
self.original_mlp = original_mlp | |
self.custom_module = custom_module | |
def forward(self, x): | |
# Get the output from the original MLP | |
o = self.original_mlp(x) | |
# Pass the output through the CustomModule | |
return o + self.custom_module(x) | |
class ModifieMambadMLP(nn.Module): | |
"""Modifed MLP structure | |
""" | |
def __init__(self, original_mlp, custom_module): | |
super(ModifieMambadMLP, self).__init__() | |
self.original_mlp = original_mlp | |
self.custom_module = custom_module | |
def forward(self, x, cache_params=None): | |
# Get the output from the original MLP | |
o = self.original_mlp(x, cache_params=cache_params) | |
# Pass the output through the CustomModule | |
return o + self.custom_module(x) | |