[Core] Fix bug when selecting tuned RMSNorm kernels (#983) · NVIDIA/TransformerEngine@7669bf3 (original) (raw)

File tree

1 file changed

lines changed

1 file changed

lines changed

Original file line number Diff line number Diff line change
@@ -89,7 +89,7 @@ BwdFunction &get_bwd_launcher(DType wtype, DType itype, DType otype, DType ctype
89 89 if (params.rows % 4 == 0 && is_aligned(params.x) && is_aligned(params.rs) &&
90 90 is_aligned(params.gamma) && is_aligned(params.dz) && is_aligned(params.dx) &&
91 91 is_aligned(params.dgamma) && is_aligned(params.dgamma_part) &&
92 -layer_norm::BWD_TUNED_FUNCS.count(tuned_key) > 0) {
92 + BWD_TUNED_FUNCS.count(tuned_key) > 0) {
93 93 return BWD_TUNED_FUNCS.at(tuned_key);
94 94 }
95 95