using nccl ops from TRT-LLM namespace by apbose · Pull Request #3250 · pytorch/TensorRT (original) (raw)

--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py 2024-10-21 20:25:45.697459+00:00 +++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py 2024-10-21 20:26:10.941910+00:00 @@ -26,44 +26,51 @@ ) import tensorrt as trt import tensorrt_llm import ctypes import logging + """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py """

plugin_lib_path = "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" try:

print("plugin loaded sucessfully")

except OSError as e: print(f"unsuccessful load : {e}") logger = trt.Logger(trt.Logger.VERBOSE) -trt.init_libnvinfer_plugins(None, '') -#-[p;Iterate over all registered plugin creators +trt.init_libnvinfer_plugins(None, "") +# -[p;Iterate over all registered plugin creators plugin_registry = trt.get_plugin_registry() for plugin_creator in plugin_registry.plugin_creator_list:

@dynamo_tensorrt_converter(torch.ops._c10d_functional.all_gather_into_tensor.default) def insert_gather_op( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument],

) -> Union[TRTTensor, Sequence[TRTTensor]]: plug_inputs = [args[0]] allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( "AllGather", "1", "tensorrt_llm" ) assert allgather_plg_creator is not None world_size = dist.get_world_size() group = list(range(world_size))

p_dtype = trt.float16
pf_type = trt.PluginField(
    "type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = trt.PluginFieldCollection([group, pf_type])