Optimization: EPV1 dispatch & combine kernel by TianDi101 · Pull Request #128 · ROCm/mori (original) (raw)
Dispatch optimization result:
token=32, 1.88x
token=64, 1.48x
token=128, 1.25x
Need to check E2E accuracy and stability.
fp8, token=32, fp8, previous
fp8, token=32, fp8, optimized
fp8, token=64, previous
fp8, token=64, optimized
fp8, token=128, previous
fp8, token=128, optimized
It seems that the main improvement was made to the recv phase of the dispatch LL kernel.
if there is data with 64 tokens, and also with CX7 and MI300?
It seems that the main improvement was made to the recv phase of the dispatch LL kernel. if there is data with 64 tokens, and also with CX7 and MI300?
Yes, the perf gains come from recv phase, mainly XGMI part. I guess it should also improve CX7+MI300X.
I just attached token=64 perf improvement for your reference.
Combine optimization result:
token=32, 1.46x
token=64, 1.41x
token=128, 1.41x
Need to check E2E accuracy and stability.
bf16, token=32, previous
bf16, token=32, optimized
bf16, token=64, previous
bf16, token=64, optimized
bf16, token=128, previous
bf16, token=128, optimized
Dispatch & Combine Staging buffer copy (accum) optimization results compared to the last post:
Average perf
Token=32 => dispatch 1.125x, combine 1.25x
Token=64 => dispatch 1.10x, combine 1.26x
Token=128 =>dispatch 1.05x, combine 1.18x
Best Perf
Token=32 => dispatch 1.24x, combine 1.25x
Token=64 => dispatch 1.25x, combine 1.33x
Token=128 =>dispatch 1.08x, combine 1.18x
Need to check E2E accuracy and stability. Optims before this one has been tested.
FP8 Dispatch, 32/64/128 tokens respectively


BF16 Combine, 32/64/128 tokens respectively


EP16 E2E test OK.
EP32 perf optimization results
FP8 average perf
Token=32 => dispatch 1.31x, combine 1.48x
Token=64 => dispatch 1.15x, combine 1.62x
Token=128 =>dispatch 1.03x, combine 1.23x
BF16 average perf
Token=32 => dispatch 1.28x, combine 1.52x
Token=64 => dispatch 1.11x, combine 1.48x
Token=128 =>dispatch 1.03x, combine 1.26x
Optimization result for high-bandwidth case:
Previous: EP16 bf16 token=4096
Optimized: EP16 bf16 token=4096
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR optimizes the EPV1 (Expert Parallelism V1) dispatch and combine kernels by refactoring the data copy operations into separate kernels and improving the parallelization strategy. The changes include extracting staging buffer copy logic into a dedicated EpDispatchCopyToStaging kernel, introducing a separate EpCombineAll kernel for final combination, and adding low-latency variants for both dispatch and combine operations.
Changes:
- Separated staging buffer copy operations into a dedicated
EpDispatchCopyToStagingkernel for better parallelism - Added
EpCombineAllkernel to separate the final token combination from inter-node combine operations - Introduced
EpCombineInterNodeV1KernelLowLatencywith new low-latency implementations (DispatchInterNodeLLRecv,CombineInterNodeLL,CombineIntraNodeLL)
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 17 comments.
Show a summary per file
| File | Description |
|---|---|
| src/ops/dispatch_combine/internode_v1.hpp | Added kernel function declarations for new optimization kernels |
| src/ops/dispatch_combine/internode_v1.cpp | Refactored dispatch/combine logic, extracted staging copy, added low-latency variants, includes large commented code blocks |
| src/ops/dispatch_combine/dispatch_combine.cpp | Updated kernel launches to use new separate kernels and added multiprocessor count initialization |
| include/mori/utils/hip_helper.hpp | New utility header for querying GPU multiprocessor count (has critical linking issue) |
| include/mori/ops/dispatch_combine/dispatch_combine.hpp | Added cuCount member variable to store multiprocessor count |
| examples/ops/dispatch_combine/test_dispatch_combine_internode.py | Added sweep benchmark functionality and updated token count handling |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| [max - min for max, min in zip(comb_lat_max_list, comb_lat_min_list)], |
|---|
| label="Combine Max-Min", |
| ) |
| plt.xticks([i * 16 for i in range(max_tokens // 16)]) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The xticks calculation range(max_tokens // 16) may not align with the actual max_token_list values. Consider using range(max_tokens // 16 + 1) or calculating ticks based on the actual max_token_list to ensure proper tick placement on the x-axis.
| plt.xticks([i * 16 for i in range(max_tokens // 16)]) |
|---|
| max_x = max(max_token_list) if max_token_list else 0 |
| plt.xticks([i * 16 for i in range(max_x // 16 + 1)]) |
Copilot uses AI. Check for mistakes.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.Learn more about bidirectional Unicode characters
[ Show hidden characters]({{ revealButtonHref }})