Specify Custom Operation Backward Function - MATLAB & Simulink (original) (raw)

When you define a custom loss function, custom layer forward function, or define a deep learning model as a function, if the software does not provide the deep learning operation that you require for your task, then you can define your own function usingdlarray objects.

Most deep learning workflows use gradients to train the model. If the function only uses functions that support dlarray objects, then you can use the functions directly and the software determines the gradients automatically using automatic differentiation. For example, you can pass dlarray object functions like crossentropy to as a loss function to the trainnet function, or use dlarray object functions like dlconv in custom layer functions. For a list of functions that support dlarray objects, see List of Functions with dlarray Support.

If you want to use functions that do not support dlarray objects, or want to use a specific algorithm to compute the gradients, then you can define a custom deep learning operation as a differentiable function object.

To define a custom deep learning operation as a differentiable function, you can use the template provided in this example, which takes you through these steps:

When you define the functions, you can use dlarray objects. Using dlarray objects makes working with high dimensional data easier by allowing you to label the dimensions. For example, you can label which dimensions correspond to spatial, time, channel, and batch dimensions using the"S", "T", "C", and"B" labels, respectively. For unspecified and other dimensions, use the"U" label. For dlarray object functions that operate over particular dimensions, you can specify the dimension labels by formatting thedlarray object directly, or by using the DataFormat option.

This example shows how to create a SReLU function, which is an operation five inputs.A SReLU layer performs a thresholding operation, where for each channel, the layer scales values outside an interval. The interval thresholds and scaling factors are learnable parameters. [1].

The SReLU operation is given by

where xi is the input on channel i, tli and tri are the left and right thresholds on channel i, respectively, and ali and ari are the left and right scaling factors on channel i, respectively.

Custom Operation Template

To define a custom deep learning operation, use this class definition template. This template gives the structure of a custom operation class definition. It outlines:

classdef myFunction < deep.DifferentiableFunction

properties
    % (Optional) Operation properties.

    % Declare operation properties here.
end

methods
    function fcn = myFunction
        % Create a myFunction. 
        % This function must have the same name as the class.

        fcn@deep.DifferentiableFunction(numOutputs, ...
            SaveInputsForBackward=tf, ...
            SaveOutputsForBackward=tf, ...
            NumMemoryValues=K);
    end

    function [Y,memory] = forward(fcn,X)
        % Forward input data through the function and output the result
        % and a memory value.
        %
        % Inputs:
        %         fcn - Function object to forward propagate through 
        %         X   - Function input data
        % Outputs:
        %         Y      - Output of function forward function 
        %         memory - (Optional) Memory value for backward
        %                  function
        %
        %  - For functions with multiple inputs, replace X with 
        %    X1,...,XN, where N is the number of inputs.
        %  - For functions with multiple outputs, replace Y with
        %    Y1,...,YM, where M is the number of outputs.
        %  - For functions with multiple memory outputs, replace
        %    memory with memory1,...,memoryK, where K is the
        %    number of memory outputs.

        % Define forward function here.
    end

    function dLdX = backward(fcn,dLdY,computeGradients,X,Y,memory)
        % Backward propagate the derivative of the loss function 
        % through the function.
        %
        % Inputs:
        %         fcn              - Function object to backward 
        %                            propagate through 
        %         dLdY             - Derivative of loss with respect to
        %                            function output
        %         computeGradients - Logical flag indicating whether to
        %                            compute gradients
        %         X                - (Optional) Functon input data 
        %         Y                - (Optional) Function output data
        %         memory           - (Optional) Memory value from 
        %                            forward function
        % Outputs:
        %         dLdX   - Derivative of loss with respect to function
        %                  input 
        %
        %  - For functions with multiple inputs, replace X and dLdX 
        %    with X1,...,XN and dLdX1,...,dLdXN, respectively, where N 
        %    is the number of inputs. In this case, computeGradients is
        %    a logical vector of size N, where non-zero elements 
        %    indicate to compute gradients for the corresponding input.
        %  - For functions with multiple outputs, replace Y and dLdY 
        %    with Y1,...,YM and dLdY,...,dLdYM, respectively, where M 
        %    is the number of outputs.

        % Define backward function here.
    end
end

end

Name Function

First, give the operation a name. In the first line of the class file, replace the existing name myFunction withsreluFunction.

classdef sreluFunction < deep.DifferentiableFunction ... end

Next, rename the myFunction constructor function (the first function in the methods section) so that it has the same name as the layer and update the header comment.

methods
    function fcn = sreluFunction           
        % Create a sreluFunction. 

        ...
    end

    ...
end

Save the Function

Save the class file in a new file named sreluFunction.m. The file name must match the function name. To use the function, you must save the file in the current folder or in a folder on the MATLAB path.

Declare Properties and Learnable Parameters

Declare the layer properties in the properties section.

Tip

The forward and backward functions receive data specified as numeric arrays. For formatted dlarray workflows, you can create a property Format that stores the format of the input data. You can then use this property value in theforward and backward functions.

