slice_scatter decomposition by apbose · Pull Request #2519 · pytorch/TensorRT (original) (raw)

--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py 2024-02-20 19:59:59.374321+00:00 +++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py 2024-02-20 20:01:49.660284+00:00 @@ -1,10 +1,11 @@ """

Reference

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-20 19:59:59.382321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-20 20:01:49.759276+00:00 @@ -30,16 +30,18 @@ gpu_id (int): Device ID for target GPU dla_core (int): Core ID for target DLA core allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed """

gpu_id: int = -1  #: Device ID for target GPU
dla_core: int = -1  #: Core ID for target DLA core
def __init__(self, *args: Any, **kwargs: Any):
    """__init__ Method for torch_tensorrt.Device

    Device accepts one of a few construction patterns

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-20 19:59:59.382321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-20 20:01:49.959821+00:00 @@ -26,16 +26,16 @@

class _ShapeMode(Enum):
    STATIC = 0
    DYNAMIC = 1
dtype: _enums.dtype = (
    _enums.dtype.unknown
)  #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = (

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2024-02-20 19:59:59.382321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2024-02-20 20:01:50.013227+00:00 @@ -212,13 +212,13 @@ "precision": precision, "debug": debug, "device": device, "workspace_size": workspace_size, "min_block_size": min_block_size,

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-20 19:59:59.382321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-20 20:01:50.235895+00:00 @@ -26,13 +26,13 @@

from packaging import version

_LOGGER: logging.Logger = logging.getLogger(name)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[

-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (

+)

class UnsupportedOperatorException(RuntimeError): pass

@@ -90,13 +90,13 @@ self.input_specs_iter = 0 self._cur_node_name: Optional[str] = None self._cur_node: Optional[torch.fx.Node] = None self._input_names: List[str] = [] self._output_names: List[str] = []

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-20 19:59:59.382321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-20 20:01:50.278485+00:00 @@ -322,17 +322,15 @@ else: raise AssertionError(f"Cannot convert {input_val} to TRT constant")

@overload -def get_positive_dim(dim: int, dim_size: int) -> int:

+def get_positive_dim(dim: int, dim_size: int) -> int: ...

@overload -def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:

+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...

def get_positive_dim( dim: Union[int, Sequence[int]], dim_size: int ) -> Union[int, Tuple[int, ...]]: --- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-20 19:59:59.386321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-20 20:01:50.623768+00:00 @@ -5,13 +5,13 @@ from torch._decomp import get_decompositions as get_torch_decompositions from torch._ops import OpOverload, OpOverloadPacket

aten = torch.ops.aten

-_core_aten_decompositions: Dict[

-] = core_aten_decompositions() +_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (

+) torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = { aten.adaptive_avg_pool2d_backward, aten.addcdiv, aten.addcdiv, aten.addcmul, @@ -179,13 +179,13 @@ torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = { aten._softmax.default, }

-ENABLED_TORCH_DECOMPOSITIONS: Dict[

-] = get_torch_decompositions(torch_enabled_decompositions) +ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (

+) TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}

def check_decomp_set_invariants() -> None: """Validates no overlap between enabled and disabled decomposition sets""" --- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-20 19:59:59.386321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-20 20:01:50.628829+00:00 @@ -20,16 +20,14 @@ logger.debug(f"Graph after lowering linear:\n{gm.graph}")

return gm

-def linear_replacement() -> (

-): +def linear_replacement() -> Tuple[

+]: """Constructs the original and replacement functions for linear"""

# Original graph
def orig(
    input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-20 19:59:59.386321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-20 20:01:50.665412+00:00 @@ -20,16 +20,14 @@ logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

return gm

-def view_replacement() -> (

-): +def view_replacement() -> Tuple[

+]: """Constructs the original and replacement functions for view"""

# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
    return torch.ops.aten.view.default(input, shape)

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-20 19:59:59.386321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-20 20:01:50.681914+00:00 @@ -58,16 +58,14 @@ logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")

return gm

-def scaled_dot_product_attention_replacement() -> (

-): +def scaled_dot_product_attention_replacement() -> Tuple[

+]: """Constructs the original and replacement functions for efficient attention"""

# Efficient Attention original graph
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
    outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2024-02-20 19:59:59.386321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2024-02-20 20:01:50.959439+00:00 @@ -99,25 +99,29 @@ self.engine.get_binding_dtype(idx), Frameworks.TORCH ) for idx in self.output_binding_indices_in_order ] self.output_shapes = [

@@ -165,13 +169,15 @@ self.dict.update(state) if self.engine: self.context = self.engine.create_execution_context()

def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:

@@ -198,13 +204,17 @@ torch.cuda.set_device(device_id)

                inputs = tuple([tensor.to(device) for tensor in inputs])
                logger.warning(f"Moved all input Tensors to cuda:{device_id}")

@@ -237,13 +247,17 @@

                self.context.set_binding_shape(
                    idx, tuple(contiguous_inputs[i].shape)
                )

@@ -264,13 +278,17 @@ dtype=self.hidden_output_dtypes[i], device=torch.cuda.current_device(), ) bindings[idx] = output.data_ptr()

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-20 19:59:59.390321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-20 20:01:51.233651+00:00 @@ -315,25 +315,21 @@ name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: kwargs_new = { "input": args[0], "kernel_size": args[1],

@@ -295,14 +299,12 @@ module.half() # A custom conversion function can be passed to the lowerer to # handle inputs with custom types. By default, just handle # tensors and NoneType. if fp16_conversion_fn is None:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-20 19:59:59.390321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-20 20:01:51.328023+00:00 @@ -19,13 +19,13 @@ from .observer import Observer from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks

_LOGGER: logging.Logger = logging.getLogger(name)

-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[

-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (

+)

class TRTInterpreterResult(NamedTuple): engine: Any input_names: Sequence[str] @@ -73,13 +73,13 @@ self.input_specs_iter = 0 self.validate_input_specs() self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = []

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2024-02-20 19:59:59.390321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2024-02-20 20:01:51.545029+00:00 @@ -194,13 +194,15 @@ lowering_start_time = datetime.datetime.now()

                self.lower_setting.input_specs = generate_input_specs(
                    submod_inputs,
                    self.lower_setting,

@@ -234,13 +236,15 @@ if not submod_name.startswith(split_result.non_acc_submodule_prefix): _LOGGER.info(f"ACC submodule graph: {submod.graph}") lowering_start_time = datetime.datetime.now()

                self.lower_setting.additional_inputs = (

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-20 19:59:59.390321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-20 20:01:51.722875+00:00 @@ -193,13 +193,11 @@ kwargs2 = {"equal_nan": True} if rtol: kwargs2["rtol"] = rtol if atol: kwargs2["atol"] = atol

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-20 19:59:59.390321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-20 20:01:51.782883+00:00 @@ -536,13 +536,13 @@ reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node( maybe_reshape ) if not reshape_batch_size: continue

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py 2024-02-20 19:59:59.394321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py 2024-02-20 20:01:52.206806+00:00 @@ -21,13 +21,15 @@ inputs = [torch.randn(1, 10)] self.run_test( Split(), inputs, expected_ops={

@@ -68,13 +70,15 @@ ] self.run_test_with_dynamic_shape( Split(), input_specs, expected_ops={

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py 2024-02-20 19:59:59.394321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py 2024-02-20 20:01:52.903172+00:00 @@ -152,13 +152,13 @@ mod.eval() if len(expected_ops): self.assert_has_op(mod, expected_ops)

        interpreter_result = interpreter.run(

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py 2024-02-20 19:59:59.398321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py 2024-02-20 20:01:53.269384+00:00 @@ -67,25 +67,29 @@ self.engine.get_binding_dtype(idx), Frameworks.TORCH ) for idx in self.output_binding_indices_in_order ] self.output_shapes = [

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py 2024-02-20 19:59:59.398321+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py 2024-02-20 20:01:53.546949+00:00 @@ -404,13 +404,13 @@ "inputs": inputs if inputs is not None else [], # "input_signature": input_signature, "device": device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers.