File size: 3,932 Bytes
ad5b231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os

os.environ["PATH"] = "/usr/local/cuda/bin:" + os.environ["PATH"]
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

import bitblas
import torch
import torch.nn as nn
from transformers import BertConfig, BertModel, PreTrainedModel, PretrainedConfig,AutoModel,AutoConfig,BertPreTrainedModel

class bitlinear(bitblas.Linear):
    def __init__(
            self,
            in_features: int,
            out_features: int,
            bias: bool = False,
            A_dtype: str = "float16",
            W_dtype: str = "int2",
            accum_dtype: str = "float16",
            out_dtype: str = "float16",
            group_size: int = -1,
            with_scaling: bool = False,
            with_zeros: bool = False,
            zeros_mode: str = None,
            opt_M: list = [1, 16, 32, 64, 128, 256, 512],
            fast_decoding: bool = True,
            alpha: torch.dtype = torch.float16,
            b:torch.Tensor=None
    ):
        super().__init__(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
            A_dtype=A_dtype,
            W_dtype=W_dtype,
            accum_dtype=accum_dtype,
            out_dtype=out_dtype,
            group_size=group_size,
            with_scaling=with_scaling,
            with_zeros=with_zeros,
            zeros_mode=zeros_mode,
            opt_M=opt_M,
            fast_decoding=fast_decoding,
        )
        self.alpha = nn.Parameter(alpha,requires_grad=False)
        self.b = nn.Parameter(b,requires_grad=False)

    def forward(self, A: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
        out = super().forward(A, out)
        out *= self.alpha
        if self.b is not None:
            out += self.b.view(1, -1).expand_as(out)
        return out.to(torch.float32)


class TernaryBertConfig(BertConfig):
    model_type = "ternarybert"
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
       

class TernaryBert(PreTrainedModel):
    #config_class = TernaryBertConfig
    config_class = BertConfig

    def __init__(self, config):
        super().__init__(config)
        self.bert = BertModel(config)
        self.replace_linear2bitblas(self.bert)
    
    #def forward(self, input_ids, attention_mask=None,token_type_ids=None):
    #    return self.bert(input_ids, attention_mask=attention_mask,token_type_ids=token_type_ids)
    def forward(self, **kwargs):
        return self.bert(**kwargs)
    
    def convert_to_bitlinear(self,layer):
        bitlayer = bitlinear(
            in_features=layer.in_features,
            out_features=layer.out_features,
            bias=False,
            A_dtype="float16",  # activation A dtype
            W_dtype="int2",  # weight W dtype
            accum_dtype="float16",  # accumulation dtype
            out_dtype="float16",  # output dtype
            # configs for weight only quantization
            group_size=-1,  # setting for grouped quantization
            with_scaling=False,  # setting for scaling factor
            with_zeros=False,  # setting for zeros
            zeros_mode=None,  # setting for how to calculating zeros
            # Target optimization var for dynamic symbolic.
            # For detailed information please checkout docs/PythonAPI.md
            # By default, the optimization var is [1, 16, 32, 64, 128, 256, 512]
            opt_M=[1, 16, 32, 64, 128, 256, 512],
            fast_decoding=True,
            alpha=torch.tensor(1.).to(torch.float16),
            b = layer.bias.data.to(torch.float16)
        )
        return bitlayer

    def replace_linear2bitblas(self,model):
        for name, module in model.named_children():
            if isinstance(module, nn.Linear):
                new_layer = self.convert_to_bitlinear(module)
                setattr(model, name, new_layer)
            elif len(list(module.children())) > 0:
                self.replace_linear2bitblas(module)