The SReLU backward operation requires the input data format information, so declare the property Format that stores the input data format.

Create Constructor Function

Create the constructor function that constructs the function object and specifies the number of outputs. Specify any variables required to create the function as inputs to the constructor function.

To construct an instance of the object, use the commandfcn@deep.DifferentiableFunction(numOutputs), wherenumOutputs is the number of outputs of the operation. This command instantiates the function object with the specified output size using the constructor function of the superclass deep.DifferentiableFunction. Thedeep.DifferentiableFunction function has additional optional name-value arguments:

The SReLU operation has one output. The SReLU backward operation requires the input data format information, so specify the format as an input argument and store it in the Format property. The backward function also requires the operation input data.

    function fcn = sreluFunction(format)
        % Create a sreluFunction.
        % 
        % fcn = sreluFunction(format) create a sreluFunction object
        % that operations on data with the specified format.

        fcn@deep.DifferentiableFunction(1,SaveInputsForBackward=true); 
        fcn.Format = string(format);
    end

Create Forward Functions

Create the forward function of the operation named forward that propagates the data forward through the operation and outputs the result.

The forward function defines the deep learning forward pass operation. It has the syntax [Y,memory] = forward(~,X). The function has these inputs and outputs:

You can adjust the syntax for operations with multiple inputs, outputs, and memory values:

Tip

If the number of inputs to the operation can vary, then use varargin instead of X1,…,XN. In this case, varargin is a cell array of the inputs, where varargin{i} corresponds toXi.

Because the SReLU operation has five inputs input, one output, and no memory values, the syntax for forward for the SReLU operation is Y = forward(~,X,tl,al,tr,ar), where

The SReLU operation is given by

where xi is the input on channel i, tli and tri are the left and right thresholds on channel i, respectively, and ali and ari are the left and right scaling factors on channel i, respectively.

Implement this operation in forward. In predict. Add a comment to the top of the function that explains the syntaxes of the function.

    function Y = forward(~,X,tl,al,tr,ar)
        % Forward input data through the function and output the result
        % and a memory value.
        %
        % Inputs:
        %         X  - Functon input data
        %         tl - Left threshold
        %         al - Left slope
        %         tr - Right threshold
        %         ar - Right slope
        % Outputs:
        %         Y - Output of function forward function

        Y = (X <= tl) .* (tl + al.*(X-tl)) ...
            + ((tl < X) & (X < tr)) .* X ...
            + (tr <= X) .* (tr + ar.*(X-tr));
    end

Create Backward Function

Implement the backward function that returns the derivatives of the loss with respect to the input data.

The backward function defines the operation backward function. It has the syntax dLdX = backward(fcn,dLdY,computeGradients,X,Y,memory). The function has these inputs and output:

The values of X and Y are the same as in the forward function. The dimensions of dLdY are the same as the dimensions of Y.

The dimensions and data type of dLdX are the same as the dimensions and data type of X.

You can adjust the syntaxes for operations with multiple inputs, and multiple outputs:

To reduce memory usage by preventing unused variables being saved between the forward and backward pass, replace the corresponding input arguments with ~.

Tip

If the number of inputs to backward can vary, then use varargin instead of the input arguments after fcn. In this case, varargin is a cell array of the inputs, where:

Because the SReLU operation has five inputs, one output, requires the operation input data, and does not require the outputs of the layer forward function or a memory value, the syntax for backward for the SReLU operation is[dLdX,dLdtl,dLdtr,dLdal,dLdar] = backward(~,dLdY,computeGradients,X,tl,al,tr,ar).

The values of X, tl, al,tr, and ar are the same as in the forward function. The dimensions of dLdY are the same as the dimensions of the output Y of the forward function. The dimensions and data type ofdLdX, dLdtl, dLdtr,dLdal, and dLdar are the same as the dimensions and data type of X, tl,tr, al, and ar, respectively. The input computeGradients is a logical vector with5 elements, where non-zero elements indicate to compute gradients for the corresponding input.

The derivative of the loss with respect to the input data is

where ∂L/∂f(xi,til,ail,tir,air) is the gradient propagated from the next operation, and the derivative of the activation is

The derivative of the loss with respect to the parameter_tli_ is

where i indexes the channels, j indexes the remaining elements, and the gradient of the activation is

Similarly, for the other parameters, the gradients are:

