Nemotron-3-Nano on Jetson Thor vLLM : ITL degrades 4.7x with concurrency, MTP rejected (original) (raw)
Hello NVIDIA team,
On Jetson AGX Thor with vLLM 0.20.2, we ran the same benchmark
on four NVFP4 models.NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4is the
fastest model at single concurrency (65.67 tok/s output @ conc=1,
ITL 14.18 ms) thanks to its MoE 3B-active-params design. But its inter-token latency (ITL) degrades 4.7x from conc=1 to conc=32, while
dense Qwen-8B and dense Mistral-24B keep ITL stable. We also could not
enable MTP speculative decoding on Nano (both'mtp'and'nemotron_h_mtp'
rejected withNotImplementedErrorin vLLM 0.20.2). We suspect the
missing Thor-tunedfused_moeconfig is the root cause. We need help.
1. Hardware
| Item | Value | Source |
|---|---|---|
| Device | NVIDIA Jetson AGX Thor Developer Kit | /proc/device-tree/model |
| GPU architecture | Blackwell, compute capability sm_110 (a.k.a. sm_110a) | NVIDIA spec |
| CUDA toolkit | 13.0 | container env |
| OS | Ubuntu 24.04.4 LTS (Noble Numbat) | /etc/os-release |
| Kernel | Linux 6.8.12-tegra | uname -r |
| Unified memory (RAM + VRAM) | 122 GiB | free -h |
| Model cache filesystem | EXT4 (local NVMe) | vLLM startup log |
2. Software / Docker setup
| Item | Value |
|---|---|
| Docker image | vllm/vllm-openai:v0.20.2 (generic ARM64 build) |
| Orchestration | Docker Compose, runtime: nvidia, CUDA_VISIBLE_DEVICES=0 |
| vLLM | 0.20.2 |
| FlashInfer | bundled with image, JIT autotuner enabled |
| Network | internal Docker bridge only, no host port |
| HF cache volume | baillia_vllm_models:/root/.cache/huggingface |
The same image is used for all four models so the runtime stays comparable.
We are aware NVIDIA AI-IoT publishes a Jetson-Thor-specific container
(ghcr.io/nvidia-ai-iot/vllm:latest-jetson-thor). We explicitly compare
generic vLLM 0.20.2 here because we want an ISO reproducible setup, but we
are open to switching (see Q5 below).
3. Per-model configuration tested
Baseline flags shared by every model :
--async-scheduling
--enable-chunked-prefill
--kv-cache-dtype fp8_e4m3
--max-num-batched-tokens 16384
--tensor-parallel-size 1
Differences per model:
E1 — nvidia/Qwen3-8B-NVFP4
--gpu-memory-utilization 0.70
--max-model-len 8192
--max-num-seqs 1024
--enable-prefix-caching
No special env vars. No reasoning parser. No MTP.
vLLM auto-selected:
quantization=modelopt_fp4FlashInferCutlassNvFp4LinearKernelfor NVFP4 GEMMFLASHINFERattention backend (chosen from['FLASHINFER', 'TRITON_ATTN'])- FlashInfer JIT autotuner ran successfully for fp4_gemm
E2 — RedHatAI/Mistral-Small-3.2-24B-Instruct-2506-NVFP4
--gpu-memory-utilization 0.70
--max-model-len 8192
--max-num-seqs 1024
--enable-prefix-caching
Identical flags as E1. We initially tested without any
Mistral-native flags to keep a fair ISO comparison.
Side experiment — We later re-ran E2 with the recommended Mistral
flags (--tokenizer-mode mistral+--limit-mm-per-prompt '{"image":0}'
--tool-call-parser mistral+--enable-auto-tool-choice). Result:
output throughput dropped 5-42%, but the “aggregate” throughput
(input + output) jumped 200-261%. We believe this is because the Mistral
tokenizer counts input tokens differently (~9-13x more tokens for the
same source text). The headline “Mistral aggregate throughput” numbers
in vLLM blog posts likely include this tokenizer-counting effect, not
a real acceleration.
vLLM auto-selected: same NVFP4 + FLASHINFER stack as E1. Quantization
detected as compressed-tensors (RedHatAI format).
E3 — nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4
--gpu-memory-utilization 0.70
--max-model-len 8192
--max-num-seqs 1024
--enable-prefix-caching
--trust-remote-code
Env vars (from the Nemotron cookbook):
VLLM_USE_FLASHINFER_MOE_FP4=1
VLLM_FLASHINFER_MOE_BACKEND=throughput
We tried to enable MTP speculative decoding (the model card mentions
native MTP layers):
--speculative-config '{"method":"mtp","num_speculative_tokens":3}'
vLLM 0.20.2 fails at startup:
NotImplementedError: Unsupported speculative method: 'mtp'
We also tried 'nemotron_h_mtp' → same error. The string is not in theSpeculativeMethod Literal in vLLM 0.20.2 (only ngram, medusa,mlp_speculator, draft_model, suffix, plus EagleModelTypes are
accepted). So E3 was finally benchmarked without MTP.
We also did NOT enable --reasoning-parser nano_v3 + --reasoning-parser-plugin nano_v3_reasoning_parser.py
--tool-call-parser qwen3_coderin our benchmark. We later discovered
these are recommended in the vLLM Nemotron-3-Nano-30B-A3B recipe.
We have not yet re-benchmarked with them (see Q3).
Startup logs we can confirm:
- vLLM printed
WARNING fused_moe.py: Using default MoE config. Performance might be sub-optimal! Config file not found at .../E=128,N=768,device_name=NVIDIA_Thor.json Using triton Mamba SSU backend(correct)
E4 — nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 (with MTP)
--gpu-memory-utilization 0.80
--max-model-len 8192
--max-num-seqs 32
--max-cudagraph-capture-size 32
--attention-backend TRITON_ATTN
--moe-backend marlin
--trust-remote-code
--mamba-ssm-cache-dtype float32
--speculative-config '{"method":"mtp","num_speculative_tokens":2,"moe_backend":"triton"}'
Env vars:
VLLM_NVFP4_GEMM_BACKEND=marlin
VLLM_USE_FLASHINFER_MOE_FP4=0
VLLM_ALLOW_LONG_MAX_MODEL_LEN=1
Notes:
- 75 GiB model + util=0.80 leaves only 9.76 GiB for KV cache → max
physical concurrency 16.88x, confirmed by vLLM startup log. - We manually supplied a MoE config file
E=512,N=2688,device_name=NVIDIA_Thor.json
(copied fromE=512,N=1344,device_name=NVIDIA_B200.jsonbecause Thor TP=1
needs N=2688 but only B200 N=1344 ships upstream). Without this file vLLM
prints the sameUsing default MoE configwarning. - For Super, MTP
'mtp'method is accepted by vLLM 0.20.2 (different from
Nano which is rejected — see Q1).
4. Benchmark results
Workload (identical for every model and concurrency):
vllm bench serve \
--dataset-name random \
--random-input-len 200 \
--random-output-len 100 \
--request-rate inf \
--ignore-eos \
--num-prompts $((CONC * 10)) \
--max-concurrency $CONC
All four benchmarks were run on the same machine, same Docker image, same workload. Numbers below are output token throughput (server-side) and ITL/TTFT as reported by vllm bench serve.
4.1 Full comparison (E1 / E2 / E3 / E4)
Conc | Metric (unit) | E1 | E2 | E3 | E4 | Best
-----|-----------------------|--------|--------|--------|---------|--------
1 | Output (tok/s) | 41.12 | 17.20 | 19.04 | 13.11 | E1 ⭐
1 | TPOT mean (ms) | 23.25 | 57.82 | 14.32 | 101.26 | E3 ⭐
1 | TTFT median (ms) | 43.79 | 85.83 | 100.87 | 423.18 | E1 ⭐
-----|-----------------------|--------|--------|--------|---------|--------
4 | Output (tok/s) | 162.72 | 66.82 | 147.68 | 25.48 | E1 ⭐
4 | TPOT mean (ms) | 23.89 | 58.21 | 23.86 | 164.84 | E1/E3 ⭐
4 | TTFT median (ms) | 99.82 | 246.39 | 276.71 | 673.36 | E1 ⭐
-----|-----------------------|--------|--------|--------|---------|--------
8 | Output (tok/s) | 229.87 | 130.16 | 224.41 | 42.20 | E1 ⭐
8 | TPOT mean (ms) | 28.22 | 59.28 | 32.53 | 201.31 | E1 ⭐
8 | TTFT median (ms) | 156.49 | 202.59 | 332.44 | 830.94 | E1 ⭐
-----|-----------------------|--------|--------|--------|---------|--------
16 | Output (tok/s) | 599.14 | 247.53 | 311.65 | 62.30 | E1 ⭐
16 | TPOT mean (ms) | 25.14 | 60.47 | 47.08 | 273.00 | E1 ⭐
16 | TTFT median (ms) | 156.90 | 446.38 | 442.25 | 989.43 | E1 ⭐
-----|-----------------------|--------|--------|--------|---------|--------
32 | Output (tok/s) |1065.22 | 457.12 | 429.57 | 65.13 | E1 ⭐
32 | TPOT mean (ms) | 28.45 | 63.91 | 65.92 | 273.43 | E1 ⭐
32 | TTFT median (ms) | 159.14 | 371.39 | 741.70 |25228.36 | E1 ⭐
4.2 Focus E2 vs E3
Conc | Metric (unit) | E2 | E3 | Best
-----|-----------------------|--------|--------|--------------------
1 | Output (tok/s) | 17.20 | 19.04 | E3 ⭐ (+11%)
1 | TPOT mean (ms) | 57.82 | 14.32 | E3 ⭐ (-75%)
1 | TTFT median (ms) | 85.83 | 100.87 | E2 ⭐ (-15%)
-----|-----------------------|--------|--------|--------------------
4 | Output (tok/s) | 66.82 | 147.68 | E3 ⭐ (+121%)
4 | TPOT mean (ms) | 58.21 | 23.86 | E3 ⭐ (-59%)
4 | TTFT median (ms) | 246.39 | 276.71 | E2 ⭐ (-11%)
-----|-----------------------|--------|--------|--------------------
8 | Output (tok/s) | 130.16 | 224.41 | E3 ⭐ (+72%)
8 | TPOT mean (ms) | 59.28 | 32.53 | E3 ⭐ (-45%)
8 | TTFT median (ms) | 202.59 | 332.44 | E2 ⭐ (-39%)
-----|-----------------------|--------|--------|--------------------
16 | Output (tok/s) | 247.53 | 311.65 | E3 ⭐ (+26%)
16 | TPOT mean (ms) | 60.47 | 47.08 | E3 ⭐ (-22%)
16 | TTFT median (ms) | 446.38 | 442.25 | E3 ⭐ (-1%)
-----|-----------------------|--------|--------|--------------------
32 | Output (tok/s) | 457.12 | 429.57 | E2 ⭐ (+6%)
32 | TPOT mean (ms) | 63.91 | 65.92 | E2 ⭐ (-3%)
32 | TTFT median (ms) | 371.39 | 741.70 | E2 ⭐ (-50%)
4.3 VRAM measurements
| Model VRAM | KV cache | Max conc | Notes
E1 Qwen3-8B-NVFP4 | 5.98 GiB | 72.74 | 129.32x | dense, util=0.70
E2 Mistral-24B-NVFP4 | 15.05 GiB | 31.69 | 50.70x | dense, util=0.70
E3 Nemotron-Nano-30B | 18.65 GiB | 54.52 | 451.10x | MoE 3B active, util=0.70
E4 Nemotron-Super-120B | 75.03 GiB | 9.76 | 16.88x | MoE+MTP num=2, util=0.80
All “Model VRAM”, “KV cache” and “Max conc” values are taken verbatim from
vLLM startup logs (Model loading took X GiB memory and Available KV cache memory: Y GiB / Maximum concurrency for 8192 tokens per request: Zx).
4.4 Legend
- Concurrency — number of parallel client requests bursted at the
server (--request-rate inf). Each run uses--num-prompts = 10 * conc
prompts. - Output (tok/s) — generated token throughput across all concurrent
requests, measured server-side byvllm bench serve. - TPOT mean (ms) — Time Per Output Token, mean of per-request ITL
(Inter-Token Latency). Lower is better. This is the field reported byvllm bench serveasMean ITL. - TTFT median (ms) — Time To First Token, median across the
concurrency window. Lower is better. Reported asMedian TTFTby the
bench. - Best — lowest TPOT/TTFT or highest Output throughput in the row.
“⭐” marks the winner; percentage in parentheses is relative to the
other model when relevant. - Model VRAM — GiB taken by model weights only, from vLLM
Model loading took X GiB memory and Y seconds. - KV cache — GiB available for the KV cache after model load and
CUDA-graph capture, from vLLMAvailable KV cache memory: X GiB. - Max conc —
Maximum concurrency for 8192 tokens per request: Xx
from vLLM startup. This is the physical ceiling imposed by the KV
cache size — beyond this, requests get queued.
4.5 Observations
- At conc=1, E3 also has the lowest ITL (14.18 ms). This is consistent with the MoE design: only 3B params
are active per token, so single-stream decoding is faster than dense 8B Qwen. - E3 inter-token latency degrades 4.65x between conc=1 and conc=32:
- conc=1 → 14.18 ms ITL (lowest of all 4 models at conc=1)
- conc=4 → 23.86 ms (+68%)
- conc=8 → 32.61 ms (+130%)
- conc=16 → 47.22 ms (+233%)
- conc=32 → 65.90 ms (+365% vs conc=1)
For comparison over the same conc=1 → conc=32 range: - E1 Qwen-8B (dense) : 23.07 → 27.28 ms ITL (+18%)
- E2 Mistral-24B (dense): 57.67 → 62.97 ms ITL (+9%)
- E3 Nemotron-Nano (MoE): 14.18 → 65.90 ms ITL (+365%)
- E4 Super-120B (MoE) : 100.88 → 248.94 ms ITL (+147%, but
saturated by KV cache limit at conc=16)
Only E3 shows steep degradation that takes it from “fastest model at
conc=1” to “tied with Mistral at conc=16” to “slower than Mistral at
conc=32”.
- E3 crosses over with E2 (Mistral) around conc=16 → conc=32:
- At conc=16, E3 is still slightly faster on output (311.65 vs 247.53)
and TPOT (47.08 vs 60.47). - At conc=32, E2 takes the lead on all three metrics (output +6%,
TPOT -3%, TTFT median -50%).
- At conc=16, E3 is still slightly faster on output (311.65 vs 247.53)
- MTP could not be enabled on E3 with vLLM 0.20.2. The same MTP
speculative-config that works on E4 Super ('mtp',num_speculative_tokens=2) fails on E3 Nano withNotImplementedError: Unsupported speculative method: 'mtp'. Also
tried'nemotron_h_mtp'→ same error. - fused_moe “Using default MoE config” warning appears at startup
for both MoE models (E3 Nano and E4 Super) on Thor SM_110. vLLM ships
zerodevice_name=NVIDIA_Thorconfigs invllm/model_executor/layers/fused_moe/configs/(only B200, GB200, H100,
H200, RTX_PRO_6000 configs are present in v0.20.2). - E4 saturates at conc=16-32 (output 62.30 → 65.13 tok/s = +4.5%
only) because the 9.76 GiB KV cache physically caps concurrency at
16.88x. Median TTFT explodes to 25 seconds at conc=32 due to request
queuing past the physical limit. This is a hardware constraint on
Thor for 75 GiB models; not a software issue.
5. Why is E3 ITL degrading 4.65x with concurrency?
This is the part where we want your input. Our two observations:
- vLLM startup prints
WARNING fused_moe.py: Using default MoE config. Performance might be sub-optimal!for E3 becauseE=128,N=768,device_name=NVIDIA_Thor.jsondoes not exist in vLLM
0.20.2. - Dense models (E1 Qwen, E2 Mistral) on the same machine, same vLLM
build, same workload, show ITL growth between +9% and +18% across
conc=1->32, while the MoE E3 shows +365%.
The simplest hypothesis is that the MoE expert routing kernel uses
non-tuned BLOCK_SIZE_M/N/K, GROUP_SIZE_M, num_warps, num_stages
on Thor SM_110, and this becomes the dominant cost as batched expert
dispatch grows with concurrency. Dense models do not hit this kernel
at all, which would explain why only E3 (and E4 to a lesser extent,
because we workaround-ed it with a hand-copied B200 config) is
affected.
We have not yet run benchmark_moe.py --tune on Thor to generate a
proper config — we wanted to ask first before publishing potentially
sub-optimal numbers (see Q2).
6. Why MTP could not be enabled on E3
vLLM 0.20.2 SpeculativeMethod Literal invllm/config/speculative.py contains only:"ngram", "medusa", "mlp_speculator", "draft_model", "suffix" plus
EagleModelTypes. Neither "mtp" nor "nemotron_h_mtp" are in this
list. Passing them at startup hits the else branch:
raise NotImplementedError(f"Unsupported speculative method: {method!r}")
The same 'mtp' method is accepted for E4 Super (no error at startup,
acceptance rate 7-20% reported by the bench). Internally, when the
config is loaded, vLLM rewrites hf_config.model_type ifnum_nextn_predict_layers > 0 — this code path apparently exists for
Super but not for Nano in v0.20.2.
We do not have a clear answer on whether this is intentional
(Nano-Nano MTP not yet supported by vLLM) or a config-side issue we
could fix from outside.
7. Sources we already consulted
- vLLM recipe for Nemotron-3-Nano-30B-A3B:
NVIDIA Nemotron-3-Nano-30B-A3B User Guide - vLLM Recipes
(mentions--reasoning-parser nano_v3,--reasoning-parser-plugin nano_v3_reasoning_parser.py,--tool-call-parser qwen3_coder) - NVIDIA Nemotron cookbook (vLLM, Nano & Super):
GitHub - NVIDIA-NeMo/Nemotron: Developer Asset Hub for NVIDIA Nemotron — A one-stop resource for training recipes, usage cookbooks, datasets, and full end-to-end reference examples to build with Nemotron models · GitHub - NVIDIA Nemotron-3-Super DGX Spark deployment guide:
Nemotron 3 Super — DGX Spark Deployment Guide — Nemotron - vLLM blog “Nemotron 3 Super”: Run Highly Efficient and Accurate Multi-Agent AI with NVIDIA Nemotron 3 Super Using vLLM | vLLM Blog
- vLLM blog “Nemotron 3 Nano”: Run Highly Efficient and Accurate AI Agents with NVIDIA Nemotron 3 Nano on vLLM | vLLM Blog
- HF discussion #9 on MTP OOM:
nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-NVFP4 · vLLM MTP unusable on RTX 6000 Pro, as spec decoding consumes 20GB+ VRAM at start-up, causing OOM - HF discussion on Nano tool calling + reasoning parsing broken:
nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 · Tool calling with reasoning parsing broken - NVIDIA Developer Forum: “CUDA illegal memory access with MTP
speculative decoding on Nemotron-3-Super-120B-NVFP4”:
CUDA illegal memory access with MTP speculative decoding on Nemotron-3-Super-120B-NVFP4 (vLLM cu130-nightly, single DGX Spark GB10) - NVIDIA Developer Forum: “Jetson AGX Thor + vLLM (26.02): MoE
performance significantly below reference — missing fused MoE config”:
Jetson AGX Thor + vLLM (26.02): MoE performance significantly below reference — missing fused MoE config? - NVIDIA AI-IoT vLLM Thor container:
Package vllm · GitHub - vLLM source
vllm/model_executor/layers/fused_moe/configs/— only
B200 / GB200 / H100 / H200 / RTX_PRO_6000 configs ship in v0.20.2;
nodevice_name=NVIDIA_Thorconfig at all. - vLLM source
SpeculativeMethodLiteral — onlyngram,medusa,mlp_speculator,draft_model,suffix+ EagleModelTypes are
accepted; nomtp/nemotron_h_mtpstrings.
8. Our questions for you
Each question lists what we already found publicly so you can focus on
what is still ambiguous.
Q1. MTP on Nemotron-3-Nano with vLLM 0.20.2
What we found: the Nemotron-3-Nano vLLM recipe
does NOT include any --speculative-config flag in its example command,
while the Nemotron-3-Super recipe
does. The Nano model card mentions native MTP layers, but the deployment
guide does not show how to use them.
Is MTP officially supported for Nano-30B-A3B with vLLM 0.20.2? If
yes, what is the exact --speculative-config JSON (we tried 'mtp'
and 'nemotron_h_mtp', both rejected). If no, is it scheduled for a
specific vLLM version, or only available in a separate branch / fork?
Q2. fused_moe tuning for Thor SM_110
What we found: benchmark_moe.py --tune --save-dir <path> is the
standard way to generate the JSON for a new device. NVIDIA AI-IoT does
not publish pre-tuned configs for Thor. For E4 Super we hand-copiedE=512,N=1344,NVIDIA_B200.json → E=512,N=2688,NVIDIA_Thor.json
(workaround, not a real tune).
Is there an officially-blessed tuning command for Thor, including
batch sizes to sweep, expected duration, and confirmation that the
generated configs are picked up by vLLM at runtime? Specifically for
E3 Nano (E=128, N=768) and E4 Super (E=512, N=2688). And will
shipping the Thor configs in a future vLLM release be part of the
Jetson Thor support roadmap?
Q3. Reasoning parser and tool-call parser for Nano
What we found: the Nemotron-3-Nano vLLM recipe
recommends --reasoning-parser nano_v3 --reasoning-parser-plugin nano_v3_reasoning_parser.py --tool-call-parser qwen3_coder. We did NOT
enable these in our benchmark.
Will adding these flags change the benchmark numbers? We have two
specific concerns:
(a) Does the reasoning parser cause the model to emit fewer “thinking”
tokens, reducing throughput-counted output?
(b) Does it add latency on the critical path (we already see 4.65x ITL
growth — would it get worse)?
We would re-run our benchmark with these flags if you confirm they are
the recommended setup for production inference and not just for
chat/tool-use scenarios.
Q4. Backend choice for Nano-30B-A3B on Thor
What we found: two conflicting recipes.
- Nano cookbook env vars:
VLLM_USE_FLASHINFER_MOE_FP4=1+VLLM_FLASHINFER_MOE_BACKEND=throughput - Super DGX Spark guide env vars:
VLLM_NVFP4_GEMM_BACKEND=marlin+VLLM_USE_FLASHINFER_MOE_FP4=0, with CLI--moe-backend marlin --attention-backend TRITON_ATTN
We used the Nano cookbook env vars on E3, and the Super-style flags on
E4. For E4 we saw Using 'MARLIN' NvFp4 MoE backend out of potential backends: ['FLASHINFER_TRTLLM', 'FLASHINFER_CUTEDSL', 'FLASHINFER_CUTEDSL_BATCHED', 'FLASHINFER_CUTLASS', 'VLLM_CUTLASS', 'MARLIN', 'EMULATION'] and the warning Your GPU does not have native support for FP4 computation — confusing on Thor since Blackwell does
have native NVFP4 hardware (~2070 TFLOPS NVFP4 advertised).
For Nano on Thor specifically, which combination do you recommend?
And is the Marlin “no native FP4” warning a real performance penalty on
Thor, or just a hard-coded check that misses SM_110?
Q5. Generic vllm/vllm-openai:v0.20.2 vs ghcr.io/nvidia-ai-iot/vllm:latest-jetson-thor
What we found: NVIDIA AI-IoT publishes a Jetson-Thor-specific container
(ghcr.io/nvidia-ai-iot/vllm:r38.2.arm64-sbsa-cu130-24.04, also taggedlatest-jetson-thor). The
NGC catalog page
indicates it includes Thor-optimized builds.
Does the AI-IoT container include:
(a) Pre-tuned fused_moe configs for device_name=NVIDIA_Thor?
(b) MTP support for Nemotron-3-Nano (i.e., 'mtp' /'nemotron_h_mtp' accepted in --speculative-config)?
(c) The Nano reasoning-parser plugin pre-installed?
If yes to any of these, we would switch immediately. If no, is the
container essentially just an ARM build of the same upstream vLLM that
we are already running?
Thank you very much.