support argmax converter by bowang007 · Pull Request #2291 · pytorch/TensorRT (original) (raw)

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py 2023-09-05 22:31:02.244529+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py 2023-09-05 22:33:23.441716+00:00 @@ -23,18 +23,15 @@ dim: int = 0, keep_dim: bool = False, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError(

def test_argmax(self, _, input_shape, dim, keep_dim):
    class ArgMax(nn.Module):
        def __init__(self):
            super().__init__()

if name == "main": - run_tests()