predict - Predict labels using k-nearest neighbor classification

        model - MATLAB ([original](http://www.mathworks.com/help/stats/classificationknn.predict.html)) ([raw](?raw))

Predict labels using _k_-nearest neighbor classification model

Syntax

Description

[label](#d126e988883) = predict([mdl](#bs85nou%5Fsep%5Fshared-mdl),[X](#d126e988737)) returns a vector of predicted class labels for the predictor data in the table or matrix X, based on the trained _k_-nearest neighbor classification model mdl. See Predicted Class Label.

example

[[label](#d126e988883),[score](#d126e988927),[cost](#d126e988978)] = predict([mdl](#bs85nou%5Fsep%5Fshared-mdl),[X](#d126e988737)) also returns:

example

Examples

collapse all

Create a _k_-nearest neighbor classifier for Fisher's iris data, where k = 5. Evaluate some model predictions on new data.

Load the Fisher iris data set.

load fisheriris X = meas; Y = species;

Create a classifier for five nearest neighbors. Standardize the noncategorical predictor data.

mdl = fitcknn(X,Y,'NumNeighbors',5,'Standardize',1);

Predict the classifications for flowers with minimum, mean, and maximum characteristics.

Xnew = [min(X);mean(X);max(X)]; [label,score,cost] = predict(mdl,Xnew)

label = 3×1 cell {'versicolor'} {'versicolor'} {'virginica' }

score = 3×3

0.4000    0.6000         0
     0    1.0000         0
     0         0    1.0000

cost = 3×3

0.6000    0.4000    1.0000
1.0000         0    1.0000
1.0000    1.0000         0

The second and third rows of the score and cost matrices have binary values, which means all five nearest neighbors of the mean and maximum flower measurements have identical classifications.

Train _k_-nearest neighbor classifiers for various k values, and compare the decision boundaries of the classifiers.

Load the fisheriris data set.

The data set contains length and width measurements from the sepals and petals of three species of iris flowers. Remove the sepal lengths and widths, and all observed setosa irises.

inds = ~strcmp(species,'setosa'); X = meas(inds,3:4); species = species(inds);

Create a binary label variable y. The label is 1 for a virginica iris and 0 for versicolor.

y = strcmp(species,'virginica');

Train the _k-_nearest neighbor classifier. Specify 5 as the number of nearest neighbors to find, and standardize the predictor data.

EstMdl = fitcknn(X,y,'NumNeighbors',5,'Standardize',1)

EstMdl = ClassificationKNN ResponseName: 'Y' CategoricalPredictors: [] ClassNames: [0 1] ScoreTransform: 'none' NumObservations: 100 Distance: 'euclidean' NumNeighbors: 5

Properties, Methods

EstMdl is a trained ClassificationKNN classifier. Some of its properties appear in the Command Window.

Plot the decision boundary, which is the line that distinguishes between the two iris species based on their features.

x1 = min(X(:,1)):0.01:max(X(:,1)); x2 = min(X(:,2)):0.01:max(X(:,2)); [x1G,x2G] = meshgrid(x1,x2); XGrid = [x1G(:),x2G(:)]; pred = predict(EstMdl,XGrid);

figure gscatter(XGrid(:,1),XGrid(:,2),pred,[1,0,0;0,0.5,1]) hold on plot(X(y == 0,1),X(y == 0,2),'ko', ... X(y == 1,1),X(y == 1,2),'kx') xlabel('Petal length (cm)') ylabel('Petal width (cm)') title('{\bf 5-Nearest Neighbor Classifier Decision Boundary}') legend('Versicolor Region','Virginica Region', ... 'Sampled Versicolor','Sampled Virginica', ... 'Location','best') axis tight hold off

Figure contains an axes object. The axes object with title blank 5 -Nearest blank Neighbor blank Classifier blank Decision blank Boundary, xlabel Petal length (cm), ylabel Petal width (cm) contains 4 objects of type line. One or more of the lines displays its values using only markers These objects represent Versicolor Region, Virginica Region, Sampled Versicolor, Sampled Virginica.

The partition between the red and blue regions is the decision boundary. If you change the number of neighbors k, then the boundary changes.

Retrain the classifier using k = 1 (default value for NumNeighbors of fitcknn) and k = 20.

EstMdl1 = fitcknn(X,y); pred1 = predict(EstMdl1,XGrid);

EstMdl20 = fitcknn(X,y,'NumNeighbors',20); pred20 = predict(EstMdl20,XGrid);

figure gscatter(XGrid(:,1),XGrid(:,2),pred1,[1,0,0;0,0.5,1]) hold on plot(X(y == 0,1),X(y == 0,2),'ko', ... X(y == 1,1),X(y == 1,2),'kx') xlabel('Petal length (cm)') ylabel('Petal width (cm)') title('{\bf 1-Nearest Neighbor Classifier Decision Boundary}') legend('Versicolor Region','Virginica Region', ... 'Sampled Versicolor','Sampled Virginica', ... 'Location','best') axis tight hold off

Figure contains an axes object. The axes object with title blank 1 -Nearest blank Neighbor blank Classifier blank Decision blank Boundary, xlabel Petal length (cm), ylabel Petal width (cm) contains 4 objects of type line. One or more of the lines displays its values using only markers These objects represent Versicolor Region, Virginica Region, Sampled Versicolor, Sampled Virginica.

figure gscatter(XGrid(:,1),XGrid(:,2),pred20,[1,0,0;0,0.5,1]) hold on plot(X(y == 0,1),X(y == 0,2),'ko', ... X(y == 1,1),X(y == 1,2),'kx') xlabel('Petal length (cm)') ylabel('Petal width (cm)') title('{\bf 20-Nearest Neighbor Classifier Decision Boundary}') legend('Versicolor Region','Virginica Region', ... 'Sampled Versicolor','Sampled Virginica', ... 'Location','best') axis tight hold off

Figure contains an axes object. The axes object with title blank 20 -Nearest blank Neighbor blank Classifier blank Decision blank Boundary, xlabel Petal length (cm), ylabel Petal width (cm) contains 4 objects of type line. One or more of the lines displays its values using only markers These objects represent Versicolor Region, Virginica Region, Sampled Versicolor, Sampled Virginica.

The decision boundary seems to linearize as k increases. This linearization happens because the algorithm down-weights the importance of each input with increasing k. When k = 1, the algorithm correctly predicts the species of almost all training samples. When k = 20, the algorithm has a higher misclassification rate within the training set. You can find an optimal value of k by using the OptimizeHyperparameters name-value argument of fitcknn. For an example, see Optimize Fitted KNN Classifier.

Input Arguments

collapse all

_k_-nearest neighbor classifier model, specified as aClassificationKNN object.

Predictor data to be classified, specified as a numeric matrix or table.

Each row of X corresponds to one observation, and each column corresponds to one variable.

If you set 'Standardize',true infitcknn to train mdl, then the software standardizes the columns of X using the corresponding means in mdl.Mu and standard deviations inmdl.Sigma.

Data Types: double | single | table

Output Arguments

collapse all

Predicted class labels for the observations (rows) inX, returned as a categorical array, character array, logical vector, vector of numeric values, or cell array of character vectors. label has length equal to the number of rows in X.

For each observation, the label is the class with minimal expected cost.For an observation with NaN scores, the function classifies the observation into the majority class, which makes up the largest proportion of the training labels.

See Predicted Class Label.

Predicted class scores or posterior probabilities, returned as a numeric matrix of size _n_-by-K.n is the number of observations (rows) inX, and K is the number of classes (in mdl.ClassNames).score(i,j) is the posterior probability that observation i in X is of classj in mdl.ClassNames. See Posterior Probability.

Data Types: single | double

Expected classification costs, returned as a numeric matrix of size_n_-by-K. n is the number of observations (rows) in X, and_K_ is the number of classes (inmdl.ClassNames). cost(i,j) is the cost of classifying row i of X as class j in mdl.ClassNames. See Expected Cost.

Data Types: single | double

Algorithms

collapse all

predict classifies by minimizing the expected misclassification cost:

where:

Consider a vector (single query point) xnew and a modelmdl.

If the model contains a vector of prior probabilities, then the observation weightsW are normalized by class to sum to the priors. This process might involve a calculation for the point xnew, because weights can depend on the distance from xnew to the points in mdl.X.

The posterior probability p(j|xnew) is

Here, 1Y(X(i))=j is 1 whenmdl.Y(i) = j, and0 otherwise.

Two costs are associated with KNN classification: the true misclassification cost per class and the expected misclassification cost per observation.

You can set the true misclassification cost per class by using the 'Cost' name-value pair argument when you run fitcknn. The value Cost(i,j) is the cost of classifying an observation into class j if its true class is i. By default, Cost(i,j) = 1 if i ~= j, andCost(i,j) = 0 if i = j. In other words, the cost is 0 for correct classification and 1 for incorrect classification.

Two costs are associated with KNN classification: the true misclassification cost per class and the expected misclassification cost per observation. The third output of predict is the expected misclassification cost per observation.

Suppose you have Nobs observations that you want to classify with a trained classifier mdl, and you have K classes. You place the observations into a matrix Xnew with one observation per row. The command

[label,score,cost] = predict(mdl,Xnew)

returns a matrix cost of sizeNobs-by-K, among other outputs. Each row of thecost matrix contains the expected (average) cost of classifying the observation into each of the K classes. cost(n,j) is

where

Alternative Functionality

To integrate the prediction of a nearest neighbor classification model into Simulink®, you can use the ClassificationKNN Predict block in the Statistics and Machine Learning Toolbox™ library or a MATLAB® Function block with the predict function. For examples, see Predict Class Labels Using ClassificationKNN Predict Block and Predict Class Labels Using MATLAB Function Block.

When deciding which approach to use, consider the following:

Extended Capabilities

expand all

Thepredict function fully supports tall arrays. For more information, see Tall Arrays.

Usage notes and limitations:

For more information, see Introduction to Code Generation.

Usage notes and limitations:

For more information, see Run MATLAB Functions on a GPU (Parallel Computing Toolbox).

Version History

Introduced in R2012a