Deep Learning Function Acceleration for Custom Training Loops - MATLAB & Simulink (original) (raw)

When using the dlfeval function in a custom training loop, the software traces each input dlarray object of the model loss function to determine the computation graph used for automatic differentiation. This tracing process can take some time and can spend time recomputing the same trace. By optimizing, caching, and reusing the traces, you can speed up gradient computation in deep learning functions. You can also optimize, cache, and reuse traces to accelerate other deep learning functions that do not require automatic differentiation, for example you can also accelerate model functions and functions used for prediction.

To speed up calls to deep learning functions, you can use the dlaccelerate function to create an AcceleratedFunction object that automatically optimizes, caches, and reuses the traces. You can use the dlaccelerate function to accelerate model functions and model loss functions directly.

The returned AcceleratedFunction object caches the traces of calls to the underlying function and reuses the cached result when the same input pattern reoccurs.

Try using dlaccelerate for function calls that:

Invoke the accelerated function as you would invoke the underlying function. Note that the accelerated function is not a function handle.

Note

When using the dlfeval function, the software automatically accelerates the forward and predict functions fordlnetwork input. If you accelerate a deep learning function where the majority of the computation takes place in calls to the forward orpredict functions for dlnetwork input, then you might not see an improvement in training time.

Because of the nature of caching traces, not all functions support acceleration.

The caching process can cache values that you might expect to change or that depend on external factors. You must take care when you accelerate functions that:

Because the caching process requires extra computation, acceleration can lead to longer running code in some cases. This scenario can happen when the software spends time creating new caches that do not get reused often. For example, when you pass multiple mini-batches of different sequence lengths to the function, the software triggers a new trace for each unique sequence length.

Accelerated functions can do the following when calculating a new trace only.

When using accelerated functions in parallel, such as when using aparfor loop, then each worker maintains its own cache. The cache is not transferred to the host.

Functions and custom layers used in accelerated functions must also support acceleration.

You can nest and recursively call accelerated functions. However, it is usually more efficient to have a single accelerated function.

Accelerate Deep Learning Function Directly

In most cases, you can accelerate deep learning functions directly. For example, you can accelerate the model loss function directly by replacing calls to the model loss function with calls to the corresponding accelerated function:

Consider the following use of the dlfeval function in a custom training loop.

[loss,gradients,state] = dlfeval(@modelLoss,parameters,X,T,state)

To accelerate the model loss function and evaluate the accelerated function, use thedlaccelerate function and evaluate the returned AcceleratedFunction object:

accfun = dlaccelerate(@modelLoss); [loss,gradients,state] = dlfeval(accfun,parameters,X,T,state)

Because the cached traces are not directly attached to theAcceleratedFunction object and that they are shared betweenAcceleratedFunction objects that use the same underlying function, you can create the AcceleratedFunction either in or before the custom training loop body.

Accelerate Parts of Deep Learning Function

If a deep learning function does not fully support acceleration, for example, functions that require an if statement with a condition that depends on the value of a dlarray object, then you can accelerate parts of a deep learning function by creating a separate function contains any supported function calls you want to accelerate.

For example, consider the following code snippet that calls different functions depending on whether the sum of the dlarray object X is negative or nonnegative.

if sum(X,"all") < 0 Y = negFun1(parameters,X); Y = negFun2(parameters,Y); else Y = posFun1(parameters,X); Y = posFun2(parameters,Y); end

Because the if statement depends on the value of adlarray object, a function that contains this code snippet does not support acceleration. However, if the blocks of code used inside the body of theif statement support acceleration, then you can accelerate these parts separately by creating a new function containing those blocks and accelerating the new functions instead.

For example, create the functions negFunAll andposFunAll that contain the blocks of code used in the body of theif statement.

function Y = negFunAll(parameters,X)

Y = negFun1(parameters,X); Y = negFun2(parameters,Y);

end

function Y = posFunAll(parameters,X)

