[FEAT] Add custom CUDA tinygemm
unpacker by jeromeku · Pull Request #415 · pytorch/ao (original) (raw)
Helps explain why
tinygemm
kernel is able to use a single fma to dequantize (usinginteger
zero-point would require asub
thenmul
unless zeros are stored as scales * zeros, which is not the case).
yes, the motivation for tinygemm
to have zero_point in floating point domain is exactly to use a single fma
(I talked to Jeff about this).
This is good to know as the original motivation for this PR was to help
answer.ai / hqq
developers who are usingtinygemm
as a quantized matmul backend. However, I believehqq
is using theinteger
zeropoint derivation (but keeping the zeropoint in the original floating-point dtype), which will result in incorrect results when usingtinygemm
kernel, which is dequantizing based on thefloating-point
zeropoint calculation.
for hqq
yeah we need to make sure this detail is correct since they are using tinygemm kernels. cc @HDCharles @mobicham
What is the mathematical derivation of the
float
dequantization method vs the more commoninteger
quantization scheme? Are there any papers / blogs that explain the reasoning for this difference?
I'm not aware of any formal papers or blogs. so the differences are shown in our quant_primitive ops in these two flags:
preserve_zero (bool): a flag to indicate whether we need zero to be exactly |
---|
representable or not, this is typically required for ops that needs zero padding, like convolution |
it's less important for ops that doesn't have zero padding in the op itself, like linear. |
For example, given a floating point Tensor [1.2, 0.1, 3.0, 4.0, 0.4, 0], if `preserve_zero` is True, |
we'll make sure there is a integer value corresponding to the floating point 0, e.g. [-3, -8, 3, 7, -7, -8], 0 will be mapped to `-8` without loss. But if `preserve_zero` is not True, there won't be such |
gurantee. |
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point |
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float |
if zero_point is in integer domain, zero point is added to the quantized integer value during |
quantization |
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) |
value during quantization |
default is ZeroPointDomain.INT |
traditional integer quantization:
- preserve_zero is True: this is because traditionally we use quantization on conv, and it can have zero_padding, so there is a domain specific requirement of floating point zero has to be exactly representable: https://github.com/google/gemmlowp/blob/master/doc/quantization.md#domain-specific-constraint-the-real-value-0-must-be-exactly-representable
- zero_point is in integer domain
This is probably for static quantization where there are hardwares that only support integer compute
tinygemm:
- preserve_zero is False because mainly we care about linear, this will also help improve accuracy in some cases since we don't need to always include zero during quantization
- zero_point is in floating point domain
this is because offma
I think