Define Custom Deep Learning Metric Object - MATLAB & Simulink (original) (raw)
Note
This topic explains how to define custom deep learning metric objects for your tasks. For a list of built-in metrics in Deep Learning Toolbox™, see Metrics. You can also specify custom metrics using a function handle. For more information, see Define Custom Metric Function.
In deep learning, a metric is a numerical value that evaluates the performance of a deep learning network. You can use metrics to monitor how well a model is performing by comparing the model predictions to the ground truth. Common deep learning metrics are accuracy, F-score, precision, recall, and root mean squared error.
How To Decide Which Metric Type To Use
You can use metrics in different ways when you train a deep learning network. You can use built-in metrics by specifying a string and using the default options or by customizing the metric using a built-in metric object. You can also define custom metrics using a function handle or a custom class definition. When you select which method to use, you must trade off flexibility against complexity.
Another important consideration is how the software returns the metric value for the validation data. When you have validation data and specify your metric using a function handle, the software computes the validation metric for each mini-batch and then returns the average of those values. For some metrics, this behavior can result in a different metric value than if you compute the metric using the whole validation set at once. In most cases, the values are similar.
For most cases, you do not need to write a custom metric class. Use this flowchart to decide whether you need to create a custom metric class.
If Deep Learning Toolbox does not provide the metric that you need for your task and you cannot use a function handle, then you can define your own custom metric object using this topic as a guide. After you define the custom metric, you can specify the metric as the Metrics name-value argument in the trainingOptions
function.
Metric Template
To define a custom metric, use this class definition template as a starting point. For an example that shows how to use this template to create a custom metric, see Define Custom Metric Object.
The template outlines how to specify these aspects of the class definition:
- The
properties
block for public metric properties. This block must contain theName
property. - The
properties
block for private metric properties. This block is optional. - The metric constructor function.
- The optional
initialize
function. - The required
reset
,update
,aggregate
, andevaluate
functions.
For information about when the software calls each function, see Function Call Order.
classdef myMetric < deep.Metric
properties
% (Required) Metric name.
Name
% Declare public metric properties here.
% Any code can access these properties. Include here any properties
% that you want to access or edit outside of the class.
end
properties (Access = private)
% (Optional) Metric properties.
% Declare private metric properties here.
% Only members of the defining class can access these properties.
% Include here properties that you do not want to edit outside
% the class.
end
methods
function metric = myMetric(args)
% Create a myMetric object.
% This function must have the same name as the class.
% Define metric construction function here.
end
function metric = initialize(metric,batchY,batchT)
% (Optional) Initialize metric.
%
% Use this function to initialize variables and run validation
% checks.
%
% Inputs:
% metric - Metric to initialize
% batchY - Mini-batch of predictions
% batchT - Mini-batch of targets
%
% Output:
% metric - Initialized metric
%
% For networks with multiple outputs, replace batchY with
% batchY1,...,batchYN and batchT with batchT1,...,batchTN,
% where N is the number of network outputs. To create a metric
% that supports any number of network outputs, replace batchY
% and batchT with varargin.
% Define metric initialization function here.
end
function metric = reset(metric)
% Reset metric properties.
%
% Use this function to reset the metric properties between
% iterations.
%
% Input:
% metric - Metric containing properties to reset
%
% Output:
% metric - Metric with reset properties
% Define metric reset function here.
end
function metric = update(metric,batchY,batchT)
% Update metric properties.
%
% Use this function to update metric properties that you use to
% compute the final metric value.
%
% Inputs:
% metric - Metric containing properties to update
% batchY - Mini-batch of predictions
% batchT - Mini-batch of targets
%
% Output:
% metric - Metric with updated properties
%
% For networks with multiple outputs, replace batchY with
% batchY1,...,batchYN and batchT with batchT1,...,batchTN,
% where N is the number of network outputs. To create a metric
% that supports any number of network outputs, replace batchY
% and batchT with varargin.
% Define metric update function here.
end
function metric = aggregate(metric,metric2)
% Aggregate metric properties.
%
% Use this function to define how to aggregate properties from
% multiple instances of the same metric object during parallel
% training.
%
% Inputs:
% metric - Metric containing properties to aggregate
% metric2 - Metric containing properties to aggregate
%
% Output:
% metric - Metric with aggregated properties
%
% Define metric aggregation function here.
end
function val = evaluate(metric)
% Evaluate metric properties.
%
% Use this function to define how to use the metric properties
% to compute the final metric value.
%
% Input:
% metric - Metric containing properties to use to
% evaluate the metric value
%
% Output:
% val - Evaluated metric value
%
% To return multiple metric values, replace val with val1,...
% valN.
% Define metric evaluation function here.
end
end
end
Metric Properties
Declare the metric properties in the property sections. You can specify attributes in the class definition to customize the behavior of properties for specific purposes. This template defines two property types by setting their Access
attribute. Use the Access
attribute to control access to specific class properties.
properties
— Any code can access these properties. This is the default properties block with the default property attributes. By default, theAccess
attribute ispublic
.properties (Access = private)
— Only members of the defining class can access the property.
Public Properties
Declare public metric properties in the properties
section of the class definition. These properties have public
access, which means any code can access the values. By default, custom metrics have theNetworkOutput
public property with the default value[]
and the Maximize
public property with the default value []
. The NetworkOutput
property defines which network output to apply the metric to. TheMaximize
property sets a flag that defines if the optimal value for the metric occurs when the metric is maximized (1 ortrue
) or when the metric is minimized (0 orfalse
).
You must define the Name
property in this block. TheName
property controls the name of the metric in any plots or command line output.
Private Properties
Declare private metric properties in the properties (Access = private)
section of the class definition. These properties haveprivate
access, which means only members of the defining class can access these properties. For example, the class functions can access private properties. If the metric has no private properties, then you can omit thisproperties
section.
Constructor Function
The constructor function creates the metric and initializes the metric properties. The constructor function must take as input any variables that you need to compute the metric. This function must have the same name as the class.
To use any properties as name-value arguments, you must set them in the constructor function. All metrics require the optional Name
argument.
Tip
To use the NetworkOutput
property as a name-value argument, you must set the property in the constructor function.
Initialization Function
The initialize
function is an optional function that the software calls after reading the first batch of data. You can use this function to initialize variables and run validation checks.
The initialize
function must have this syntax, wherebatchY
and batchT
inputs represent the mini-batch predictions and targets, respectively. For networks with multiple outputs, replace batchY
with batchY1,...,batchYN
andbatchT
with batchT1,...,batchTN
, whereN
is the number of network outputs. To create a metric that supports any number of network outputs, replace batchY
andbatchT
with varargin
.
metric = initialize(metric,batchY,batchT)
Example initialize
Function
This code shows an example of an initialize
function that checks that you are using the metric for a network with a single output and therefore only one set of batch predictions and targets.
function metric = initialize(metric,batchY,batchT)
if nargin ~= 3
error("Metric not supported for networks with multiple outputs.")
end
end
Reset Function
The reset
function resets the metric properties. The software calls this function before each iteration. For more information, see Function Call Order.
The reset
function must have this syntax.
Update Function
The update
function updates the metric properties that you use to compute the metric value. The function calls update
during each training and validation mini-batch. For more information, see Function Call Order.
The update
function must have this syntax, wherebatchY
and batchT
inputs represent the mini-batch predictions and targets, respectively. For networks with multiple outputs, replace batchY
with batchY1,...,batchYN
andbatchT
with batchT1,...,batchTN
, whereN
is the number of network outputs. To create a metric that supports any number of network outputs, replace batchY
andbatchT
with varargin
.
metric = update(metric,batchY,batchT)
For categorical targets, the layout of the targets that the software passes to the metric depends on which function you want to use the metric with.
- When using the metric with trainnet and the targets are categorical arrays, if the loss function is
"index-crossentropy"
, then the software automatically converts the targets to numeric class indices and passes them to the metric. For other loss functions, the software converts the targets to one-hot encoded vectors and passes them to the metric. - When using the metric with testnet and the targets are categorical arrays, if the specified metrics include
"index-crossentropy"
but do not include"crossentropy"
, then the software converts the targets to numeric class indices and passes them to the metric. Otherwise, the software converts the targets to one-hot encoded vectors and passes them to the metric.
Aggregation Function
The aggregate
function specifies how to combine properties from multiple instances of the same metric object during parallel training. When you train a network in parallel, the software divides each training mini-batch into smaller subsets. For each subset, the software then calls update
to update the metric properties, and then calls aggregate
to consolidate the results for the whole mini-batch. For more information, see Function Call Order.
The aggregate
function must have this syntax, wheremetric2
input is another instance of the metric. To ensure that your function always produces the same results, make sure thataggregate
is an associative function.
metric = aggregate(metric,metric2)
Evaluation Function
The evaluate
function specifies how to compute the metric value. In most cases, the final metric value is a function of the metric properties.
For the training data, the software calls evaluate
at the end of each mini-batch. For the validation data, the software callsevaluate
after all of the data passes through the network. Therefore, the software computes the metric for each batch of training data but for all of the validation data. For more information, see Function Call Order.
The evaluate
function must have this syntax, whereM
is the number of metrics to return.
[val,...,valM] = evaluate(metric)
Function Call Order
The order in which the software calls the initialize
,reset
, update
,aggregate
, and evaluate
functions depends on where in the training loop the software is. The first function the software calls isinitialize
. The software calls initialize
after it reads the first batch of data.
The order in which the software calls the remaining functions depends on whether the data is training or validation data.
- Training data — For each mini-batch, the software calls
reset
, thenupdate
, and thenevaluate
. Therefore, the software returns the metric value for each training mini-batch, where each batch is equivalent to a single training iteration. - Validation data — For each mini-batch, the software calls
update
only. The software callsevaluate
after all of the validation data passes through the network. Therefore, the software returns the metric value for the whole validation set (full-batch). This behavior is equivalent to a validation iteration. The software callsreset
before the first validation mini-batch.
This diagram illustrates the difference between how the software computes the metric for the training and validation data.
Note
When you train a network using the L-BFGS solver, the software processes all of the data in a single batch. This behavior is equivalent to a single mini-batch with all of the observations.
Aggregate Data
The aggregate
function defines how to aggregate properties from multiple instances of the same metric object during parallel training. When you train a network in parallel, the software divides each training mini-batch into smaller subsets. For each subset, the software then callsupdate
to update the metric properties, and then callsaggregate
to consolidate the results for the whole mini-batch. Finally, the software calls evaluate
to obtain the metric value for the whole training mini-batch.
See Also
trainingOptions | trainnet | dlnetwork