[Core] Fix bug when selecting tuned RMSNorm kernels (#983) · NVIDIA/TransformerEngine@7669bf3 (original) (raw)
File tree
1 file changed
lines changed
- transformer_engine/common/rmsnorm
1 file changed
lines changed
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -89,7 +89,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype | ||
89 | 89 | if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) && |
90 | 90 | is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) && |
91 | 91 | is_aligned(params.dgamma) && is_aligned(params.dgamma_part) && |
92 | -layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) { | |
92 | + BWD_TUNED_FUNCS.count(tuned_key) > 0) { | |
93 | 93 | return BWD_TUNED_FUNCS.at(tuned_key); |
94 | 94 | } |
95 | 95 |