import torch | |
import torch.nn as nn | |
from .gptq import * | |
from .modelutils import * | |
from .quant import * | |
from transformers import BloomForCausalLM as LM | |
class SakuraForCausalLM(LM): | |
def __init__(self,*args,**kwargs): | |
def noop(*args, **kwargs): | |
pass | |
torch.nn.init.kaiming_uniform_ = noop | |
torch.nn.init.uniform_ = noop | |
torch.nn.init.normal_ = noop | |
torch.set_default_dtype(torch.half) | |
transformers.modeling_utils._init_weights = False | |
torch.set_default_dtype(torch.half) | |
super().__init__(*args,**kwargs) | |
torch.set_default_dtype(torch.float) | |
self.eval() | |
layers = find_layers(self) | |
for name in ['lm_head']: | |
if name in layers: | |
del layers[name] | |
make_quant(self, layers, 8, 128) |