adamupdate - Update parameters using adaptive moment estimation (Adam) - MATLAB (original) (raw)

Update parameters using adaptive moment estimation (Adam)

Syntax

Description

Update the network learnable parameters in a custom training loop using the adaptive moment estimation (Adam) algorithm.

Note

This function applies the Adam optimization algorithm to update network parameters in custom training loops. To train a neural network using the trainnet function using the Adam solver, use the trainingOptions function and set the solver to"adam".

[[netUpdated](#mw%5F1400a3cf-e891-44c4-81bf-b6aac542f3f1%5Fsep%5Fmw%5F4b03d933-4c38-4b0a-8778-a71088eb30f5),[averageGrad](#mw%5F8e91f1e3-a860-440f-87a6-a75476c7d598),[averageSqGrad](#mw%5F648ca5f5-a005-4cab-a480-bd9c0c709376)] = adamupdate([net](#mw%5Fbb2e9134-6469-4b63-869f-194a21fc6566),[grad](#mw%5Fd2dde031-b7e8-46da-b00d-7972166f62c1),[averageGrad](#mw%5F8e91f1e3-a860-440f-87a6-a75476c7d598),[averageSqGrad](#mw%5F648ca5f5-a005-4cab-a480-bd9c0c709376),[iteration](#mw%5Fc07f004e-6a1e-4d0b-ae4b-56cd50e894c7)) updates the learnable parameters of the network net using the Adam algorithm. Use this syntax in a training loop to iteratively update a network defined as adlnetwork object.

example

[[params](#mw%5Fc8807d6d-ffa5-42d1-9819-664895180778),[averageGrad](#mw%5F8e91f1e3-a860-440f-87a6-a75476c7d598),[averageSqGrad](#mw%5F648ca5f5-a005-4cab-a480-bd9c0c709376)] = adamupdate([params](#mw%5Fc8807d6d-ffa5-42d1-9819-664895180778),[grad](#mw%5Fd2dde031-b7e8-46da-b00d-7972166f62c1),[averageGrad](#mw%5F8e91f1e3-a860-440f-87a6-a75476c7d598),[averageSqGrad](#mw%5F648ca5f5-a005-4cab-a480-bd9c0c709376),[iteration](#mw%5Fc07f004e-6a1e-4d0b-ae4b-56cd50e894c7)) updates the learnable parameters in params using the Adam algorithm. Use this syntax in a training loop to iteratively update the learnable parameters of a network defined using functions.

example

[___] = adamupdate(___[learnRate](#mw%5Fccbe6be8-0814-45e0-b174-941657659eaf),[gradDecay](#mw%5Ff8412970-a894-4a1d-a1b9-f389270e2dd7),[sqGradDecay](#mw%5F687086c7-af0a-4248-a5b0-1326277d0272),[epsilon](#mw%5F1daccf73-f308-4016-807c-596721a29b94)) also specifies values to use for the global learning rate, gradient decay, square gradient decay, and small constant epsilon, in addition to the input arguments in previous syntaxes.

example

Examples

collapse all

Update Learnable Parameters Using adamupdate

Perform a single adaptive moment estimation update step with a global learning rate of 0.05, gradient decay factor of 0.75, and squared gradient decay factor of 0.95.

Create the parameters and parameter gradients as numeric arrays.

params = rand(3,3,4); grad = ones(3,3,4);

Initialize the iteration counter, average gradient, and average squared gradient for the first iteration.

iteration = 1; averageGrad = []; averageSqGrad = [];

Specify custom values for the global learning rate, gradient decay factor, and squared gradient decay factor.

learnRate = 0.05; gradDecay = 0.75; sqGradDecay = 0.95;

Update the learnable parameters using adamupdate.

[params,averageGrad,averageSqGrad] = adamupdate(params,grad,averageGrad,averageSqGrad,iteration,learnRate,gradDecay,sqGradDecay);

Update the iteration counter.

iteration = iteration + 1;

Train Network Using adamupdate

Use adamupdate to train a network using the Adam algorithm.

Load Training Data

Load the digits training data.

[XTrain,TTrain] = digitTrain4DArrayData; classes = categories(TTrain); numClasses = numel(classes);

Define Network

Define the network and specify the average image value using the Mean option in the image input layer.

layers = [ imageInputLayer([28 28 1],'Mean',mean(XTrain,4)) convolution2dLayer(5,20) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer convolution2dLayer(3,20,'Padding',1) reluLayer fullyConnectedLayer(numClasses) softmaxLayer];

Create a dlnetwork object from the layer array.

**Define Model Loss Function

Create the helper function modelLoss, listed at the end of the example. The function takes a dlnetwork object and a mini-batch of input data with corresponding labels, and returns the loss and the gradients of the loss with respect to the learnable parameters.

**Specify Training Options

Specify the options to use during training.

miniBatchSize = 128; numEpochs = 20; numObservations = numel(TTrain); numIterationsPerEpoch = floor(numObservations./miniBatchSize);

Train Network

Initialize the average gradients and squared average gradients.

averageGrad = []; averageSqGrad = [];

Calculate the total number of iterations for the training progress monitor.

numIterations = numEpochs * numIterationsPerEpoch;

Initialize the TrainingProgressMonitor object. Because the timer starts when you create the monitor object, make sure that you create the object close to the training loop.

monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");

Train the model using a custom training loop. For each epoch, shuffle the data and loop over mini-batches of data. Update the network parameters using the adamupdate function. At the end of each iteration, display the training progress.

Train on a GPU, if one is available. Using a GPU requires Parallel Computing Toolbox™ and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox).

iteration = 0; epoch = 0;

while epoch < numEpochs && ~monitor.Stop epoch = epoch + 1;

% Shuffle data.
idx = randperm(numel(TTrain));
XTrain = XTrain(:,:,:,idx);
TTrain = TTrain(idx);

i = 0;
while i < numIterationsPerEpoch && ~monitor.Stop
    i = i + 1;
    iteration = iteration + 1;

    % Read mini-batch of data and convert the labels to dummy
    % variables.
    idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
    X = XTrain(:,:,:,idx);

    T = zeros(numClasses, miniBatchSize,"single");
    for c = 1:numClasses
        T(c,TTrain(idx)==classes(c)) = 1;
    end

    % Convert mini-batch of data to a dlarray.
    X = dlarray(single(X),"SSCB");

    % If training on a GPU, then convert data to a gpuArray.
    if canUseGPU
        X = gpuArray(X);
    end

    % Evaluate the model loss and gradients using dlfeval and the
    % modelLoss function.
    [loss,gradients] = dlfeval(@modelLoss,net,X,T);

    % Update the network parameters using the Adam optimizer.
    [net,averageGrad,averageSqGrad] = adamupdate(net,gradients,averageGrad,averageSqGrad,iteration);

    % Update the training progress monitor.
    recordMetrics(monitor,iteration,Loss=loss);
    updateInfo(monitor,Epoch=epoch + " of " + numEpochs);
    monitor.Progress = 100 * iteration/numIterations;
end

end

Test Network

Test the classification accuracy of the model by comparing the predictions on a test set with the true labels.

[XTest,TTest] = digitTest4DArrayData;

Convert the data to a dlarray with the dimension format "SSCB" (spatial, spatial, channel, batch). For GPU prediction, also convert the data to a gpuArray.

XTest = dlarray(XTest,"SSCB"); if canUseGPU XTest = gpuArray(XTest); end

To classify images using a dlnetwork object, use the predict function and find the classes with the highest scores.

YTest = predict(net,XTest); [~,idx] = max(extractdata(YTest),[],1); YTest = classes(idx);

Evaluate the classification accuracy.

accuracy = mean(YTest==TTest)

Model Loss Function

The modelLoss helper function takes a dlnetwork object net and a mini-batch of input data X with corresponding labels T, and returns the loss and the gradients of the loss with respect to the learnable parameters in net. To compute the gradients automatically, use the dlgradient function.

function [loss,gradients] = modelLoss(net,X,T)

Y = forward(net,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,net.Learnables);

end

Input Arguments

collapse all

net — Network

dlnetwork object

Network, specified as a dlnetwork object.

The function updates the Learnables property of thedlnetwork object. net.Learnables is a table with three variables:

The input argument grad must be a table of the same form as net.Learnables.

params — Network learnable parameters

dlarray | numeric array | cell array | structure | table

Network learnable parameters, specified as a dlarray, a numeric array, a cell array, a structure, or a table.

If you specify params as a table, it must contain the following three variables:

You can specify params as a container of learnable parameters for your network using a cell array, structure, or table, or nested cell arrays or structures. The learnable parameters inside the cell array, structure, or table must bedlarray or numeric values of data type double orsingle.

The input argument grad must be provided with exactly the same data type, ordering, and fields (for structures) or variables (for tables) asparams.

The learnables can be complex-valued. (since R2024a) Ensure that the corresponding operations support complex-valued learnables.

Before R2024a: The learnables must not be complex-valued. If your model involves complex learnables, then convert the learnables to real values before calculating the gradients.

grad — Gradients of the loss

dlarray | numeric array | cell array | structure | table

Gradients of the loss, specified as a dlarray, a numeric array, a cell array, a structure, or a table.

The exact form of grad depends on the input network or learnable parameters. The following table shows the required format for grad for possible inputs to adamupdate.

Input Learnable Parameters Gradients
net Table net.Learnables containingLayer, Parameter, andValue variables. The Value variable consists of cell arrays that contain each learnable parameter as adlarray. Table with the same data type, variables, and ordering asnet.Learnables. grad must have aValue variable consisting of cell arrays that contain the gradient of each learnable parameter.
params dlarray dlarray with the same data type and ordering asparams
Numeric array Numeric array with the same data type and ordering asparams
Cell array Cell array with the same data types, structure, and ordering asparams
Structure Structure with the same data types, fields, and ordering asparams
Table with Layer, Parameter, andValue variables. The Value variable must consist of cell arrays that contain each learnable parameter as adlarray. Table with the same data types, variables, and ordering asparams. grad must have aValue variable consisting of cell arrays that contain the gradient of each learnable parameter.

You can obtain grad from a call to dlfeval that evaluates a function that contains a call to dlgradient. For more information, see Use Automatic Differentiation In Deep Learning Toolbox.

The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.

Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.

averageGrad — Moving average of parameter gradients

[] | dlarray | numeric array | cell array | structure | table

Moving average of parameter gradients, specified as an empty array, adlarray, a numeric array, a cell array, a structure, or a table.

The exact form of averageGrad depends on the input network or learnable parameters. The following table shows the required format foraverageGrad for possible inputs toadamupdate.

Input Learnable Parameters Average Gradients
net Table net.Learnables containingLayer, Parameter, andValue variables. The Value variable consists of cell arrays that contain each learnable parameter as adlarray. Table with the same data type, variables, and ordering asnet.Learnables. averageGrad must have aValue variable consisting of cell arrays that contain the average gradient of each learnable parameter.
params dlarray dlarray with the same data type and ordering asparams
Numeric array Numeric array with the same data type and ordering asparams
Cell array Cell array with the same data types, structure, and ordering asparams
Structure Structure with the same data types, fields, and ordering asparams
Table with Layer, Parameter, andValue variables. The Value variable must consist of cell arrays that contain each learnable parameter as adlarray. Table with the same data types, variables, and ordering asparams. averageGrad must have aValue variable consisting of cell arrays that contain the average gradient of each learnable parameter.

If you specify averageGrad and averageSqGrad as empty arrays, the function assumes no previous gradients and runs in the same way as for the first update in a series of iterations. To update the learnable parameters iteratively, use the averageGrad output of a previous call toadamupdate as the averageGrad input.

The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.

Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.

averageSqGrad — Moving average of squared parameter gradients

[] | dlarray | numeric array | cell array | structure | table

Moving average of squared parameter gradients, specified as an empty array, adlarray, a numeric array, a cell array, a structure, or a table.

The exact form of averageSqGrad depends on the input network or learnable parameters. The following table shows the required format foraverageSqGrad for possible inputs toadamupdate.

Input Learnable parameters Average Squared Gradients
net Table net.Learnables containingLayer, Parameter, andValue variables. The Value variable consists of cell arrays that contain each learnable parameter as adlarray. Table with the same data type, variables, and ordering asnet.Learnables. averageSqGrad must have a Value variable consisting of cell arrays that contain the average squared gradient of each learnable parameter.
params dlarray dlarray with the same data type and ordering asparams
Numeric array Numeric array with the same data type and ordering asparams
Cell array Cell array with the same data types, structure, and ordering asparams
Structure Structure with the same data types, fields, and ordering asparams
Table with Layer, Parameter, andValue variables. The Value variable must consist of cell arrays that contain each learnable parameter as adlarray. Table with the same data types, variables and ordering asparams. averageSqGrad must have aValue variable consisting of cell arrays that contain the average squared gradient of each learnable parameter.

If you specify averageGrad and averageSqGrad as empty arrays, the function assumes no previous gradients and runs in the same way as for the first update in a series of iterations. To update the learnable parameters iteratively, use the averageSqGrad output of a previous call toadamupdate as the averageSqGrad input.

The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.

Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.

iteration — Iteration number

positive integer

Iteration number, specified as a positive integer. For the first call toadamupdate, use a value of 1. You must incrementiteration by 1 for each subsequent call in a series of calls to adamupdate. The Adam algorithm uses this value to correct for bias in the moving averages at the beginning of a set of iterations.

learnRate — Global learning rate

0.001 (default) | positive scalar

Global learning rate, specified as a positive scalar. The default value oflearnRate is 0.001.

If you specify the network parameters as a dlnetwork, the learning rate for each parameter is the global learning rate multiplied by the corresponding learning rate factor property defined in the network layers.

gradDecay — Gradient decay factor

0.9 (default) | positive scalar between 0 and 1

Gradient decay factor, specified as a positive scalar between 0 and 1. The default value of gradDecay is0.9.

sqGradDecay — Squared gradient decay factor

0.999 (default) | positive scalar between 0 and 1

Squared gradient decay factor, specified as a positive scalar between0 and 1. The default value ofsqGradDecay is 0.999.

epsilon — Small constant

1e-8 (default) | positive scalar

Small constant for preventing divide-by-zero errors, specified as a positive scalar. The default value of epsilon is 1e-8.

Output Arguments

collapse all

netUpdated — Updated network

dlnetwork object

Updated network, returned as a dlnetwork object.

The function updates the Learnables property of thedlnetwork object.

params — Updated network learnable parameters

dlarray | numeric array | cell array | structure | table

Updated network learnable parameters, returned as a dlarray, a numeric array, a cell array, a structure, or a table with a Value variable containing the updated learnable parameters of the network.

The learnables can be complex-valued. (since R2024a) Ensure that the corresponding operations support complex-valued learnables.

Before R2024a: The learnables must not be complex-valued. If your model involves complex learnables, then convert the learnables to real values before calculating the gradients.

averageGrad — Updated moving average of parameter gradients

dlarray | numeric array | cell array | structure | table

Updated moving average of parameter gradients, returned as adlarray, a numeric array, a cell array, a structure, or a table.

The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.

Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.

averageSqGrad — Updated moving average of squared parameter gradients

dlarray | numeric array | cell array | structure | table

Updated moving average of squared parameter gradients, returned as adlarray, a numeric array, a cell array, a structure, or a table.

The gradients can be complex-valued. (since R2024a) Using complex valued gradients can lead to complex-valued learnable parameters. Ensure that the corresponding operations support complex-valued learnables.

Before R2024a: The gradients must not be complex-valued. If your model involves complex numbers, then convert all outputs to real values before calculating the gradients.

Algorithms

collapse all

Adaptive Moment Estimation

Adaptive moment estimation (Adam) [1] uses a parameter update that is similar to RMSProp, but with an added momentum term. It keeps an element-wise moving average of both the parameter gradients and their squared values,

The β1 and_β2_ decay rates are the gradient decay and squared gradient decay factors, respectively. Adam uses the moving averages to update the network parameters as

The value α is the learning rate. If gradients over many iterations are similar, then using a moving average of the gradient enables the parameter updates to pick up momentum in a certain direction. If the gradients contain mostly noise, then the moving average of the gradient becomes smaller, and so the parameter updates become smaller too. The full Adam update also includes a mechanism to correct a bias that appears in the beginning of training. For more information, see [1].

References

[1] Kingma, Diederik, and Jimmy Ba. "Adam: A method for stochastic optimization." arXiv preprint arXiv:1412.6980 (2014).

Extended Capabilities

GPU Arrays

Accelerate code by running on a graphics processing unit (GPU) using Parallel Computing Toolbox™.

The adamupdate function supports GPU array input with these usage notes and limitations:

For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).

Version History

Introduced in R2019b

expand all

R2024a: Complex-valued learnable parameters and gradients

The learnable parameters, gradients, moving average of gradients, and moving average of squared gradients can be complex-valued. When the updated learnable parameters are complex-valued, ensure that the corresponding operations support complex-valued parameters.