Define Model Loss Function for Custom Training Loop - MATLAB & Simulink (original) (raw)
When you train a deep learning model with a custom training loop, the software minimizes the loss with respect to the learnable parameters. To minimize the loss, the software uses the gradients of the loss with respect to the learnable parameters. To calculate these gradients using automatic differentiation, you must define a model gradients function.
For an example showing how to train deep learning model with a dlnetwork
object, see Train Network Using Custom Training Loop. For an example showing how to training a deep learning model defined as a function, see Train Network Using Model Function.
Create Model Loss Function for Model Defined as dlnetwork
Object
If you have a deep learning model defined as a dlnetwork
object, then create a model loss function that takes the dlnetwork
object as input.
For a model specified as a dlnetwork
object, create a function of the form[loss,gradients] = modelLoss(net,X,T)
, where net
is the network, X
is the network input, T
contains the targets, and loss
and gradients
are the returned loss and gradients, respectively. Optionally, you can pass extra arguments to the gradients function (for example, if the loss function requires extra information), or return extra arguments (for example, the updated network state).
For example, this function returns the cross-entropy loss and the gradients of the loss with respect to the learnable parameters in the specifieddlnetwork
object net
, given input dataX
, and targets T
.
function [loss,gradients] = modelLoss(net,X,T)
% Forward data through the dlnetwork object.
Y = forward(net,X);
% Compute loss.
loss = crossentropy(Y,T);
% Compute gradients.
gradients = dlgradient(loss,net.Learnables);
end
Create Model Loss Function for Model Defined as Function
If you have a deep learning model defined as a function, then create a model loss function that takes the model learnable parameters as input.
For a model specified as a function, create a function of the form [loss,gradients] = modelLoss(parameters,X,T)
, where parameters
contains the learnable parameters, X
is the model input, T
contains the targets, and loss
and gradients
are the returned loss and gradients, respectively. Optionally, you can pass extra arguments to the gradients function (for example, if the loss function requires extra information), or return extra arguments (for example, the updated model state).
For example, this function returns the cross-entropy loss and the gradients of the loss with respect to the learnable parameters parameters
, given input data X
, and targets T
.
function [loss,gradients] = modelLoss(parameters,X,T)
% Forward data through the model function.
Y = model(parameters,X);
% Compute loss.
loss = crossentropy(Y,T);
% Compute gradients.
gradients = dlgradient(loss,parameters);
end
Evaluate Model Loss Function
To evaluate the model loss function using automatic differentiation, use the dlfeval function, which evaluates a function with automatic differentiation enabled. For the first input of dlfeval
, pass the model loss function specified as a function handle. For the following inputs, pass the required variables for the model loss function. For the outputs of the dlfeval
function, specify the same outputs as the model loss function.
For example, evaluate the model loss function modelLoss
with adlnetwork
object net
, input dataX
, and targets T
, and return the model loss and gradients.
[loss,gradients] = dlfeval(@modelLoss,net,X,T);
Similarly, evaluate the model loss function modelLoss
using a model function with learnable parameters specified by the structureparameters
, input data X
, and targetsT
, and return the model loss and gradients.
[loss,gradients] = dlfeval(@modelLoss,parameters,X,T);
Update Learnable Parameters Using Gradients
To update the learnable parameters, you can use these functions.
Function | Description |
---|---|
adamupdate | Update parameters using adaptive moment estimation (Adam) |
rmspropupdate | Update parameters using root mean squared propagation (RMSProp) |
sgdmupdate | Update parameters using stochastic gradient descent with momentum (SGDM) |
lbfgsupdate | Update parameters using limited-memory BFGS (L-BFGS) |
dlupdate | Update parameters using custom function |
For example, update the learnable parameters of a dlnetwork
objectnet
using the adamupdate
function.
[net,trailingAvg,trailingAvgSq] = adamupdate(net,gradients, ... trailingAvg,trailingAverageSq,iteration);
Here,gradients
is the gradients of the loss with respect to the learnable parameters, and trailingAvg
,trailingAvgSq
, and iteration
are the hyperparameters required by the adamupdate
function.
Similarly, update the learnable parameters for a model functionparameters
using the adamupdate
function.
[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ... trailingAvg,trailingAverageSq,iteration);
Here,gradients
is the gradients of the loss with respect to the learnable parameters, and trailingAvg
,trailingAvgSq
, and iteration
are the hyperparameters required by the adamupdate
function.
Use Model Loss Function in Custom Training Loop
When training a deep learning model using a custom training loop, evaluate the model loss and gradients and update the learnable parameters for each mini-batch.
This code snippet shows an example of using the dlfeval
andadamupdate
functions in a custom training loop.
iteration = 0;
% Loop over epochs. for epoch = 1:numEpochs
% Loop over mini-batches.
for i = 1:numIterationsPerEpoch
iteration = iteration + 1;
% Prepare mini-batch.
% ...
% Evaluate model loss and gradients.
[loss,gradients] = dlfeval(@modelLoss,net,X,T);
% Update learnable parameters.
[parameters,trailingAvg,trailingAvgSq] = adamupdate(parameters,gradients, ...
trailingAvg,trailingAverageSq,iteration);
end
end
For an example showing how to train a deep learning model with adlnetwork
object, see Train Network Using Custom Training Loop. For an example showing how to training a deep learning model defined as a function, see Train Network Using Model Function.
Debug Model Loss Functions
If the implementation of the model loss function has an issue, then the call todlfeval
can throw an error. Sometimes, when you use thedlfeval
function, it is not clear which line of code is throwing the error. To help locate the error, you can try the following.
Call Model Loss Function Directly
Try calling the model loss function directly (that is, without using thedlfeval
function) with generated inputs of the expected sizes. If any of the lines of code throw an error, then the error message provides extra detail. Note that when you do not use the dlfeval
function, any calls to the dlgradient
function throw an error.
% Generate image input data. X = rand([28 28 1 100],'single'); X = dlarray(X);
% Generate one-hot encoded target data. T = repmat(eye(10,'single'),[1 10]);
[loss,gradients] = modelLoss(net,X,T);
Run Model Loss Code Manually
Run the code inside the model loss function manually with generated inputs of the expected sizes and inspect the output and any thrown error messages.
For example, consider the following model loss function.
function [loss,gradients] = modelLoss(net,X,T)
% Forward data through the dlnetwork object.
Y = forward(net,X);
% Compute loss.
loss = crossentropy(Y,T);
% Compute gradients.
gradients = dlgradient(loss,net.Learnables);
end
Check the model loss function by running the following code.
% Generate image input data. X = rand([28 28 1 100],'single'); X = dlarray(X);
% Generate one-hot encoded target data. T = repmat(eye(10,'single'),[1 10]);
% Check forward pass. Y = forward(net,X);
% Check loss calculation. loss = crossentropy(Y,T)
Related Topics
- Train Network Using Custom Training Loop
- Train Network Using Model Function
- Define Custom Training Loops, Loss Functions, and Networks
- Specify Training Options in Custom Training Loop
- Update Batch Normalization Statistics in Custom Training Loop
- Make Predictions Using dlnetwork Object
- List of Functions with dlarray Support