Llama3.1 and kv_cache quantization by HDCharles · Pull Request #738 · pytorch/ao (original) (raw)
this PR has support for llama 3.1 and some improvements to kv_cache quantization and general peak memory performance for llama
high level, we can now do inference with 130k context length in 18.9 GB peak memory if we apply kv cache quantization, linear causal mask and int4 weight-only quantization
summary of changes
- add 3.1 support for llama
- change quantized_kv_cache init so it doesn't create a full precision peak: see below
- reorder causal mask init: see below
- add option for linear causal mask: see below
- add option for cache_size: the default generate.py behavior requires you do generate 32k tokens if you want to haev a size 32k kv_cache/causal_mask, the cache_size option lets you simply set the cache size but generate a smaller number of tokens to make it easier to benchmark
- add option to generate memory profile: used to generate the images below
context length (tokens) | normal peak (GB) | kv_quant peak (GB) | kv quant+causal fix peak (GB) |
---|---|---|---|
8192 | 17.86 | 17.52 | 17.47 |
16384 | 19.81 | 18.75 | 18.48 |
32768 | 23.83 | 21.72 | 20.64 |
65536 | 33.5 | 29.54 | 25.24 |
131072 | 59.27 | 52.62 | 34.18 |
Change to quantized kv_cache init
The first change is avoiding creating of the full precision kv_cache, previously we would initialize the kv_cache and then convert it to the quantized form as seen in this memory profile:
those horizontal lines from ~16.1 GB to 16.6GB are the normal kv_cache and you can see them being deallocated on the right side of the image as the quantized kv_cache's are instantiated. This created an unnecessary increase in peak memory any time the initialization is the peak (which was the case for very long context lengths).
Change to causal mask
This is a memory profile for 32k context length without kv_cache quantization or any other changes, compare to one with kv_cache quantization
those horizontal bands that run from 16GB to 20.5 GB on the top image and 18.5 on the bottom, are the kv_cache. With quantization its 2 GB smaller which shows the technique is performing as expected, however there is a large blue (top) or (green) blob (with a spike on the left side) that appears in the memory profile, this is the causal mask.
Normally the causal mask is handled by creating a (token length x token length) tensor of ones, then creating a copy that is lower triangular and taking slices from it throughout the model runs. Notice the sharp peak right at the start, this occurs because in order to copy a tensor of ones into a lower triangular matrix requires you to hold 2 instances of this in memory for a moment, thereby doubling its impact in addition to taking up O(context_length^2) memory. The doubling issue was solved by creating the causal mask before the kv_cache, if done like that, the momentary doubling spike doesn't affect the peak memory since the kv_cache will be higher than the spike.
Although the earlier instantiation of the causal mask helps (red blob now), it is still taking up a ton of space, especially at even higher context lengths, which is eating into the gains we expect from kv_cache quantization. Why do we need to actually store the causal mask though? A slice of the causal mask is essentually just a sequence of n ones in a row and then
context_length-n zeros in a row where n is the current token being generated. Each slice differs from the next only by a single value. We can just store the slice and update it each iteration instead. Result:
tests:
see benchmarks.sh
the 18.9 GB number came from
python generate.py --checkpoint_path CHECKPOINTPATH/CHECKPOINT_PATH/CHECKPOINTPATH/MODEL_REPO/model.pth --compile --write_result benchmark_results.txt --cache_size 131072 --kv_cache_quantization --linear_causal_mask --quantization int4wo-64