importONNXFunction - Import pretrained ONNX network as a function - MATLAB (original) (raw)

Import pretrained ONNX network as a function

Since R2020b

Syntax

Description

[params](#mw%5F2ec6ddf8-1314-4ab4-9712-1cf55b8b6505) = importONNXFunction([modelfile](#mw%5Fe40f6643-6c8d-495b-b622-86c347bcd4a2),[NetworkFunctionName](#mw%5F55c73be1-2555-4b25-9676-b91647c54a94)) imports an Open Neural Network Exchange ONNX™ network from the file modelfile and returns an ONNXParameters object (params) that contains the network parameters. The function also creates a model function with the name specified by NetworkFunctionName that contains the network architecture. For more information about the network function, seeImported ONNX Model Function.

Use the ONNXParameters object and theNetworkFunctionName model function to perform common deep learning tasks, such as image and sequence data classification, transfer learning, object detection, and image segmentation. importONNXFunction is useful when you cannot import the network using the importNetworkFromONNX function (for example,importONNXFunction can import YOLOv3) or if you want to define your own custom training loop (for more details, see Train Network Using Custom Training Loop).

This function requires the Deep Learning Toolbox™ Converter for ONNX Model Format support package. If this support package is not installed, then the function provides a download link.

example

Examples

collapse all

Import ONNX Network as Function

Import an ONNX network as a function. You can use the imported model function for deep learning tasks, such as prediction and transfer learning.

Download and install the Deep Learning Toolbox Converter for ONNX Model Format support package. You can enter importONNXFunction at the command line to check if the support package is installed. If it is not installed, then the function provides a link to the required support package in the Add-On Explorer. To install the support package, click the link, and then click Install.

Specify the file to import as shufflenet with operator set 9 from the ONNX Model Zoo. shufflenet is a convolutional neural network that is trained on images from the ImageNet database.

modelfile = "shufflenet-9.onnx";

Import the network as a function to generate a model function that you can readily use for deep learning tasks.

params = importONNXFunction(modelfile,"shufflenetFcn")

Function containing the imported ONNX network architecture was saved to the file shufflenetFcn.m. To learn how to use this function, type: help shufflenetFcn.

params = ONNXParameters with properties:

         Learnables: [1x1 struct]
      Nonlearnables: [1x1 struct]
              State: [1x1 struct]
      NumDimensions: [1x1 struct]
NetworkFunctionName: 'shufflenetFcn'

importONNXFunction returns the ONNXParameters object params, which contains the network parameters, and the model function shufflnetFcn, which contains the network architecture. importONNXFunction saves shufflenetFcn in the current folder. You can open the model function to view or edit the network architecture by using open shufflenetFcn.

Deep Learning Toolbox Converter for ONNX Model Format also provides the importNetworkFromONNX function, which you can use to import a pretrained ONNX network.

Predict Using Imported ONNX Function

Import an ONNX network as a function, and use the pretrained network to predict the class label of an input image.

Specify the file to import as shufflenet with operator set 9 from the ONNX Model Zoo. shufflenet is a convolutional neural network that is trained on more than a million images from the ImageNet database. As a result, the network has learned rich feature representations for a wide range of images. The network can classify images into 1000 object categories, such as keyboard, mouse, pencil, and many animals.

modelfile = 'shufflenet-9.onnx';

Import the pretrained ONNX network as a function by using importONNXFunction, which returns the ONNXParameters object params. This object contains the network parameters. The function also creates a new model function in the current folder that contains the network architecture. Specify the name of the model function as shufflenetFcn.

params = importONNXFunction(modelfile,'shufflenetFcn');

A function containing the imported ONNX network has been saved to the file shufflenetFcn.m. To learn how to use this function, type: help shufflenetFcn.

Read the image you want to classify and display the size of the image. The image is 792-by-1056 pixels and has three color channels (RGB).

I = imread('peacock.jpg'); size(I)

Resize the image to the input size of the network. Show the image.

I = imresize(I,[224 224]); imshow(I)

The inputs to shufflenet require further preprocessing (for more details, see ShuffleNet in ONNX Model Zoo). Rescale the image. Normalize the image by subtracting the training images mean and dividing by the training images standard deviation.

I = rescale(I,0,1);

meanIm = [0.485 0.456 0.406]; stdIm = [0.229 0.224 0.225]; I = (I - reshape(meanIm,[1 1 3]))./reshape(stdIm,[1 1 3]);

imshow(I)

Import the class names from squeezenet, which is also trained with images from the ImageNet database.

net = squeezenet; ClassNames = net.Layers(end).ClassNames;

Calculate the class probabilities by specifying the image to classify I and the ONNXParameters object params as input arguments to the model function shufflenetFcn.

scores = shufflenetFcn(I,params);

Find the class index with the highest probability. Display the predicted class for the input image and the corresponding classification score.

indMax = find(scores==max(scores)); ClassNames(indMax)

ans = 1×1 cell array {'peacock'}

scoreMax = scores(indMax)

Train Imported ONNX Function Using Custom Training Loop

Import the SqueezeNet convolution neural network as a function and fine-tune the pretrained network with transfer learning to perform classification on a new collection of images.

This example uses several helper functions. To view the code for these functions, see Helper Functions.

Unzip and load the new images as an image datastore. imageDatastore automatically labels the images based on folder names and stores the data as an ImageDatastore object. An image datastore enables you to store large image data, including data that does not fit in memory, and efficiently read batches of images during training of a convolutional neural network. Specify the mini-batch size.

unzip("MerchData.zip"); miniBatchSize = 8; imds = imageDatastore("MerchData", ... IncludeSubfolders=true, ... LabelSource="foldernames", ... ReadSize=miniBatchSize);

This data set is small, containing 75 training images. Display some sample images.

numImages = numel(imds.Labels); idx = randperm(numImages,16); figure for i = 1:16 subplot(4,4,i) I = readimage(imds,idx(i)); imshow(I) end

Extract the training set and one-hot encode the categorical classification labels.

XTrain = readall(imds); XTrain = single(cat(4,XTrain{:})); YTrain_categ = categorical(imds.Labels); YTrain = onehotencode(YTrain_categ,2)';

Determine the number of classes in the data.

classes = categories(YTrain_categ); numClasses = numel(classes)

SqueezeNet is a convolutional neural network that is trained on more than a million images from the ImageNet database. As a result, the network has learned rich feature representations for a wide range of images. The network can classify images into 1000 object categories, such as keyboard, mouse, pencil, and many animals.

Import the pretrained SqueezeNet network as a function.

squeezenetONNX() params = importONNXFunction("squeezenet.onnx","squeezenetFcn")

Function containing the imported ONNX network architecture was saved to the file squeezenetFcn.m. To learn how to use this function, type: help squeezenetFcn.

params = ONNXParameters with properties:

         Learnables: [1×1 struct]
      Nonlearnables: [1×1 struct]
              State: [1×1 struct]
      NumDimensions: [1×1 struct]
NetworkFunctionName: 'squeezenetFcn'

params is an ONNXParameters object that contains the network parameters. squeezenetFcn is a model function that contains the network architecture. importONNXFunction saves squeezenetFcn in the current folder.

Calculate the classification accuracy of the pretrained network on the new training set.

accuracyBeforeTraining = getNetworkAccuracy(XTrain,YTrain,params); fprintf("%.2f accuracy before transfer learning\n",accuracyBeforeTraining);

0.01 accuracy before transfer learning

The accuracy is very low.

Display the learnable parameters of the network by typing params.Learnables. These parameters, such as the weights (W) and bias (B) of convolution and fully connected layers, are updated by the network during training. Nonlearnable parameters remain constant during training.

The last two learnable parameters of the pretrained network are configured for 1000 classes.

conv10_W: [1×1×512×1000 dlarray]

conv10_B: [1000×1 dlarray]

The parameters conv10_W and conv10_B must be fine-tuned for the new classification problem. Transfer the parameters to classify five classes by initializing the parameters.

params.Learnables.conv10_W = rand(1,1,512,5); params.Learnables.conv10_B = rand(5,1);

Freeze all the parameters of the network to convert them to nonlearnable parameters. Because you do not need to compute the gradients of the frozen layers, freezing the weights of many initial layers can significantly speed up network training.

params = freezeParameters(params,"all");

Unfreeze the last two parameters of the network to convert them to learnable parameters.

params = unfreezeParameters(params,"conv10_W"); params = unfreezeParameters(params,"conv10_B");

The network is ready for training. Specify the training options.

velocity = []; numEpochs = 5; miniBatchSize = 16; initialLearnRate = 0.01; momentum = 0.9; decay = 0.01;

Calculate the total number of iterations for the training progress monitor.

numObservations = size(YTrain,2); numIterationsPerEpoch = floor(numObservations./miniBatchSize); numIterations = numEpochs*numIterationsPerEpoch;

Initialize the TrainingProgressMonitor object. Because the timer starts when you create the monitor object, make sure that you create the object immediately after the training loop.

monitor = trainingProgressMonitor(Metrics="Loss",Info="Epoch",XLabel="Iteration");

Train the network.

epoch = 0; iteration = 0; executionEnvironment = "cpu"; % Change to "gpu" to train on a GPU.

% Loop over epochs. while epoch < numEpochs && ~monitor.Stop

epoch = epoch + 1;

% Shuffle data.
idx = randperm(numObservations);
XTrain = XTrain(:,:,:,idx);
YTrain = YTrain(:,idx);

% Loop over mini-batches.
i = 0;
while i < numIterationsPerEpoch && ~monitor.Stop
    i = i + 1;
    iteration = iteration + 1;
    
    % Read mini-batch of data.
    idx = (i-1)*miniBatchSize+1:i*miniBatchSize;
    X = XTrain(:,:,:,idx);        
    Y = YTrain(:,idx);
    
    % If training on a GPU, then convert data to gpuArray.
    if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
        X = gpuArray(X);         
    end
    
    % Evaluate the model gradients and loss using dlfeval and the
    % modelGradients function.
    [gradients,loss,state] = dlfeval(@modelGradients,X,Y,params);
    params.State = state;
    
    % Determine the learning rate for the time-based decay learning rate schedule.
    learnRate = initialLearnRate/(1 + decay*iteration);
    
    % Update the network parameters using the SGDM optimizer.
    [params.Learnables,velocity] = sgdmupdate(params.Learnables,gradients,velocity,learnRate);
    
    % Update the training progress monitor.
    recordMetrics(monitor,iteration,Loss=loss);
    updateInfo(monitor,Epoch=epoch,LearnRate=learnRate);
    monitor.Progress = 100 * iteration/numIterations;
end

end

Calculate the classification accuracy of the network after fine-tuning.

accuracyAfterTraining = getNetworkAccuracy(XTrain,YTrain,params); fprintf("%.2f accuracy after transfer learning\n",accuracyAfterTraining);

1.00 accuracy after transfer learning

Helper Functions

This section provides the code of the helper functions used in this example.

The getNetworkAccuracy function evaluates the network performance by calculating the classification accuracy.

function accuracy = getNetworkAccuracy(X,Y,onnxParams)

N = size(X,4); Ypred = squeezenetFcn(X,onnxParams,Training=false);

[,YIdx] = max(Y,[],1); [,YpredIdx] = max(Ypred,[],1); numIncorrect = sum(abs(YIdx-YpredIdx) > 0); accuracy = 1 - numIncorrect/N;

end

The modelGradients function calculates the loss and gradients.

function [grad, loss, state] = modelGradients(X,Y,onnxParams)

[y,state] = squeezenetFcn(X,onnxParams,Training=true); loss = crossentropy(y,Y,DataFormat="CB"); grad = dlgradient(loss,onnxParams.Learnables);

end

The squeezenetONNX function generates an ONNX model of the SqueezeNet network.

function squeezenetONNX()

exportONNXNetwork(squeezenet,"squeezenet.onnx");

end

Sequence Classification Using Imported ONNX Function

Import an ONNX long short-term memory (LSTM) network as a function, and use the pretrained network to classify sequence data. An LSTM network enables you to input sequence data into a network, and make predictions based on the individual time steps of the sequence data.

This example uses the helper function preparePermutationVector. To view the code for this function, see Helper Function.

lstmNet has a similar architecture to the LSTM network created in Sequence Classification Using Deep Learning. lstmNet is trained to recognize the speaker given time series data representing two Japanese vowels spoken in succession. The training data contains time series data for nine speakers. Each sequence has 12 features and varies in length.

Specify lstmNet as the model file.

modelfile = 'lstmNet.onnx';

Import the pretrained ONNX network as a function by using importONNXFunction, which returns the ONNXParameters object params containing the network parameters. The function also creates a new model function in the current folder that contains the network architecture. Specify the name of the model function as lstmnetFcn.

params = importONNXFunction(modelfile,'lstmnetFcn');

Function containing the imported ONNX network architecture was saved to the file lstmnetFcn.m. To learn how to use this function, type: help lstmnetFcn.

Load the Japanese Vowels test data. XTest is a cell array containing 370 sequences of dimension 12 and varying length. TTest is a categorical vector of labels "1","2",..."9", which correspond to the nine speakers.

load JapaneseVowelsTestData;

lstmNet was trained using mini-batches with sequences of similar length. To organize the test data in the same way, sort the test data by sequence length.

numObservationsTest = numel(XTest); for i=1:numObservationsTest sequence = XTest{i}; sequenceLengthsTest(i) = size(sequence,2); end [sequenceLengthsTest,idx] = sort(sequenceLengthsTest); XTest = XTest(idx); TTest = TTest(idx);

Use preparePermutationVector to compute the permutation vector inputPerm, which permutes the dimension ordering of the input sequence data to the dimension ordering of the imported LSTM network input. You can type help lstmnetFcn to view the dimension ordering of the network input SEQUENCEINPUT.

inputPerm = preparePermutationVector(["FeaturesLength","SequenceLength","BatchSize"],... ["SequenceLength","BatchSize","FeaturesLength"]);

Calculate the class probabilities by specifying the sequence data to classify XTest and the ONNXParameters object params as input arguments to the model function lstmnetFcn. Customize the input dimension ordering by assigning the numeric vector inputPerm to the name-value argument 'InputDataPermutation'. Return scores in the dimension ordering of the network output by assigning 'none' to the name-value argument 'OutputDataPermutation'.

for i = 1:length(XTest) scores = lstmnetFcn(XTest{i},params,'InputDataPermutation',inputPerm,'OutputDataPermutation','none'); YPred(i) = find(scores==max(scores)); end YPred = categorical(YPred');

Calculate the classification accuracy of the predictions.

acc = sum(YPred == TTest)./numel(TTest)

Helper Function

This section provides the code of the helper function preparePermutationVector used in this example.

The preparePermutationVector function returns a permutation vector perm, which permutes the dimension ordering in fromDimOrder to the dimension ordering in toDimOrder. You can specify the input arguments fromDimOrder and toDimOrder as character vectors, string scalars, string arrays, cell arrays of character vectors, or numeric vectors. Both arguments must have the same type and the same unique elements. For example, if fromDimOrder is the character vector 'hwcn', toDimOrder can be the character vector 'nchw' (where h, w, and c correspond to the height, width, and number of channels of the image, respectively, and n is the number of observations).

function perm = preparePermutationVector(fromDimOrder, toDimOrder)

% Check if both fromDimOrder and toDimOrder are vectors. if ~isvector(fromDimOrder) || ~isvector(toDimOrder) error(message('nnet_cnn_onnx:onnx:FPVtypes')); end

% Convert fromDimOrder and toDimOrder to the appropriate type. if isstring(fromDimOrder) && isscalar(fromDimOrder) fromDimOrder = char(fromDimOrder); end if isstring(toDimOrder) && isscalar(toDimOrder) toDimOrder = char(toDimOrder); end

% Check if fromDimOrder and toDimOrder have unique elements. [fromSorted, ifrom] = unique(fromDimOrder); [toSorted, ~, iToInv] = unique(toDimOrder);

if numel(fromSorted) ~= numel(fromDimOrder) error(message('nnet_cnn_onnx:onnx:FPVfromunique')); end if numel(toSorted) ~= numel(toDimOrder) error(message('nnet_cnn_onnx:onnx:FPVtounique')); end

% Check if fromDimOrder and toDimOrder have the same number of elements. if ~isequal(fromSorted, toSorted) error(message('nnet_cnn_onnx:onnx:FPVsame')); end

% Compute the permutation vector. perm = ifrom(iToInv); perm = perm(:)';

end

Input Arguments

collapse all

modelfile — Name of ONNX model file

character vector | string scalar

Name of the ONNX model file containing the network, specified as a character vector or string scalar. The file must be in the current folder or a folder on the MATLAB® path, or you must include a full or relative path to the file.

Example: 'shufflenet.onnx'

NetworkFunctionName — Name of model function

character vector | string scalar

Name of the model function, specified as a character vector or string scalar. The function NetworkFunctionName contains the architecture of the imported ONNX network. The file is saved in an M-file in the current folder, or you must include a full or relative path to the file. The NetworkFunctionName file is required for using the network. For more information, see Imported ONNX Model Function.

Example: 'shufflenetFcn'

Output Arguments

collapse all

params — Network parameters

ONNXParameters object

Network parameters, returned as an ONNXParameters object. params contains the network parameters of the imported ONNX model. Use dot notation to reference properties ofparams. For example, params.Learnables displays the network learnable parameters, such as the weights of the convolution layers.

Limitations

More About

collapse all

Imported ONNX Model Function

importONNXFunction creates a model function that contains the network architecture of the imported ONNX model. Specify the name NetworkFunctionName as an input argument to importONNXFunction.

Syntax

Use the following syntaxes to interface with the imported ONNX model function (NetworkFunctionName):

Input Arguments

Argument Description
X Input data, specified as an array or dlarray.
params Network parameters, specified as an ONNXParameters object.

Name-Value Arguments

Argument name Description
'Training' Training option, specified as 'false' (default) or'true'. Set value to 'false' to useONNXFunction to predict. For an example, see Predict Using Imported ONNX Function.Set value to 'true' to useONNXFunction to train. For an example, see Train Imported ONNX Function Using Custom Training Loop.
'InputDataPermutation' Permutation applied to the dimension ordering of inputX, specified as 'auto' (default),'none', a numeric vector, or a cell array.Assign a value to the name-value pair argument'InputDataPermutation' to permute the input data into the dimension ordering required by the imported ONNX model.Assign the value 'auto' to apply an automatic permutation based on assumptions about common input dataX. For more details, see Automatic Input Data Permutation.Assign the value 'none' to pass X in the original ordering.Assign a numeric vector value to customize the input dimension ordering; for example, [4 3 1 2]. For an example, see Sequence Classification Using Imported ONNX Function.Assign a cell array value for multiple inputs; for example, {[3 2 1],'none'}.
'OutputDataPermutation' Permutation applied to the dimension ordering of outputY, specified as 'auto' (default),'none', a numeric vector, or a cell array.Assign a value to the name-value pair argument'OutputDataPermutation' to match the dimension ordering of the imported ONNX model.Assign the value 'auto' to returnY in Deep Learning Toolbox ordering. For more details, see Automatic Output Data Permutation.Assign the value 'none' to returnY in ONNX ordering. For an example, see Sequence Classification Using Imported ONNX Function.Assign a numeric vector value to customize the output dimension ordering; for example, [3 4 2 1].Assign a cell array value for multiple outputs; for example,{[3 2 1],'none'}.

Output Arguments

Argument Description
Y Output data, returned as an array or dlarray. If X is an array or you useONNXFunction to predict, Y is a array. If X is a dlarray or you useONNXFunction for training, Y is adlarray.
state Updated network state, specified as a structure.The network state contains information remembered by the network between iterations and updated across multiple training batches.

The interpretation of input argument X and output argumentY can differ between models. For more information about the model input and output arguments, refer to help for the imported model function NetworkFunctionName, or refer to the ONNX documentation [1].

Automatic Permutation for Imported Model Function

By default, NetworkFunctionName automatically permutes input and output data to facilitate image classification tasks. Automatic permutation might be unsuitable for other tasks, such as object detection and time series classification.

Automatic Input Data Permutation

To automatically permute the input, NetworkFunctionName assumes the following based on the input dimensions specified by the imported ONNX network.

Number of ONNX Model Input Dimensions Interpretation of Input Data ONNX Standard Dimension Ordering Deep Learning Toolbox Standard Dimension Ordering Automatic Permutation of Input
4 2-D image NCHWH,W, and C correspond to the height, width, and number of channels of the image, respectively, andN is the number of observations. HWCNH,W, and C correspond to the height, width, and number of channels of the image, respectively, andN is the number of observations. [ 4 3 1 2 ]

If the size of the input dimensions is a number other than 4,NetworkFunctionName specifies the input argument'InputDataPermutation' as 'none'.

Automatic Output Data Permutation

To automatically permute the output, NetworkFunctionName assumes the following based on the output dimensions specified by the imported ONNX network.

Number of ONNX Model Output Dimensions Interpretation of Output Data ONNX Standard Dimension Ordering Deep Learning Toolbox Standard Dimension Ordering Automatic Permutation of Output
2 2-D image classification scores NKK is the number of classes and N is the number of observations. KNK is the number of classes and N is the number of observations. [ 2 1 ]
4 2-D image pixel classification scores NKHWH andW correspond to the height and width of the image, respectively, K is the number of classes, andN is the number of observations. HWKNH andW correspond to the height and width of the image, respectively, K is the number of classes, andN is the number of observations. [3 4 2 1]

If the size of the output dimensions is a number other than 2 or 4,NetworkFunctionName specifies the input argument'OutputDataPermutation' as 'none'.

ONNX Operators Supported for Conversion into Built-In MATLAB or Custom Layers

importONNXFunction supports the following ONNX operators for conversion into built-in MATLAB layers or custom layers, with some limitations. For a list of equivalent built-in or custom layers obtained usingimportNetworkFromONNX, see ONNX Operators Supported for Conversion into Built-In MATLAB Layers

ONNX Operator importNetworkFromONNX Support
Abs No
Add Yes
And No
ArgMax No
AveragePool Yes
BatchNormalization Yes
Bernoulli No
Cast No
Ceil No
Clip No
Compress No
Concat Yes
Constant Yes
ConstantOfShape No
Conv Yes
ConvTranspose Yes
Cos No
CumSum No
DepthToSpace Yes
Div Yes
Dropout Yes
Elu Yes
Equal No
Erf No
Exp No
Expand No
Flatten Yes
Floor No
Gather No
GatherElements No
GatherND No
Gemm Yes
GlobalAveragePool Yes
GlobalMaxPool Yes
Greater No
GridSample No
GRU Yes
HardSigmoid No
Hardmax No
Identity Yes
If No
InstanceNormalization Yes
LayerNormalization Yes
LeakyRelu Yes
Less No
LessOrEqual No
Log No
Loop No
LRN Yes
LSTM Yes
MatMul Yes
Max No
MaxPool Yes
Min No
Mod No
Mul Yes
Neg No
NonMaxSuppression No
NonZero No
Not No
OneHot No
Or No
Pad No
Pow No
PRelu Yes
RandomUniform No
Range No
Reciprocal No
ReduceL1 No
ReduceLogSum No
ReduceLogSumExp No
ReduceMax No
ReduceMean No
ReduceMin No
ReduceProd No
ReduceSum No
ReduceSumSquare No
Relu Yes
Reshape Yes
Resize Yes
RoiAlign No
Round No
Scan No
Scatter No
ScatterElements No
ScatterND No
SequenceAt No
Shape No
Sigmoid Yes
Sign No
Sin No
Slice No
Softmax Yes
SpaceToDepth Yes
Split No
SplitToSequence No
Sqrt No
Squeeze No
Sub Yes
Sum Yes
Tanh Yes
Tile No
TopK No
Transpose No
Unsqueeze No
Upsample Yes
Where No

Tips

Alternative Functionality

importONNXFunction is useful when you cannot import a pretrained ONNX network by using importNetworkFromONNX.

References

Version History

Introduced in R2020b

expand all

R2024b: Import networks that include multiple new operators

You can now import an ONNX network that includes the following operators:

R2024b: Updated support for ONNX intermediate representation and operator sets

importONNXFunction now supports ONNX intermediate representation version 9 and ONNX operator sets 6 to 18.