network — Model Optimizer 0.27.1 (original) (raw)
Utility functions for PyTorch models.
Functions
compare_dict | Compare two dictionaries and return keys with unmatched values. |
---|---|
get_model_attributes | Get the key attributes of a PyTorch model. |
get_module_device | Get the device of a PyTorch module. |
get_same_padding | Get the same padding for a given kernel size. |
init_model_from_model_like | Initialize a model from a model-like object. |
is_channels_last | Check if the model is using channels last memory format. |
is_parallel | Check if a PyTorch model is parallelized. |
make_divisible | Function taken from the original tf repo. |
model_to | Convert model to the same device, dtype and memory layout as the target_model. |
param_num | Get the number of parameters of a PyTorch model. |
param_num_from_forward | Get the number of parameters of a PyTorch model from a forward pass. |
remove_bn | Remove all batch normalization layers in the network. |
run_forward_loop | Run multiple forward passes with a model according to the provided data loader. |
set_submodule | The set function that complements nn.Module.get_submodule(). |
standardize_model_args | Standardize model arguments according to torch.onnx.export. |
standardize_model_like_tuple | Standardize a model-like tuple. |
standardize_named_model_args | Standardize model arguments according to torch.onnx.export and give them a name. |
standardize_constructor_args | Standardize a constructor-like tuple. |
unwrap_model | Unwrap a model that is wrapped by supported wrapper module or return original model. |
zero_grad | Set any gradients in the model's parameters to None. |
create_param_grad_clear_hook | Create a hook to clear gradients for a parameter. |
get_unwrapped_name | Get the cleaned module name (i.e, the name before wrapping with sharded modules). |
compare_dict(dict1, dict2)
Compare two dictionaries and return keys with unmatched values.
Parameters:
- dict1 (dict [ str , Any ]) –
- dict2 (dict [ str , Any ]) –
Return type:
_tuple_[str, …]
create_param_grad_clear_hook(param)
Create a hook to clear gradients for a parameter.
The hook will be fired after the gradient is accumulated for the parameter. Important: For this to work, accum_grad
should be kept alive as longs as this utility is needed.
get_model_attributes(model)
Get the key attributes of a PyTorch model.
Parameters:
model (Module) –
Return type:
_dict_[str, _Any_]
get_module_device(module)
Get the device of a PyTorch module.
Parameters:
module (Module) –
Return type:
device
get_same_padding(kernel_size)
Get the same padding for a given kernel size.
Parameters:
kernel_size (int | tuple [ int , int ]) –
Return type:
int | tuple
get_unwrapped_name(name)
Get the cleaned module name (i.e, the name before wrapping with sharded modules).
Parameters:
name (str) –
Return type:
str
init_model_from_model_like(model)
Initialize a model from a model-like object.
Parameters:
model (Module | type [ Module ] | tuple | Callable) – A model-like object. Can be a nn.Module (returned as it is), a model class or callable, or a tuple. If a tuple, it must be of the form (model_cls_or_callable,) or (model_cls_or_callable, args) or (model_cls_or_callable, args, kwargs). Model will be initialized as model_cls_or_callable(*args, **kwargs)
.
Return type:
Module
is_channels_last(model)
Check if the model is using channels last memory format.
Parameters:
model (Module) –
is_parallel(model)
Check if a PyTorch model is parallelized.
Parameters:
model (Module) –
Return type:
bool
make_divisible(v, divisor, min_val=None)
Function taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8 It can be seen here:https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
Parameters:
- v (int | float) –
- divisor (int | None) –
Return type:
int | float
model_to(model, target_model)
Convert model to the same device, dtype and memory layout as the target_model.
Parameters:
- model (Module) –
- target_model (Module) –
param_num(network, trainable_only=False, unit=1000000.0)
Get the number of parameters of a PyTorch model.
Parameters:
- network (Module) – The PyTorch model.
- trainable_only (bool) – Whether to only count trainable parameters. Default is False.
- unit – The unit to return the number of parameters in. Default is 1e6 (million).
Returns:
The number of parameters in the model in the given unit.
Return type:
float
param_num_from_forward(model, trainable_only=False, args=None, unit=1000000.0)
Get the number of parameters of a PyTorch model from a forward pass.
Parameters:
- network – The PyTorch model.
- trainable_only (bool) – Whether to only count trainable parameters. Default is False.
- unit (float) – The unit to return the number of parameters in. Default is 1e6 (million).
- model (Module) –
- args (Tensor | tuple | None) –
Returns:
The number of parameters from the model’s forward pass in the given unit.
This can helpful for dynamic modules, where the state dict might contain extra parameters that is not actively used in the model, e.g., because of a DynamicModule that is deactivated for the forward pass. We circumvent this issue by just counting parameters of modules that appear in a forward pass.
remove_bn(model)
Remove all batch normalization layers in the network.
Parameters:
model (Module) –
run_forward_loop(model, data_loader, max_iters=None, collect_func=None, progress_bar=None, post_process=None)
Run multiple forward passes with a model according to the provided data loader.
Parameters:
- model – The model with which we run forward.
- data_loader (Iterable) – An iterator with data samples.
- max_iters (int | None) – Number of batches to run; by default it is infiinite or until
data_loader
is exhausted. - collect_func (Callable [ [ Any ] , Any | tuple ] | None) –
ACallable
that takes a batch of data from thedata_loader
as input and returns the input tomodel.forward()
such that the return value (input
) is either:- a single argument (
type(input) != tuple
) corresponding to - a tuple of arguments corresponding to
- a tuple of arguments such that
type(input[-1]) == dict
corresponding to
model.forward(*input[:-1], **input[-1])
Note
In order to pass a dict as last non-keyword argument, you need to use a tuple as
input
and add an empty dict as the last element, e.g.,input = (x, {"y": y, "z": z}, {})
The empty dict at the end will then be interpreted as the keyword args.
See the
args
argument oftorch.onnx.exportfor more info on the format of the return value ofcollect_func
(input
).The default
collect_func
assumes that the data loader returns a tuple, e.g.,(images, labels, ...)
, and returns the first element of the tuple. - a single argument (
- progress_bar (str | None) – Set to a description string to see the progress bar.
- post_process (Callable | None) – A callable that takes the model outputs and the data as input and can be used to run any post-processing or operations such as backward pass.
set_submodule(model, target, target_submodule)
The set function that complements nn.Module.get_submodule().
Parameters:
- model (Module) –
- target (str) –
- target_submodule (Module) –
standardize_constructor_args(constructor_args)
Standardize a constructor-like tuple.
Parameters:
constructor_args (Callable | tuple) –
Return type:
_tuple_[Callable, tuple, _dict_]
standardize_model_args(model_or_fw_or_sig, args, use_kwargs=False)
Standardize model arguments according to torch.onnx.export.
Parameters:
- model_or_fw_or_sig (Module | Callable | Signature) – A nn.Module, its forward method, or its forward method’s signature.
- args (Any | tuple) – Refer to the
dummy_input
parameter inmtn.profile(). - use_kwargs – Affects the return value, see below. For
use_kwargs==False
, the returned args are also compatible withtorch.onnx.export
.
Returns:
Standardized model args that can be used in model.forward()
in the same standardized way no matter how they were provided, see below for more info.
Return type:
tuple
- If
use_kwargs == False
, the returned args can be used as
args = standardize_model_args(model, args, use_kwargs=False)
model(*args) - If
use_kwargs == True
, the returned args can be used as
args = standardize_model_args(model, args, use_kwargs=True)
model.forward(*args[:-1], **args[-1])
Warning
If use_kwargs == False
the model’s forward()
method cannot contain keyword-only arguments (e.g. forward(..., *, kw_only_args)
) without default values and you must not provide them in args
.
Warning
If use_kwargs == False
you must not provide variable keyword arguments in args
that are processed via variable keyword arguments in the model’s forward()
method (e.g. forward(..., **kwargs)
).
standardize_model_like_tuple(model)
Standardize a model-like tuple.
Parameters:
model (Module | type [ Module ] | tuple | Callable) –
Return type:
_tuple_[_type_[_Module_], tuple, _dict_]
standardize_named_model_args(model_or_fw_or_sig, args)
Standardize model arguments according to torch.onnx.export and give them a name.
Parameters:
- model_or_fw_or_sig (Module | Callable | Signature) – A nn.Module, its forward method, or its forward method’s signature.
- args (Any | tuple) – A tuple of args/kwargs or torch.Tensor feed into the model’s
forward()
method.
Return type:
_tuple_[_dict_[str, _Any_], _set_[_str_]]
Returns: A tuple (args_normalized, args_with_default) where
args_normalized is a dictionary of ordered model args where the key represents a unique
serialized string based on the the argument’s name in the function signature and the value contains the actual argument,
args_with_default is a set indicating whether the argument was retrieved from the default
value in the function signature of the model’s forward()
method or whether the argument exactly corresponds to the default value.
Note
See standardize_model_args() for more info as well.
unwrap_model(model, warn=False, raise_error=False, msg='', force_unwrap=False)
Unwrap a model that is wrapped by supported wrapper module or return original model.
Parameters:
- model (Module) –
- warn (bool) –
- raise_error (bool) –
- msg (str) –
- force_unwrap (bool) –
Return type:
Module
zero_grad(model)
Set any gradients in the model’s parameters to None.
Parameters:
model (Module) –
Return type:
None