support argmax converter by bowang007 · Pull Request #2291 · pytorch/TensorRT (original) (raw)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py 2023-09-05 22:31:02.244529+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/argmax.py 2023-09-05 22:33:23.441716+00:00 @@ -23,18 +23,15 @@ dim: int = 0, keep_dim: bool = False, ) -> TRTTensor: if not isinstance(input, TRTTensor): raise RuntimeError(
f"argmax received input {input} that is not part "
"of the TensorRT region!"
f"argmax received input {input} that is not part " "of the TensorRT region!" )
if dim < 0: dim = len(tuple(input.shape)) + dim reduce_mask = 1 << dim topk_layer = network.add_topk(input, trt.TopKOperation.MAX, 1, reduce_mask)
set_layer_name(topk_layer, target, name)
return topk_layer.get_output(1)
- --- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py 2023-09-05 22:31:02.264529+00:00 +++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/converters/test_argmax.py 2023-09-05 22:33:26.764451+00:00 @@ -2,33 +2,23 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from harness import DispatchTestCase
- class TestArgmaxConverter(DispatchTestCase):
- @parameterized.expand(
[
("dim_0_keep_dim_false", (3, 4), 0, False)
]
- )
- @parameterized.expand([("dim_0_keep_dim_false", (3, 4), 0, False)])
def test_argmax(self, _, input_shape, dim, keep_dim):
class ArgMax(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input):
def forward(self, input): return torch.argmax(input, dim, keep_dim)
input = [torch.randn(*input_shape)]
self.run_test(
ArgMax(),
input,
expected_ops={torch.ops.aten.argmax.default}
)
self.run_test(ArgMax(), input, expected_ops={torch.ops.aten.argmax.default})
if name == "main": - run_tests()
- run_tests()