How to get best performance on DeepSeek-R1 in TensorRT-LLM — TensorRT-LLM (original) (raw)

NVIDIA has announced world-record DeepSeek-R1 inference performance at NVIDIA GTC 2025. A single NVIDIA DGX system with eight NVIDIA Blackwell GPUs can achieve over 250 tokens per second per user or a maximum throughput of over 30,000 tokens per second on the massive, state-of-the-art 671 billion parameter DeepSeek-R1 model. NVIDIA Blackwell Delivers World-Record DeepSeek-R1 Inference Performance

In this blog, we share the configurations and procedures about how to reproduce the number on both B200 and H200 with PyTorch workflow.

Table of Contents#

Prerequisites: Install TensorRT-LLM and download models#

This section can be skipped if you already have TensorRT-LLM installed and have already downloaded the DeepSeek R1 model checkpoint.

1. Download TensorRT-LLM#

You can also find more comprehensive instructions to install TensorRT-LLM in this TensorRT-LLM installation guide, refer to that guide for common issues if you encounter any here.

Prerequisites

apt-get update && apt-get -y install git git-lfs git lfs install

Replace with your actual path

YOUR_WORK_PATH=

Clone the TensorRT-LLM repository

cd $YOUR_WORK_PATH git clone https://github.com/NVIDIA/TensorRT-LLM.git cd TensorRT-LLM git submodule update --init --recursive git lfs pull

Note: Replace <*_PATH> to your actual path.

2. Download the DeepSeek R1 models#

For NVIDIA Blackwell GPUs, it’s recommended to use the FP4 quantized version of DeepSeek R1 to get the best performance. For NVIDIA Hopper GPUs, it’s recommended to use the FP8 version of the DeepSeek R1 model.

Replace with your actual path

YOUR_MODEL_PATH= cd $YOUR_MODEL_PATH

Download FP4 model for Blackwell GPUs

git clone https://huggingface.co/nvidia/DeepSeek-R1-FP4

Download FP8 model for Hopper GPUs

FP8 model also works for Blackwell, but FP4 has the best performance on Blackwell.

git clone https://huggingface.co/deepseek-ai/DeepSeek-R1

3. Build and run TensorRT-LLM container#

cd TensorRT-LLM make -C docker run LOCAL_USER=1 DOCKER_RUN_ARGS="-v YOURMODELPATH:YOUR_MODEL_PATH:YOURMODELPATH:YOUR_MODEL_PATH:ro -v YOURWORKPATH:YOUR_WORK_PATH:YOURWORKPATH:YOUR_WORK_PATH"

Here we set LOCAL_USER=1 argument to set up the local user instead of root account inside the container, you can remove it if running as root inside container is fine.

4. Compile and Install TensorRT-LLM#

Here we compile the source inside the container:

python3 ./scripts/build_wheel.py --trt_root /usr/local/tensorrt --benchmarks --cuda_architectures "90-real;100-real" --python_bindings --clean

You can set the cuda_architectures to “100-real” if targeting Blackwell only, and “90-real” to target Hopper only to save some build time.

Install and set environment variables:

pip install --user build/tensorrt_llm*.whl export PATH=${HOME}/.local/bin:${PATH} export PYTHONPATH=pwd

5. Optional: Tune GPU clocks#

sudo nvidia-smi -pm 0; sudo nvidia-smi -pm 1; sudo nvidia-smi boost-slider --vboost 4

The boost-slider option will tune the GPU clock and can get you slight perf increase, for B200 min-latency scenarios it’s about 8 TPS/USER. This is not a required step, it’s provided here to make sure the perf numbers in this doc can be reproduced more closely to our internal run.

6. Dataset preparation#

The trtllm-bench tool requires a dataset file to read prompts and output sequence length of each prompt. Format details of this dataset file can be seen in preparing a dataset.

For min-latency benchmarking, real dataset is required since the MTP accept rate is affected by the dataset thus affecting the performance. You can use your own dataset following the format described in the link above.

For the max-throughput benchmarking, synthetic dataset is enough to be representative, since it does not use MTP. The command to generate synthetic dataset will be attached to the max throughput section.

Reproducing steps#

This section provides the reproducing steps for NVIDIA Blackwell B200 and H200 GPUs, for both min-latency and max-throughput scenarios.

All the benchmarking is done by the trtllm-bench command line tool provided in the TensorRT-LLM installation, see TensorRT-LLM Benchmarking for details of this tool.

For brevity, we only provide the commands to reproduce the perf numbers without detailed explanation of the tools and options in this doc.

All these commands here are assumed to be running inside the container started by make -C docker run ... command mentioned in the Build and run TensorRT-LLM container section

B200 min-latency#

Our benchmark results are based on Batch = 1, ISL = 1K, OSL = 2K, num_requests = 10 from real dataset

To do the benchmark, run the following command:

YOUR_DATA_PATH=

cat >./extra-llm-api-config.yml<<EOF use_cuda_graph: true moe_backend: TRTLLM speculative_config: decoding_type: MTP num_nextn_predict_layers: 3 EOF

export TRTLLM_ENABLE_PDL=1

trtllm-bench --model nvidia/DeepSeek-R1-FP4
throughput
--dataset $YOUR_DATA_PATH
--backend pytorch
--num_requests 10
--concurrency 1
--max_batch_size 1
--tp 8
--ep 2
--extra_llm_api_options ./extra-llm-api-config.yml

Explanation:

Expected Results#

The perf can be different when using different datasets and different machines.

=========================================================== = PERFORMANCE OVERVIEW

Request Throughput (req/sec): 0.1341 Total Output Throughput (tokens/sec): 274.4168 Per User Output Throughput (tokens/sec/user): 274.7188 Per GPU Output Throughput (tokens/sec/gpu): 34.3021 Total Token Throughput (tokens/sec): 414.0461 Total Latency (ms): 74561.7520 Average request latency (ms): 7456.1219

B200 max-throughput#

Our benchmark results are based on Batch = 3072, ISL = 1K, OSL = 2K, num_requests = 49152 from synthetic dataset

Benchmark#

To do the benchmark, run the following command:

generate synthetic dataset

python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py
--stdout
--tokenizer nvidia/DeepSeek-R1-FP4
token-norm-dist
--input-mean 1024 --output-mean 2048
--input-stdev 0 --output-stdev 0
--num-requests 49152 > dataset.txt

YOUR_DATA_PATH=./dataset.txt

cat >./extra-llm-api-config.yml <<EOF use_cuda_graph: true cuda_graph_padding_enabled: true cuda_graph_batch_sizes:

trtllm-bench -m nvidia/DeepSeek-R1-FP4
throughput
--tp 8
--ep 8
--warmup 0
--dataset ${YOUR_DATA_PATH}
--backend pytorch
--max_batch_size 384
--max_num_tokens 1536
--num_requests 49152
--concurrency 3072
--kv_cache_free_gpu_mem_fraction 0.85
--extra_llm_api_options ./extra-llm-api-config.yml

Expected Result Format#

The perf might be different from different datasets and machines

=========================================================== = PERFORMANCE OVERVIEW

Request Throughput (req/sec): 17.3885 Total Output Throughput (tokens/sec): 35611.5942 Per User Output Throughput (tokens/sec/user): 11.6701 Per GPU Output Throughput (tokens/sec/gpu): 4451.4493 Total Latency (ms): 2826700.0758 Average request latency (ms): 176064.1921

H200 min-latency#

Our benchmark results are based on Batch = 1, ISL = 1K, OSL = 2K, num_requests = 10 from real datasetTo do the benchmark, run the following command:

YOUR_DATA_PATH=

cat >./extra-llm-api-config.yml<<EOF use_cuda_graph: true speculative_config: decoding_type: MTP num_nextn_predict_layers: 3 EOF

trtllm-bench --model deepseek-ai/DeepSeek-R1
throughput
--dataset $YOUR_DATA_PATH
--backend pytorch
--num_requests 10
--max_batch_size 1
--tp 8
--ep 4
--concurrency 1
--extra_llm_api_options ./extra-llm-api-config.yml

Expected Result Format#

The perf might be different from different datasets and machines

=========================================================== = PERFORMANCE OVERVIEW

Request Throughput (req/sec): 0.0772 Total Output Throughput (tokens/sec): 158.0669 Per User Output Throughput (tokens/sec/user): 158.1196 Per GPU Output Throughput (tokens/sec/gpu): 19.7584 Total Latency (ms): 129498.2168 Average request latency (ms): 12945.9379

H200 max-throughput#

Our benchmark results are based on Batch = 1024, ISL = 1K, OSL = 2K, num_requests = 5120 from real datasetTo do the benchmark, run the following command:

generate synthetic dataset

python ${YOUR_WORK_PATH}/benchmarks/cpp/prepare_dataset.py
--stdout
--tokenizer deepseek-ai/DeepSeek-R1
token-norm-dist
--input-mean 1024 --output-mean 2048
--input-stdev 0 --output-stdev 0
--num-requests 5120 > dataset.txt YOUR_DATA_PATH=./dataset.txt

cat >./extra-llm-api-config.yml<<EOF use_cuda_graph: true cuda_graph_batch_sizes:

Use NVCC for DeepGEMM JIT compilation

export TRTLLM_DG_JIT_USE_NVCC=1

trtllm-bench -m deepseek-ai/DeepSeek-R1
throughput
--tp 8
--ep 8
--warmup 0
--dataset $YOUR_DATA_PATH
--backend pytorch
--max_batch_size 128
--max_num_tokens 1151
--num_requests 5120
--concurrency 1024
--kv_cache_free_gpu_mem_fraction 0.8
--extra_llm_api_options ./extra-llm-api-config.yml

Expected Result Format#

The perf might be different from different datasets and machines

=========================================================== = PERFORMANCE OVERVIEW

Request Throughput (req/sec): 5.6100 Total Output Throughput (tokens/sec): 11489.2671 Per User Output Throughput (tokens/sec/user): 11.3476 Per GPU Output Throughput (tokens/sec/gpu): 1436.1584 Total Token Throughput (tokens/sec): 17233.9007 Total Latency (ms): 912656.9938 Average request latency (ms): 181540.5739

Exploring more ISL/OSL combinations#

To benchmark TensorRT-LLM on DeepSeek models with more ISL/OSL combinations, you can use prepare_dataset.py to generate the dataset and use similar commands mentioned in the previous section. TensorRT-LLM is working on enhancements that can make the benchmark process smoother.

WIP: Enable more features by default#

Currently, there are some features that need to be enabled through a user-defined file extra-llm-api-config.yml, such as CUDA graph, overlap scheduler and attention dp. We’re working on to enable those features by default, so that users can get good out-of-the-box performance on DeepSeek models.

Note that, max_batch_size and max_num_tokens can easily affect the performance. The default values for them are already carefully designed and should deliver good performance on overall cases, however, you may still need to tune it for peak performance.

Generally, you should make sure that max_batch_size is not too low to bottleneck the throughput, and max_num_tokens needs to be large enough so that it covers the max input sequence length of the samples in dataset, as mentioned in below section “WIP: Chunked context support on DeepSeek models”.

For more details on max_batch_size and max_num_tokens, refer to Tuning Max Batch Size and Max Num Tokens.

WIP: Chunked context support on DeepSeek models#

TensorRT-LLM team is actively working on chunked context support for DeepSeek models. Because of that missing feature, there is currently a limitation that max_num_tokens has to be at least larger than the max input sequence length of the samples in dataset. For more details on max_num_tokens, refer to Tuning Max Batch Size and Max Num Tokens.

Out of memory issues#

It’s possible seeing OOM issues on some cases. Considering reducing kv_cache_free_gpu_mem_fraction to a smaller value as a workaround. We’re working on the investigation and addressing the problem.