feat(router): Dynamic batch sizing by njhill · Pull Request #210 · huggingface/text-generation-inference (original) (raw)

Motivation

Currently to avoid OOM you must set a "worst case" max batch size based on the desired max sequence length. This means that (a) throughput is unnecessarily limited when there are many shorter sequences and (b) you have to be pretty conservative about the max context length offered.

These changes introduce a maximum batch "weight" parameter which in the flash attention case corresponds to a maximum total number of tokens in the batch. The idea is that this is roughly proportional to the memory requirement.

If max_batch_weight is not set, it just infers this from the max_batch_size and max_total_tokens args. In this case it should behave roughly the same as it does now, so could hopefully be a "non breaking" change for existing configurations

It turns out to be simpler to configure for a particular model/GPU. The precise values for max_batch_size and max_sequence_length no longer matter much, they can both be set quite high. You just need to determine one number (the max weight / total tokens), which is easy to do with minimal experimentation.

We have been using this successfully for a while now and it means we can support a much higher throughput / volume of users with the same hardware while offering larger context lengths. For example, we have a deployment of GPT-NeoX 20B on one 80GB A100 with the max batch size set to 256 and the max sequence length (max_total_tokens) set to 8192. The actual batch size flexes automatically as needed. Our max_batch_weight setting for this is 10k.

Details/caveats