feat: support aten._cdist_forward converter by chohk88 · Pull Request #2726 · pytorch/TensorRT (original) (raw)

Thank you for your valuable comment. In this update, I've implemented a matrix multiplication-based approach to compute the distance for p=2. During my tests on my local PC with inputs x1=(150,100,50,50) and x2=(150,100,30,50), I observed that the TRT run time improved by approximately 6.3 times. Here are the relevant logs from those tests:

INFO:harness:FX graph= graph():
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %_cdist_forward_default : [num_users=1] = call_function[target=torch.ops.aten._cdist_forward.default](args = (%x1, %x2, 2, 0), kwargs = {})
    return _cdist_forward_default
===============based on matmul : ==================
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.007999
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:255: DeprecationWarning: Use build_serialized_network instead.
  engine = self.builder.build_engine(self.ctx.net, builder_config)
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:33.710717
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 244801024 bytes of Memory
INFO:harness:Interpreter run time(s): 33.719337898997765
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:63: DeprecationWarning: Use get_tensor_name instead.
  self.engine.get_binding_index(name) for name in self.input_names
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:66: DeprecationWarning: Use get_tensor_name instead.
  self.engine.get_binding_index(name) for name in self.output_names
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:88: DeprecationWarning: Use get_tensor_dtype instead.
  self.engine.get_binding_dtype(idx), Frameworks.TORCH
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:93: DeprecationWarning: Use get_tensor_shape instead.
  tuple(self.engine.get_binding_shape(idx))
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:98: DeprecationWarning: Use get_tensor_dtype instead.
  self.engine.get_binding_dtype(idx), Frameworks.TORCH
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:247: DeprecationWarning: Use set_input_shape instead.
  self.context.set_binding_shape(
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:262: DeprecationWarning: Use get_tensor_shape instead.
  shape = tuple(self.context.get_binding_shape(idx))
INFO:harness:TRT run time(s)= 0.0026535038948059084
.INFO:harness:FX graph= graph():
    %x1 : [num_users=1] = placeholder[target=x1]
    %x2 : [num_users=1] = placeholder[target=x2]
    %_cdist_forward_default : [num_users=1] = call_function[target=torch.ops.aten._cdist_forward.default](args = (%x1, %x2, 2, 1), kwargs = {})
    return _cdist_forward_default
===============based on elementwise pow for diff : ==================
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.003120
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py:255: DeprecationWarning: Use build_serialized_network instead.
  engine = self.builder.build_engine(self.ctx.net, builder_config)
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:43.902498
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 4590000640 bytes of Memory
INFO:harness:Interpreter run time(s): 43.90606521600421
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:63: DeprecationWarning: Use get_tensor_name instead.
  self.engine.get_binding_index(name) for name in self.input_names
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:66: DeprecationWarning: Use get_tensor_name instead.
  self.engine.get_binding_index(name) for name in self.output_names
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:88: DeprecationWarning: Use get_tensor_dtype instead.
  self.engine.get_binding_dtype(idx), Frameworks.TORCH
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:93: DeprecationWarning: Use get_tensor_shape instead.
  tuple(self.engine.get_binding_shape(idx))
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:98: DeprecationWarning: Use get_tensor_dtype instead.
  self.engine.get_binding_dtype(idx), Frameworks.TORCH
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:247: DeprecationWarning: Use set_input_shape instead.
  self.context.set_binding_shape(
/home/hoonkyungc/workspace/TorchTRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py:262: DeprecationWarning: Use get_tensor_shape instead.
  shape = tuple(self.context.get_binding_shape(idx))
INFO:harness:TRT run time(s)= 0.016830976486206056
.
----------------------------------------------------------------------