sgdmupdate -
Update parameters using stochastic gradient descent with
momentum (SGDM)
- MATLAB ([original](https://in.mathworks.com/help/deeplearning/ref/sgdmupdate.html)) ([raw](?raw))
Update parameters using stochastic gradient descent with momentum (SGDM)
Syntax
Description
Update the network learnable parameters in a custom training loop using the stochastic gradient descent with momentum (SGDM) algorithm.
Note
This function applies the SGDM optimization algorithm to update network parameters in custom training loops. To train a neural network using the trainnet function using the SGDM solver, use the trainingOptions function and set the solver to"sgdm"
.
[[netUpdated](#mw%5Fd1e20a30-cf4b-41c5-971f-dfdf35b169c9%5Fsep%5Fmw%5F4b03d933-4c38-4b0a-8778-a71088eb30f5),[vel](#mw%5Fd746ef72-4278-4cb3-8bce-00e0b9e977cb)] = sgdmupdate([net](#mw%5F9bade078-17c6-45d3-9d55-5c23c161b668),[grad](#mw%5F9fbb3de1-afe2-483c-b188-14521a7199cd),[vel](#mw%5Fd746ef72-4278-4cb3-8bce-00e0b9e977cb))
updates the learnable parameters of the network net
using the SGDM algorithm. Use this syntax in a training loop to iteratively update a network defined as adlnetwork
object.
[[params](#mw%5F8cdd32cc-92b2-4bb3-9876-483e58524cd7),[vel](#mw%5Fd746ef72-4278-4cb3-8bce-00e0b9e977cb)] = sgdmupdate([params](#mw%5F8cdd32cc-92b2-4bb3-9876-483e58524cd7),[grad](#mw%5F9fbb3de1-afe2-483c-b188-14521a7199cd),[vel](#mw%5Fd746ef72-4278-4cb3-8bce-00e0b9e977cb))
updates the learnable parameters in params
using the SGDM algorithm. Use this syntax in a training loop to iteratively update the learnable parameters of a network defined using functions.
[___] = sgdmupdate(___[learnRate](#mw%5F60788d0d-250c-4c88-9f74-e583554a0335),[momentum](#mw%5F5b02d1ad-7fd5-4475-9c83-7d5215a7a6b1))
also specifies values to use for the global learning rate and momentum, in addition to the input arguments in previous syntaxes.
Examples
Perform a single SGDM update step with a global learning rate of0.05
and momentum of 0.95
.
Create the parameters and parameter gradients as numeric arrays.
params = rand(3,3,4); grad = ones(3,3,4);
Initialize the parameter velocities for the first iteration.
Specify custom values for the global learning rate and momentum.
learnRate = 0.05; momentum = 0.95;
Update the learnable parameters using sgdmupdate
.
[params,vel] = sgdmupdate(params,grad,vel,learnRate,momentum);
Use sgdmupdate
to train a network using the SGDM algorithm.
Load Training Data
Load the digits training data.
[XTrain,TTrain] = digitTrain4DArrayData; classes = categories(TTrain); numClasses = numel(classes);
Define Network
Define the network architecture and specify the average image value using the Mean
option in the image input layer.
layers = [ imageInputLayer([28 28 1],'Mean',mean(XTrain,4)) convolution2dLayer(5,20) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer fullyConnectedLayer(numClasses) softmaxLayer];
Create a dlnetwork
object from the layer array.
**Define Model Loss Function
Create the helper function modelLoss
, listed at the end of the example. The function takes a dlnetwork
object and a mini-batch of input data with corresponding labels, and returns the loss and the gradients of the loss with respect to the learnable parameters.
**Specify Training Options
Specify the options to use during training.
miniBatchSize = 128; numEpochs = 20; numObservations = numel(TTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);
Train Network
Initialize the velocity parameter.
Calculate the total number of iterations for the training progress monitor.
numIterations = numEpochs * numIterationsPerEpoch;
Initialize the TrainingProgressMonitor
object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.
monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");
Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters using the sgdmupdate
function. At the end of each iteration, display the training progress.
Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).
iteration = 0; epoch = 0;
while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1;
% Shuffle data.
idx = randperm(numel(TTrain));
XTrain = XTrain(:,:,:,idx);
TTrain = TTrain(idx);
i = 0;
while i < numIterationsPerEpoch && ~monitor.Stop
i = i + 1;
iteration = iteration + 1;
% Read mini-batch of data and convert the labels to dummy
% variables.
idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
X = XTrain(:,:,:,idx);
T = zeros(numClasses, miniBatchSize,"single");
for c = 1:numClasses
T(c,TTrain(idx)==classes(c)) = 1;
end
% Convert mini-batch of data to a dlarray.
X = dlarray(single(X),"SSCB");
% If training on a GPU, then convert data to a gpuArray.
if canUseGPU
X = gpuArray(X);
end
% Evaluate the model loss and gradients using dlfeval and the
% modelLoss function.
[loss,gradients] = dlfeval(@modelLoss,net,X,T);
% Update the network parameters using the SGDM optimizer.
[net,vel] = sgdmupdate(net,gradients,vel);
% Update the training progress monitor.
recordMetrics(monitor,iteration,Loss=loss);
updateInfo(monitor,Epoch=epoch + " of " + numEpochs);
monitor.Progress = 100 * iteration/numIterations;
end
end
Test the Network
Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.
[XTest,TTest] = digitTest4DArrayData;
Convert the data to a dlarray
with the dimension format "SSCB"
(spatial, spatial, channel, batch). For GPU prediction, also convert the data to a gpuArray
.
XTest = dlarray(XTest,"SSCB"); if canUseGPU XTest = gpuArray(XTest); end
To classify images using a dlnetwork
object, use the predict
function and find the classes with the highest scores.
YTest = predict(net,XTest); [~,idx] = max(extractdata(YTest),[],1); YTest = classes(idx);
Evaluate the classification accuracy.
accuracy = mean(YTest==TTest)
Model Loss Function
The modelLoss
function takes a dlnetwork
object net
and a mini-batch of input data X
with corresponding labels T
, and returns the loss and the gradients of the loss with respect to the learnable parameters in net
. To compute the gradients automatically, use the dlgradient
function.
function [loss,gradients] = modelLoss(net,X,T)
Y = forward(net,X);
loss = crossentropy(Y,T);
gradients = dlgradient(loss,net.Learnables);
end
Input Arguments
Network, specified as a dlnetwork
object.
The function updates the Learnables
property of thedlnetwork
object. net.Learnables
is a table with three variables:
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Value of parameter, specified as a cell array containing adlarray
.
The input argument grad
must be a table of the same form as net.Learnables
.
Network learnable parameters, specified as a dlarray
, a numeric array, a cell array, a structure, or a table.
If you specify params
as a table, it must contain the following three variables.
Layer
— Layer name, specified as a string scalar.Parameter
— Parameter name, specified as a string scalar.Value
— Value of parameter, specified as a cell array containing adlarray
.
You can specify params
as a container of learnable parameters for your network using a cell array, structure, or table, or nested cell arrays or structures. The learnable parameters inside the cell array, structure, or table must bedlarray
or numeric values of data type double
orsingle
.
The input argument grad
must be provided with exactly the same data type, ordering, and fields (for structures) or variables (for tables) asparams
.
The learnables can be complex-valued. (since R2024a) Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The learnables must not be complex-valued. If your model involves complex learnables, then convert the learnables to real values before calculating the gradients.
Data Types: single
| double
| struct
| table
| cell
Gradients of the loss, specified as a dlarray
, a numeric array, a cell array, a structure, or a table.
The exact form of grad
depends on the input network or learnable parameters. The following table shows the required format for grad
for possible inputs to sgdmupdate
.
Input | Learnable Parameters | Gradients |
---|---|---|
net | Table net.Learnables containingLayer, Parameter, andValue variables. The Value variable consists of cell arrays that contain each learnable parameter as adlarray. | Table with the same data type, variables, and ordering asnet.Learnables. grad must have aValue variable consisting of cell arrays that contain the gradient of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering asparams |
Numeric array | Numeric array with the same data type and ordering asparams | |
Cell array | Cell array with the same data types, structure, and ordering asparams | |
Structure | Structure with the same data types, fields, and ordering asparams | |
Table with Layer, Parameter, andValue variables. The Value variable must consist of cell arrays that contain each learnable parameter as adlarray. | Table with the same data types, variables, and ordering asparams. grad must have aValue variable consisting of cell arrays that contain the gradient of each learnable parameter. |
You can obtain grad
from a call to dlfeval that evaluates a function that contains a call to dlgradient. For more information, see Use Automatic Differentiation In Deep Learning Toolbox.
The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.
Parameter velocities, specified as an empty array, a dlarray
, a numeric array, a cell array, a structure, or a table.
The exact form of vel
depends on the input network or learnable parameters. The following table shows the required format for vel
for possible inputs to sgdmpdate
.
Input | Learnable Parameters | Velocities |
---|---|---|
net | Table net.Learnables containingLayer, Parameter, andValue variables. The Value variable consists of cell arrays that contain each learnable parameter as adlarray. | Table with the same data type, variables, and ordering asnet.Learnables. vel must have aValue variable consisting of cell arrays that contain the velocity of each learnable parameter. |
params | dlarray | dlarray with the same data type and ordering asparams |
Numeric array | Numeric array with the same data type and ordering asparams | |
Cell array | Cell array with the same data types, structure, and ordering asparams | |
Structure | Structure with the same data types, fields, and ordering asparams | |
Table with Layer, Parameter, andValue variables. The Value variable must consist of cell arrays that contain each learnable parameter as adlarray. | Table with the same data types, variables, and ordering asparams. vel must have aValue variable consisting of cell arrays that contain the velocity of each learnable parameter. |
If you specify vel
as an empty array, the function assumes no previous velocities and runs in the same way as for the first update in a series of iterations. To update the learnable parameters iteratively, use thevel
output of a previous call to sgdmupdate
as the vel
input.
The velocity can be complex-valued. (since R2024a) Using complex valued gradients and velocities can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.
Learning rate, specified as a positive scalar. The default value oflearnRate
is 0.01
.
If you specify the network parameters as a dlnetwork
object, the learning rate for each parameter is the global learning rate multiplied by the corresponding learning rate factor property defined in the network layers.
Momentum, specified as a positive scalar between 0
and1
. The default value of momentum
is0.9
.
Output Arguments
Updated network, returned as a dlnetwork object.
The function updates the Learnables
property of thedlnetwork
object.
Updated network learnable parameters, returned as a dlarray
, a numeric array, a cell array, a structure, or a table with a Value
variable containing the updated learnable parameters of the network.
The learnables can be complex-valued. (since R2024a) Ensure that the corresponding operations support complex-valued learnables.
Before R2024a: The learnables must not be complex-valued. If your model involves complex learnables, then convert the learnables to real values before calculating the gradients.
Updated parameter velocities, returned as a dlarray
, a numeric array, a cell array, a structure, or a table.
Algorithms
The standard gradient descent algorithm updates the network parameters (weights and biases) to minimize the loss function by taking small steps at each iteration in the direction of the negative gradient of the loss,
where ℓis the iteration number, α>0 is the learning rate, θ is the parameter vector, and E(θ) is the loss function. In the standard gradient descent algorithm, the gradient of the loss function, ∇E(θ), is evaluated using the entire training set, and the standard gradient descent algorithm uses the entire data set at once.
By contrast, at each iteration the stochastic gradient descent algorithm evaluates the gradient and updates the parameters using a subset of the training data. A different subset, called a mini-batch, is used at each iteration. The full pass of the training algorithm over the entire training set using mini-batches is one_epoch_. Stochastic gradient descent is stochastic because the parameter updates computed using a mini-batch is a noisy estimate of the parameter update that would result from using the full data set.
The stochastic gradient descent algorithm can oscillate along the path of steepest descent towards the optimum. Adding a momentum term to the parameter update is one way to reduce this oscillation [1]. The stochastic gradient descent with momentum (SGDM) update is
where the learning rate α and the momentum value γ determine the contribution of the previous gradient step to the current iteration.
References
[1] Murphy, K. P. Machine Learning: A Probabilistic Perspective. The MIT Press, Cambridge, Massachusetts, 2012.
Extended Capabilities
The sgdmupdate
function supports GPU array input with these usage notes and limitations:
- When at least one of the following input arguments is a
gpuArray
or adlarray
with underlying data of typegpuArray
, this function runs on the GPU.grad
params
For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).
Version History
Introduced in R2019b
The learnable parameters, gradients, and velocity can be complex-valued. When the updated learnable parameters are complex-valued, ensure that the corresponding operations support complex-valued parameters.