Custom Training with Multiple GPUs in Experiment Manager - MATLAB & Simulink (original) (raw)

Main Content

This example shows how to configure multiple parallel workers to collaborate on each trial of a custom training experiment. In this example, parallel workers train on portions of the overall mini-batch in each trial of an image classification experiment. During training, a DataQueue object sends training progress information back to Experiment Manager. If you have a supported GPU, then training happens on the GPU. For more information, see GPU Computing Requirements (Parallel Computing Toolbox).

As an alternative, you can set up a parallel custom training loop that runs a single trial of this experiment programmatically. For more information, see Train Network in Parallel with Custom Training Loop.

Open Experiment

First, open the example. Experiment Manager loads a project with a preconfigured experiment that you can inspect and run. To open the experiment, in the Experiment Browser pane, double-click ParallelCustomLoopExperiment.

Custom training experiments consist of a description, a table of hyperparameters, and a training function. For more information, see Train Network Using Custom Training Loop and Display Visualization.

The Description field contains a textual description of the experiment. For this example, the description is:

Use multiple parallel workers to train an image classification network. Each trial uses a different initial learning rate and momentum.

The Hyperparameters section specifies the strategy and hyperparameter values to use for the experiment. When you run the experiment, Experiment Manager trains the network using every combination of hyperparameter values specified in the hyperparameter table. This example uses two hyperparameters:

The Training Function section specifies a function that defines the training data, network architecture, training options, and training procedure used by the experiment. To open this function in MATLABĀ® Editor, click Edit. The code for the function also appears in Training Function. The input to the training function is a structure with fields from the hyperparameter table and an experiments.Monitor object that you can use to track the progress of the training, record values of the metrics used by the training, and produce training plots. The function returns a structure that contains the trained network, the training loss, and the validation accuracy. Experiment Manager saves this output so you can export it to the MATLAB workspace when the training is complete. The training function has these sections:

output.network = []; output.loss = []; output.accuracy = [];

monitor.Status = "Loading Data";

dataFolder = fullfile(toolboxdir('nnet'), ... 'nndemos','nndatasets','DigitDataset'); imds = imageDatastore(dataFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

classes = categories(imdsTrain.Labels); numClasses = numel(classes);

XTest = readall(imdsTest); XTest = cat(4,XTest{:}); XTest = single(XTest) ./ 255; trueLabels = imdsTest.Labels;

monitor.Status = "Creating Network";

layers = [ imageInputLayer([28 28 1],Normalization="none") convolution2dLayer(5,20) batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding=1) batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding=1) batchNormalizationLayer reluLayer fullyConnectedLayer(numClasses)];

lgraph = layerGraph(layers);

net = dlnetwork(lgraph);

monitor.Status = "Starting Parallel Pool";

pool = gcp("nocreate");

if canUseGPU executionEnvironment = "gpu"; if isempty(pool) numberOfGPUs = gpuDeviceCount("available"); pool = parpool(numberOfGPUs); end else executionEnvironment = "cpu"; if isempty(pool) pool = parpool; end end

N = pool.NumWorkers;

numEpochs = 20; miniBatchSize = 128; velocity = []; initialLearnRate = params.InitialLearnRate; momentum = params.Momentum; decay = 0.01;

if executionEnvironment == "gpu" miniBatchSize = miniBatchSize .* N; end

workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N)); remainder = miniBatchSize - sum(workerMiniBatchSize); workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)];

monitor.Metrics = ["TrainingLoss" "ValidationAccuracy"]; monitor.XLabel = "Iteration"; monitor.Status = "Training";

Q = parallel.pool.DataQueue; updateFcn = @(x) updateTrainingProgress(x,monitor); afterEach(Q,updateFcn);

spmd workerImds = partition(imdsTrain,N,spmdIndex); workerImds.ReadSize = workerMiniBatchSize(spmdIndex);

workerVelocity = velocity;

iteration = 0;
lossArray = [];
accuracyArray = [];

