Classify Text Data Using Convolutional Neural Network - MATLAB & Simulink (original) (raw)

This example shows how to classify text data using a convolutional neural network.

To classify text data using convolutions, use 1-D convolutional layers that convolve over the time dimension of the input.

This example trains a network with 1-D convolutional filters of varying widths. The width of each filter corresponds the number of words the filter can see (the n-gram length). The network has multiple branches of convolutional layers, so it can use different n-gram lengths.

Load Data

Create a tabular text datastore from the data in factoryReports.csv and view the first few reports.

data = readtable("factoryReports.csv"); head(data)

                              Description                                         Category            Urgency            Resolution          Cost 
_______________________________________________________________________    ______________________    __________    ______________________    _____

{'Items are occasionally getting stuck in the scanner spools.'        }    {'Mechanical Failure'}    {'Medium'}    {'Readjust Machine'  }       45
{'Loud rattling and banging sounds are coming from assembler pistons.'}    {'Mechanical Failure'}    {'Medium'}    {'Readjust Machine'  }       35
{'There are cuts to the power when starting the plant.'               }    {'Electronic Failure'}    {'High'  }    {'Full Replacement'  }    16200
{'Fried capacitors in the assembler.'                                 }    {'Electronic Failure'}    {'High'  }    {'Replace Components'}      352
{'Mixer tripped the fuses.'                                           }    {'Electronic Failure'}    {'Low'   }    {'Add to Watch List' }       55
{'Burst pipe in the constructing agent is spraying coolant.'          }    {'Leak'              }    {'High'  }    {'Replace Components'}      371
{'A fuse is blown in the mixer.'                                      }    {'Electronic Failure'}    {'Low'   }    {'Replace Components'}      441
{'Things continue to tumble off of the belt.'                         }    {'Mechanical Failure'}    {'Low'   }    {'Readjust Machine'  }       38

Partition the data into training and validation partitions. Use 80% of the data for training and the remaining data for validation.

cvp = cvpartition(data.Category,Holdout=0.2); dataTrain = data(training(cvp),:); dataValidation = data(test(cvp),:);

Preprocess Text Data

Extract the text data from the "Description" column of the table and preprocess it using the preprocessText function, listed in the section Preprocess Text Function of the example.

documentsTrain = preprocessText(dataTrain.Description);

Extract the labels from the "Category" column and convert them to categorical.

TTrain = categorical(dataTrain.Category);

View the class names and the number of observations.

classNames = unique(TTrain)

classNames = 4×1 categorical Electronic Failure Leak Mechanical Failure Software Failure

numObservations = numel(TTrain)

Extract and preprocess the validation data using the same steps.

documentsValidation = preprocessText(dataValidation.Description); TValidation = categorical(dataValidation.Category);

Convert Documents to Sequences

To input the documents into a neural network, use a word encoding to convert the documents into sequences of numeric indices.

Create a word encoding from the documents.

enc = wordEncoding(documentsTrain);

View the vocabulary size of the word encoding. The vocabulary size is the number of unique words of the word encoding.

Convert the documents to sequences of integers using the doc2sequence function.

XTrain = doc2sequence(enc,documentsTrain);

Convert the validation documents to sequences using the word encoding created from the training data.

XValidation = doc2sequence(enc,documentsValidation);

Define Network Architecture

Define the network architecture for the classification task.

The following steps describe the network architecture.

Specify the network hyperparameters.

embeddingDimension = 100; ngramLengths = [2 3 4 5]; numFilters = 200;

First, create a dlnetwork object containing the input layer and a word embedding layer of dimension 100. To help connect the word embedding layer to the convolution layers, set the word embedding layer name to "emb". To check that the convolution layers do not convolve the sequences to have a length of zero during training, set the MinLength option to the length of the shortest sequence in the training data.

net = dlnetwork; minLength = min(doclength(documentsTrain)); layers = [ sequenceInputLayer(1,MinLength=minLength) wordEmbeddingLayer(embeddingDimension,numWords,Name="emb")]; net = addLayers(net,layers);

For each of the n-gram lengths, create a block of 1-D convolution, batch normalization, ReLU, dropout, and 1-D global max pooling layers. Connect each block to the word embedding layer.

numBlocks = numel(ngramLengths); for j = 1:numBlocks N = ngramLengths(j);

block = [
    convolution1dLayer(N,numFilters,Name="conv"+N,Padding="same")
    batchNormalizationLayer(Name="bn"+N)
    reluLayer(Name="relu"+N)
    dropoutLayer(0.2,Name="drop"+N)
    globalMaxPooling1dLayer(Name="max"+N)];

net = addLayers(net,block);
net = connectLayers(net,"emb","conv"+N);

end

Add the concatenation layer, the fully connected layer, and the softmax layer.

numClasses = numel(classNames);

layers = [ concatenationLayer(1,numBlocks,Name="cat") fullyConnectedLayer(numClasses,Name="fc") softmaxLayer(Name="soft")];

net = addLayers(net,layers);

Connect the global max pooling layers to the concatenation layer and view the network architecture in a plot.

for j = 1:numBlocks N = ngramLengths(j); net = connectLayers(net,"max"+N,"cat/in"+j); end

figure plot(net) title("Network Architecture")

Train Network

Specify the training options:

options = trainingOptions("adam", ... MiniBatchSize=128, ... ValidationData={XValidation,TValidation}, ... OutputNetwork="best-validation", ... Plots="training-progress", ... Metrics="accuracy", ... Verbose=false, ... InputDataFormats='CTB');

Train the network using the trainnet function.

net = trainnet(XTrain,TTrain,net,"crossentropy",options);

Test Network

Make predictions using the neural network. To make predictions with multiple observations, use the minibatchpredict function. To convert the prediction scores to labels, use the scores2label function. The minibatchpredict function automatically uses a GPU if one is available. Using a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information on supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the function uses the CPU.

Because the data has sequences with rows and columns corresponding to channels and time steps, respectively, specify the input data format "CTB" (channel, time, batch).

scores = minibatchpredict(net,XValidation,InputDataFormats="CTB"); YValidation = scores2label(scores,classNames);

Visualize the predictions in a confusion chart.

figure confusionchart(TValidation,YValidation)

Calculate the classification accuracy. The accuracy is the proportion of labels predicted correctly.

accuracy = mean(TValidation == YValidation)

Predict Using New Data

Classify the event type of three new reports. Create a string array containing the new reports.

reportsNew = [ "Coolant is pooling underneath sorter." "Sorter blows fuses at start up." "There are some very loud rattling sounds coming from the assembler."];

Preprocess the text data using the preprocessing steps as the training and validation documents.

documentsNew = preprocessText(reportsNew); XNew = doc2sequence(enc,documentsNew);

Classify the new sequences using the trained network.

scores = minibatchpredict(net,XNew,InputDataFormats="CTB"); YNew = scores2label(scores,classNames)

YNew = 3×1 categorical Leak Electronic Failure Mechanical Failure

Preprocess Text Function

The preprocessTextData function takes text data as input and performs these steps:

  1. Tokenize the text.
  2. Convert the text to lowercase.

function documents = preprocessText(textData)

documents = tokenizedDocument(textData); documents = lower(documents);

end

See Also

trainnet | trainingOptions | dlnetwork | fastTextWordEmbedding (Text Analytics Toolbox) | wordcloud (Text Analytics Toolbox) | wordEmbedding (Text Analytics Toolbox) | convolution2dLayer | batchNormalizationLayer | doc2sequence (Text Analytics Toolbox) | tokenizedDocument (Text Analytics Toolbox) | transform