import math import torch import torch.nn as nn def quantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): dtype = tensor.dtype clamp_min, clamp_max = torch.tensor(-240., dtype=dtype), torch.tensor(240., dtype=dtype) quant_tensor = torch.clamp((tensor/scale), clamp_min, clamp_max).to(torch.float8_e4m3fnuz).to(dtype) return quant_tensor def dequantize_fp8(tensor: torch.Tensor, scale: torch.Tensor): return tensor * scale # Based on: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html def qdq_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): query = dequantize_fp8(quantize_fp8(query, query_scale), query_scale) key = dequantize_fp8(quantize_fp8(key, key_scale), key_scale) value = dequantize_fp8(quantize_fp8(value, value_scale), value_scale) L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale attn_bias = torch.zeros(L, S, dtype=query.dtype) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = query @ key.transpose(-2, -1) * scale_factor attn_weight += attn_bias attn_weight = dequantize_fp8(quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale), softmax_scale) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return attn_weight @ value def qop_scaled_dot_product_attention(query, key, value, query_scale, key_scale, value_scale, softmax_scale, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): query = quantize_fp8(query, query_scale) key = quantize_fp8(key, key_scale) value = quantize_fp8(value, value_scale) # Your quantized kernel starts here L, S = query.size(-2), key.size(-2) scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale scale_factor *= (query_scale * key_scale) attn_bias = torch.zeros(L, S, dtype=query.dtype) if is_causal: assert attn_mask is None temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) if attn_mask is not None: if attn_mask.dtype == torch.bool: attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) else: attn_bias += attn_mask attn_weight = (query @ key.transpose(-2, -1)) * scale_factor # or, attn_weight = dequantize_fp8(query @ key.transpose(-2, -1), scale_factor) attn_weight += attn_bias attn_weight = quantize_fp8(torch.softmax(attn_weight, dim=-1), softmax_scale) attn_weight = torch.dropout(attn_weight, dropout_p, train=True) return (attn_weight @ value) * (softmax_scale * value_scale) # or, return dequantize_fp8(attn_weight @ value, softmax_scale * value_scale) # Module that implements `torch.nn.functional.scaled_dot_product_attention` class QuantScaledDotProductAttention(nn.Module): def __init__(self, quant_param): super().__init__() q_scale = torch.tensor(quant_param['out_q']['act_scale']).view(quant_param['out_q']['act_scale_shape']) k_scale = torch.tensor(quant_param['out_k']['act_scale']).view(quant_param['out_k']['act_scale_shape']) v_scale = torch.tensor(quant_param['out_v']['act_scale']).view(quant_param['out_v']['act_scale_shape']) sm_scale = torch.tensor(quant_param['output_softmax_quant']['act_scale']).view(quant_param['output_softmax_quant']['act_scale_shape']) # Not used, included in model. Kept because we use the zp_dtype as a type hint #q_zp = torch.tensor(quant_param['out_q']['act_zp']).view(quant_param['out_q']['act_zp_shape']) #k_zp = torch.tensor(quant_param['out_k']['act_zp']).view(quant_param['out_k']['act_zp_shape']) #v_zp = torch.tensor(quant_param['out_v']['act_zp']).view(quant_param['out_v']['act_zp_shape']) assert quant_param['out_q']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"Q Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_q']['act_zp_dtype']}" assert quant_param['out_k']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"K Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_k']['act_zp_dtype']}" assert quant_param['out_v']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"V Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['out_v']['act_zp_dtype']}" assert quant_param['output_softmax_quant']['act_zp_dtype'] == 'torch.float8_e4m3fnuz', f"SoftMax Zero-Point dtype should be 'torch.float8_e4m3fnuz', found: {quant_param['output_softmax_quant']['act_zp_dtype']}" self.register_buffer('q_scale', q_scale) self.register_buffer('k_scale', k_scale) self.register_buffer('v_scale', v_scale) self.register_buffer('sm_scale', sm_scale) # I.e., "fake quantization" def qdq_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): return qdq_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale) # Accelerated version def qop_forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None): return qop_scaled_dot_product_attention(query, key, value, self.q_scale, self.k_scale, self.v_scale, self.sm_scale, attn_mask, dropout_p, is_causal, scale) def forward(self, query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None, qop=False): if qop: return self.qop_forward(query, key, value, attn_mask, dropout_p, is_causal, scale) else: return self.qdq_forward(query, key, value, attn_mask, dropout_p, is_causal, scale)