Custom Quantization — AWS Neuron Documentation (original) (raw)

Custom Quantization#

Overview#

This document gives an overview of customizable quantization feature in the NxD Inference. Users can specify which modules should not be converted during quantization, allowing custom quantized model inference. Users can take an un-quantized model and apply selective quantization to specific layers while keeping others in full precision.

The document also explains how to use external libraries likellmcompressor, including quantization config setup and applying necessary patches. It also covers running inference with quantized models and specifying unconverted modules through either command-line arguments or NeuronConfig kwargs.

Quantization#

Custom quantization allows users to have fine-grained control over which layers of the model are quantized. This can be particularly useful for maintaining model accuracy while still benefiting from the reduced memory footprint of quantization. For more detailed information on quantization techniques and implementation, please refer to thequantization feature guide.

Quantize Using NxD#

Quantization can significantly reduce the model size and inference time, making it more suitable for deployment of large models that typically cannot fit on a single instance. However, not all layers of the model benefit equally from quantization.

To leverage the customizable quantization feature in NxD, follow the steps below. This process involves importing necessary libraries, defining the model and output paths, specifying modules to not convert, and utilizing a quantization function to create a quantized model.

import torch from typing import Optional, List from transformers import AutoModelForCausalLM, AutoTokenizer from neuronx_distributed_inference.modules.checkpoint import prune_state_dict,save_state_dict_safetensors from neuronx_distributed.quantization.quantization_utils import quantize_pytorch_model_per_channel_symmetric, convert_qint8_to_int8_state_dict

model_path = "/<model_path/llama-3.1-405b-instruct-4layers/" output_path = ""

modules_to_not_convert = [ "lm_head", "layers.0.self_attn", "layers.1.self_attn", "layers.2.self_attn", "layers.1.mlp" ]

def quantize(model: torch.nn.Module, dtype=torch.qint8, modules_to_not_convert: Optional[List[str]] = None) -> torch.nn.Module: quant_model = quantize_pytorch_model_per_channel_symmetric(model,dtype=dtype, modules_to_not_convert=modules_to_not_convert) model_quant_sd = quant_model.state_dict() convert_qint8_to_int8_state_dict(model_quant_sd) quantized_state_dict = prune_state_dict(model_quant_sd) return quantized_state_dict

model = AutoModelForCausalLM.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)

state_dict = quantize(model,torch.float8_e4m3fn,modules_to_not_convert)

save_state_dict_safetensors(state_dict=state_dict,state_dict_dir=output_path) tokenizer.save_pretrained(output_path)

Quantize using external libraries#

In addition to the built-in quantization features of NxD, users can also leverage external libraries for more flexible and advanced quantization options. One such library is llmcompressor, which offers a robust set of tools for quantizing models. To use the llmcompressor library for quantization, follow the steps below.

This process involves importing necessary libraries, specifying modules to not convert, setting up a quantization recipe, and applying the quantization to create a quantized model. llmcompressor gives us a range from -/+448, so it is important to ensure the scale range is set from -/+240 if you need to run inference on the quantized model later using NxD Inference. Values outside the range of -/+240 on Neuron devices result in NaNs.

The LLaMA model is an example where not all layers are quantized.

import torch from llmcompressor.transformers import oneshot, SparseAutoModelForCausalLM from transformers import AutoTokenizer from compressed_tensors.quantization.utils.helpers import calculate_range from compressed_tensors.quantization.quant_args import QuantizationType import compressed_tensors.quantization.utils.helpers as helpers

model_path = "//llama-3.1-405b-instruct-4layers/" output_path = ""

modules_to_not_convert = ['lm_head', "model.layers.0.mlp.down_proj", "model.layers.0.mlp.gate_proj", "model.layers.0.mlp.up_proj", "model.layers.3.mlp.down_proj", "model.layers.3.mlp.gate_proj", "model.layers.3.mlp.up_proj", "model.layers.0.self_attn.k_proj", "model.layers.0.self_attn.o_proj", "model.layers.0.self_attn.q_proj", "model.layers.0.self_attn.v_proj", "model.layers.1.self_attn.k_proj", "model.layers.1.self_attn.o_proj", "model.layers.1.self_attn.q_proj", "model.layers.1.self_attn.v_proj", "model.layers.2.self_attn.k_proj", "model.layers.2.self_attn.o_proj", "model.layers.2.self_attn.q_proj", "model.layers.2.self_attn.v_proj", "model.layers.3.self_attn.k_proj", "model.layers.3.self_attn.o_proj", "model.layers.3.self_attn.q_proj", "model.layers.3.self_attn.v_proj"]

recipe = f""" quant_stage: quant_modifiers: QuantizationModifier: ignore: {modules_to_not_convert} config_groups: group_0: weights: num_bits: 8 type: float strategy: channel dynamic: false symmetric: true input_activations: num_bits: 8 type: float strategy: token dynamic: true symmetric: true targets: ["Linear"] """

model = SparseAutoModelForCausalLM.from_pretrained( model_path, torch_dtype="auto" )

Monkey patch to rescale weights from -/+448 to -/+240

original_calculate_range = helpers.calculate_range def calculate_range(*args, **kwargs): q_min, q_max = original_calculate_range(*args, **kwargs) if args[0].type == QuantizationType.FLOAT and args[0].num_bits == 8: return torch.tensor(-240.0, device=args[1]), torch.tensor(240.0, device=args[1]) return q_min, q_max

Patch it

helpers.calculate_range = calculate_range oneshot(model=model, recipe=recipe)

for name, module in model.named_modules(): if hasattr(module, 'weight_scale'): module.weight_scale.data = module.weight_scale.data.to(torch.float32)

tokenizer = AutoTokenizer.from_pretrained(model_path)

model.save_pretrained(output_path) tokenizer.save_pretrained(output_path)

Quantization Commands#

To utilize the quantization commands in NxD Inference, users can follow the instructions below. These commands cover the required flags to enable running inference with quantized models.

First Quantize then Inference#

If you have a model in full precision and need to quantize it on the CPU first before using it for inference, you can set the following flags to enable quantization during inference:

inference_demo --model-type llama --task-type causal-lm run
--model-path /your_model_path/
--compiled-model-path /save_to_path/
--torch-dtype bfloat16
--tp-degree 32
--batch-size 1
--max-context-length 1024
--quantized
--quantization-dtype f8e4m3
--quantization-type per_channel_symmetric
--quantized-checkpoints-path /save_to_path/
--seq-len 2048
--fused-qkv
--pad-token-id 2
--on-device-sampling
--sequence-parallel-enabled
--attn-kernel-enabled
--prompt "I believe the meaning of life is"
--is-continuous-batching
--enable-fused-speculation
--enable-eagle-speculation
--speculation-length 4
--draft-model-path /your_draft_model_path
--modules-to-not-convert-file /path/modules_to_not_convert.json

Inference Using Already quantized checkpoint#

To utilize the quantization commands in NxD, users can follow the instructions below. These commands cover the required flags to enable running inference with quantized models. Themodules-to-not-convert-file allows you to specify the list of modules to not quantize, useful for quantizing models that explicitly require having some modules left in their original precision.

How to Use#

inference_demo --model-type llama --task-type causal-lm run
--model-path
--compiled-model-path
--torch-dtype bfloat16
--tp-degree
--batch-size
--max-context-length
--seq-len
--on-device-sampling
--mlp-kernel-enabled
--quantized-mlp-kernel-enabled
--quantization-dtype
--quantization-type
--prompt "I believe the meaning of life is"
--modules-to-not-convert-file //modules_to_not_convert.json

neuron_config = NeuronConfig( tp_degree=32, batch_size=2, max_context_length=32, seq_len=64, on_device_sampling_config=OnDeviceSamplingConfig(top_k=1), enable_bucketing=True, flash_decoding_enabled=False, modules_to_not_convert=["lm_head", "layers.0.self_attn", "layers.1.mlp", ...], draft_model_modules_to_not_convert=["lm_head", "layers.0.self_attn", "layers.1.mlp", ..., "fc"] )

Note: If you are creating different NeuronConfig for draft and target models, you only need to pass the modules_to_not_convert list for both.

JSON File Structure#

The JSON structure is a crucial component for specifying which modules should not be converted during the quantization if you are using inference demo. This section provides detailed examples of how to format the JSON file. The JSON structure depends on whether fused speculation is used.

  1. Basic Structure

For simple cases:

{ "modules_to_not_convert": [ "lm_head", "layers.0.self_attn", "layers.1.self_attn", "layers.2.self_attn", "layers.3.self_attn", "layers.0.mlp", "layers.3.mlp" ]}

OR#

{ "model": { "modules_to_not_convert": [ "lm_head", "layers.0.self_attn", "layers.1.self_attn", "layers.2.self_attn", "layers.3.self_attn", "layers.0.mlp", "layers.3.mlp" ] }}

  1. With Fused Speculation

{ "model": { "modules_to_not_convert": [ "lm_head", "layers.0.self_attn", "layers.1.self_attn", "layers.2.self_attn", "layers.3.self_attn", "layers.0.mlp", "layers.3.mlp" ] }, "draft_model": { "modules_to_not_convert": [ "lm_head", "layers.0.self_attn", "layers.0.mlp", "fc" ] }}

Important Notes#

Backward Incompatible Changes:#