feat(router): Dynamic batch sizing by njhill · Pull Request #210 · huggingface/text-generation-inference (original) (raw)
Motivation
Currently to avoid OOM you must set a "worst case" max batch size based on the desired max sequence length. This means that (a) throughput is unnecessarily limited when there are many shorter sequences and (b) you have to be pretty conservative about the max context length offered.
These changes introduce a maximum batch "weight" parameter which in the flash attention case corresponds to a maximum total number of tokens in the batch. The idea is that this is roughly proportional to the memory requirement.
- Batches are filled taking both the input lengths and max new tokens of existing and new requests into account
- A "projection" is done when evaluating each next request in the queue for admission into the batch, to determine whether the max batch weight will ever be exceeded in future assuming the worst case of all requests running to their
max_new_tokensvalues - As long as they did not arrive too far behind, smaller requests can jump ahead of larger ones if the larger one doesn't fit in the batch but the later smaller one does
- You can optionally set a separate "max prefill weight" to limit how many tokens can be prefilled at once. This is to help avoid long delays where no tokens are produced.
If max_batch_weight is not set, it just infers this from the max_batch_size and max_total_tokens args. In this case it should behave roughly the same as it does now, so could hopefully be a "non breaking" change for existing configurations
It turns out to be simpler to configure for a particular model/GPU. The precise values for max_batch_size and max_sequence_length no longer matter much, they can both be set quite high. You just need to determine one number (the max weight / total tokens), which is easy to do with minimal experimentation.
We have been using this successfully for a while now and it means we can support a much higher throughput / volume of users with the same hardware while offering larger context lengths. For example, we have a deployment of GPT-NeoX 20B on one 80GB A100 with the max batch size set to 256 and the max sequence length (max_total_tokens) set to 8192. The actual batch size flexes automatically as needed. Our max_batch_weight setting for this is 10k.
Details/caveats
- I have ported this from my internal fork and not yet tested this branch
I've only included the implementation for the flash attention case so far. The additions to generalize to the regular attention case aren't very big (we run non flash-attention models with this too), but I thought this was probably complicated enough to start with. It will need to support general case of course before actually being included.- Some of the logic (e.g. related to the extra fields in the queue state) is to cut down on the overhead of repeated analysis - most calls to
next_batchshould return immediately before getting into the more complex logic. - Since the queue's
next_batchfunction now takes the current entries map instead of a min and max, the tests inqueue.rswould need updating, so I just removed them for now. - Though it should be fully-functional please consider it still wip - if you are interested, more cleanup/rework can be done