turicreate.nearest_neighbor_classifier.NearestNeighborClassifier.predict — Turi Create API 6.4.1 documentation (original) (raw)
NearestNeighborClassifier.
predict
(dataset, max_neighbors=10, radius=None, output_type='class', verbose=True)¶
Return predicted class labels for instances in dataset. This model makes predictions based on the closest neighbors stored in the nearest neighbors classifier model.
Parameters: | dataset : SFrame Dataset of new observations. Must include the features used for model training, but does not require a target column. Additional columns are ignored. max_neighbors : int, optional Maximum number of neighbors to consider for each point. radius : float, optional Maximum distance from each point to a neighbor in the reference dataset. output_type : {‘class’, ‘probability’}, optional Type of prediction output: class: Predicted class label. The class with the maximum number of votes among the nearest neighbors in the reference dataset. probability: Maximum number of votes for any class out of all nearest neighbors in the reference dataset. |
---|---|
Returns: | out : SArray An SArray with model predictions. |
Notes
- If the ‘radius’ parameter is small, it is possible that a query point has no qualified neighbors in the training dataset. In this case, the result for that query is ‘None’ in the SArray output by this method. If the target column in the training dataset has missing values, these predictions will be ambiguous.
- Ties between predicted classes are broken randomly.
Examples
sf_train = turicreate.SFrame({'species': ['cat', 'dog', 'fossa', 'dog'], ... 'height': [9, 25, 20, 23], ... 'weight': [13, 28, 33, 22]}) ... sf_new = turicreate.SFrame({'height': [26, 19], ... 'weight': [25, 35]}) ... m = turicreate.nearest_neighbor_classifier.create(sf, target='species') ystar = m.predict(sf_new, max_neighbors=2, output_type='class') print ystar ['dog', 'fossa']