select_scatter decomp by apbose · Pull Request #2515 · pytorch/TensorRT (original) (raw)
--- /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py 2024-02-16 00:01:27.167252+00:00 +++ /home/runner/work/TensorRT/TensorRT/examples/int8/training/vgg16/vgg16.py 2024-02-16 00:03:16.025977+00:00 @@ -1,10 +1,11 @@ """
Reference
- import torch import torch.nn as nn import torch.nn.functional as F from functools import reduce
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-16 00:01:27.175252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Device.py 2024-02-16 00:03:16.122237+00:00 @@ -30,16 +30,18 @@ gpu_id (int): Device ID for target GPU dla_core (int): Core ID for target DLA core allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed """
- device_type: Optional[
trt.DeviceType
- ] = None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
- device_type: Optional[trt.DeviceType] = (
None #: Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
- )
gpu_id: int = -1 #: Device ID for target GPU
dla_core: int = -1 #: Core ID for target DLA core
- allow_gpu_fallback: bool = False #: Whether falling back to GPU if DLA cannot support an op should be allowed
- allow_gpu_fallback: bool = (
False #: Whether falling back to GPU if DLA cannot support an op should be allowed
- )
def __init__(self, *args: Any, **kwargs: Any):
"""__init__ Method for torch_tensorrt.Device
Device accepts one of a few construction patterns
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-16 00:01:27.175252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/_Input.py 2024-02-16 00:03:16.328118+00:00 @@ -26,16 +26,16 @@
class _ShapeMode(Enum):
STATIC = 0
DYNAMIC = 1
- shape_mode: Optional[
_ShapeMode
- ] = None #: Is input statically or dynamically shaped
- shape: Optional[
Tuple[int, ...] | Dict[str, Tuple[int, ...]]
- ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form
{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }
- shape_mode: Optional[_ShapeMode] = (
None #: Is input statically or dynamically shaped
- )
- shape: Optional[Tuple[int, ...] | Dict[str, Tuple[int, ...]]] = (
None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
- )
dtype: _enums.dtype = (
_enums.dtype.unknown
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
_explicit_set_dtype: bool = False
format: _enums.TensorFormat = (
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2024-02-16 00:01:27.175252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/_compiler.py 2024-02-16 00:03:16.375728+00:00 @@ -212,13 +212,13 @@ "precision": precision, "debug": debug, "device": device, "workspace_size": workspace_size, "min_block_size": min_block_size,
"torch_executed_ops": torch_executed_ops
if torch_executed_ops is not None
else set(),
"torch_executed_ops": (
torch_executed_ops if torch_executed_ops is not None else set()
), "pass_through_build_failures": pass_through_build_failures, "max_aux_streams": max_aux_streams, "version_compatible": version_compatible, "optimization_level": optimization_level, "use_python_runtime": use_python_runtime,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-16 00:01:27.175252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-02-16 00:03:16.569143+00:00 @@ -26,13 +26,13 @@
from packaging import version
_LOGGER: logging.Logger = logging.getLogger(name)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
- Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class UnsupportedOperatorException(RuntimeError): pass
@@ -90,13 +90,13 @@ self.input_specs_iter = 0 self._cur_node_name: Optional[str] = None self._cur_node: Optional[torch.fx.Node] = None self._input_names: List[str] = [] self._output_names: List[str] = []
self._itensor_to_tensor_meta: Dict[
trt.tensorrt.ITensor, TensorMetadata
] = dict()
self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
dict()
) self.compilation_settings = compilation_settings # Data types for TRT Module output Tensors self.output_dtypes = output_dtypes
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-16 00:01:27.179252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py 2024-02-16 00:03:16.647465+00:00 @@ -322,17 +322,15 @@ else: raise AssertionError(f"Cannot convert {input_val} to TRT constant")
@overload -def get_positive_dim(dim: int, dim_size: int) -> int:
- ...
+def get_positive_dim(dim: int, dim_size: int) -> int: ...
@overload -def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]:
- ...
+def get_positive_dim(dim: Sequence[int], dim_size: int) -> Tuple[int, ...]: ...
def get_positive_dim( dim: Union[int, Sequence[int]], dim_size: int ) -> Union[int, Tuple[int, ...]]: --- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-16 00:01:27.179252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py 2024-02-16 00:03:17.010073+00:00 @@ -5,13 +5,13 @@ from torch._decomp import get_decompositions as get_torch_decompositions from torch._ops import OpOverload, OpOverloadPacket
aten = torch.ops.aten
-_core_aten_decompositions: Dict[
- OpOverload, Callable[[Any], Any]
-] = core_aten_decompositions() +_core_aten_decompositions: Dict[OpOverload, Callable[[Any], Any]] = (
- core_aten_decompositions()
+) torch_enabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = { aten.adaptive_avg_pool2d_backward, aten.addcdiv, aten.addcdiv, aten.addcmul, @@ -179,13 +179,13 @@ torch_disabled_decompositions: Set[Union[OpOverload, OpOverloadPacket]] = { aten._softmax.default, }
-ENABLED_TORCH_DECOMPOSITIONS: Dict[
- OpOverload, Callable[[Any], Any]
-] = get_torch_decompositions(torch_enabled_decompositions) +ENABLED_TORCH_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = (
- get_torch_decompositions(torch_enabled_decompositions)
+) TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}
def check_decomp_set_invariants() -> None: """Validates no overlap between enabled and disabled decomposition sets""" --- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-16 00:01:27.179252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py 2024-02-16 00:03:17.018926+00:00 @@ -20,16 +20,14 @@ logger.debug(f"Graph after lowering linear:\n{gm.graph}")
return gm
-def linear_replacement() -> (
- Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-): +def linear_replacement() -> Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]: """Constructs the original and replacement functions for linear"""
# Original graph
def orig(
input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-16 00:01:27.179252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py 2024-02-16 00:03:17.052927+00:00 @@ -20,16 +20,14 @@ logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")
return gm
-def view_replacement() -> (
- Tuple[
torch.fx.GraphModule,
Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
- ]
-): +def view_replacement() -> Tuple[
- torch.fx.GraphModule,
- Callable[[torch.Tensor, List[torch.SymInt]], torch.Tensor],
+]: """Constructs the original and replacement functions for view"""
# Original graph
def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return torch.ops.aten.view.default(input, shape)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-16 00:01:27.179252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/passes/lower_scaled_dot_product_attention.py 2024-02-16 00:03:17.057189+00:00 @@ -58,16 +58,14 @@ logger.debug(f"Graph after lowering scaled dot product attention:\n{gm.graph}")
return gm
-def scaled_dot_product_attention_replacement() -> (
- Tuple[
Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
- ]
-): +def scaled_dot_product_attention_replacement() -> Tuple[
- Sequence[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]],
- Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor],
+]: """Constructs the original and replacement functions for efficient attention"""
# Efficient Attention original graph
def efficient(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
outputs = torch.ops.aten._scaled_dot_product_efficient_attention.default(
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2024-02-16 00:01:27.179252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py 2024-02-16 00:03:17.277499+00:00 @@ -99,25 +99,29 @@ self.engine.get_binding_dtype(idx), Frameworks.TORCH ) for idx in self.output_binding_indices_in_order ] self.output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
) for idx in self.output_binding_indices_in_order ] self.hidden_output_dtypes = [ unified_dtype_converter( self.engine.get_binding_dtype(idx), Frameworks.TORCH ) for idx in self.hidden_output_binding_indices_in_order ] self.hidden_output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
def _check_initialized(self) -> None: if not self.initialized:) for idx in self.hidden_output_binding_indices_in_order ]
@@ -165,13 +169,15 @@ self.dict.update(state) if self.engine: self.context = self.engine.create_execution_context()
def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]:
with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:Forward"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward")
if self.profiling_enabled
else nullcontext()
): self._check_initialized() # If in safe mode, check at each iteration for for whether a switch is required if ( torch_tensorrt.runtime.multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE
@@ -198,13 +204,17 @@ torch.cuda.set_device(device_id)
inputs = tuple([tensor.to(device) for tensor in inputs])
logger.warning(f"Moved all input Tensors to cuda:{device_id}")
with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessInputs"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessInputs"
)
if self.profiling_enabled
else nullcontext()
): assert len(inputs) == len( self.input_names ), f"Wrong number of inputs, expect {len(self.input_names)} get {len(inputs)}." contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
@@ -237,13 +247,17 @@
self.context.set_binding_shape(
idx, tuple(contiguous_inputs[i].shape)
)
with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
)
if self.profiling_enabled
else nullcontext()
): # create output tensors outputs: List[torch.Tensor] = [] for i, idx in enumerate(self.output_binding_indices_in_order): shape = tuple(self.context.get_binding_shape(idx))
@@ -264,13 +278,17 @@ dtype=self.hidden_output_dtypes[i], device=torch.cuda.current_device(), ) bindings[idx] = output.data_ptr()
with torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
) if self.profiling_enabled else nullcontext():
with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
)
if self.profiling_enabled
else nullcontext()
): self.context.execute_async_v2( bindings, torch.cuda.current_stream().cuda_stream ) if len(outputs) == 1:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-16 00:01:27.183252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/converters/aten_ops_converters.py 2024-02-16 00:03:17.622507+00:00 @@ -315,25 +315,21 @@ name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: kwargs_new = { "input": args[0], "kernel_size": args[1],
"stride": args[2]
if len(args) > 2
else (None, None)
if len(args[1]) == 2
else (None, None, None),
"padding": args[3]
if len(args) > 3
else (0, 0)
if len(args[1]) == 2
else (0, 0, 0),
"dilation": args[4]
if len(args) > 4
else (1, 1)
if len(args[1]) == 2
else (1, 1, 1),
"stride": (
args[2]
if len(args) > 2
else (None, None) if len(args[1]) == 2 else (None, None, None)
),
"padding": (
args[3] if len(args) > 3 else (0, 0) if len(args[1]) == 2 else (0, 0, 0)
),
"dilation": (
args[4] if len(args) > 4 else (1, 1) if len(args[1]) == 2 else (1, 1, 1)
} return acc_ops_converters.acc_ops_max_poolnd( network, target, None, kwargs_new, name ) --- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-16 00:01:27.183252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/fx2trt.py 2024-02-16 00:03:17.675873+00:00 @@ -19,13 +19,13 @@ from .observer import Observer from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks), "ceil_mode": args[5] if len(args) > 5 else False,
_LOGGER: logging.Logger = logging.getLogger(name)
-TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[
- Callable[[torch.fx.GraphModule], None]
-] = Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER") +TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
- Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
+)
class TRTInterpreterResult(NamedTuple): engine: Any input_names: Sequence[str] @@ -73,13 +73,13 @@ self.input_specs_iter = 0 self.validate_input_specs() self._cur_node_name: Optional[str] = None self._input_names: List[str] = [] self._output_names: List[str] = []
self._itensor_to_tensor_meta: Dict[
trt.tensorrt.ITensor, TensorMetadata
] = dict()
self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
dict()
def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: if not self.network.has_implicit_batch_dimension: assert ()
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py 2024-02-16 00:01:27.183252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/lower.py 2024-02-16 00:03:17.684340+00:00 @@ -124,25 +124,29 @@ interpreter = TRTInterpreter( mod, input_specs=self.lower_setting.input_specs, explicit_batch_dimension=self.lower_setting.explicit_batch_dimension, explicit_precision=self.lower_setting.explicit_precision,
logger_level=trt.Logger.VERBOSE
if self.lower_setting.verbose_log
else trt.Logger.WARNING,
logger_level=(
trt.Logger.VERBOSE
if self.lower_setting.verbose_log
else trt.Logger.WARNING
), ) interp_result: TRTInterpreterResult = interpreter.run( max_batch_size=self.lower_setting.max_batch_size, max_workspace_size=self.lower_setting.max_workspace_size, lower_precision=self.lower_setting.lower_precision, strict_type_constraints=self.lower_setting.strict_type_constraints, algorithm_selector=algo_selector, timing_cache=cache_data,
profiling_verbosity=trt.ProfilingVerbosity.DETAILED
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
profiling_verbosity=(
trt.ProfilingVerbosity.DETAILED
if self.lower_setting.verbose_profile
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
), tactic_sources=self.lower_setting.tactic_sources, ) # Update timing cache file if needed timing_cache = interp_result.serialized_cache
@@ -295,14 +299,12 @@ module.half() # A custom conversion function can be passed to the lowerer to # handle inputs with custom types. By default, just handle # tensors and NoneType. if fp16_conversion_fn is None:
conversion_fn = (
lambda x: x.half()
if x is not None and x.dtype == torch.float32
else x
conversion_fn = lambda x: (
x.half() if x is not None and x.dtype == torch.float32 else x ) else: conversion_fn = fp16_conversion_fn inputs = tuple(conversion_fn(x) for x in inputs)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2024-02-16 00:01:27.183252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py 2024-02-16 00:03:17.896029+00:00 @@ -194,13 +194,15 @@ lowering_start_time = datetime.datetime.now()
self.lower_setting.input_specs = generate_input_specs(
submod_inputs,
self.lower_setting,
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None,
(
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None
), ) lowered_module = self._lower_func( submod, submod_inputs, self.lower_setting, submod_name ) setattr(split_result.split_module, submod_name, lowered_module)
@@ -234,13 +236,15 @@ if not submod_name.startswith(split_result.non_acc_submodule_prefix): _LOGGER.info(f"ACC submodule graph: {submod.graph}") lowering_start_time = datetime.datetime.now()
self.lower_setting.additional_inputs = (
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None,
(
additional_submodule_inputs[submod_name]
if additional_submodule_inputs
else None
), ) lowered_module = self._lower_func( submod, submod_inputs, self.lower_setting, submod_name )
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-16 00:01:27.183252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/pass_utils.py 2024-02-16 00:03:18.123580+00:00 @@ -193,13 +193,11 @@ kwargs2 = {"equal_nan": True} if rtol: kwargs2["rtol"] = rtol if atol: kwargs2["atol"] = atol
kwargs2[
"msg"
] = (
kwargs2["msg"] = ( lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" ) # If tensors are on different devices, make sure to compare # their copies that are on the same device. if x.get_device() != y.get_device():
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-16 00:01:27.183252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/passes/lower_basic_pass.py 2024-02-16 00:03:18.166433+00:00 @@ -536,13 +536,13 @@ reshape_batch_size: Optional[fx.Node] = get_reshape_batch_size_as_node( maybe_reshape ) if not reshape_batch_size: continue
reshape_batch_size_inferred_source: Optional[
fx.Node
] = get_reshape_batch_size_inferred_source(reshape_batch_size)
reshape_batch_size_inferred_source: Optional[fx.Node] = (
get_reshape_batch_size_inferred_source(reshape_batch_size)
) if not reshape_batch_size_inferred_source: continue reshape_input: fx.Node = maybe_reshape.kwargs["input"] if reshape_input == reshape_batch_size_inferred_source:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py 2024-02-16 00:01:27.187252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/test/converters/acc_op/test_split.py 2024-02-16 00:03:18.592970+00:00 @@ -21,13 +21,15 @@ inputs = [torch.randn(1, 10)] self.run_test( Split(), inputs, expected_ops={
acc_ops.split
if isinstance(split_size_or_sections, int)
else acc_ops.slice_tensor
(
acc_ops.split
if isinstance(split_size_or_sections, int)
else acc_ops.slice_tensor
@parameterized.expand() }, test_explicit_batch_dim=False, )
@@ -68,13 +70,15 @@ ] self.run_test_with_dynamic_shape( Split(), input_specs, expected_ops={
acc_ops.split
if isinstance(split_size_or_sections, int)
else acc_ops.slice_tensor
(
acc_ops.split
if isinstance(split_size_or_sections, int)
else acc_ops.slice_tensor
) }, )
Testing with (-1, -1, -1) results into following error:
AssertionError: Can't chunk on dynamic shape dimension!
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py 2024-02-16 00:01:27.191252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/tools/common_fx2trt.py 2024-02-16 00:03:19.259154+00:00 @@ -152,13 +152,13 @@ mod.eval() if len(expected_ops): self.assert_has_op(mod, expected_ops)
interpreter_result = interpreter.run(
lower_precision=LowerPrecision.FP16
if fp16_mode
else LowerPrecision.FP32
lower_precision=(
LowerPrecision.FP16 if fp16_mode else LowerPrecision.FP32
) ) trt_mod = TRTModule( interpreter_result.engine, interpreter_result.input_names, interpreter_result.output_names,
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py 2024-02-16 00:01:27.191252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/fx/trt_module.py 2024-02-16 00:03:19.609670+00:00 @@ -67,25 +67,29 @@ self.engine.get_binding_dtype(idx), Frameworks.TORCH ) for idx in self.output_binding_indices_in_order ] self.output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
) for idx in self.output_binding_indices_in_order ] self.hidden_output_dtypes: Sequence[torch.dtype] = [ unified_dtype_converter( self.engine.get_binding_dtype(idx), Frameworks.TORCH ) for idx in self.hidden_output_binding_indices_in_order ] self.hidden_output_shapes = [
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
(
tuple(self.engine.get_binding_shape(idx))
if self.engine.has_implicit_batch_dimension
else tuple()
def _check_initialized(self): if not self.initialized:) for idx in self.hidden_output_binding_indices_in_order ]
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py 2024-02-16 00:01:27.191252+00:00 +++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/ts/_compile_spec.py 2024-02-16 00:03:19.911816+00:00 @@ -404,13 +404,13 @@ "inputs": inputs if inputs is not None else [], # "input_signature": input_signature, "device": device, "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas "sparse_weights": sparse_weights, # Enable sparsity for convolution and fully connected layers.
"enabled_precisions": enabled_precisions
if enabled_precisions is not None
else set(), # Enabling FP16 kernels
"enabled_precisions": (
enabled_precisions if enabled_precisions is not None else set()
), # Enabling FP16 kernels "refit": refit, # enable refit "debug": debug, # enable debuggable engine "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels "num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels "workspace_size": workspace_size, # Maximum size of workspace given to TensorRT