apex.optimizers — Apex 0.1.0 documentation (original) (raw)
class apex.optimizers.
FusedAdam
(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, adam_w_mode=True, weight_decay=0.0, amsgrad=False, set_grad_none=True)[source]¶
Implements Adam algorithm.
Currently GPU-only. Requires Apex to be installed viapip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
.
This version of fused Adam implements 2 fusions.
- Fusion of the Adam update’s elementwise operations
- A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.
apex.optimizers.FusedAdam may be used as a drop-in replacement for torch.optim.Adam
:
opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) ... opt.step()
apex.optimizers.FusedAdam may be used with or without Amp. If you wish to use FusedAdam with Amp, you may choose any opt_level
:
opt = apex.optimizers.FusedAdam(model.parameters(), lr = ....) model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") ... opt.step()
In general, opt_level="O1"
is recommended.
Warning
A previous version of FusedAdam allowed a number of additional arguments to step
. These additional arguments are now deprecated and unnecessary.
Adam was been proposed in Adam: A Method for Stochastic Optimization.
Parameters
- params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.
- lr (float, optional) – learning rate. (default: 1e-3)
- betas (Tuple [_float,_ float] , optional) – coefficients used for computing running averages of gradient and its square. (default: (0.9, 0.999))
- eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)
- weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
- amsgrad (boolean , optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and Beyond(default: False) NOT SUPPORTED in FusedAdam!
- adam_w_mode (boolean , optional) – Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True)
- set_grad_none (bool, optional) – whether set grad to None when zero_grad() method is called. (default: True)
step
(closure=None, grads=None, output_params=None, scale=None, grad_norms=None)[source]¶
Performs a single optimization step.
Parameters
closure (callable , optional) – A closure that reevaluates the model and returns the loss.
The remaining arguments are deprecated, and are only retained (for the moment) for error-checking purposes.
Clears the gradients of all optimized torch.Tensor
s.
class apex.optimizers.
FusedLAMB
(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-06, weight_decay=0.01, amsgrad=False, adam_w_mode=True, grad_averaging=True, set_grad_none=True, max_grad_norm=1.0)[source]¶
Implements LAMB algorithm.
Currently GPU-only. Requires Apex to be installed viapip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
.
This version of fused LAMB implements 2 fusions.
- Fusion of the LAMB update’s elementwise operations
- A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.
apex.optimizers.FusedLAMB’s usage is identical to any ordinary Pytorch optimizer:
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) ... opt.step()
apex.optimizers.FusedLAMB may be used with or without Amp. If you wish to use FusedLAMB with Amp, you may choose any opt_level
:
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....) model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") ... opt.step()
In general, opt_level="O1"
is recommended.
LAMB was proposed in Large Batch Optimization for Deep Learning: Training BERT in 76 minutes.
Parameters
- params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.
- lr (float, optional) – learning rate. (default: 1e-3)
- betas (Tuple [_float,_ float] , optional) – coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999))
- eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)
- weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
- amsgrad (boolean , optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and BeyondNOT SUPPORTED now! (default: False)
- adam_w_mode (boolean , optional) – Apply L2 regularization or weight decay True for decoupled weight decay(also known as AdamW) (default: True)
- grad_averaging (bool, optional) – whether apply (1-beta2) to grad when calculating running averages of gradient. (default: True)
- set_grad_none (bool, optional) – whether set grad to None when zero_grad() method is called. (default: True)
- max_grad_norm (float, optional) – value used to clip global grad norm (default: 1.0)
Performs a single optimization step.
Parameters
closure (callable , optional) – A closure that reevaluates the model and returns the loss.
Clears the gradients of all optimized torch.Tensor
s.
class apex.optimizers.
FusedNovoGrad
(params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0, amsgrad=False, reg_inside_moment=False, grad_averaging=True, norm_type=2, init_zero=False, set_grad_none=True)[source]¶
Implements NovoGrad algorithm.
Currently GPU-only. Requires Apex to be installed viapip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
.
This version of fused NovoGrad implements 2 fusions.
- Fusion of the NovoGrad update’s elementwise operations
- A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.
apex.optimizers.FusedNovoGrad’s usage is identical to any Pytorch optimizer:
opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....) ... opt.step()
apex.optimizers.FusedNovoGrad may be used with or without Amp. If you wish to use FusedNovoGrad with Amp, you may choose any opt_level
:
opt = apex.optimizers.FusedNovoGrad(model.parameters(), lr = ....) model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") ... opt.step()
In general, opt_level="O1"
is recommended.
It has been proposed in Jasper: An End-to-End Convolutional Neural Acoustic Model. More info: https://nvidia.github.io/OpenSeq2Seq/html/optimizers.html#novograd
Parameters
- params (iterable) – iterable of parameters to optimize or dicts defining parameter groups.
- lr (float, optional) – learning rate. (default: 1e-3)
- betas (Tuple [_float,_ float] , optional) – coefficients used for computing running averages of gradient and its norm. (default: (0.9, 0.999))
- eps (float, optional) – term added to the denominator to improve numerical stability. (default: 1e-8)
- weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
- amsgrad (boolean , optional) – whether to use the AMSGrad variant of this algorithm from the paper On the Convergence of Adam and BeyondNOT SUPPORTED now! (default: False)
- reg_inside_moment (bool, optional) – whether do regularization (norm and L2) in momentum calculation. True for include, False for not include and only do it on update term. (default: False)
- grad_averaging (bool, optional) – whether apply (1-beta2) to grad when calculating running averages of gradient. (default: True)
- norm_type (int, optional) – which norm to calculate for each layer. 2 for L2 norm, and 0 for infinite norm. These 2 are only supported type now. (default: 2)
- init_zero (bool, optional) – whether init norm with 0 (start averaging on 1st step) or first step norm (start averaging on 2nd step). True for init with 0. (default: False)
- set_grad_none (bool, optional) – whether set grad to None when zero_grad() method is called. (default: True)
load_state_dict
(state_dict)[source]¶
Loads the optimizer state.
Parameters
state_dict (dict) – optimizer state. Should be an object returned from a call to state_dict()
.
Performs a single optimization step.
Parameters
closure (callable , optional) – A closure that reevaluates the model and returns the loss.
Clears the gradients of all optimized torch.Tensor
s.
class apex.optimizers.
FusedSGD
(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False, wd_after_momentum=False, materialize_master_grads=True)[source]¶
Implements stochastic gradient descent (optionally with momentum).
Currently GPU-only. Requires Apex to be installed viapip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
.
This version of fused SGD implements 2 fusions.
- Fusion of the SGD update’s elementwise operations
- A multi-tensor apply launch that batches the elementwise updates applied to all the model’s parameters into one or a few kernel launches.
apex.optimizers.FusedSGD may be used as a drop-in replacement for torch.optim.SGD
:
opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....) ... opt.step()
apex.optimizers.FusedSGD may be used with or without Amp. If you wish to use FusedSGD with Amp, you may choose any opt_level
:
opt = apex.optimizers.FusedSGD(model.parameters(), lr = ....) model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2") ... opt.step()
In general, opt_level="O1"
is recommended.
Nesterov momentum is based on the formula fromOn the importance of initialization and momentum in deep learning.
Parameters
- params (iterable) – iterable of parameters to optimize or dicts defining parameter groups
- lr (float) – learning rate
- momentum (float, optional) – momentum factor (default: 0)
- weight_decay (float, optional) – weight decay (L2 penalty) (default: 0)
- dampening (float, optional) – dampening for momentum (default: 0)
- nesterov (bool, optional) – enables Nesterov momentum (default: False)
Example
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) optimizer.zero_grad() loss_fn(model(input), target).backward() optimizer.step()
Note
The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
\[\begin{split}v = \rho * v + g \\ p = p - lr * v\end{split}\]
where p, g, v and \(\rho\) denote the parameters, gradient, velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and other frameworks which employ an update of the form
\[\begin{split}v = \rho * v + lr * g \\ p = p - v\end{split}\]
The Nesterov version is analogously modified.
Performs a single optimization step.
Parameters
closure (callable , optional) – A closure that reevaluates the model and returns the loss.