Y = posFun1(parameters,X); Y = posFun2(parameters,Y);

end

Then, accelerate these functions and use them in the body of the if statement instead.

accfunNeg = dlaccelerate(@negFunAll) accfunPos = dlaccelerate(@posFunAll)

if sum(X,"all") < 0 Y = accfunNeg(parameters,X); else Y = accfunPos(parameters,X); end

Reusing Caches

Reusing a cached trace depends on the function inputs and outputs:

When necessary, the software caches any new traces by evaluating the underlying function and caching the resulting trace in the AcceleratedFunction object.

Caution

An AcceleratedFunction object is not aware of updates to the underlying function. If you modify the function associated with the accelerated function, then clear the cache using the clearCache object function or alternatively use the commandclear functions.

Storing and Clearing Caches

AcceleratedFunction objects store the cache in a queue:

The AcceleratedFunction objects do not directly hold the cache. This means that:

Accelerated functions that have the same underlying function share the same cache.

To clear the cache of an accelerated function, use the clearCache object function. Alternatively, you can clear all functions in the current MATLABĀ® session using the commands clear functions orclear all.

Note

Clearing the AcceleratedFunction variable does not clear the cache associated with the input function. To clear the cache for an AcceleratedFunction object that no longer exists in the workspace, create a new AcceleratedFunction object to the same function, and use theclearCache function on the new object. Alternatively, you can clear all functions in the current MATLAB session using the commands clear functions orclear all.

Acceleration Considerations

Because of the nature of caching traces, not all functions support acceleration.

The caching process can cache values that you might expect to change or that depend on external factors. You must take care when you accelerate functions that:

Because the caching process requires extra computation, acceleration can lead to longer running code in some cases. This scenario can happen when the software spends time creating new caches that do not get reused often. For example, when you pass multiple mini-batches of different sequence lengths to the function, the software triggers a new trace for each unique sequence length.

Accelerated functions can do the following when calculating a new trace only.

When using accelerated functions in parallel, such as when using aparfor loop, then each worker maintains its own cache. The cache is not transferred to the host.

Functions and custom layers used in accelerated functions must also support acceleration.

Function Inputs with Random or Frequently Changing Values

You must take care when you accelerate functions that take random or frequently changing values as input, such as a model loss function that takes random noise as input and adds it to the input data. If any random or frequently changing inputs to an accelerated function are not dlarray objects, then the function trigger a new trace for each previously unseen value.

You can check for scenarios like this by inspecting the Occupancy and HitRate properties of the AcceleratedFunction object. If the Occupancy property is high and the HitRate is low, then this can indicate that theAcceleratedFunction object creates many new traces that it does not reuse.

For dlarray object input, changes in value to not trigger new traces. To prevent frequently changing input from triggering new traces for each evaluation, refactor your code such that the random inputs aredlarray inputs.

For example, consider the model loss function that accepts a random array of noise values:

function [loss,gradients,state] = modelLoss(parameters,X,T,state,noise)

X = X + noise; [Y,state] = model(parameters,X,state); loss = crossentropy(Y,T); gradients = dlgradient(loss,parameters);

end

To accelerate this model loss function, convert the input noise to dlarray before evaluating the accelerated function. Because themodelLoss function also supports dlarray input for noise, you do not need to make changes to the function.

noise = dlarray(noise,"SSCB"); accfun = dlaccelerate(@modelLoss); [loss,gradients,state] = dlfeval(accfun,parameters,X,T,state,noise);

Alternatively, you can accelerate the parts of the model loss function that do not require the random input.

Functions with Random Number Generation

You must take care when you accelerate functions that use random number generation, such as functions that generate random noise to add to the input. When the software caches the trace of a function that generates random numbers that are notdlarray objects, the software caches the resulting random samples in the trace. When reusing the trace, the accelerated function uses the cached random sample. The accelerated function does not generate new random values.