for epoch = 1:numEpochs
    reset(workerImds);
    workerImds = shuffle(workerImds);
    
    if ~monitor.Stop
        while spmdReduce(@and,hasdata(workerImds))
            iteration = iteration + 1;
            
            [workerXBatch,workerTBatch] = read(workerImds);
            workerXBatch = cat(4,workerXBatch{:});
            workerNumObservations = numel(workerTBatch.Label);

            workerXBatch =  single(workerXBatch) ./ 255;
            
            workerY = zeros(numClasses,workerNumObservations,"single");
            for c = 1:numClasses
                workerY(c,workerTBatch.Label==classes(c)) = 1;
            end
            
            workerX = dlarray(workerXBatch,"SSCB");
            
            if executionEnvironment == "gpu"
                workerX = gpuArray(workerX);
            end
            
            [workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerY);
            
            workerNormalizationFactor = workerMiniBatchSize(spmdIndex)./miniBatchSize;
            loss = spmdPlus(workerNormalizationFactor*extractdata(workerLoss));
            
            net.State = aggregateState(workerState,workerNormalizationFactor);
            
            workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
            
            learnRate = initialLearnRate/(1 + decay*iteration);
            
            [net.Learnables,workerVelocity] = sgdmupdate(net.Learnables,workerGradients,workerVelocity,learnRate,momentum);
        end             
        
        if spmdIndex == 1
            YPredScores = predict(net,dlarray(XTest,"SSCB"));
            [~,idx] = max(YPredScores,[],1);
            Ypred = classes(idx);
            accuracy = mean(Ypred==trueLabels);
            
            lossArray = [lossArray; iteration, loss];
            accuracyArray = [accuracyArray; iteration, accuracy];
            
            data = [numEpochs epoch iteration loss accuracy];
            send(Q,gather(data)); 
        end  
    end
end

end

output.network = net{1}; output.loss = lossArray{1}; output.accuracy = accuracyArray{1}; predictedLabels = categorical(Ypred{1});

delete(gcp("nocreate"));

figure(Name="Confusion Matrix") confusionchart(trueLabels,predictedLabels, ... ColumnSummary="column-normalized", ... RowSummary="row-normalized", ... Title="Confusion Matrix for Validation Data");

Run Experiment

When you run the experiment, Experiment Manager trains the network defined by the training function multiple times. Each trial uses a different combination of hyperparameter values.

Because this experiment uses the parallel pool for this MATLAB session, you cannot train multiple trials at the same time. On the Experiment Manager toolstrip, set Mode to Sequential and click Run. Alternatively, to offload the experiment as a batch job, set Mode to Batch Sequential, specify your Cluster and Pool Size, and click Run. For more information, see Offload Experiments as Batch Jobs to a Cluster.

A table of results displays the training loss and validation accuracy for each trial.

To display the training plot and track the progress of each trial while the experiment is running, under Review Results, click Training Plot.

Note that the training function for this experiment uses an spmd statement, which cannot contain break, continue, or return statements. As a result, you cannot interrupt a trial of the experiment while training is in progress. If you click Stop, Experiment Manager runs the current trial to completion before stopping the experiment.

Evaluate Results

To find the best result for your experiment, sort the table of results by validation accuracy:

  1. Point to the ValidationAccuracy column.
  2. Click the triangle icon.
  3. Select Sort in Descending Order.

The trial with the highest validation accuracy appears at the top of the results table.

To display the confusion matrix for this trial, select the top row in the results table and, under Review Results, click Confusion Matrix.

To record observations about the results of your experiment, add an annotation:

  1. In the results table, right-click the ValidationAccuracy cell of the best trial.
  2. Select Add Annotation.
  3. In the Annotations pane, enter your observations in the text box.

Close Experiment

In the Experiment Browser pane, right-click DigitClassificationInParallelProject and select Close Project. Experiment Manager closes the experiment and results contained in the project.

Training Function

This function configures the training data, network architecture, and training options for the experiment. To execute the code simultaneously on all the workers, the function uses an spmd block. Within the spmd block, spmdIndex gives the index of the worker currently executing the code. Before training, the function partitions the datastore for each worker by using the partition function, and sets ReadSize to the mini-batch size of the worker. For each epoch, the function resets and shuffles the datastore. For each iteration in the epoch, the function:

