Removing grid lowering by apbose · Pull Request #2686 · pytorch/TensorRT (original) (raw)
def test_grid(self, _, op_name, input_shape, dim_shape, padding_mode, interpolation_mode, align_corners):
class TestModule(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
return grid_sampler_aten_ops[op_name](x, grid, padding_mode, interpolation_mode, align_corners)
inputs = [torch.randn(input_shape, dtype=torch.float32)]
grid_model = TestModule()
self.run_test(grid_model, inputs)
It would be code design choice to either encapsulate it in the lambda function in the parameters, or declare it in the above way. In my opinion both should be good.
You could let me know if you think otherwise.