model_quant — Model Optimizer 0.27.1 (original) (raw)

TensorRT Model Optimizer

User-facing quantization API.

Functions

calibrate Adjusts weights and scaling factors based on selected algorithms.
postprocess_amax Experimental API to postprocess the amax values after calibration.
quantize Quantizes and calibrates the model in-place.
auto_quantize API for AutoQuantize which quantizes a model by searching for the best quantization formats per-layer.
disable_quantizer Disable quantizer by wildcard or filter function.
enable_quantizer Enable quantizer by wildcard or filter function.
print_quant_summary Print summary of all quantizer modules in the model.
fold_weight Fold weight quantizer for fast evaluation.

auto_quantize(model, constraints={'effective_bits': 4.8}, quantization_formats=['NVFP4_DEFAULT_CFG', 'FP8_DEFAULT_CFG', None], data_loader=None, forward_step=None, loss_func=None, forward_backward_step=None, disabled_layers=None, num_calib_steps=512, num_score_steps=128, verbose=False)

API for AutoQuantize which quantizes a model by searching for the best quantization formats per-layer.

auto_quantize uses a gradient based sensitivity score to rank the per-layer quantization formats and search for the best quantization formats per-layer.

Parameters:

For an effective quantization bits of 4.8

constraints = {"effective_bits": 4.8}

This will search for the best per-layer quantization from FP8, W4A8_AWQ or No quantization

quantization_formats = ["FP8_DEFAULT_CFG", "W4A8_AWQ", None]

Takes the model and a batch of data as input and returns the model output

def forward_step(model, batch) -> torch.Tensor:
output = model(batch)
return output

Takes the model output and a batch of data as input and returns the loss

def loss_func(output, batch) -> torch.Tensor:
...
return loss

loss should be a scalar tensor such that loss.backward() can be called

loss = loss_func(output, batch)
loss.backward()
If this argument is not provided, forward_backward_step should be provided.

Takes the model and a batch of data as input and runs forward and backward pass

def forward_backward_step(model, batch) -> None:
output = model(batch)
loss = my_loss_func(output, batch)
run_custom_backward(loss)
If this argument is not provided, loss_func should be provided.

Returns: A tuple (model, state_dict) where model is the searched and quantized model and

state_dict contains the history and detailed stats of the search procedure.

Note

auto_quantize groups certain layers and restricts the quantization formats for them to be same. For example, Q, K, V linear layers belonging to the same transformer layer will have the same quantization format. This is to ensure compatibility with TensorRT-LLM which fuses these three linear layers into a single linear layer.

A list of regex pattern rules as defined in rulesare used to specify the group of layers. The first captured group in the regex pattern (i.e, pattern.match(name).group(1)) is used to group the layers. All the layers that share the same first captured group will have the same quantization format..

For example, the rule r"^(.*?)\.(q_proj|k_proj|v_proj)$"groups the q_proj, k_proj, v_proj linear layers belonging to the same transformer layer.

You may modify the rules to group the layers as per your requirement.

from modelopt.torch.quantization.algorithms import AutoQuantizeSearcher

To additionally group the layers belonging to same mlp layer,

add the following rule

AutoQuantizeSearcher.rules.append(r"^(.*?).mlp")

Perform auto_quantize

model, state_dict = auto_quantize(model, ...)

Note

The auto_quantize API and algorithm is experimental and subject to change. auto_quantize searched models might not be readily deployable to TensorRT-LLM yet.

calibrate(model, algorithm='max', forward_loop=None)

Adjusts weights and scaling factors based on selected algorithms.

Parameters:

Return type:

Module

Returns: The calibrated pytorch model.

disable_quantizer(model, wildcard_or_filter_func)

Disable quantizer by wildcard or filter function.

Parameters:

enable_quantizer(model, wildcard_or_filter_func)

Enable quantizer by wildcard or filter function.

Parameters:

fold_weight(model)

Fold weight quantizer for fast evaluation.

Parameters:

model (Module) –

postprocess_amax(model, key, post_process_fn)

Experimental API to postprocess the amax values after calibration.

Parameters:

Return type:

Module

print_quant_summary(model)

Print summary of all quantizer modules in the model.

Parameters:

model (Module) –

quantize(model, config, forward_loop=None)

Quantizes and calibrates the model in-place.

This method performs replacement of modules with their quantized counterparts and performs calibration as specified by quant_cfg.forward_loop is used to forward data through the model and gather statistics for calibration.

Parameters:

Example 2:
def forward_loop(model) -> float:
# evaluate the model on the task
return evaluate(model, task, ....)
Example 3:
def forward_loop(model) -> None:
# run evaluation pipeline
evaluator.model = model
evaluator.evaluate()
Note
Calibration does not require forwarding the entire dataset through the model. Please subsample the dataset or reduce the number of batches if needed.

Return type:

Module

Returns: A pytorch model which has been quantized and calibrated.