Benchmark script for fp8 vs bf16 gemm by mgoin · Pull Request #17126 · vllm-project/vllm (original) (raw)
Added a new benchmark script to compare BF16 torch gemm vs FP8 gemm in vllm, with and without quantization overhead (no-quant)
python benchmarks/kernels/bench_fp8_gemm.py --models meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1
INFO 05-29 22:29:32 [__init__.py:243] Automatically detected platform cuda.
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
/home/mgoin/venvs/vllm/lib/python3.12/site-packages/triton/testing.py:366: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
plt.show()
BF16 vs FP8 GEMMs:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 2.513023 4.342494 4.059088 5.445039 5.396452
1 16.0 39.110705 70.214162 66.707845 87.459173 88.895575
2 64.0 160.586644 307.791427 287.609822 396.962910 405.436237
3 128.0 307.854920 533.529615 506.953838 693.312524 680.167694
4 256.0 544.327503 760.587796 742.650816 956.002107 968.299157
5 512.0 659.818085 871.316668 841.789937 1016.919282 1090.208643
6 1024.0 679.870773 1074.918844 1041.177030 1245.489842 1323.486919
7 2048.0 715.146193 1138.206768 1095.479869 1270.047415 1300.250343
8 4096.0 688.488074 1115.499749 1098.683793 1271.604106 1281.346376
9 8192.0 693.667044 1139.277410 1156.750194 1297.870193 1407.382366
10 16384.0 738.364767 1152.441292 1164.380682 1303.004870 1379.475530
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
/home/mgoin/venvs/vllm/lib/python3.12/site-packages/triton/testing.py:366: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
plt.show()
BF16 vs FP8 GEMMs:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 3.022849 3.738120 3.454632 5.029108 5.011478
1 16.0 57.997303 60.181618 56.355188 81.457602 83.997588
2 64.0 246.859062 240.507280 224.540611 327.559158 336.949841
3 128.0 305.704131 367.782719 345.825568 465.246361 471.016790
4 256.0 444.487112 525.910212 498.997757 644.668348 658.720996
5 512.0 660.629686 914.257608 890.507412 1176.315226 1274.190498
6 1024.0 695.209664 996.459173 918.298155 1236.218150 1252.457273
7 2048.0 716.502344 1056.981523 994.623697 1266.633314 1329.249553
8 4096.0 723.462180 1115.875903 1060.574284 1395.043331 1531.164294
9 8192.0 720.793173 1148.234440 1065.220958 1388.929933 1439.672401
10 16384.0 703.359030 1130.539672 1084.622860 1319.794222 1458.651195
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
/home/mgoin/venvs/vllm/lib/python3.12/site-packages/triton/testing.py:366: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
plt.show()
BF16 vs FP8 GEMMs:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 3.040167 5.305363 5.192819 5.666133 5.590927
1 16.0 47.639857 81.203396 79.568668 86.687591 85.659827
2 64.0 179.276992 322.470350 314.634598 346.134781 341.709870
3 128.0 353.182903 599.326413 584.849064 637.398127 634.405193
4 256.0 572.841213 962.250753 997.621909 1023.170474 1083.889599
5 512.0 706.133317 1174.717913 1316.892471 1261.189899 1407.735561
6 1024.0 703.879969 1241.477126 1273.706221 1321.068095 1353.264014
7 2048.0 723.412752 1283.466702 1263.461985 1281.973444 1307.706341
8 4096.0 722.899287 1244.501696 1262.635446 1255.003517 1314.737949
9 8192.0 758.432702 1150.943252 1271.087402 1164.841908 1313.115819
10 16384.0 707.356986 1097.374597 1085.444272 1088.785555 1081.725312
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:
/home/mgoin/venvs/vllm/lib/python3.12/site-packages/triton/testing.py:366: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
plt.show()
BF16 vs FP8 GEMMs:
batch_size torch-bf16 fp8-tensor-w-tensor-a fp8-channel-w-token-a fp8-tensor-w-tensor-a-noquant fp8-channel-w-token-a-noquant
0 1.0 2.752944 4.047596 3.732006 5.064539 4.994591
1 16.0 43.177335 62.902398 58.386002 79.964621 79.562602
2 64.0 165.294967 245.108078 227.268942 310.961558 309.570065
3 128.0 292.545743 424.859790 397.218176 534.932955 534.235440
4 256.0 514.383299 631.829479 594.578070 766.593431 768.526847
5 512.0 725.617627 1040.409996 1022.839331 1239.861994 1479.940763
6 1024.0 751.169134 1077.981151 1025.406015 1281.749855 1389.644284
7 2048.0 750.779056 1075.997395 1054.874047 1291.808245 1354.043586
8 4096.0 758.474556 1092.443969 1066.314752 1271.869110 1360.628962
9 8192.0 717.395317 1086.953195 1066.911216 1309.764548 1321.988758
10 16384.0 733.439720 1100.848401 1075.886839 1298.541038 1325.329581
Benchmark finished!