Exporting LLMs (original) (raw)

Instead of needing to manually write code to call torch.export(), use ExecuTorch’s assortment of lowering APIs, or even interact with TorchAO quantize_ APIs for quantization, we have provided an out of box experience which performantly exports a selection of supported models to ExecuTorch.

Prerequisites#

The LLM export functionality requires the pytorch_tokenizers package. If you encounter a ModuleNotFoundError: No module named 'pytorch_tokenizers' error, install it from the ExecuTorch source code:

pip install -e ./extension/llm/tokenizers/

Supported Models#

As of this doc, the list of supported LLMs include the following:

The up-to-date list of supported LLMs can be found in the code here.

Note: If you need to export models that are not on this list or other model architectures (such as Gemma, Mistral, BERT, T5, Whisper, etc.), see Exporting LLMs with Optimum which supports a much wider variety of models from Hugging Face Hub.

The export_llm API#

export_llm is ExecuTorch’s high-level export API for LLMs. In this tutorial, we will focus on exporting Llama 3.2 1B using this API. export_llm’s arguments are specified either through CLI args or through a yaml configuration whose fields are defined in LlmConfig. To call export_llm:

python -m executorch.extension.llm.export.export_llm --config +base.

Basic export#

To perform a basic export of Llama3.2, we will first need to download the checkpoint file (consolidated.00.pth) and params file (params.json). You can find these from the Llama website or Hugging Face.

Then, we specify the model_class, checkpoint (path to checkpoint file), and params (path to params file) as arguments. Additionally, later when we run the exported .pte with our runner APIs, the runner will need to know about the bos and eos ids for this model to know when to terminate. These are exposed through bos and eos getter methods in the .pte, which we can add by specifying bos and eos ids in a metadata argument. The values for these tokens can usually be found in the model’s tokenizer_config.json on HuggingFace.

path/to/config.yaml

base: model_class: llama3_2 checkpoint: path/to/consolidated.00.pth params: path/to/params.json metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}'

export_llm

python -m extension.llm.export.export_llm
--config path/to/config.yaml

We only require manually specifying a checkpoint path for the Llama model family, since it is our most optimized model and we have more advanced optimizations such as SpinQuant that require custom checkpoints.

For the other supported LLMs, the checkpoint will be downloaded from HuggingFace automatically, and the param files can be found in their respective directories under executorch/examples/models, for instance executorch/examples/models/qwen3/config/0_6b_config.json.

Export settings#

ExportConfig contains settings for the exported .pte, such as max_seq_length (max length of the prompt) and max_context_length (max length of the model’s memory/cache).

Adding optimizations#

export_llm performs a variety of optimizations to the model before export, during export, and during lowering. Quantization and delegation to accelerator backends are the main ones and will be covered in the next two sections. All other optimizations can be found under ModelConfig. We will go ahead and add a few optimizations.

path/to/config.yaml

base: model_class: llama3_2 checkpoint: path/to/consolidated.00.pth params: path/to/params.json metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' model: use_kv_cache: True use_sdpa_with_kv_cache: True

export_llm

python -m extension.llm.export.export_llm
--config path/to/config.yaml

use_kv_cache and use_sdpa_with_kv_cache are recommended to export any LLM, while other options are useful situationally. For example:

Quantization#

Quantization options are defined by QuantizationConfig. ExecuTorch does quantization in two ways:

  1. TorchAO quantize_ API
  2. pt2e quantization

TorchAO (XNNPACK)#

TorchAO quantizes at the source code level, swapping out Linear modules for QuantizedLinear modules.**To quantize on XNNPACK backend, this is the quantization path to follow.**The quantization modes are defined here.

Common ones to use are:

Group size is specified with:

For Arm CPUs, there are also low-bit kernels for int8 dynamic activation + int[1-8] weight quantization. Note that this should not be used alongside XNNPACK, and experimentally we have found that the performance could sometimes even be better for the equivalent 8da4w. To use these, specify qmode to either:

To quantize embeddings, specify either embedding_quantize: <bitwidth>,<groupsize> (bitwidth here must be 2, 4, or 8), or for low-bit kernels use embedding_quantize: torchao:<bitwidth>,<groupsize> (bitwidth can be from 1-8).

path/to/config.yaml

base: model_class: llama3_2 checkpoint: path/to/consolidated.00.pth params: path/to/params.json metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' model: use_kv_cache: True use_sdpa_with_kv_cache: True quantization: embedding_quantize: 4,32 qmode: 8da4w

export_llm

python -m extension.llm.export.export_llm
--config path/to/config.yaml

pt2e (QNN, CoreML, and Vulkan)#

pt2e quantizes at the post-export graph level, swapping nodes and injecting quant/dequant nodes.**To quantize on non-CPU backends (QNN, CoreML, Vulkan), this is the quantization path to follow.**Read more about pt2e here, and how ExecuTorch uses pt2e here.

CoreML and Vulkan support for export_llm is currently experimental and limited. To read more about QNN export, please read Running on Android (Qualcomm).

Backend support#

Backend options are defined by BackendConfig. Each backend has their own backend configuration options. Here is an example of lowering the LLM to XNNPACK for CPU acceleration:

path/to/config.yaml

base: model_class: llama3_2 checkpoint: path/to/consolidated.00.pth params: path/to/params.json metadata: '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' model: use_kv_cache: True use_sdpa_with_kv_cache: True quantization: embedding_quantize: 4,32 qmode: 8da4w backend: xnnpack: enabled: True extended_ops: True # Expand the selection of ops delegated to XNNPACK.

export_llm

python -m extension.llm.export.export_llm
--config path/to/config.yaml

Profiling and Debugging#

To see which ops got delegated to the backend and which didn’t, specify verbose: True:

path/to/config.yaml

... debug: verbose: True ...

export_llm

python -m extension.llm.export.export_llm
--config path/to/config.yaml

In the logs, there will be a table of all ops in the graph, and which ones were and were not delegated.

Here is an example:

Click to see delegation details

Total delegated subgraphs: 368
Number of delegated nodes: 2588
Number of non-delegated nodes: 2513

To do further performance analysis, you can may opt to use ExecuTorch’s Developer Tools to do things such as trace individual operator performance back to source code, view memory planning, and debug intermediate activations. To generate the ETRecord to link back .pte program to source code, you can use:

path/to/config.yaml

... debug: generate_etrecord: True ...

export_llm

python -m extension.llm.export.export_llm
--config path/to/config.yaml

Other debug and profiling options can be found in DebugConfig.

A few examples ones:

To learn more about ExecuTorch’s Developer Tools, see the Introduction to the ExecuTorch Developer Tools.