using nccl ops from TRT-LLM namespace by apbose · Pull Request #3250 · pytorch/TensorRT (original) (raw)
--- /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py 2024-10-21 20:25:45.697459+00:00 +++ /home/runner/work/TensorRT/TensorRT/examples/distributed_inference/tensor_parallel_simple_example.py 2024-10-21 20:26:10.941910+00:00 @@ -26,44 +26,51 @@ ) import tensorrt as trt import tensorrt_llm import ctypes import logging + """ This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py """
plugin_lib_path = "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" try:
- ctypes.CDLL("/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so")
- ctypes.CDLL(
"/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"
- )
print("plugin loaded sucessfully")
except OSError as e: print(f"unsuccessful load : {e}") logger = trt.Logger(trt.Logger.VERBOSE) -trt.init_libnvinfer_plugins(None, '') -#-[p;Iterate over all registered plugin creators +trt.init_libnvinfer_plugins(None, "") +# -[p;Iterate over all registered plugin creators plugin_registry = trt.get_plugin_registry() for plugin_creator in plugin_registry.plugin_creator_list:
- print(f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}")
- print(
f"Plugin Name: {plugin_creator.name}, Namespace: {plugin_creator.plugin_namespace}, Version: {plugin_creator.plugin_version}"
- )
@dynamo_tensorrt_converter(torch.ops._c10d_functional.all_gather_into_tensor.default) def insert_gather_op( ctx: ConversionContext, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument],
- name: str,
- name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]: plug_inputs = [args[0]] allgather_plg_creator = trt.get_plugin_registry().get_plugin_creator( "AllGather", "1", "tensorrt_llm" ) assert allgather_plg_creator is not None world_size = dist.get_world_size() group = list(range(world_size))
- group = trt.PluginField("group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32)
- group = trt.PluginField(
"group", np.array(group, dtype=np.int32), trt.PluginFieldType.INT32
- )
p_dtype = trt.float16
pf_type = trt.PluginField(
"type_id", np.array([int(p_dtype)], np.int32), trt.PluginFieldType.INT32
)
pfc = trt.PluginFieldCollection([group, pf_type])