turicreate.nearest_neighbor_classifier.NearestNeighborClassifier.predict — Turi Create API 6.4.1 documentation (original) (raw)

NearestNeighborClassifier. predict(dataset, max_neighbors=10, radius=None, output_type='class', verbose=True)

Return predicted class labels for instances in dataset. This model makes predictions based on the closest neighbors stored in the nearest neighbors classifier model.

Parameters: dataset : SFrame Dataset of new observations. Must include the features used for model training, but does not require a target column. Additional columns are ignored. max_neighbors : int, optional Maximum number of neighbors to consider for each point. radius : float, optional Maximum distance from each point to a neighbor in the reference dataset. output_type : {‘class’, ‘probability’}, optional Type of prediction output: class: Predicted class label. The class with the maximum number of votes among the nearest neighbors in the reference dataset. probability: Maximum number of votes for any class out of all nearest neighbors in the reference dataset.
Returns: out : SArray An SArray with model predictions.

Notes

Examples

sf_train = turicreate.SFrame({'species': ['cat', 'dog', 'fossa', 'dog'], ... 'height': [9, 25, 20, 23], ... 'weight': [13, 28, 33, 22]}) ... sf_new = turicreate.SFrame({'height': [26, 19], ... 'weight': [25, 35]}) ... m = turicreate.nearest_neighbor_classifier.create(sf, target='species') ystar = m.predict(sf_new, max_neighbors=2, output_type='class') print ystar ['dog', 'fossa']