[FEAT] Add custom CUDA tinygemm unpacker by jeromeku · Pull Request #415 · pytorch/ao (original) (raw)

Helps explain why tinygemm kernel is able to use a single fma to dequantize (using integer zero-point would require a sub then mul unless zeros are stored as scales * zeros, which is not the case).

yes, the motivation for tinygemm to have zero_point in floating point domain is exactly to use a single fma (I talked to Jeff about this).

This is good to know as the original motivation for this PR was to help answer.ai / hqq developers who are using tinygemm as a quantized matmul backend. However, I believe hqq is using the integer zeropoint derivation (but keeping the zeropoint in the original floating-point dtype), which will result in incorrect results when using tinygemm kernel, which is dequantizing based on the floating-point zeropoint calculation.

for hqq yeah we need to make sure this detail is correct since they are using tinygemm kernels. cc @HDCharles @mobicham

What is the mathematical derivation of the float dequantization method vs the more common integer quantization scheme? Are there any papers / blogs that explain the reasoning for this difference?

I'm not aware of any formal papers or blogs. so the differences are shown in our quant_primitive ops in these two flags:

preserve_zero (bool): a flag to indicate whether we need zero to be exactly
representable or not, this is typically required for ops that needs zero padding, like convolution
it's less important for ops that doesn't have zero padding in the op itself, like linear.
For example, given a floating point Tensor [1.2, 0.1, 3.0, 4.0, 0.4, 0], if `preserve_zero` is True,
we'll make sure there is a integer value corresponding to the floating point 0, e.g. [-3, -8, 3, 7, -7, -8], 0 will be mapped to `-8` without loss. But if `preserve_zero` is not True, there won't be such
gurantee.
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT

traditional integer quantization:

  1. preserve_zero is True: this is because traditionally we use quantization on conv, and it can have zero_padding, so there is a domain specific requirement of floating point zero has to be exactly representable: https://github.com/google/gemmlowp/blob/master/doc/quantization.md#domain-specific-constraint-the-real-value-0-must-be-exactly-representable
  2. zero_point is in integer domain
    This is probably for static quantization where there are hardwares that only support integer compute

tinygemm:

  1. preserve_zero is False because mainly we care about linear, this will also help improve accuracy in some cases since we don't need to always include zero during quantization
  2. zero_point is in floating point domain
    this is because of fma I think