Create the backward function that returns these derivatives. For each input, only compute the gradients when the corresponding entry in thecomputeGradients argument is 1 (true). For the gradients wherecomputeGradients is 0 (false), return[].

    function [dLdX,dLdtl,dLdtr,dLdal,dLdar] = backward(~,dLdY,computeGradients,X,tl,al,tr,ar)
        % Backward propagate the derivative of the loss function
        % through the function.
        %
        % Inputs:
        %         dLdY             - Derivative of loss with respect to
        %                            function output
        %         computeGradients - Logical vector indicating which
        %                            gradients to compute
        %         X                - Functon input data
        %         tl               - Left threshold
        %         al               - Left slope
        %         tr               - Right threshold
        %         ar               - Right slope
        % Outputs:
        %         dLdX  - Derivative of loss with respect to function
        %                 input
        %         dLdtl - Derivative of loss with respect to left
        %                 threshold
        %         dLdal - Derivative of loss with respect to left 
        %                 slope
        %         dLdtr - Derivative of loss with respect to right
        %                 threshold
        %         dLdar - Derivative of loss with respect to right
        %                 slope

        ndims = strlength(fcn.Format);
        idxC = strfind(fcn.Format,"C");

        dLdX = [];
        if computeGradients(1)
            dYdX = zeros(size(X),"like",X);
            dYdX(X <= tl) = al;
            dYdX(tl < X & X < tr) = 1;
            dYdX(tr <= X) = ar;
            dLdX = dLdY .* dYdX;
        end

        idx = setdiff(1:ndims,idxC);

        dLdtl  = [];
        if computeGradients(2)
            dLdtl = sum(dLdY .* (X <= tl) .* (1 - al),idx);
        end

        dLdtr = [];
        if computeGradients(3)
        dLdtr = sum(dLdY .* (tr <= X) .* (1 - ar),idx);
        end

        dLdal = [];
        if computeGradients(4)
            dLdal = sum(dLdY .* (X <= tl) .* (X - tl),idx);
        end

        dLdar = [];
        if computeGradients(5)
            dLdar = sum(dLdY .* (tr <= X) .* (X - tr),idx);
        end
    end

Completed Function

View the completed class file.

classdef sreluFunction < deep.DifferentiableFunction

properties
    Format
end

methods
    function fcn = sreluFunction(format)
        % Create a sreluFunction.
        % 
        % fcn = sreluFunction(format) create a sreluFunction object
        % that operations on data with the specified format.

        fcn@deep.DifferentiableFunction(1,SaveInputsForBackward=true); 
        fcn.Format = string(format);
    end

    function Y = forward(~,X,tl,al,tr,ar)
        % Forward input data through the function and output the result
        % and a memory value.
        %
        % Inputs:
        %         X  - Functon input data
        %         tl - Left threshold
        %         al - Left slope
        %         tr - Right threshold
        %         ar - Right slope
        % Outputs:
        %         Y - Output of function forward function

        Y = (X <= tl) .* (tl + al.*(X-tl)) ...
            + ((tl < X) & (X < tr)) .* X ...
            + (tr <= X) .* (tr + ar.*(X-tr));
    end

    function [dLdX,dLdtl,dLdtr,dLdal,dLdar] = backward(~,dLdY,computeGradients,X,tl,al,tr,ar)
        % Backward propagate the derivative of the loss function
        % through the function.
        %
        % Inputs:
        %         dLdY             - Derivative of loss with respect to
        %                            function output
        %         computeGradients - Logical vector indicating which
        %                            gradients to compute
        %         X                - Functon input data
        %         tl               - Left threshold
        %         al               - Left slope
        %         tr               - Right threshold
        %         ar               - Right slope
        % Outputs:
        %         dLdX  - Derivative of loss with respect to function
        %                 input
        %         dLdtl - Derivative of loss with respect to left
        %                 threshold
        %         dLdal - Derivative of loss with respect to left 
        %                 slope
        %         dLdtr - Derivative of loss with respect to right
        %                 threshold
        %         dLdar - Derivative of loss with respect to right
        %                 slope

        ndims = strlength(fcn.Format);
        idxC = strfind(fcn.Format,"C");

        dLdX = [];
        if computeGradients(1)
            dYdX = zeros(size(X),"like",X);
            dYdX(X <= tl) = al;
            dYdX(tl < X & X < tr) = 1;
            dYdX(tr <= X) = ar;
            dLdX = dLdY .* dYdX;
        end

        idx = setdiff(1:ndims,idxC);

        dLdtl  = [];
        if computeGradients(2)
            dLdtl = sum(dLdY .* (X <= tl) .* (1 - al),idx);
        end

        dLdtr = [];
        if computeGradients(3)
        dLdtr = sum(dLdY .* (tr <= X) .* (1 - ar),idx);
        end

        dLdal = [];
        if computeGradients(4)
            dLdal = sum(dLdY .* (X <= tl) .* (X - tl),idx);
        end

        dLdar = [];
        if computeGradients(5)
            dLdar = sum(dLdY .* (tr <= X) .* (X - tr),idx);
        end
    end
end

end

Create Interface Function

To use the function in a deep learning model, create a function that takes input data, creates and configures a differentiable function object, evaluates the operation, and returns the result.

Because the differentiable function object strips the dlarray formats from the input data, convert the output to a formatted dlarray.

Create the interface function for the differentiable function object.

function Y = srelu(X,tl,al,tr,ar)

format = dims(X);

fcn = sreluFunction(format); Y = fcn(X,tl,al,tr,ar);

Y = dlarray(Y,format);

end

For an example that shows how to train a deep learning model defined as a function that uses a custom SReLU operation with a custom backward function, see Train Model Using Custom Backward Function.

See Also

trainnet | trainingOptions | dlnetwork | functionLayer