At the end of each epoch, the function uses only worker to send the training progress information back to the client.

function output = ParallelCustomLoopExperiment_training(params,monitor)

Initialize Output

output.network = []; output.loss = []; output.accuracy = [];

Load Training and Test Data

monitor.Status = "Loading Data";

dataFolder = fullfile(toolboxdir('nnet'), ... 'nndemos','nndatasets','DigitDataset'); imds = imageDatastore(dataFolder, ... IncludeSubfolders=true, ... LabelSource="foldernames"); [imdsTrain,imdsTest] = splitEachLabel(imds,0.9,"randomized");

classes = categories(imdsTrain.Labels); numClasses = numel(classes);

XTest = readall(imdsTest); XTest = cat(4,XTest{:}); XTest = single(XTest) ./ 255; trueLabels = imdsTest.Labels;

Define Network Architecture

monitor.Status = "Creating Network";

layers = [ imageInputLayer([28 28 1],Normalization="none") convolution2dLayer(5,20) batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding=1) batchNormalizationLayer reluLayer convolution2dLayer(3,20,Padding=1) batchNormalizationLayer reluLayer fullyConnectedLayer(numClasses)];

lgraph = layerGraph(layers);

net = dlnetwork(lgraph);

Set Up Parallel Environment

monitor.Status = "Starting Parallel Pool";

pool = gcp("nocreate");

if canUseGPU executionEnvironment = "gpu"; if isempty(pool) numberOfGPUs = gpuDeviceCount("available"); pool = parpool(numberOfGPUs); end else executionEnvironment = "cpu"; if isempty(pool) pool = parpool; end end

N = pool.NumWorkers;

Specify Training Options

numEpochs = 20; miniBatchSize = 128; velocity = []; initialLearnRate = params.InitialLearnRate; momentum = params.Momentum; decay = 0.01;

if executionEnvironment == "gpu" miniBatchSize = miniBatchSize .* N; end

workerMiniBatchSize = floor(miniBatchSize ./ repmat(N,1,N)); remainder = miniBatchSize - sum(workerMiniBatchSize); workerMiniBatchSize = workerMiniBatchSize + [ones(1,remainder) zeros(1,N-remainder)];

Train Model

monitor.Metrics = ["TrainingLoss" "ValidationAccuracy"]; monitor.XLabel = "Iteration"; monitor.Status = "Training";

Q = parallel.pool.DataQueue; updateFcn = @(x) updateTrainingProgress(x,monitor); afterEach(Q,updateFcn);

spmd workerImds = partition(imdsTrain,N,spmdIndex); workerImds.ReadSize = workerMiniBatchSize(spmdIndex);

workerVelocity = velocity;

iteration = 0;
lossArray = [];
accuracyArray = [];

for epoch = 1:numEpochs
    reset(workerImds);
    workerImds = shuffle(workerImds);
    
    if ~monitor.Stop
        while spmdReduce(@and,hasdata(workerImds))
            iteration = iteration + 1;
            
            [workerXBatch,workerTBatch] = read(workerImds);
            workerXBatch = cat(4,workerXBatch{:});
            workerNumObservations = numel(workerTBatch.Label);

            workerXBatch =  single(workerXBatch) ./ 255;
            
            workerY = zeros(numClasses,workerNumObservations,"single");
            for c = 1:numClasses
                workerY(c,workerTBatch.Label==classes(c)) = 1;
            end
            
            workerX = dlarray(workerXBatch,"SSCB");
            
            if executionEnvironment == "gpu"
                workerX = gpuArray(workerX);
            end
            
            [workerLoss,workerGradients,workerState] = dlfeval(@modelLoss,net,workerX,workerY);
            
            workerNormalizationFactor = workerMiniBatchSize(spmdIndex)./miniBatchSize;
            loss = spmdPlus(workerNormalizationFactor*extractdata(workerLoss));
            
            net.State = aggregateState(workerState,workerNormalizationFactor);
            
            workerGradients.Value = dlupdate(@aggregateGradients,workerGradients.Value,{workerNormalizationFactor});
            
            learnRate = initialLearnRate/(1 + decay*iteration);
            
            [net.Learnables,workerVelocity] = sgdmupdate(net.Learnables,workerGradients,workerVelocity,learnRate,momentum);
        end             
        
        if spmdIndex == 1
            YPredScores = predict(net,dlarray(XTest,"SSCB"));
            [~,idx] = max(YPredScores,[],1);
            Ypred = classes(idx);
            accuracy = mean(Ypred==trueLabels);
            
            lossArray = [lossArray; iteration, loss];
            accuracyArray = [accuracyArray; iteration, accuracy];
            
            data = [numEpochs epoch iteration loss accuracy];
            send(Q,gather(data)); 
        end  
    end
