Optimization: EPV1 dispatch & combine kernel by TianDi101 · Pull Request #128 · ROCm/mori (original) (raw)

@TianDi101

@TianDi101

Dispatch optimization result:
token=32, 1.88x token=64, 1.48x token=128, 1.25x

Need to check E2E accuracy and stability.

fp8, token=32, fp8, previous
image
fp8, token=32, fp8, optimized
image

fp8, token=64, previous
image
fp8, token=64, optimized
image

fp8, token=128, previous
image
fp8, token=128, optimized
image

@jhchouuu

It seems that the main improvement was made to the recv phase of the dispatch LL kernel.
if there is data with 64 tokens, and also with CX7 and MI300?

@TianDi101

It seems that the main improvement was made to the recv phase of the dispatch LL kernel. if there is data with 64 tokens, and also with CX7 and MI300?

Yes, the perf gains come from recv phase, mainly XGMI part. I guess it should also improve CX7+MI300X.
I just attached token=64 perf improvement for your reference.

@jhchouuu

@TianDi101

Combine optimization result:
token=32, 1.46x
token=64, 1.41x
token=128, 1.41x

Need to check E2E accuracy and stability.

bf16, token=32, previous
image
bf16, token=32, optimized
image

bf16, token=64, previous
image
bf16, token=64, optimized
image

bf16, token=128, previous
image
bf16, token=128, optimized
image

@TianDi101

@TianDi101

Dispatch & Combine Staging buffer copy (accum) optimization results compared to the last post:

Average perf
Token=32 => dispatch 1.125x, combine 1.25x
Token=64 => dispatch 1.10x, combine 1.26x
Token=128 =>dispatch 1.05x, combine 1.18x

Best Perf
Token=32 => dispatch 1.24x, combine 1.25x
Token=64 => dispatch 1.25x, combine 1.33x
Token=128 =>dispatch 1.08x, combine 1.18x

Need to check E2E accuracy and stability. Optims before this one has been tested.

FP8 Dispatch, 32/64/128 tokens respectively
image
image
image

BF16 Combine, 32/64/128 tokens respectively
image
image
image

@TianDi101

@TianDi101

@TianDi101

EP16 E2E test OK.

EP32 perf optimization results

FP8 average perf
Token=32 => dispatch 1.31x, combine 1.48x
Token=64 => dispatch 1.15x, combine 1.62x
Token=128 =>dispatch 1.03x, combine 1.23x

BF16 average perf
Token=32 => dispatch 1.28x, combine 1.52x
Token=64 => dispatch 1.11x, combine 1.48x
Token=128 =>dispatch 1.03x, combine 1.26x

@TianDi101

Optimization result for high-bandwidth case:
Previous: EP16 bf16 token=4096
image
Optimized: EP16 bf16 token=4096
image

@TianDi101

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the EPV1 (Expert Parallelism V1) dispatch and combine kernels by refactoring the data copy operations into separate kernels and improving the parallelization strategy. The changes include extracting staging buffer copy logic into a dedicated EpDispatchCopyToStaging kernel, introducing a separate EpCombineAll kernel for final combination, and adding low-latency variants for both dispatch and combine operations.

Changes:

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 17 comments.

Show a summary per file

File Description
src/ops/dispatch_combine/internode_v1.hpp Added kernel function declarations for new optimization kernels
src/ops/dispatch_combine/internode_v1.cpp Refactored dispatch/combine logic, extracted staging copy, added low-latency variants, includes large commented code blocks
src/ops/dispatch_combine/dispatch_combine.cpp Updated kernel launches to use new separate kernels and added multiprocessor count initialization
include/mori/utils/hip_helper.hpp New utility header for querying GPU multiprocessor count (has critical linking issue)
include/mori/ops/dispatch_combine/dispatch_combine.hpp Added cuCount member variable to store multiprocessor count
examples/ops/dispatch_combine/test_dispatch_combine_internode.py Added sweep benchmark functionality and updated token count handling

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@TianDi101

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[max - min for max, min in zip(comb_lat_max_list, comb_lat_min_list)],
label="Combine Max-Min",
)
plt.xticks([i * 16 for i in range(max_tokens // 16)])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The xticks calculation range(max_tokens // 16) may not align with the actual max_token_list values. Consider using range(max_tokens // 16 + 1) or calculating ticks based on the actual max_token_list to ensure proper tick placement on the x-axis.

plt.xticks([i * 16 for i in range(max_tokens // 16)])
max_x = max(max_token_list) if max_token_list else 0
plt.xticks([i * 16 for i in range(max_x // 16 + 1)])

Copilot uses AI. Check for mistakes.

This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters

[ Show hidden characters]({{ revealButtonHref }})