Define Custom Deep Learning Metric Object - MATLAB & Simulink (original) (raw)

Note

This topic explains how to define custom deep learning metric objects for your tasks. For a list of built-in metrics in Deep Learning Toolbox™, see Metrics. You can also specify custom metrics using a function handle. For more information, see Define Custom Metric Function.

In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.

How To Decide Which Metric Type To Use

You can use metrics in different ways when you train a deep learning network. You can use built-in metrics by specifying a string and using the default options or by customizing the metric using a built-in metric object. You can also define custom metrics using a function handle or a custom class definition. When you select which method to use, you must trade off flexibility against complexity.

Another important consideration is how the software returns the metric value for the validation data. When you have validation data and specify your metric using a function handle, the software computes the validation metric for each mini-batch and then returns the average of those values. For some metrics, this behavior can result in a different metric value than if you compute the metric using the whole validation set at once. In most cases, the values are similar.

For most cases, you do not need to write a custom metric class. Use this flowchart to decide whether you need to create a custom metric class.

Flowchart showing decision process for choosing when to use a custom metric.

If Deep Learning Toolbox does not provide the metric that you need for your task and you cannot use a function handle, then you can define your own custom metric object using this topic as a guide. After you define the custom metric, you can specify the metric as the Metrics name-value argument in the trainingOptions function.

Metric Template

To define a custom metric, use this class definition template as a starting point. For an example that shows how to use this template to create a custom metric, see Define Custom Metric Object.

The template outlines how to specify these aspects of the class definition:

For information about when the software calls each function, see Function Call Order.

classdef myMetric < deep.Metric

properties
    % (Required) Metric name.
    Name

    % Declare public metric properties here.

    % Any code can access these properties. Include here any properties
    % that you want to access or edit outside of the class.
end

properties (Access = private)
    % (Optional) Metric properties.

    % Declare private metric properties here.

    % Only members of the defining class can access these properties.
    % Include here properties that you do not want to edit outside
    % the class.
end

methods
    function metric = myMetric(args)
        % Create a myMetric object.
        % This function must have the same name as the class.

        % Define metric construction function here.
    end

    function metric = initialize(metric,batchY,batchT)
        % (Optional) Initialize metric.
        %
        % Use this function to initialize variables and run validation
        % checks.
        %
        % Inputs:
        %           metric - Metric to initialize
        %           batchY - Mini-batch of predictions
        %           batchT - Mini-batch of targets
        %
        % Output:
        %           metric - Initialized metric
        %
        % For networks with multiple outputs, replace batchY with
        % batchY1,...,batchYN and batchT with batchT1,...,batchTN,
        % where N is the number of network outputs. To create a metric
        % that supports any number of network outputs, replace batchY
        % and batchT with varargin.

        % Define metric initialization function here.
    end

    function metric = reset(metric)
        % Reset metric properties.
        %
        % Use this function to reset the metric properties between
        % iterations.
        %
        % Input:
        %           metric - Metric containing properties to reset
        %
        % Output:
        %           metric - Metric with reset properties

        % Define metric reset function here.
    end

    function metric = update(metric,batchY,batchT)
        % Update metric properties.
        %
        % Use this function to update metric properties that you use to
        % compute the final metric value.
        %
        % Inputs:
        %           metric - Metric containing properties to update
        %           batchY - Mini-batch of predictions
        %           batchT - Mini-batch of targets
        %
        % Output:
        %           metric - Metric with updated properties
        %
        % For networks with multiple outputs, replace batchY with
        % batchY1,...,batchYN and batchT with batchT1,...,batchTN,
        % where N is the number of network outputs. To create a metric
        % that supports any number of network outputs, replace batchY
        % and batchT with varargin.

        % Define metric update function here.
    end

    function metric = aggregate(metric,metric2)
        % Aggregate metric properties.
        %
        % Use this function to define how to aggregate properties from
        % multiple instances of the same metric object during parallel
        % training.
        %
        % Inputs:
        %           metric  - Metric containing properties to aggregate
        %           metric2 - Metric containing properties to aggregate
        %
        % Output:
        %           metric - Metric with aggregated properties
        %
        % Define metric aggregation function here.
    end

    function val = evaluate(metric)
        % Evaluate metric properties.
        %
        % Use this function to define how to use the metric properties
        % to compute the final metric value.
        %
        % Input:
        %           metric - Metric containing properties to use to
        %           evaluate the metric value
        %
        % Output:
        %           val - Evaluated metric value
        %
        % To return multiple metric values, replace val with val1,...
        % valN.

        % Define metric evaluation function here.
    end
end

end

Metric Properties

Declare the metric properties in the property sections. You can specify attributes in the class definition to customize the behavior of properties for specific purposes. This template defines two property types by setting their Access attribute. Use the Access attribute to control access to specific class properties.

