File size: 3,932 Bytes
ad5b231 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import os
os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ["PATH"]
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
import bitblas
import torch
import torch.nn as nn
from transformers import BertConfig, BertModel, PreTrainedModel, PretrainedConfig,AutoModel,AutoConfig,BertPreTrainedModel
class bitlinear(bitblas.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = False,
A_dtype: str = "float16",
W_dtype: str = "int2",
accum_dtype: str = "float16",
out_dtype: str = "float16",
group_size: int = -1,
with_scaling: bool = False,
with_zeros: bool = False,
zeros_mode: str = None,
opt_M: list = [1, 16, 32, 64, 128, 256, 512],
fast_decoding: bool = True,
alpha: torch.dtype = torch.float16,
b:torch.Tensor=None
):
super().__init__(
in_features=in_features,
out_features=out_features,
bias=bias,
A_dtype=A_dtype,
W_dtype=W_dtype,
accum_dtype=accum_dtype,
out_dtype=out_dtype,
group_size=group_size,
with_scaling=with_scaling,
with_zeros=with_zeros,
zeros_mode=zeros_mode,
opt_M=opt_M,
fast_decoding=fast_decoding,
)
self.alpha = nn.Parameter(alpha,requires_grad=False)
self.b = nn.Parameter(b,requires_grad=False)
def forward(self, A: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
out = super().forward(A, out)
out *= self.alpha
if self.b is not None:
out += self.b.view(1, -1).expand_as(out)
return out.to(torch.float32)
class TernaryBertConfig(BertConfig):
model_type = "ternarybert"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class TernaryBert(PreTrainedModel):
#config_class = TernaryBertConfig
config_class = BertConfig
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.replace_linear2bitblas(self.bert)
#def forward(self, input_ids, attention_mask=None,token_type_ids=None):
# return self.bert(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids)
def forward(self, **kwargs):
return self.bert(**kwargs)
def convert_to_bitlinear(self,layer):
bitlayer = bitlinear(
in_features=layer.in_features,
out_features=layer.out_features,
bias=False,
A_dtype="float16", # activation A dtype
W_dtype="int2", # weight W dtype
accum_dtype="float16", # accumulation dtype
out_dtype="float16", # output dtype
# configs for weight only quantization
group_size=-1, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
# Target optimization var for dynamic symbolic.
# For detailed information please checkout docs/PythonAPI.md
# By default, the optimization var is [1, 16, 32, 64, 128, 256, 512]
opt_M=[1, 16, 32, 64, 128, 256, 512],
fast_decoding=True,
alpha=torch.tensor(1.).to(torch.float16),
b = layer.bias.data.to(torch.float16)
)
return bitlayer
def replace_linear2bitblas(self,model):
for name, module in model.named_children():
if isinstance(module, nn.Linear):
new_layer = self.convert_to_bitlinear(module)
setattr(model, name, new_layer)
elif len(list(module.children())) > 0:
self.replace_linear2bitblas(module)
|