classify - Classify document using BERT document classifier - MATLAB (original) (raw)

Classify document using BERT document classifier

Since R2023b

Syntax

Description

[Y](#mw%5F227d06c6-6c8a-4de6-b757-a14fb2381945) = classify([mdl](#mw%5Fd05807fc-4f90-45a4-95ed-9f5c90f4d798%5Fsep%5Fmw%5F9128164b-bed4-44f3-ab46-b6cf1853a02a),[documents](#mw%5Fedb59c79-e7d7-4c4b-bd18-941e58250fcd)) classifies the specified documents using the BERT document classifiermdl.

example

[Y](#mw%5F227d06c6-6c8a-4de6-b757-a14fb2381945) = classify([mdl](#mw%5Fd05807fc-4f90-45a4-95ed-9f5c90f4d798%5Fsep%5Fmw%5F9128164b-bed4-44f3-ab46-b6cf1853a02a),[documents](#mw%5Fedb59c79-e7d7-4c4b-bd18-941e58250fcd),[Name=Value](#namevaluepairarguments)) specifies additional options using one or more name-value arguments.

Examples

collapse all

Train BERT Document Classifier

Read the training data from the factoryReports CSV file. The file contains factory reports, including a text description and categorical label for each report.

filename = "factoryReports.csv"; data = readtable(filename,TextType="string"); head(data)

                             Description                                       Category          Urgency          Resolution         Cost 
_____________________________________________________________________    ____________________    ________    ____________________    _____

"Items are occasionally getting stuck in the scanner spools."            "Mechanical Failure"    "Medium"    "Readjust Machine"         45
"Loud rattling and banging sounds are coming from assembler pistons."    "Mechanical Failure"    "Medium"    "Readjust Machine"         35
"There are cuts to the power when starting the plant."                   "Electronic Failure"    "High"      "Full Replacement"      16200
"Fried capacitors in the assembler."                                     "Electronic Failure"    "High"      "Replace Components"      352
"Mixer tripped the fuses."                                               "Electronic Failure"    "Low"       "Add to Watch List"        55
"Burst pipe in the constructing agent is spraying coolant."              "Leak"                  "High"      "Replace Components"      371
"A fuse is blown in the mixer."                                          "Electronic Failure"    "Low"       "Replace Components"      441
"Things continue to tumble off of the belt."                             "Mechanical Failure"    "Low"       "Readjust Machine"         38

Convert the labels in the Category column of the table to categorical values.

data.Category = categorical(data.Category);

Partition the data into a training set and a test set. Specify the holdout percentage as 10%.

cvp = cvpartition(data.Category,Holdout=0.1); dataTrain = data(cvp.training,:); dataTest = data(cvp.test,:);

Extract the text data and labels from the tables.

textDataTrain = dataTrain.Description; textDataTest = dataTest.Description; TTrain = dataTrain.Category; TTest = dataTest.Category;

Load a pretrained BERT-Base document classifier using the bertDocumentClassifier function.

classNames = categories(data.Category); mdl = bertDocumentClassifier(ClassNames=classNames)

mdl = bertDocumentClassifier with properties:

   Network: [1×1 dlnetwork]
 Tokenizer: [1×1 bertTokenizer]
ClassNames: ["Electronic Failure"    "Leak"    "Mechanical Failure"    "Software Failure"]

Specify the training options. Choosing among training options requires empirical analysis. To explore different training option configurations by running experiments, you can use the Experiment Manager (Deep Learning Toolbox) app.

options = trainingOptions("adam", ... MaxEpochs=8, ... InitialLearnRate=1e-4, ... Shuffle="every-epoch", ...
Plots="training-progress", ... Metrics="accuracy", ... Verbose=false);

Train the neural network using the trainBERTDocumentClassifier function. By default, the trainBERTDocumentClassifier function uses a GPU if one is available. Training on a GPU requires a Parallel Computing Toolbox™ license and a supported GPU device. For information about supported devices, see GPU Computing Requirements (Parallel Computing Toolbox). Otherwise, the trainBERTDocumentClassifier function uses the CPU. To specify the execution environment, use the ExecutionEnvironment training option.

mdl = trainBERTDocumentClassifier(textDataTrain,TTrain,mdl,options);

Make predictions using the test data.

YTest = classify(mdl,textDataTest);

Calculate the classification accuracy of the test predictions.

accuracy = mean(TTest == YTest)

Input Arguments

collapse all

mdl — BERT document classifier model

bertDocumentClassifier object

BERT document classifier model, specified as a bertDocumentClassifier object.

documents — Input documents

string array | cell array of character vectors | tokenizedDocument array

Input documents, specified as a string array, a cell array of character vectors, or a tokenizedDocument array.

Name-Value Arguments

Specify optional pairs of arguments asName1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

Example: classify(mdl,document,MiniBatchSize=64) classifies the specified documents using mini-batches of size 64.

MiniBatchSize — Mini-batch size

32 (default) | positive integer

Mini-batch size to use for prediction, specified as a positive integer. Larger mini-batch sizes require more memory, but can lead to faster predictions.

Data Types: single | double | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64

Acceleration — Performance optimization

"auto" (default) | "mex" | "none"

Performance optimization, specified as one of these values:

When you use the "auto" or "mex" option, the software can offer performance benefits at the expense of an increased initial run time. Subsequent calls to the function are typically faster. Use performance optimization when you call the function multiple times using different input data.

When Acceleration is "mex", the software generates and executes a MEX function based on the model and parameters you specify in the function call. A single model can have several associated MEX functions at one time. Clearing the model variable also clears any MEX functions associated with that model.

When Acceleration is"auto", the software does not generate a MEX function.

The "mex" option is available only when you use a GPU. You must have a C/C++ compiler installed and the GPU Coder™ Interface for Deep Learning support package. Install the support package using the Add-On Explorer in MATLAB®. For setup instructions, see MEX Setup (GPU Coder). GPU Coder is not required.

MATLAB Compiler™ software does not support compiling models when you use the"mex" option.

ExecutionEnvironment — Hardware resource

"auto" (default) | "gpu" | "cpu"

Hardware resource, specified as one of these values:

Output Arguments

collapse all

Y — Predicted classes

categorical array

Predicted classes, returned as a categorical array.

Version History

Introduced in R2023b