end

end

output.network = net{1}; output.loss = lossArray{1}; output.accuracy = accuracyArray{1}; predictedLabels = categorical(Ypred{1});

delete(gcp("nocreate"));

Plot Confusion Matrix

figure(Name="Confusion Matrix") confusionchart(trueLabels,predictedLabels, ... ColumnSummary="column-normalized", ... RowSummary="row-normalized", ... Title="Confusion Matrix for Validation Data");

Helper Functions

The modelLoss function takes a dlnetwork object net and a mini-batch of input data X with corresponding labels Y. The function returns the gradients of the loss with respect to the learnable parameters in net, the network state, and the loss. To compute the gradients automatically, the function calls the dlgradient function.

function [loss,gradients,state] = modelLoss(net,X,Y) [YPred,state] = forward(net,X); YPred = softmax(YPred); loss = crossentropy(YPred,Y); gradients = dlgradient(loss,net.Learnables); end

The updateTrainingProgress function updates the training progress information that comes from the workers. In this example, the DataQueue object calls this function every time a worker sends data.

function updateTrainingProgress(data,monitor) monitor.Progress = (data(2)/data(1))*100; recordMetrics(monitor,data(3), ... TrainingLoss=data(4), ... ValidationAccuracy=data(5)); end

The aggregateGradients function aggregates the gradients on all workers by adding them together. spmdplus adds together and replicates all the gradients on the workers. Before adding the gradients, this function normalizes them by multiplying by a factor that represents the proportion of the overall mini-batch that the worker is working on.

function gradients = aggregateGradients(gradients,factor) gradients = spmdPlus(factor*gradients); end

The aggregateState function aggregates the network state on all workers. The network state contains the trained batch normalization statistics of the data set. Because each worker only sees a portion of the mini-batch, this function aggregates the network state so that the statistics are representative of the statistics across all the data. For each mini-batch, this function calculates the combined mean as a weighted average of the mean across the workers for each iteration. This function computes the combined variance according to the formula

$$s_c^2 = \frac{1}{M} \sum_{j=1}^{N}m_j[s_j^2 + (\bar{x_j} -
\bar{x_c})^2],$$

where $N$ is the total number of workers, $M$ is the total number of observations in a mini-batch, $m_j$ is the number of observations processed on the $j$ th worker, $\bar{x}_j$ and $s_j^2$ are the mean and variance statistics calculated on that worker, and $\bar{x}_c$ is the combined mean across all workers.

function state = aggregateState(state,factor) numrows = size(state,1); for j = 1:numrows isBatchNormalizationState = state.Parameter(j) =="TrainedMean"... && state.Parameter(j+1) =="TrainedVariance"... && state.Layer(j) == state.Layer(j+1);

if isBatchNormalizationState
    meanVal = state.Value{j};
    varVal = state.Value{j+1};
    combinedMean = spmdPlus(factor*meanVal);
    combinedVarTerm = factor.*(varVal + (meanVal - combinedMean).^2);
    state.Value(j) = {combinedMean};
    state.Value(j+1) = {spmdPlus(combinedVarTerm)};
end

end end

See Also

Apps

Objects

Functions