chatsakura-3b-int8 / modeling_chatsakura.py
chinoll's picture
Update modeling_chatsakura.py
3310062
raw
history blame contribute delete
817 Bytes
import torch
import torch.nn as nn
from .gptq import *
from .modelutils import *
from .quant import *
from transformers import BloomForCausalLM as LM
class SakuraForCausalLM(LM):
def __init__(self,*args,**kwargs):
def noop(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = noop
torch.nn.init.uniform_ = noop
torch.nn.init.normal_ = noop
torch.set_default_dtype(torch.half)
transformers.modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
super().__init__(*args,**kwargs)
torch.set_default_dtype(torch.float)
self.eval()
layers = find_layers(self)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant(self, layers, 8, 128)