Specify Custom Layer Backward Function - MATLAB & Simulink (original) (raw)
If Deep Learning Toolbox™ does not provide the layer you require for your task, then you can define your own custom layer. For a list of built-in layers, see List of Deep Learning Layers.
The example Define Custom Deep Learning Layer with Learnable Parameters shows how to create a custom SReLU layer and goes through the following steps:
- Name the layer — Give the layer a name so that you can use it in MATLAB®.
- Declare the layer properties — Specify the properties of the layer, including learnable parameters and state parameters.
- Create the constructor function (optional) — Specify how to construct the layer and initialize its properties. If you do not specify a constructor function, then at creation, the software initializes the
Name
,Description
, andType
properties with[]
and sets the number of layer inputs and outputs to1
. - Create initialize function (optional) — Specify how to initialize the learnable and state parameters when the software initializes the network. If you do not specify an initialize function, then the software does not initialize parameters when it initializes the network.
- Create forward functions — Specify how data passes forward through the layer (forward propagation) at prediction time and at training time.
- Create reset state function (optional) — Specify how to reset state parameters.
- Create a backward function (optional) — Specify the derivatives of the loss with respect to the input data and the learnable parameters (backward propagation). If you do not specify a backward function, then the forward functions must support
dlarray
objects.
If the forward function only uses functions that support dlarray
objects, then creating a backward function is optional. In this case, the software determines the derivatives automatically using automatic differentiation. For a list of functions that support dlarray
objects, see List of Functions with dlarray Support. If you want to use functions that do not support dlarray
objects, or want to use a specific algorithm for the backward function, then you can define a custom backward function using this example as a guide.
Create Custom Layer
The example Define Custom Deep Learning Layer with Learnable Parameters shows how to create a SReLU layer. A SReLU layer performs a thresholding operation, where for each channel, the layer scales values outside an interval. The interval thresholds and scaling factors are learnable parameters.[1]
The SReLU operation is given by
where xi is the input on channel i,tli and_tri_ are the left and right thresholds on channel i, respectively, and_ali_ and_ari_ are the left and right scaling factors on channel i, respectively. These threshold values and scaling factors are learnable parameter, which the layer learns during training.
View the layer created in the example Define Custom Deep Learning Layer with Learnable Parameters. This layer does not have a backward
function.
classdef sreluLayer < nnet.layer.Layer ... & nnet.layer.Acceleratable % Example custom SReLU layer.
properties (Learnable)
% Layer learnable parameters
LeftSlope
RightSlope
LeftThreshold
RightThreshold
end
methods
function layer = sreluLayer(args)
% layer = sreluLayer creates a SReLU layer.
%
% layer = sreluLayer(Name=name) also specifies the
% layer name.
arguments
args.Name = "";
end
% Set layer name.
layer.Name = args.Name;
% Set layer description.
layer.Description = "SReLU";
end
function layer = initialize(layer,layout)
% layer = initialize(layer,layout) initializes the layer
% learnable parameters using the specified input layout.
% Find number of channels.
idx = finddim(layout,"C");
numChannels = layout.Size(idx);
% Initialize empty learnable parameters.
sz = ones(1,numel(layout.Size);
sz(idx) = numChannels;
if isempty(layer.LeftSlope)
layer.LeftSlope = rand(sz);
end
if isempty(layer.RightSlope)
layer.RightSlope = rand(sz);
end
if isempty(layer.LeftThreshold)
layer.LeftThreshold = rand(sz);
end
if isempty(layer.RightThreshold)
layer.RightThreshold = rand(sz);
end
end
function Y = predict(layer, X)
% Y = predict(layer, X) forwards the input data X through the
% layer and outputs the result Y.
tl = layer.LeftThreshold;
al = layer.LeftSlope;
tr = layer.RightThreshold;
ar = layer.RightSlope;
Y = (X <= tl) .* (tl + al.*(X-tl)) ...
+ ((tl < X) & (X < tr)) .* X ...
+ (tr <= X) .* (tr + ar.*(X-tr));
end
end
end
Create Backward Function
Implement the backward
function that returns the derivatives of the loss with respect to the input data and the learnable parameters.
The backward
function syntax depends on the type of layer.
dLdX = backward(layer,X,Y,dLdY,memory)
returns the derivativesdLdX
of the loss with respect to the layer input, wherelayer
has a single input and a single output.Y
corresponds to the forward function output anddLdY
corresponds to the derivative of the loss with respect toY
. The function inputmemory
corresponds to the memory output of the forward function.[dLdX,dLdW] = backward(layer,X,Y,dLdY,memory)
also returns the derivativedLdW
of the loss with respect to the learnable parameter, wherelayer
has a single learnable parameter.[dLdX,dLdSin] = backward(layer,X,Y,dLdY,dLdSout,memory)
also returns the derivativedLdSin
of the loss with respect to the state input, wherelayer
has a single state parameter anddLdSout
corresponds to the derivative of the loss with respect to the layer state output.[dLdX,dLdW,dLdSin] = backward(layer,X,Y,dLdY,dLdSout,memory)
also returns the derivativedLdW
of the loss with respect to the learnable parameter and returns the derivativedLdSin
of the loss with respect to the layer state input, wherelayer
has a single state parameter and single learnable parameter.
You can adjust the syntaxes for layers with multiple inputs, multiple outputs, multiple learnable parameters, or multiple state parameters:
- For layers with multiple inputs, replace
X
anddLdX
withX1,...,XN
anddLdX1,...,dLdXN
, respectively, whereN
is the number of inputs. - For layers with multiple outputs, replace
Y
anddLdY
withY1,...,YM
anddLdY1,...,dLdYM
, respectively, whereM
is the number of outputs. - For layers with multiple learnable parameters, replace
dLdW
withdLdW1,...,dLdWP
, whereP
is the number of learnable parameters. - For layers with multiple state parameters, replace
dLdSin
anddLdSout
withdLdSin1,...,dLdSinK
anddLdSout1,...,dLdSoutK
, respectively, whereK
is the number of state parameters.
To reduce memory usage by preventing unused variables being saved between the forward and backward pass, replace the corresponding input arguments with ~
.
Tip
If the number of inputs to backward
can vary, then usevarargin
instead of the input arguments afterlayer
. In this case, varargin
is a cell array of the inputs, where the first N
elements correspond to theN
layer inputs, the next M
elements correspond to the M
layer outputs, the next M
elements correspond to the derivatives of the loss with respect to the M
layer outputs, the next K
elements correspond to the K
derivatives of the loss with respect to the K
state outputs, and the last element corresponds to memory
.
If the number of outputs can vary, then use varargout
instead of the output arguments. In this case, varargout
is a cell array of the outputs, where the first N
elements correspond to theN
the derivatives of the loss with respect to theN
layer inputs, the next P
elements correspond to the derivatives of the loss with respect to the P
learnable parameters, and the next K
elements correspond to the derivatives of the loss with respect to the K
state inputs.
Because a SReLU layer has only one input, one output, four learnable parameters, and does not require the outputs of the layer forward function or a memory value, the syntax for backward
for a SReLU layer is[dLdX,dLdLS,dLdRS,dldLT,dldRT] = backward(layer,X,~,dLdY,~)
. The dimensions of X
are the same as in the forward function. The dimensions of dLdY
are the same as the dimensions of the outputY
of the forward function. The dimensions and data type ofdLdX
are the same as the dimensions and data type ofX
. The dimension and data type of dLdLS
,dLdRS
, dldLT
, dldRT
are the same as the dimension and data type of the learnable parametersLeftSlope
, RightSlope
,LeftThreshold
, and RightThreshold
, respectively.
During the backward pass, the layer automatically updates the learnable parameters using the corresponding derivatives.
To include a custom layer in a network, the layer forward functions must accept the outputs of the previous layer and forward propagate arrays with the size expected by the next layer. Similarly, when the layer specifies a backward
function, it must accept inputs with the same size as the corresponding output of the forward function and backward propagate derivatives with the same size.
The derivative of the loss with respect to the input data is
where ∂L/∂f(xi) is the gradient propagated from the next layer, and the derivative of the activation is
The derivative of the loss with respect to the learnable parameter_tli_ is
where i indexes the channels, j indexes the remaining elements, and the gradient of the activation is
Similarly, for the other learnable parameters, the gradients are:
Create the backward function that returns these derivatives.
function [dLdX,dLdLS,dLdRS,dLdLT,dLdRT] = backward(layer,X,~,dLdY,~)
% [dLdX,dLdLS,dLdRS,dLdLT,dLdRT] = backward(layer,X,~,dLdY,~)
% backward propagates the derivative of the loss function
% through the layer.
% Inputs:
% layer - Layer to backward propagate through
% X - Input data
% dLdY - Gradient propagated from the deeper layer
% Outputs:
% dLdX - Derivative of the loss with respect to the
% input data
% dLdLS - Derivative of the loss with respect to the
% learnable parameter LeftScale
% dLdRS - Derivative of the loss with respect to the
% learnable parameter RightScale
% dLdLT - Derivative of the loss with respect to the
% learnable parameter LeftThreshold
% dLdRT - Derivative of the loss with respect to the
% learnable parameter RightThreshold
ndims = numel(dims(X));
idxC = finddim(X,"C");
X = stripdims(X);
dLdY = stripdims(dLdY);
tl = layer.LeftThreshold;
al = layer.LeftSlope;
tr = layer.RightThreshold;
ar = layer.RightSlope;
dYdX = (X <= tl) .* al ...
+ (X > tl & X < tr) ...
+ (X >= tr) .* ar;
idx = setdiff(1:ndims,idxC);
dLdLT = sum(dLdY .* (X <= tl) .* (1 - al),idx);
dLdRT = sum(dLdY .* (tr <= X) .* (1 - ar),idx);
dLdLS = sum(dLdY .* (X <= tl) .* (X - tl),idx);
dLdRS = sum(dLdY .* (tr <= X) .* (X - tr),idx);
end
Complete Layer
View the completed layer class file.
classdef sreluLayer < nnet.layer.Layer % Example custom SReLU layer.
properties (Learnable)
% Layer learnable parameters
LeftSlope
RightSlope
LeftThreshold
RightThreshold
end
methods
function layer = sreluLayer(args)
% layer = sreluLayer creates a SReLU layer.
%
% layer = sreluLayer(Name=name) also specifies the
% layer name.
arguments
args.Name = "";
end
% Set layer name.
layer.Name = args.Name;
% Set layer description.
layer.Description = "SReLU";
end
function layer = initialize(layer,layout)
% layer = initialize(layer,layout) initializes the layer
% learnable parameters using the specified input layout.
% Find number of channels.
idx = finddim(layout,"C");
numChannels = layout.Size(idx);
% Initialize empty learnable parameters.
sz = ones(1,numel(layout.Size);
sz(idx) = numChannels;
if isempty(layer.LeftSlope)
layer.LeftSlope = rand(sz);
end
if isempty(layer.RightSlope)
layer.RightSlope = rand(sz);
end
if isempty(layer.LeftThreshold)
layer.LeftThreshold = rand(sz);
end
if isempty(layer.RightThreshold)
layer.RightThreshold = rand(sz);
end
end
function Y = predict(layer, X)
% Y = predict(layer, X) forwards the input data X through the
% layer and outputs the result Y.
tl = layer.LeftThreshold;
al = layer.LeftSlope;
tr = layer.RightThreshold;
ar = layer.RightSlope;
Y = (X <= tl) .* (tl + al.*(X-tl)) ...
+ ((tl < X) & (X < tr)) .* X ...
+ (tr <= X) .* (tr + ar.*(X-tr));
end
function [dLdX,dLdLS,dLdRS,dLdLT,dLdRT] = backward(layer,X,~,dLdY,~)
% [dLdX,dLdLS,dLdRS,dLdLT,dLdRT] = backward(layer,X,~,dLdY,~)
% backward propagates the derivative of the loss function
% through the layer.
% Inputs:
% layer - Layer to backward propagate through
% X - Input data
% dLdY - Gradient propagated from the deeper layer
% Outputs:
% dLdX - Derivative of the loss with respect to the
% input data
% dLdLS - Derivative of the loss with respect to the
% learnable parameter LeftScale
% dLdRS - Derivative of the loss with respect to the
% learnable parameter RightScale
% dLdLT - Derivative of the loss with respect to the
% learnable parameter LeftThreshold
% dLdRT - Derivative of the loss with respect to the
% learnable parameter RightThreshold
ndims = numel(dims(X));
idxC = finddim(X,"C");
X = stripdims(X);
dLdY = stripdims(dLdY);
tl = layer.LeftThreshold;
al = layer.LeftSlope;
tr = layer.RightThreshold;
ar = layer.RightSlope;
dYdX = (X <= tl) .* al ...
+ (X > tl & X < tr) ...
+ (X >= tr) .* ar;
idx = setdiff(1:ndims,idxC);
dLdLT = sum(dLdY .* (X <= tl) .* (1 - al),idx);
dLdRT = sum(dLdY .* (tr <= X) .* (1 - ar),idx);
dLdLS = sum(dLdY .* (X <= tl) .* (X - tl),idx);
dLdRS = sum(dLdY .* (tr <= X) .* (X - tr),idx);
end
end
end
GPU Compatibility
If the layer forward functions fully support dlarray
objects, then the layer is GPU compatible. Otherwise, to be GPU compatible, the layer functions must support inputs and return outputs of type gpuArray (Parallel Computing Toolbox).
Many MATLAB built-in functions support gpuArray (Parallel Computing Toolbox) and dlarray
input arguments. For a list of functions that support dlarray
objects, see List of Functions with dlarray Support. For a list of functions that execute on a GPU, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox). To use a GPU for deep learning, you must also have a supported GPU device. For information on supported devices, seeGPU Computing Requirements (Parallel Computing Toolbox). For more information on working with GPUs in MATLAB, see GPU Computing in MATLAB (Parallel Computing Toolbox).
See Also
trainnet | trainingOptions | dlnetwork | functionLayer | checkLayer | setLearnRateFactor | setL2Factor | getLearnRateFactor | getL2Factor | findPlaceholderLayers | replaceLayer | PlaceholderLayer | networkDataLayout
Related Topics
- Define Custom Deep Learning Layers
- Define Custom Deep Learning Layer with Learnable Parameters
- Define Custom Deep Learning Layer with Multiple Inputs
- Define Custom Deep Learning Layer with Formatted Inputs
- Define Custom Recurrent Deep Learning Layer
- Define Custom Deep Learning Layer for Code Generation
- Define Nested Deep Learning Layer Using Network Composition
- Check Custom Layer Validity