|
import torch.nn as nn |
|
import torch |
|
|
|
def quantize(tensor, scale, zero_point, is_asym=False): |
|
if is_asym: |
|
clamp_min, clamp_max = torch.tensor(0.), torch.tensor(255.) |
|
else: |
|
clamp_min, clamp_max = torch.tensor(-128.), torch.tensor(127.) |
|
quant_tensor = torch.clamp(torch.round(tensor/scale), clamp_min, clamp_max) + zero_point |
|
return quant_tensor |
|
|
|
def dequantize(tensor, scale, zero_point): |
|
return (tensor - zero_point) * scale |
|
|
|
|
|
class QuantLinear(nn.Module): |
|
def __init__(self, quant_param): |
|
super().__init__() |
|
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) |
|
self.register_buffer('mul_factor', mul_factor) |
|
self.linear = nn.Linear(128, 128) |
|
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) |
|
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape']) |
|
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) |
|
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) |
|
self.register_buffer('weight_scale', weight_scale) |
|
self.register_buffer('weight_zp', weight_zp) |
|
self.register_buffer('input_scale', input_scale) |
|
self.register_buffer('input_zp', input_zp) |
|
|
|
def forward(self, x): |
|
scaled_x = x * self.mul_factor |
|
quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True) |
|
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False) |
|
dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp) |
|
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp) |
|
out = torch.nn.functional.linear(dequantized_input, dequantized_weight, self.linear.bias) |
|
return out |
|
|
|
class QuantConv2d(nn.Module): |
|
def __init__(self, quant_param): |
|
super().__init__() |
|
mul_factor = torch.tensor(quant_param['smoothquant_mul']).view(quant_param['smoothquant_mul_shape']) |
|
self.register_buffer('mul_factor', mul_factor) |
|
self.conv2d = nn.Conv2d(128, 128, 3) |
|
weight_scale = torch.tensor(quant_param['weight_scale']).view(quant_param['weight_scale_shape']) |
|
weight_zp = torch.tensor(quant_param['weight_zp']).view(quant_param['weight_zp_shape']) |
|
input_scale = torch.tensor(quant_param['input_scale']).view(quant_param['input_scale_shape']) |
|
input_zp = torch.tensor(quant_param['input_zp']).view(quant_param['input_zp_shape']) |
|
self.register_buffer('weight_scale', weight_scale) |
|
self.register_buffer('weight_zp', weight_zp) |
|
self.register_buffer('input_scale', input_scale) |
|
self.register_buffer('input_zp', input_zp) |
|
|
|
def forward(self, x): |
|
scaled_x = x * self.mul_factor |
|
quant_weight = quantize(self.linear.weight, self.weight_scale, self.weight_zp, is_asym=True) |
|
quant_input = quantize(scaled_x, self.input_scale, self.input_zp, is_asym=False) |
|
dequantized_weight = dequantize(quant_weight, self.weight_scale, self.weight_zp) |
|
dequantized_input = dequantize(quant_input, self.input_scale, self.input_zp) |
|
out = torch.nn.functional.conv2d(dequantized_input, dequantized_weight, self.conv2d.bias) |
|
return out |
|
|