Segmentation Fault During TensorRT Engine Build with ScatterElements (FP16) on Thor (original) (raw)

DRIVE OS Version: Provide DRIVE OS version. Example: 7.0.3

Issue Description:

I am encountering a segmentation fault when converting a very simple ONNX model to a TensorRT engine using trtexec on NVIDIA DRIVE Thor.

The model consists of a single ScatterElements node. The model takes indices as an input tensor and uses constant initializers for data and updates. The output shape is fixed to [1, 4].

The segmentation fault only occurs when indices is provided as a network input and when the engine is built with --fp16 enabled.

I generated the ONNX model using the following Python script:

import numpy as np
import onnx
from onnx import TensorProto, helper, numpy_helper


def build_model() -> onnx.ModelProto:
    scatter_node = helper.make_node(
        "ScatterElements",
        inputs=["data", "indices", "updates"],
        outputs=["scatter_output"],
        name="ScatterElementsOnly",
        axis=-1,
        reduction="add",
    )

    inputs = [
        helper.make_tensor_value_info("indices", TensorProto.INT64, [1, 4])
    ]

    outputs = [
        helper.make_tensor_value_info("scatter_output", TensorProto.FLOAT, [1, 4])
    ]

    initializers = [
        numpy_helper.from_array(
            np.array([[0, 1, 2, 3]], dtype=np.float32),
            "updates",
        ),
        numpy_helper.from_array(
            np.array([[0.5, 0.5, 0.5, 0.5]], dtype=np.float32),
            "data",
        ),
    ]

    graph = helper.make_graph(
        name="graph",
        nodes=[scatter_node],
        inputs=inputs,
        outputs=outputs,
        initializer=initializers,
    )

    opset_imports = [helper.make_opsetid("", 17)]
    model = helper.make_model(graph, opset_imports=opset_imports, producer_name="scatter_debug")
    model.ir_version = 8
    return model


if __name__ == "__main__":
    model = build_model()
    onnx.save(model, "test.onnx")

I attempted to build the TensorRT engine using the following command:

$ ./trtexec --onnx=./test.onnx --fp16 --saveEngine=out.engine --verbose
&&&& RUNNING TensorRT.trtexec [TensorRT v101010] [b3] # ./trtexec --onnx=./test.onnx --fp16 --saveEngine=out.engine --verbose
[12/16/2025-07:41:23] [I] === Model Options ===
[12/16/2025-07:41:23] [I] Format: ONNX
[12/16/2025-07:41:23] [I] Model: ./test.onnx
[12/16/2025-07:41:23] [I] Output:
[12/16/2025-07:41:23] [I] === Build Options ===
[12/16/2025-07:41:23] [I] Memory Pools: workspace: default, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default, tacticSharedMem: default
[12/16/2025-07:41:23] [I] avgTiming: 8
[12/16/2025-07:41:23] [I] Precision: FP32+FP16
[12/16/2025-07:41:23] [I] LayerPrecisions:
[12/16/2025-07:41:23] [I] Layer Device Types:
[12/16/2025-07:41:23] [I] Calibration:
[12/16/2025-07:41:23] [I] Refit: Disabled
[12/16/2025-07:41:23] [I] Strip weights: Disabled
[12/16/2025-07:41:23] [I] Version Compatible: Disabled
[12/16/2025-07:41:23] [I] ONNX Plugin InstanceNorm: Disabled
[12/16/2025-07:41:23] [I] TensorRT runtime: full
[12/16/2025-07:41:23] [I] Lean DLL Path:
[12/16/2025-07:41:23] [I] Tempfile Controls: { in_memory: allow, temporary: allow }
[12/16/2025-07:41:23] [I] Exclude Lean Runtime: Disabled
[12/16/2025-07:41:23] [I] Sparsity: Disabled
[12/16/2025-07:41:23] [I] Safe mode: Disabled
[12/16/2025-07:41:23] [I] Build DLA standalone loadable: Disabled
[12/16/2025-07:41:23] [I] Allow GPU fallback for DLA: Disabled
[12/16/2025-07:41:23] [I] DirectIO mode: Disabled
[12/16/2025-07:41:23] [I] Restricted mode: Disabled
[12/16/2025-07:41:23] [I] Skip inference: Disabled
[12/16/2025-07:41:23] [I] Save engine: out.engine
[12/16/2025-07:41:23] [I] Load engine:
[12/16/2025-07:41:23] [I] Profiling verbosity: 0
[12/16/2025-07:41:23] [I] Tactic sources: Using default tactic sources
[12/16/2025-07:41:23] [I] timingCacheMode: local
[12/16/2025-07:41:23] [I] timingCacheFile:
[12/16/2025-07:41:23] [I] Enable Compilation Cache: Enabled
[12/16/2025-07:41:23] [I] Enable Monitor Memory: Disabled
[12/16/2025-07:41:23] [I] errorOnTimingCacheMiss: Disabled
[12/16/2025-07:41:23] [I] Preview Features: Use default preview flags.
[12/16/2025-07:41:23] [I] MaxAuxStreams: -1
[12/16/2025-07:41:23] [I] BuilderOptimizationLevel: -1
[12/16/2025-07:41:23] [I] MaxTactics: -1
[12/16/2025-07:41:23] [I] Calibration Profile Index: 0
[12/16/2025-07:41:23] [I] Weight Streaming: Disabled
[12/16/2025-07:41:23] [I] Runtime Platform: Same As Build
[12/16/2025-07:41:23] [I] Debug Tensors:
[12/16/2025-07:41:23] [I] Input(s)s format: fp32:CHW
[12/16/2025-07:41:23] [I] Output(s)s format: fp32:CHW
[12/16/2025-07:41:23] [I] Input build shapes: model
[12/16/2025-07:41:23] [I] Input calibration shapes: model
[12/16/2025-07:41:23] [I] === System Options ===
[12/16/2025-07:41:23] [I] Device: 0
[12/16/2025-07:41:23] [I] DLACore:
[12/16/2025-07:41:23] [I] Plugins:
[12/16/2025-07:41:23] [I] setPluginsToSerialize:
[12/16/2025-07:41:23] [I] dynamicPlugins:
[12/16/2025-07:41:23] [I] ignoreParsedPluginLibs: 0
[12/16/2025-07:41:23] [I]
[12/16/2025-07:41:23] [I] === Inference Options ===
[12/16/2025-07:41:23] [I] Batch: Explicit
[12/16/2025-07:41:23] [I] Input inference shapes: model
[12/16/2025-07:41:23] [I] Iterations: 10
[12/16/2025-07:41:23] [I] Duration: 3s (+ 200ms warm up)
[12/16/2025-07:41:23] [I] Sleep time: 0ms
[12/16/2025-07:41:23] [I] Idle time: 0ms
[12/16/2025-07:41:23] [I] Inference Streams: 1
[12/16/2025-07:41:23] [I] ExposeDMA: Disabled
[12/16/2025-07:41:23] [I] Data transfers: Enabled
[12/16/2025-07:41:23] [I] Spin-wait: Disabled
[12/16/2025-07:41:23] [I] Multithreading: Disabled
[12/16/2025-07:41:23] [I] CUDA Graph: Disabled
[12/16/2025-07:41:23] [I] Separate profiling: Disabled
[12/16/2025-07:41:23] [I] Time Deserialize: Disabled
[12/16/2025-07:41:23] [I] Time Refit: Disabled
[12/16/2025-07:41:23] [I] NVTX verbosity: 0
[12/16/2025-07:41:23] [I] Persistent Cache Ratio: 0
[12/16/2025-07:41:23] [I] Optimization Profile Index: 0
[12/16/2025-07:41:23] [I] Weight Streaming Budget: 100.000000%
[12/16/2025-07:41:23] [I] Inputs:
[12/16/2025-07:41:23] [I] Debug Tensor Save Destinations:
[12/16/2025-07:41:23] [I] === Reporting Options ===
[12/16/2025-07:41:23] [I] Verbose: Enabled
[12/16/2025-07:41:23] [I] Averages: 10 inferences
[12/16/2025-07:41:23] [I] Percentiles: 90,95,99
[12/16/2025-07:41:23] [I] Dump refittable layers:Disabled
[12/16/2025-07:41:23] [I] Dump output: Disabled
[12/16/2025-07:41:23] [I] Profile: Disabled
[12/16/2025-07:41:23] [I] Export timing to JSON file:
[12/16/2025-07:41:23] [I] Export output to JSON file:
[12/16/2025-07:41:23] [I] Export profile to JSON file:
[12/16/2025-07:41:23] [I]
[12/16/2025-07:41:23] [I] === Device Information ===
[12/16/2025-07:41:23] [I] Available Devices:
[12/16/2025-07:41:23] [I]   Device 0: "Thor" UUID: GPU-d8871e85-29c4-5705-bb76-cd3b6774cd82
[12/16/2025-07:41:23] [I] Selected Device: Thor
[12/16/2025-07:41:23] [I] Selected Device ID: 0
[12/16/2025-07:41:23] [I] Selected Device UUID: GPU-d8871e85-29c4-5705-bb76-cd3b6774cd82
[12/16/2025-07:41:23] [I] Compute Capability: 10.1
[12/16/2025-07:41:23] [I] SMs: 20
[12/16/2025-07:41:23] [I] Device Global Memory: 15020 MiB
[12/16/2025-07:41:23] [I] Shared Memory per SM: 228 KiB
[12/16/2025-07:41:23] [I] Memory Bus Width: 256 bits (ECC disabled)
[12/16/2025-07:41:23] [I] Application Compute Clock Rate: 1.575 GHz
[12/16/2025-07:41:23] [I] Application Memory Clock Rate: 1.575 GHz
[12/16/2025-07:41:23] [I]
[12/16/2025-07:41:23] [I] Note: The application clock rates do not reflect the actual clock rates that the GPU is currently running at.
[12/16/2025-07:41:23] [I]
[12/16/2025-07:41:23] [I] TensorRT version: 10.10.10
[12/16/2025-07:41:23] [I] Loading standard plugins
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::ROIAlign_TRT version 2
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::BatchedNMSDynamic_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::BatchedNMS_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::BatchTilePlugin_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::Clip_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::CoordConvAC version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::CropAndResizeDynamic version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::CropAndResize version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::DecodeBbox3DPlugin version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::DetectionLayer_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::EfficientNMS_Explicit_TF_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::EfficientNMS_Implicit_TF_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::EfficientNMS_ONNX_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::EfficientNMS_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::FlattenConcat_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::GenerateDetection_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::GridAnchor_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::GridAnchorRect_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::InstanceNormalization_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::InstanceNormalization_TRT version 2
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::InstanceNormalization_TRT version 3
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::LReLU_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::ModulatedDeformConv2d version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::MultilevelCropAndResize_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::MultilevelProposeROI_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::MultiscaleDeformableAttnPlugin_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::NMSDynamic_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::NMS_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::Normalize_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::PillarScatterPlugin version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::PriorBox_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::ProposalDynamic version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::ProposalLayer_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::Proposal version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::PyramidROIAlign_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::Region_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::Reorg_TRT version 2
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::Reorg_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::ResizeNearest_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::ROIAlign_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::RPROI_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::ScatterElements version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::ScatterElements version 2
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::ScatterND version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::SpecialSlice_TRT version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::Split version 1
[12/16/2025-07:41:23] [V] [TRT] Registered plugin creator - ::VoxelGeneratorPlugin version 1
[12/16/2025-07:41:23] [I] [TRT] [MemUsageChange] Init CUDA: CPU +1, GPU +0, now: CPU 25, GPU 98 (MiB)
[12/16/2025-07:41:23] [V] [TRT] Trying to load shared library libnvinfer_builder_resource.so.10.10.10
[12/16/2025-07:41:23] [V] [TRT] Loaded shared library libnvinfer_builder_resource.so.10.10.10
[12/16/2025-07:41:23] [I] [TRT] [MemUsageChange] Init builder kernel library: CPU +480, GPU +2, now: CPU 572, GPU 100 (MiB)
[12/16/2025-07:41:23] [V] [TRT] CUDA lazy loading is enabled.
[12/16/2025-07:41:23] [I] Start parsing network model.
[12/16/2025-07:41:23] [I] [TRT] ----------------------------------------------------------------
[12/16/2025-07:41:23] [I] [TRT] Input filename:   ./test.onnx
[12/16/2025-07:41:23] [I] [TRT] ONNX IR version:  0.0.8
[12/16/2025-07:41:23] [I] [TRT] Opset version:    17
[12/16/2025-07:41:23] [I] [TRT] Producer name:    scatter_debug
[12/16/2025-07:41:23] [I] [TRT] Producer version:
[12/16/2025-07:41:23] [I] [TRT] Domain:
[12/16/2025-07:41:23] [I] [TRT] Model version:    0
[12/16/2025-07:41:23] [I] [TRT] Doc string:
[12/16/2025-07:41:23] [I] [TRT] ----------------------------------------------------------------
[12/16/2025-07:41:23] [V] [TRT] Adding network input: indices with dtype: int64, dimensions: (1, 4)
[12/16/2025-07:41:23] [W] [TRT] ModelImporter.cpp:503: Make sure input indices has Int64 binding.
[12/16/2025-07:41:23] [V] [TRT] Registering tensor: indices for ONNX tensor: indices
[12/16/2025-07:41:23] [V] [TRT] Importing initializer: updates
[12/16/2025-07:41:23] [V] [TRT] Importing initializer: data
[12/16/2025-07:41:23] [V] [TRT] Static check for parsing node: ScatterElementsOnly [ScatterElements]
[12/16/2025-07:41:23] [V] [TRT] Parsing node: ScatterElementsOnly [ScatterElements]
[12/16/2025-07:41:23] [V] [TRT] Searching for input: data
[12/16/2025-07:41:23] [V] [TRT] Searching for input: indices
[12/16/2025-07:41:23] [V] [TRT] Searching for input: updates
[12/16/2025-07:41:23] [V] [TRT] ScatterElementsOnly [ScatterElements] inputs: [data -> (1, 4)[HALF]], [indices -> (1, 4)[INT64]], [updates -> (1, 4)[HALF]],
[12/16/2025-07:41:23] [V] [TRT] Registering layer: data required by ONNX-TRT
[12/16/2025-07:41:23] [V] [TRT] Registering layer: updates required by ONNX-TRT
[12/16/2025-07:41:23] [V] [TRT] Registering layer: ScatterElementsOnly for ONNX node: ScatterElementsOnly
[12/16/2025-07:41:23] [V] [TRT] Registering tensor: scatter_output_0 for ONNX tensor: scatter_output
[12/16/2025-07:41:23] [V] [TRT] ScatterElementsOnly [ScatterElements] outputs: [scatter_output -> (1, 4)[HALF]],
[12/16/2025-07:41:23] [V] [TRT] Marking scatter_output_0 as output: scatter_output
[12/16/2025-07:41:23] [I] Finished parsing network model. Parse time: 0.00104468
[12/16/2025-07:41:23] [V] [TRT] could not open /sys/fs/cgroup/memory/memory.limit_in_bytes or /sys/fs/cgroup/memory.max
[12/16/2025-07:41:23] [V] [TRT] Original: 3 layers
[12/16/2025-07:41:23] [V] [TRT] After dead-layer removal: 3 layers
[12/16/2025-07:41:23] [V] [TRT] SYMBOLIC CHECKS
[12/16/2025-07:41:23] [V] [TRT] GRAPH NODES
[12/16/2025-07:41:23] [V] [TRT] CONSTANT data
[12/16/2025-07:41:23] [V] [TRT]     Output 0
[12/16/2025-07:41:23] [V] [TRT]         1       1
[12/16/2025-07:41:23] [V] [TRT]         4       4
[12/16/2025-07:41:23] [V] [TRT] CONSTANT updates
[12/16/2025-07:41:23] [V] [TRT]     Output 0
[12/16/2025-07:41:23] [V] [TRT]         1       1
[12/16/2025-07:41:23] [V] [TRT]         4       4
[12/16/2025-07:41:23] [V] [TRT] PLUGIN_V3 ScatterElementsOnly
[12/16/2025-07:41:23] [V] [TRT]     Input 0
[12/16/2025-07:41:23] [V] [TRT]         1       1
[12/16/2025-07:41:23] [V] [TRT]         4       4
[12/16/2025-07:41:23] [V] [TRT]     Input 1
[12/16/2025-07:41:23] [V] [TRT]         1       1
[12/16/2025-07:41:23] [V] [TRT]         4       4
[12/16/2025-07:41:23] [V] [TRT]     Input 2
[12/16/2025-07:41:23] [V] [TRT]         1       1
[12/16/2025-07:41:23] [V] [TRT]         4       4
[12/16/2025-07:41:23] [V] [TRT]     Output 0
[12/16/2025-07:41:23] [V] [TRT]         1       1
[12/16/2025-07:41:23] [V] [TRT]         4       4
[12/16/2025-07:41:25] [V] [TRT] Graph construction completed in 1.67384 seconds.
[12/16/2025-07:41:25] [V] [TRT] After adding DebugOutput nodes: 3 layers
[12/16/2025-07:41:25] [V] [TRT] After Myelin optimization: 1 layers
[12/16/2025-07:41:25] [V] [TRT] Applying ScaleNodes fusions.
[12/16/2025-07:41:25] [V] [TRT] After scale fusion: 1 layers
[12/16/2025-07:41:25] [V] [TRT] After dupe layer removal: 1 layers
[12/16/2025-07:41:25] [V] [TRT] After final dead-layer removal: 1 layers
[12/16/2025-07:41:25] [V] [TRT] After tensor merging: 1 layers
[12/16/2025-07:41:25] [V] [TRT] After vertical fusions: 1 layers
[12/16/2025-07:41:25] [V] [TRT] After dupe layer removal: 1 layers
[12/16/2025-07:41:25] [V] [TRT] After final dead-layer removal: 1 layers
[12/16/2025-07:41:25] [V] [TRT] After tensor merging: 1 layers
[12/16/2025-07:41:25] [V] [TRT] After slice removal: 1 layers
[12/16/2025-07:41:25] [V] [TRT] After concat removal: 1 layers
[12/16/2025-07:41:25] [V] [TRT] Trying to split Reshape and strided tensor
[12/16/2025-07:41:25] [V] [TRT] Graph optimization time: 0.000288231 seconds.
[12/16/2025-07:41:25] [V] [TRT] Building graph using backend strategy 2
[12/16/2025-07:41:25] [I] [TRT] Local timing cache in use. Profiling results in this builder pass will not be stored.
[12/16/2025-07:41:25] [V] [TRT] Constructing optimization profile number 0 [1/1].
[12/16/2025-07:41:25] [V] [TRT] Applying generic optimizations to the graph for inference.
[12/16/2025-07:41:25] [V] [TRT] Reserving memory for host IO tensors. Host: 0 bytes
[12/16/2025-07:41:25] [V] [TRT] =============== Computing costs for {ForeignNode[data...ScatterElementsOnly]}
[12/16/2025-07:41:25] [V] [TRT] ForeignNode {ForeignNode[data...ScatterElementsOnly]} metadata: [ONNX Layer: ScatterElementsOnly]
[12/16/2025-07:41:25] [V] [TRT] *************** Autotuning format combination: Int64(4,1) -> Half(4,1) ***************
[12/16/2025-07:41:25] [V] [TRT] --------------- Timing Runner: {ForeignNode[data...ScatterElementsOnly]} (Myelin[0x80000023])
[12/16/2025-07:41:25] [V] [TRT] Serialized custom layer data size: 1226
[12/16/2025-07:41:25] [I] [TRT] Compiler backend is used during engine build.
Segmentation fault (core dumped)
t2@tegra-ubuntu:~/second$ ./trtexec --onnx=./test.onnx --fp16 --saveEngine=out.engine --verbose
&&&& RUNNING TensorRT.trtexec [TensorRT v101010] [b3] # ./trtexec --onnx=./test.onnx --fp16 --saveEngine=out.engine --verbose
[12/16/2025-07:53:43] [I] === Model Options ===
[12/16/2025-07:53:43] [I] Format: ONNX
[12/16/2025-07:53:43] [I] Model: ./test.onnx
[12/16/2025-07:53:43] [I] Output:
[12/16/2025-07:53:43] [I] === Build Options ===
[12/16/2025-07:53:43] [I] Memory Pools: workspace: default, dlaSRAM: default, dlaLocalDRAM: default, dlaGlobalDRAM: default, tacticSharedMem: default
[12/16/2025-07:53:43] [I] avgTiming: 8
[12/16/2025-07:53:43] [I] Precision: FP32+FP16
[12/16/2025-07:53:43] [I] LayerPrecisions:
[12/16/2025-07:53:43] [I] Layer Device Types:
[12/16/2025-07:53:43] [I] Calibration:
[12/16/2025-07:53:43] [I] Refit: Disabled
[12/16/2025-07:53:43] [I] Strip weights: Disabled
[12/16/2025-07:53:43] [I] Version Compatible: Disabled
[12/16/2025-07:53:43] [I] ONNX Plugin InstanceNorm: Disabled
[12/16/2025-07:53:43] [I] TensorRT runtime: full
[12/16/2025-07:53:43] [I] Lean DLL Path:
[12/16/2025-07:53:43] [I] Tempfile Controls: { in_memory: allow, temporary: allow }
[12/16/2025-07:53:43] [I] Exclude Lean Runtime: Disabled
[12/16/2025-07:53:43] [I] Sparsity: Disabled
[12/16/2025-07:53:43] [I] Safe mode: Disabled
[12/16/2025-07:53:43] [I] Build DLA standalone loadable: Disabled
[12/16/2025-07:53:43] [I] Allow GPU fallback for DLA: Disabled
[12/16/2025-07:53:43] [I] DirectIO mode: Disabled
[12/16/2025-07:53:43] [I] Restricted mode: Disabled
[12/16/2025-07:53:43] [I] Skip inference: Disabled
[12/16/2025-07:53:43] [I] Save engine: out.engine
[12/16/2025-07:53:43] [I] Load engine:
[12/16/2025-07:53:43] [I] Profiling verbosity: 0
[12/16/2025-07:53:43] [I] Tactic sources: Using default tactic sources
[12/16/2025-07:53:43] [I] timingCacheMode: local
[12/16/2025-07:53:43] [I] timingCacheFile:
[12/16/2025-07:53:43] [I] Enable Compilation Cache: Enabled
[12/16/2025-07:53:43] [I] Enable Monitor Memory: Disabled
[12/16/2025-07:53:43] [I] errorOnTimingCacheMiss: Disabled
[12/16/2025-07:53:43] [I] Preview Features: Use default preview flags.
[12/16/2025-07:53:43] [I] MaxAuxStreams: -1
[12/16/2025-07:53:43] [I] BuilderOptimizationLevel: -1
[12/16/2025-07:53:43] [I] MaxTactics: -1
[12/16/2025-07:53:43] [I] Calibration Profile Index: 0
[12/16/2025-07:53:43] [I] Weight Streaming: Disabled
[12/16/2025-07:53:43] [I] Runtime Platform: Same As Build
[12/16/2025-07:53:43] [I] Debug Tensors:
[12/16/2025-07:53:43] [I] Input(s)s format: fp32:CHW
[12/16/2025-07:53:43] [I] Output(s)s format: fp32:CHW
[12/16/2025-07:53:43] [I] Input build shapes: model
[12/16/2025-07:53:43] [I] Input calibration shapes: model
[12/16/2025-07:53:43] [I] === System Options ===
[12/16/2025-07:53:43] [I] Device: 0
[12/16/2025-07:53:43] [I] DLACore:
[12/16/2025-07:53:43] [I] Plugins:
[12/16/2025-07:53:43] [I] setPluginsToSerialize:
[12/16/2025-07:53:43] [I] dynamicPlugins:
[12/16/2025-07:53:43] [I] ignoreParsedPluginLibs: 0
[12/16/2025-07:53:43] [I]
[12/16/2025-07:53:43] [I] === Inference Options ===
[12/16/2025-07:53:43] [I] Batch: Explicit
[12/16/2025-07:53:43] [I] Input inference shapes: model
[12/16/2025-07:53:43] [I] Iterations: 10
[12/16/2025-07:53:43] [I] Duration: 3s (+ 200ms warm up)
[12/16/2025-07:53:43] [I] Sleep time: 0ms
[12/16/2025-07:53:43] [I] Idle time: 0ms
[12/16/2025-07:53:43] [I] Inference Streams: 1
[12/16/2025-07:53:43] [I] ExposeDMA: Disabled
[12/16/2025-07:53:43] [I] Data transfers: Enabled
[12/16/2025-07:53:43] [I] Spin-wait: Disabled
[12/16/2025-07:53:43] [I] Multithreading: Disabled
[12/16/2025-07:53:43] [I] CUDA Graph: Disabled
[12/16/2025-07:53:43] [I] Separate profiling: Disabled
[12/16/2025-07:53:43] [I] Time Deserialize: Disabled
[12/16/2025-07:53:43] [I] Time Refit: Disabled
[12/16/2025-07:53:43] [I] NVTX verbosity: 0
[12/16/2025-07:53:43] [I] Persistent Cache Ratio: 0
[12/16/2025-07:53:43] [I] Optimization Profile Index: 0
[12/16/2025-07:53:43] [I] Weight Streaming Budget: 100.000000%
[12/16/2025-07:53:43] [I] Inputs:
[12/16/2025-07:53:43] [I] Debug Tensor Save Destinations:
[12/16/2025-07:53:43] [I] === Reporting Options ===
[12/16/2025-07:53:43] [I] Verbose: Enabled
[12/16/2025-07:53:43] [I] Averages: 10 inferences
[12/16/2025-07:53:43] [I] Percentiles: 90,95,99
[12/16/2025-07:53:43] [I] Dump refittable layers:Disabled
[12/16/2025-07:53:43] [I] Dump output: Disabled
[12/16/2025-07:53:43] [I] Profile: Disabled
[12/16/2025-07:53:43] [I] Export timing to JSON file:
[12/16/2025-07:53:43] [I] Export output to JSON file:
[12/16/2025-07:53:43] [I] Export profile to JSON file:
[12/16/2025-07:53:43] [I]
[12/16/2025-07:53:43] [I] === Device Information ===
[12/16/2025-07:53:43] [I] Available Devices:
[12/16/2025-07:53:43] [I]   Device 0: "Thor" UUID: GPU-d8871e85-29c4-5705-bb76-cd3b6774cd82
[12/16/2025-07:53:43] [I] Selected Device: Thor
[12/16/2025-07:53:43] [I] Selected Device ID: 0
[12/16/2025-07:53:43] [I] Selected Device UUID: GPU-d8871e85-29c4-5705-bb76-cd3b6774cd82
[12/16/2025-07:53:43] [I] Compute Capability: 10.1
[12/16/2025-07:53:43] [I] SMs: 20
[12/16/2025-07:53:43] [I] Device Global Memory: 15020 MiB
[12/16/2025-07:53:43] [I] Shared Memory per SM: 228 KiB
[12/16/2025-07:53:43] [I] Memory Bus Width: 256 bits (ECC disabled)
[12/16/2025-07:53:43] [I] Application Compute Clock Rate: 1.575 GHz
[12/16/2025-07:53:43] [I] Application Memory Clock Rate: 1.575 GHz
[12/16/2025-07:53:43] [I]
[12/16/2025-07:53:43] [I] Note: The application clock rates do not reflect the actual clock rates that the GPU is currently running at.
[12/16/2025-07:53:43] [I]
[12/16/2025-07:53:43] [I] TensorRT version: 10.10.10
[12/16/2025-07:53:43] [I] Loading standard plugins
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::ROIAlign_TRT version 2
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::BatchedNMSDynamic_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::BatchedNMS_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::BatchTilePlugin_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::Clip_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::CoordConvAC version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::CropAndResizeDynamic version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::CropAndResize version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::DecodeBbox3DPlugin version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::DetectionLayer_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::EfficientNMS_Explicit_TF_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::EfficientNMS_Implicit_TF_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::EfficientNMS_ONNX_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::EfficientNMS_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::FlattenConcat_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::GenerateDetection_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::GridAnchor_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::GridAnchorRect_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::InstanceNormalization_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::InstanceNormalization_TRT version 2
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::InstanceNormalization_TRT version 3
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::LReLU_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::ModulatedDeformConv2d version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::MultilevelCropAndResize_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::MultilevelProposeROI_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::MultiscaleDeformableAttnPlugin_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::NMSDynamic_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::NMS_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::Normalize_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::PillarScatterPlugin version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::PriorBox_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::ProposalDynamic version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::ProposalLayer_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::Proposal version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::PyramidROIAlign_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::Region_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::Reorg_TRT version 2
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::Reorg_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::ResizeNearest_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::ROIAlign_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::RPROI_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::ScatterElements version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::ScatterElements version 2
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::ScatterND version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::SpecialSlice_TRT version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::Split version 1
[12/16/2025-07:53:43] [V] [TRT] Registered plugin creator - ::VoxelGeneratorPlugin version 1
[12/16/2025-07:53:43] [I] [TRT] [MemUsageChange] Init CUDA: CPU +2, GPU +0, now: CPU 25, GPU 98 (MiB)
[12/16/2025-07:53:43] [V] [TRT] Trying to load shared library libnvinfer_builder_resource.so.10.10.10
[12/16/2025-07:53:43] [V] [TRT] Loaded shared library libnvinfer_builder_resource.so.10.10.10
[12/16/2025-07:53:44] [I] [TRT] [MemUsageChange] Init builder kernel library: CPU +480, GPU +2, now: CPU 572, GPU 100 (MiB)
[12/16/2025-07:53:44] [V] [TRT] CUDA lazy loading is enabled.
[12/16/2025-07:53:44] [I] Start parsing network model.
[12/16/2025-07:53:44] [I] [TRT] ----------------------------------------------------------------
[12/16/2025-07:53:44] [I] [TRT] Input filename:   ./test.onnx
[12/16/2025-07:53:44] [I] [TRT] ONNX IR version:  0.0.8
[12/16/2025-07:53:44] [I] [TRT] Opset version:    17
[12/16/2025-07:53:44] [I] [TRT] Producer name:    scatter_debug
[12/16/2025-07:53:44] [I] [TRT] Producer version:
[12/16/2025-07:53:44] [I] [TRT] Domain:
[12/16/2025-07:53:44] [I] [TRT] Model version:    0
[12/16/2025-07:53:44] [I] [TRT] Doc string:
[12/16/2025-07:53:44] [I] [TRT] ----------------------------------------------------------------
[12/16/2025-07:53:44] [V] [TRT] Adding network input: indices with dtype: int64, dimensions: (1, 4)
[12/16/2025-07:53:44] [W] [TRT] ModelImporter.cpp:503: Make sure input indices has Int64 binding.
[12/16/2025-07:53:44] [V] [TRT] Registering tensor: indices for ONNX tensor: indices
[12/16/2025-07:53:44] [V] [TRT] Importing initializer: updates
[12/16/2025-07:53:44] [V] [TRT] Importing initializer: data
[12/16/2025-07:53:44] [V] [TRT] Static check for parsing node: ScatterElementsOnly [ScatterElements]
[12/16/2025-07:53:44] [V] [TRT] Parsing node: ScatterElementsOnly [ScatterElements]
[12/16/2025-07:53:44] [V] [TRT] Searching for input: data
[12/16/2025-07:53:44] [V] [TRT] Searching for input: indices
[12/16/2025-07:53:44] [V] [TRT] Searching for input: updates
[12/16/2025-07:53:44] [V] [TRT] ScatterElementsOnly [ScatterElements] inputs: [data -> (1, 4)[HALF]], [indices -> (1, 4)[INT64]], [updates -> (1, 4)[HALF]],
[12/16/2025-07:53:44] [V] [TRT] Registering layer: data required by ONNX-TRT
[12/16/2025-07:53:44] [V] [TRT] Registering layer: updates required by ONNX-TRT
[12/16/2025-07:53:44] [V] [TRT] Registering layer: ScatterElementsOnly for ONNX node: ScatterElementsOnly
[12/16/2025-07:53:44] [V] [TRT] Registering tensor: scatter_output_0 for ONNX tensor: scatter_output
[12/16/2025-07:53:44] [V] [TRT] ScatterElementsOnly [ScatterElements] outputs: [scatter_output -> (1, 4)[HALF]],
[12/16/2025-07:53:44] [V] [TRT] Marking scatter_output_0 as output: scatter_output
[12/16/2025-07:53:44] [I] Finished parsing network model. Parse time: 0.00101608
[12/16/2025-07:53:44] [V] [TRT] could not open /sys/fs/cgroup/memory/memory.limit_in_bytes or /sys/fs/cgroup/memory.max
[12/16/2025-07:53:44] [V] [TRT] Original: 3 layers
[12/16/2025-07:53:44] [V] [TRT] After dead-layer removal: 3 layers
[12/16/2025-07:53:44] [V] [TRT] SYMBOLIC CHECKS
[12/16/2025-07:53:44] [V] [TRT] GRAPH NODES
[12/16/2025-07:53:44] [V] [TRT] CONSTANT data
[12/16/2025-07:53:44] [V] [TRT]     Output 0
[12/16/2025-07:53:44] [V] [TRT]         1       1
[12/16/2025-07:53:44] [V] [TRT]         4       4
[12/16/2025-07:53:44] [V] [TRT] CONSTANT updates
[12/16/2025-07:53:44] [V] [TRT]     Output 0
[12/16/2025-07:53:44] [V] [TRT]         1       1
[12/16/2025-07:53:44] [V] [TRT]         4       4
[12/16/2025-07:53:44] [V] [TRT] PLUGIN_V3 ScatterElementsOnly
[12/16/2025-07:53:44] [V] [TRT]     Input 0
[12/16/2025-07:53:44] [V] [TRT]         1       1
[12/16/2025-07:53:44] [V] [TRT]         4       4
[12/16/2025-07:53:44] [V] [TRT]     Input 1
[12/16/2025-07:53:44] [V] [TRT]         1       1
[12/16/2025-07:53:44] [V] [TRT]         4       4
[12/16/2025-07:53:44] [V] [TRT]     Input 2
[12/16/2025-07:53:44] [V] [TRT]         1       1
[12/16/2025-07:53:44] [V] [TRT]         4       4
[12/16/2025-07:53:44] [V] [TRT]     Output 0
[12/16/2025-07:53:44] [V] [TRT]         1       1
[12/16/2025-07:53:44] [V] [TRT]         4       4
[12/16/2025-07:53:45] [V] [TRT] Graph construction completed in 1.70143 seconds.
[12/16/2025-07:53:45] [V] [TRT] After adding DebugOutput nodes: 3 layers
[12/16/2025-07:53:45] [V] [TRT] After Myelin optimization: 1 layers
[12/16/2025-07:53:45] [V] [TRT] Applying ScaleNodes fusions.
[12/16/2025-07:53:45] [V] [TRT] After scale fusion: 1 layers
[12/16/2025-07:53:45] [V] [TRT] After dupe layer removal: 1 layers
[12/16/2025-07:53:45] [V] [TRT] After final dead-layer removal: 1 layers
[12/16/2025-07:53:45] [V] [TRT] After tensor merging: 1 layers
[12/16/2025-07:53:45] [V] [TRT] After vertical fusions: 1 layers
[12/16/2025-07:53:45] [V] [TRT] After dupe layer removal: 1 layers
[12/16/2025-07:53:45] [V] [TRT] After final dead-layer removal: 1 layers
[12/16/2025-07:53:45] [V] [TRT] After tensor merging: 1 layers
[12/16/2025-07:53:45] [V] [TRT] After slice removal: 1 layers
[12/16/2025-07:53:45] [V] [TRT] After concat removal: 1 layers
[12/16/2025-07:53:45] [V] [TRT] Trying to split Reshape and strided tensor
[12/16/2025-07:53:45] [V] [TRT] Graph optimization time: 0.000296935 seconds.
[12/16/2025-07:53:45] [V] [TRT] Building graph using backend strategy 2
[12/16/2025-07:53:45] [I] [TRT] Local timing cache in use. Profiling results in this builder pass will not be stored.
[12/16/2025-07:53:45] [V] [TRT] Constructing optimization profile number 0 [1/1].
[12/16/2025-07:53:45] [V] [TRT] Applying generic optimizations to the graph for inference.
[12/16/2025-07:53:45] [V] [TRT] Reserving memory for host IO tensors. Host: 0 bytes
[12/16/2025-07:53:45] [V] [TRT] =============== Computing costs for {ForeignNode[data...ScatterElementsOnly]}
[12/16/2025-07:53:45] [V] [TRT] ForeignNode {ForeignNode[data...ScatterElementsOnly]} metadata: [ONNX Layer: ScatterElementsOnly]
[12/16/2025-07:53:45] [V] [TRT] *************** Autotuning format combination: Int64(4,1) -> Half(4,1) ***************
[12/16/2025-07:53:45] [V] [TRT] --------------- Timing Runner: {ForeignNode[data...ScatterElementsOnly]} (Myelin[0x80000023])
[12/16/2025-07:53:45] [V] [TRT] Serialized custom layer data size: 1226
[12/16/2025-07:53:45] [I] [TRT] Compiler backend is used during engine build.
Segmentation fault (core dumped)

The ONNX model is parsed successfully, and the ScatterElements node is correctly recognized and mapped to the TensorRT plugin. However, the process crashes with a segmentation fault during the engine build phase, specifically after autotuning begins for the ScatterElements plugin in FP16 mode.

Is this a known issue with the ScatterElements plugin in FP16 mode on Thor? Is there a recommended workaround or configuration to avoid this segmentation fault when indices is an input tensor?

Any guidance or confirmation would be greatly appreciated.