Public Properties

Declare public metric properties in the properties section of the class definition. These properties have public access, which means any code can access the values. By default, custom metrics have theNetworkOutput public property with the default value[] and the Maximize public property with the default value []. The NetworkOutput property defines which network output to apply the metric to. TheMaximize property sets a flag that defines if the optimal value for the metric occurs when the metric is maximized (1 ortrue) or when the metric is minimized (0 orfalse).

You must define the Name property in this block. TheName property controls the name of the metric in any plots or command line output.

Private Properties

Declare private metric properties in the properties (Access = private) section of the class definition. These properties haveprivate access, which means only members of the defining class can access these properties. For example, the class functions can access private properties. If the metric has no private properties, then you can omit thisproperties section.

Constructor Function

The constructor function creates the metric and initializes the metric properties. The constructor function must take as input any variables that you need to compute the metric. This function must have the same name as the class.

To use any properties as name-value arguments, you must set them in the constructor function. All metrics require the optional Name argument.

Tip

To use the NetworkOutput property as a name-value argument, you must set the property in the constructor function.

Initialization Function

The initialize function is an optional function that the software calls after reading the first batch of data. You can use this function to initialize variables and run validation checks.

The initialize function must have this syntax, wherebatchY and batchT inputs represent the mini-batch predictions and targets, respectively. For networks with multiple outputs, replace batchY with batchY1,...,batchYN andbatchT with batchT1,...,batchTN, whereN is the number of network outputs. To create a metric that supports any number of network outputs, replace batchY andbatchT with varargin.

metric = initialize(metric,batchY,batchT)

Example initialize Function

This code shows an example of an initialize function that checks that you are using the metric for a network with a single output and therefore only one set of batch predictions and targets.

    function metric = initialize(metric,batchY,batchT)
        if nargin ~= 3
            error("Metric not supported for networks with multiple outputs.")
        end
    end

Reset Function

The reset function resets the metric properties. The software calls this function before each iteration. For more information, see Function Call Order.

The reset function must have this syntax.

Update Function

The update function updates the metric properties that you use to compute the metric value. The function calls update during each training and validation mini-batch. For more information, see Function Call Order.

The update function must have this syntax, wherebatchY and batchT inputs represent the mini-batch predictions and targets, respectively. For networks with multiple outputs, replace batchY with batchY1,...,batchYN andbatchT with batchT1,...,batchTN, whereN is the number of network outputs. To create a metric that supports any number of network outputs, replace batchY andbatchT with varargin.

metric = update(metric,batchY,batchT)

For categorical targets, the layout of the targets that the software passes to the metric depends on which function you want to use the metric with.

Aggregation Function

The aggregate function specifies how to combine properties from multiple instances of the same metric object during parallel training. When you train a network in parallel, the software divides each training mini-batch into smaller subsets. For each subset, the software then calls update to update the metric properties, and then calls aggregate to consolidate the results for the whole mini-batch. For more information, see Function Call Order.

The aggregate function must have this syntax, wheremetric2 input is another instance of the metric. To ensure that your function always produces the same results, make sure thataggregate is an associative function.

metric = aggregate(metric,metric2)

Evaluation Function

The evaluate function specifies how to compute the metric value. In most cases, the final metric value is a function of the metric properties.

For the training data, the software calls evaluate at the end of each mini-batch. For the validation data, the software callsevaluate after all of the data passes through the network. Therefore, the software computes the metric for each batch of training data but for all of the validation data. For more information, see Function Call Order.

The evaluate function must have this syntax, whereM is the number of metrics to return.

[val,...,valM] = evaluate(metric)

Function Call Order

The order in which the software calls the initialize,reset, update,aggregate, and evaluate functions depends on where in the training loop the software is. The first function the software calls isinitialize. The software calls initialize after it reads the first batch of data.

The order in which the software calls the remaining functions depends on whether the data is training or validation data.

This diagram illustrates the difference between how the software computes the metric for the training and validation data.

Diagram showing the difference between batching for training and validation data. For the training data, the software calls the reset, update, and evaluate functions in all batches as well as initialize in the first batch. For the validation data, the software calls update in each batch, reset before the first batch, and evaluate after the last batch.

Note

When you train a network using the L-BFGS solver, the software processes all of the data in a single batch. This behavior is equivalent to a single mini-batch with all of the observations.

Aggregate Data

The aggregate function defines how to aggregate properties from multiple instances of the same metric object during parallel training. When you train a network in parallel, the software divides each training mini-batch into smaller subsets. For each subset, the software then callsupdate to update the metric properties, and then callsaggregate to consolidate the results for the whole mini-batch. Finally, the software calls evaluate to obtain the metric value for the whole training mini-batch.

See Also

trainingOptions | trainnet | dlnetwork