GitHub - triton-inference-server/pytorch_backend: The Triton backend for the PyTorch TorchScript models. (original) (raw)

License

PyTorch (LibTorch) Backend

The Triton backend for PyTorch. You can learn more about Triton backends in the backend repo. Ask questions or report problems on the issues page. This backend is designed to run TorchScriptmodels using the PyTorch C++ API. All models created in PyTorch using the python API must be traced/scripted to produce a TorchScript model.

Where can I ask general questions about Triton and Triton backends? Be sure to read all the information below as well as the general Triton documentationavailable in the main serverrepo. If you don't find your answer there you can ask questions on the main Triton issues page.

Build the PyTorch Backend

Use a recent cmake to build. First install the required dependencies.

$ apt-get install rapidjson-dev python3-dev python3-pip
$ pip3 install patchelf==0.17.2

An appropriate PyTorch container from NGC must be used. For example, to build a backend that uses the 23.04 version of the PyTorch container from NGC:

$ mkdir build
$ cd build
$ cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install -DTRITON_PYTORCH_DOCKER_IMAGE="nvcr.io/nvidia/pytorch:23.04-py3" ..
$ make install

The following required Triton repositories will be pulled and used in the build. By default, the "main" branch/tag will be used for each repo but the listed CMake argument can be used to override.

Build the PyTorch Backend With Custom PyTorch

Currently, Triton requires that a specially patched version of PyTorch be used with the PyTorch backend. The full source for these PyTorch versions are available as Docker images fromNGC. For example, the PyTorch version compatible with the 22.12 release of Triton is available as nvcr.io/nvidia/pytorch:22.12-py3.

Copy over the LibTorch and Torchvision headers and libraries from thePyTorch NGC containerinto local directories. You can see which headers and libraries are needed/copied from the docker.

$ mkdir build
$ cd build
$ cmake -DCMAKE_INSTALL_PREFIX:PATH=`pwd`/install -DTRITON_PYTORCH_INCLUDE_PATHS="<PATH_PREFIX>/torch;<PATH_PREFIX>/torch/torch/csrc/api/include;<PATH_PREFIX>/torchvision" -DTRITON_PYTORCH_LIB_PATHS="<LIB_PATH_PREFIX>" ..
$ make install

Using the PyTorch Backend

Parameters

Triton exposes some flags to control the execution mode of the TorchScript models through the Parameters section of the model's config.pbtxt file.

The initial calls to a loaded TorchScript model take extremely long. Due to this longer model warmup issue, Triton also allows execution of models without these optimizations. In some models, optimized execution does not benefit performance as seen hereand in other cases impacts performance negatively, as seen here.

The section of model config file specifying this parameter will look like:

parameters: {
key: "DISABLE_OPTIMIZED_EXECUTION"
    value: {
    string_value: "true"
    }
}

InferenceMode is a new RAII guard analogous to NoGradMode to be used when you are certain your operations will have no interactions with autograd. Compared to NoGradMode, code run under this mode gets better performance by disabling autograd.

Please note that in some models, InferenceMode might not benefit performance and in fewer cases might impact performance negatively.

The section of model config file specifying this parameter will look like:

parameters: {
key: "INFERENCE_MODE"
    value: {
    string_value: "true"
    }
}

cuDNN is a GPU-accelerated library of primitives for deep neural networks. cuDNN provides highly tuned implementations for standard routines.

Typically, models run with cuDNN enabled are faster. However there are some exceptions where using cuDNN can be slower, cause higher memory usage or result in errors.

The section of model config file specifying this parameter will look like:

parameters: {
key: "DISABLE_CUDNN"
    value: {
    string_value: "true"
    }
}

The section of model config file specifying this parameter will look like:

parameters: {
key: "ENABLE_WEIGHT_SHARING"
    value: {
    string_value: "true"
    }
}

The section of model config file specifying this parameter will look like:

parameters: {
key: "ENABLE_CACHE_CLEANING"
    value: {
    string_value:"true"
    }
}

PyTorch allows using multiple CPU threads during TorchScript model inference. One or more inference threads execute a model’s forward pass on the given inputs. Each inference thread invokes a JIT interpreter that executes the ops of a model inline, one by one. This parameter sets the size of this thread pool. The default value of this setting is the number of cpu cores. Please refer to thisdocument on how to set this parameter properly.

The section of model config file specifying this parameter will look like:

parameters: {
key: "INTER_OP_THREAD_COUNT"
    value: {
    string_value:"1"
    }
}

In addition to the inter-op parallelism, PyTorch can also utilize multiple threads within the ops (intra-op parallelism). This can be useful in many cases, including element-wise ops on large tensors, convolutions, GEMMs, embedding lookups and others. The default value for this setting is the number of CPU cores. Please refer to thisdocument on how to set this parameter properly.

The section of model config file specifying this parameter will look like:

parameters: {
key: "INTRA_OP_THREAD_COUNT"
    value: {
    string_value:"1"
    }
}

Support

Model Instance Group Kind

The PyTorch backend supports the following kinds ofModel Instance Groupswhere the input tensors are placed as follows:

Important Notes

PyTorch 2.0 Backend [Experimental]

Warning

This feature is subject to change and removal.

Starting from 24.01, PyTorch models can be served directly viaPython runtime. By default, Triton will use theLibTorch runtime for PyTorch models. To use Python runtime, provide the followingruntime settingin the model configuration:

Dependencies

Python backend dependency

This feature depends onPython backend, seePython-based Backendsfor more details.

PyTorch dependency

This feature will take advantage of thetorch.compileoptimization, make sure thePyTorch 2.0+ pip package is available in the same Python environment.

Alternatively, a Python Execution Environmentwith the PyTorch dependency may be used. It can be created with theprovided script. The resultingpb_exec_env_model.py.tar.gz file should be placed at the samebackend shared librarydirectory as the Python runtime.

Model Layout

PyTorch 2.0 models

The model repository should look like:

model_repository/
`-- model_directory
    |-- 1
    |   |-- model.py
    |   `-- [model.pt]
    `-- config.pbtxt

The model.py contains the class definition of the PyTorch model. The class should extend thetorch.nn.Module. The model.pt may be optionally provided which contains the savedstate_dictof the model.

TorchScript models

The model repository should look like:

model_repository/
`-- model_directory
    |-- 1
    |   `-- model.pt
    `-- config.pbtxt

The model.pt is the TorchScript model file.

Customization

The following PyTorch settings may be customized by setting parameters on theconfig.pbtxt.

torch.set_num_threads(int)

torch.set_num_interop_threads(int)

torch.compile() parameters

For example:

parameters: {
    key: "NUM_THREADS"
    value: { string_value: "4" }
}
parameters: {
    key: "TORCH_COMPILE_OPTIONAL_PARAMETERS"
    value: { string_value: "{\"disable\": true}" }
}

Limitations

Following are few known limitations of this feature: