implemente a dequantize instruction in triton
def dequantize(input, scale, shift, nbit, dst_ty=float16, _builder=None):
input is nbit (8, 4, or 2) integers packed into int16s or int32s. scale and shift are float16 scalars. For example, for nbit = 8, input is of type int32. The instruction will convert [{int8_0, int8_1, int8_2, int8_3}, {int8_4, int8_5, int8_6, int8_7}, ...] (every four int8s packed into one int32) to scale * [int8_0, int8_1, int8_2, int8_3, int8_4, int8_5, int8_6, int8_7, ..., ] + shift in float16s. If the size of input is N, the size of output is 4 * N. Similarly for int4 and int2, eight int4s are packed into one int32 and eight int2s are packed into one int16. See test file https://github.com/yuguo68/triton/blob/dequantize_inst/python/test/unit/language/test_dequantize.py for code examples.
For our use case at Meta, the scale and shift are usually concatenated together with the quantized integers.
input in memory: scale(16 bits), shift (16bits), int8_0, int8_1, int8_2, ...,
output = scale * ([int8_0, int8_1, int8_2, ...]) + shift
similarly for int4 and int2.
We find that using existing triton instruction (bit mask, bitwise cast etc) to unpack the quantized integers is slow. Hence we decide to implement the algorithm similar to https://github.com/pytorch/FBGEMM/blob/6a59bb6621ba9ec7d650ccb78b78ea24d62a3904/fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh#L1566-L1619. We observe 2X speedup for Meta use case.
During the implementation, we find that it is critical to make the nano tile size (nts_
) https://github.com/openai/triton/blob/09cc2d454b442301e88d1df153214732bd8714d8/include/triton/codegen/analysis/layout.h#L232-L233 consistent between the input and output.
For example, for 8-bit quantization with input size of 64 (output size 256), the output layout
[0, 0, 0, 0, 1, 1, 1, 1, …, 31, 31, 31, 31, 0, 0, 0, 0, 1, 1, 1, 1, …, 31, 31, 31, 31]
does not work with input layout [0, 0, 1, 1,…, 31, 31]
, but work with input layout [0,1,…,31; 0,1,…,31]
. input layout [0, 0, 1, 1,…, 31, 31]
works with output layout
[0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, …, 31, 31, 31, 31, 31, 31, 31, 31]
. In general, supposing size(output)/size(input) = N, it requires nts_(output) = N * nts_(input)
.
Currently we use tl.multiple_of hints https://github.com/yuguo68/triton/blob/2b3ba853a6f641584b0fb4c4ed8e15b772f7549c/python/test/unit/language/test_dequantize.py#L32-L38 to enforce the nano tile size consistency. Would love to hear better ways to enforce it, for example, in populate_starting_multiple_dequantize
and populate_max_contiguous_dequantize
.
The PR author is new to Triton backend and would appreciate feedbacks/comments for improvement, especially for changes in lib/codegen/analysis/align.cc
, lib/codegen/analysis/axes.cc
. We are aware of the new MLIR backend, and would love to implement this instruction in the new backend as well. Comments on the feasibility in the new backend are appreciated. Thank you!
@ngimel @jianyuh @ajtulloch