sdxl-quant-int8 / math_model.py
GiusFra's picture
Upload math_model.py with huggingface_hub
d5dfd96 verified
raw
history blame
3.39 kB
import torch.nn as nn
import torch
def quantize(tensor, scale, zero_point, is_asym=False):
if is_asym:
clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.)
else:
clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.)
quant_tensor = torch.clamp(torch.round(tensor/scale), clamp_min, clamp_max) + zero_point
return quant_tensor
def dequantize(tensor, scale, zero_point):
return (tensor - zero_point) * scale
class QuantLinear(nn.Module):
def __init__(self, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.linear = nn.Linear(128, 128)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
self.register_buffer('weight_scale', weight_scale)
self.register_buffer('weight_zp', weight_zp)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
def forward(self, x):
scaled_x = x * self.mul_factor
quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias)
return out
class QuantConv2d(nn.Module):
def __init__(self, quant_param):
super().__init__()
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape'])
self.register_buffer('mul_factor', mul_factor)
self.conv2d = nn.Conv2d(128, 128, 3)
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape'])
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape'])
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape'])
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape'])
self.register_buffer('weight_scale', weight_scale)
self.register_buffer('weight_zp', weight_zp)
self.register_buffer('input_scale', input_scale)
self.register_buffer('input_zp', input_zp)
def forward(self, x):
scaled_x = x * self.mul_factor
quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True)
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False)
dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp)
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp)
out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias)
return out