Random number generation using the "like" option of the rand function with a dlarray object supports acceleration. To use random number generation in an accelerated function, ensure that the function uses the rand function with the "like" option set to a traced dlarray object (a dlarray object that depends on an input dlarray object).

For example, consider the following model loss function.

function [loss,gradients,state] = modelLoss(parameters,X,T,state)

sz = size(X); noise = rand(sz); X = X + noise;

[Y,state] = model(parameters,X,state); loss = crossentropy(Y,T); gradients = dlgradient(loss,parameters);

end

To ensure that the rand function generates a new value for each evaluation, use the "like" option with the traceddlarray objectX.

function [loss,gradients,state] = modelLoss(parameters,X,T,state)

sz = size(X); noise = rand(sz,"like",X); X = X + noise;

[Y,state] = model(parameters,X); loss = crossentropy(Y,T); gradients = dlgradient(loss,parameters);

end

Alternatively, you can accelerate the parts of the model loss function that do not require random number generation.

Using if Statements and while Loops

You must take care when you accelerate functions that use if statements andwhile loops. In particular, you can get unexpected results when you accelerate functions with if statements or while loops that yield different code paths for function inputs of the same size and format.

Accelerating functions with if statement or while loop conditions that depend on the values of the function input or values from external sources (for example, results of random number generation) can lead to unexpected behavior. When the accelerated function caches a new trace, if the function contains anif statement or while loop, then the software caches the trace of the resulting code path given by the if statement orwhile loop condition for that particular trace. Because changes in the value of the dlarray input do not trigger a new trace, when reusing the trace with different values, the software uses the same cached trace (which contains the same cached code path) even when a difference in value should result in a different code path.

Usually, accelerating functions that contain if statements orwhile loops with conditions that do not depend on the values of the function input or external factors (for example, while loops that iterate over elements in an array) does not result in unexpected behavior. For example, because changes in the size of a dlarray input trigger a new trace, when reusing the trace with inputs of the same size, the cached code path for inputs of that size remain consistent, even when there are differences in values.

To avoid unexpected behavior from caching code paths of if statements, you can refactor your code so that it determines the correct result by combining the results of all branches and extracting the desired solution.

For example, consider this code.

if tf Y = funcA(X); else Y = funcB(X); end

To support acceleration, you can replace it with code of the following form.

Y = tffuncA(X) + ~tffuncB(X);

Alternatively, to avoid unnecessary multiply operations, you can also use this replacement.

Y = cat(3,funcA(X),funcB(X)); Y = Y(:,:,[tf ~tf]);

Note that these techniques can result in longer running code because they require executing the code used in both branches of the if statement.

To use if statements and while loops that depend on dlarray object values, accelerate the body of theif statement or while loop only.

Function Inputs that Depend on Handles

You must take care when you accelerate functions that take objects that depend on handles as input, such as a minibatchqueue object that has a preprocessing function specified as a function handle. TheAcceleratedFunction object throws an error when evaluating the function with inputs depending on handles.

Instead, you can accelerate the parts of the model loss function that do not require inputs that depend on handles.

Debugging

You must take care when you debug accelerated functions. Cached traces do not support break points. When using accelerated functions, the software reaches break points in the underlying function during the tracing process only.

To debug the code in the underlying function using breakpoints, disable the acceleration by setting the Enabled property tofalse.

To debug the cached traces, you can compare the outputs of the accelerated functions with the outputs of the underlying function, by setting the CheckMode property to"tolerance".

dlode45 Does Not Support Acceleration When GradientMode Is "direct"

The dlaccelerate function does not support accelerating thedlode45 function when the GradientMode option is"direct". To accelerate the code that calls thedlode45 function, set the GradientMode option to"adjoint" or accelerate parts of your code that do not call thedlode45 function with the GradientMode option set to "direct".

See Also

dlaccelerate | AcceleratedFunction | clearCache | dlarray | dlgradient | dlfeval