Numpy changes for aten::index converter by apbose · Pull Request #2396 · pytorch/TensorRT (original) (raw)
--- /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-10-30 19:32:26.833431+00:00 +++ /home/runner/work/TensorRT/TensorRT/tests/py/dynamo/conversion/test_index_aten.py 2023-10-30 19:35:28.665002+00:00 @@ -24,11 +24,11 @@ input = [torch.randn(2, 2)] self.run_test( TestModule(), input, )
- def test_index_zero_two_dim_ITensor(self): class TestModule(nn.Module): def forward(self, x, index0): indices = [None, index0] out = torch.ops.aten.index.Tensor(x, indices)
@@ -56,25 +56,22 @@ input = [torch.randn(2, 2, 2)] self.run_test( TestModule(), input, )
- def test_index_zero_index_three_dim_ITensor(self): class TestModule(nn.Module): def forward(self, x, index0): indices = [None, index0, None] out = torch.ops.aten.index.Tensor(x, indices) return out input = torch.randn(2, 2, 2) index0 = torch.randint(0, 1, (1, 1)) index0 = index0.to(torch.int32)
self.run_test(
TestModule(),
[input, index0]
)
def test_index_zero_index_one_index_two_three_dim(self): class TestModule(nn.Module): def init(self): self.index0 = torch.randint(0, 1, (1, 1))self.run_test(TestModule(), [input, index0])