Benchmark script for fp8 vs bf16 gemm by mgoin · Pull Request #17126 · vllm-project/vllm (original) (raw)

Added a new benchmark script to compare BF16 torch gemm vs FP8 gemm in vllm, with and without quantization overhead (no-quant)

python benchmarks/kernels/bench_fp8_gemm.py --models meta-llama/Llama-3.1-8B-Instruct --tp-sizes 1
INFO 05-29 22:29:32 [__init__.py:243] Automatically detected platform cuda.
meta-llama/Llama-3.1-8B-Instruct, N=6144 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
/home/mgoin/venvs/vllm/lib/python3.12/site-packages/triton/testing.py:366: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
  plt.show()
BF16 vs FP8 GEMMs:
    batch_size  torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0    2.513023               4.342494               4.059088                       5.445039                       5.396452
1         16.0   39.110705              70.214162              66.707845                      87.459173                      88.895575
2         64.0  160.586644             307.791427             287.609822                     396.962910                     405.436237
3        128.0  307.854920             533.529615             506.953838                     693.312524                     680.167694
4        256.0  544.327503             760.587796             742.650816                     956.002107                     968.299157
5        512.0  659.818085             871.316668             841.789937                    1016.919282                    1090.208643
6       1024.0  679.870773            1074.918844            1041.177030                    1245.489842                    1323.486919
7       2048.0  715.146193            1138.206768            1095.479869                    1270.047415                    1300.250343
8       4096.0  688.488074            1115.499749            1098.683793                    1271.604106                    1281.346376
9       8192.0  693.667044            1139.277410            1156.750194                    1297.870193                    1407.382366
10     16384.0  738.364767            1152.441292            1164.380682                    1303.004870                    1379.475530
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
/home/mgoin/venvs/vllm/lib/python3.12/site-packages/triton/testing.py:366: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
  plt.show()
BF16 vs FP8 GEMMs:
    batch_size  torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0    3.022849               3.738120               3.454632                       5.029108                       5.011478
1         16.0   57.997303              60.181618              56.355188                      81.457602                      83.997588
2         64.0  246.859062             240.507280             224.540611                     327.559158                     336.949841
3        128.0  305.704131             367.782719             345.825568                     465.246361                     471.016790
4        256.0  444.487112             525.910212             498.997757                     644.668348                     658.720996
5        512.0  660.629686             914.257608             890.507412                    1176.315226                    1274.190498
6       1024.0  695.209664             996.459173             918.298155                    1236.218150                    1252.457273
7       2048.0  716.502344            1056.981523             994.623697                    1266.633314                    1329.249553
8       4096.0  723.462180            1115.875903            1060.574284                    1395.043331                    1531.164294
9       8192.0  720.793173            1148.234440            1065.220958                    1388.929933                    1439.672401
10     16384.0  703.359030            1130.539672            1084.622860                    1319.794222                    1458.651195
meta-llama/Llama-3.1-8B-Instruct, N=28672 K=4096, BF16 vs FP8 GEMMs TFLOP/s:
/home/mgoin/venvs/vllm/lib/python3.12/site-packages/triton/testing.py:366: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
  plt.show()
BF16 vs FP8 GEMMs:
    batch_size  torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0    3.040167               5.305363               5.192819                       5.666133                       5.590927
1         16.0   47.639857              81.203396              79.568668                      86.687591                      85.659827
2         64.0  179.276992             322.470350             314.634598                     346.134781                     341.709870
3        128.0  353.182903             599.326413             584.849064                     637.398127                     634.405193
4        256.0  572.841213             962.250753             997.621909                    1023.170474                    1083.889599
5        512.0  706.133317            1174.717913            1316.892471                    1261.189899                    1407.735561
6       1024.0  703.879969            1241.477126            1273.706221                    1321.068095                    1353.264014
7       2048.0  723.412752            1283.466702            1263.461985                    1281.973444                    1307.706341
8       4096.0  722.899287            1244.501696            1262.635446                    1255.003517                    1314.737949
9       8192.0  758.432702            1150.943252            1271.087402                    1164.841908                    1313.115819
10     16384.0  707.356986            1097.374597            1085.444272                    1088.785555                    1081.725312
meta-llama/Llama-3.1-8B-Instruct, N=4096 K=14336, BF16 vs FP8 GEMMs TFLOP/s:
/home/mgoin/venvs/vllm/lib/python3.12/site-packages/triton/testing.py:366: UserWarning: FigureCanvasAgg is non-interactive, and thus cannot be shown
  plt.show()
BF16 vs FP8 GEMMs:
    batch_size  torch-bf16  fp8-tensor-w-tensor-a  fp8-channel-w-token-a  fp8-tensor-w-tensor-a-noquant  fp8-channel-w-token-a-noquant
0          1.0    2.752944               4.047596               3.732006                       5.064539                       4.994591
1         16.0   43.177335              62.902398              58.386002                      79.964621                      79.562602
2         64.0  165.294967             245.108078             227.268942                     310.961558                     309.570065
3        128.0  292.545743             424.859790             397.218176                     534.932955                     534.235440
4        256.0  514.383299             631.829479             594.578070                     766.593431                     768.526847
5        512.0  725.617627            1040.409996            1022.839331                    1239.861994                    1479.940763
6       1024.0  751.169134            1077.981151            1025.406015                    1281.749855                    1389.644284
7       2048.0  750.779056            1075.997395            1054.874047                    1291.808245                    1354.043586
8       4096.0  758.474556            1092.443969            1066.314752                    1271.869110                    1360.628962
9       8192.0  717.395317            1086.953195            1066.911216                    1309.764548                    1321.988758
10     16384.0  733.439720            1100.848401            1075.886839                    1298.541038                    1325.329581
Benchmark finished!