Update the test checking for cooperative kernels in conditional nodes. by galv · Pull Request #9869 · NVIDIA-NeMo/NeMo (original) (raw)
Yeah, that's an interesting possibility.
One of the big challenges is that the error returned by torch is not very precise. It's just a RuntimeError corresponding to "invalid argument", or cudaErrorInvalidValue, which is not a precise enough error for us to tell that the problem specifically is that the code is using a cooperative kernel within a conditional node's body graph. And unfortunately we cannot check whether this is the case because conditional node API does not expose a way to get the body graph(s) of a conditional node, right now...
Anyway, I suppose if the error was not because of a cooperative kernel, but because of something else, then there is a good chance the error will get thrown by the partial graphs implementation. But it's still not a guarantee!