utils — Model Optimizer 0.29.0 (original) (raw)

Utility functions for prune-related and search-space related tasks.

Note

Generally, methods in the modelopt.torch.nas module should use these utility functions directly instead of accessing the SearchSpace class. This is to ensure that potentially required pre- and post-processing operations are performed correctly.

Classes

enable_modelopt_patches Context manager to enable modelopt patches such as those for autonas/fastnas.
no_modelopt_patches Context manager to disable modelopt patches to the model.
set_modelopt_patches_enabled Context manager that sets patches to on or off.

Functions

inference_flops Get the inference FLOPs of a PyTorch model.
print_search_space_summary Print the search space summary.
get_subnet_config Return the config dict of all hyperparameters.
sample Sample searchable hparams using the provided sample_func and return resulting config.
select Select the sub-net according to the provided config dict.
is_modelopt_patches_enabled Check if modelopt patches for model are enabled.
replace_forward Context manager to temporarily replace the forward method of the underlying type of a model.

class enable_modelopt_patches

Bases: _DecoratorContextManager

Context manager to enable modelopt patches such as those for autonas/fastnas.

It can also be used as a decorator (make sure to instantiate with parenthesis).

For example:

modelopt_model.train() modelopt_model(inputs) # architecture changes

with mtn.no_modelopt(): with mtn.enable_modelopt(): modelopt_model(inputs) # architecture changes

@mtn.enable_modelopt() def forward(model, inputs): return model(inputs)

with mtn.no_modelopt(): forward(modelopt_model, inputs) # architecture changes because of decorator on forward

__init__()

Constructor.

get_subnet_config(model, configurable=None)

Return the config dict of all hyperparameters.

Parameters:

Returns:

A dict of (parameter_name, choice) that specifies an active subnet.

Return type:

_dict_[str, _Any_]

inference_flops(network, dummy_input=None, data_shape=None, unit=1000000.0, return_str=False)

Get the inference FLOPs of a PyTorch model.

Parameters:

Returns:

The number of inference FLOPs in the given unit as either string or float.

Return type:

float | str

is_modelopt_patches_enabled()

Check if modelopt patches for model are enabled.

Return type:

bool

class no_modelopt_patches

Bases: _DecoratorContextManager

Context manager to disable modelopt patches to the model.

Disabling modelopt patches is useful when you want to use the model’s original behavior For example, you can use this to perform a forward pass without NAS operations.

It can also be used as a decorator (make sure to instantiate with parenthesis).

For example:

modelopt_model.train() modelopt_model(inputs) # architecture changes

with mtn.no_modelopt(): modelopt_model(inputs) # architecture does not change

@mtn.no_modelopt() def forward(model, inputs): return model(inputs)

forward(modelopt_model, inputs) # architecture does not change

__init__()

Constructor.

print_search_space_summary(model, skipped_hparams=['kernel_size'])

Print the search space summary.

Parameters:

Return type:

None

replace_forward(model, new_forward)

Context manager to temporarily replace the forward method of the underlying type of a model.

The original forward function is temporarily accessible via model.forward_original.

Parameters:

Return type:

_Iterator_[_None_]

For example:

fake_forward = lambda _: None

with replace_forward(model, fake_forward): out = model(inputs) # this output is None

out_original = model(inputs) # this output is the original output

sample(model, sample_func=)

Sample searchable hparams using the provided sample_func and return resulting config.

Parameters:

Returns:

A dict of (parameter_name, choice) that specifies an active subnet.

Return type:

_dict_[str, _Any_]

select(model, config, strict=True)

Select the sub-net according to the provided config dict.

Parameters:

Return type:

None

class set_modelopt_patches_enabled

Bases: _DecoratorContextManager

Context manager that sets patches to on or off.

It can be used as context manager or as a function. If used as function, operations are disabled globally (thread local).

Parameters:

enabled – whether to enable (True) or disable (False) patched methods.

For example:

modelopt_model.train() modelopt_model(inputs) # architecture changes

mtn.set_modelopt_enabled(False) modelopt_model(inputs) # architecture does not change

with mtn.set_modelopt_enabled(True): modelopt_model(inputs) # architecture changes

modelopt_model(inputs) # architecture does not change

__init__(enabled)

Constructor.

Parameters:

enabled (bool) –

clone()

Clone the context manager.