ClassificationModel (Spark 3.5.5 JavaDoc) (original) (raw)
Object
- org.apache.spark.ml.PipelineStage
- org.apache.spark.ml.Transformer
- org.apache.spark.ml.Model
* * org.apache.spark.ml.PredictionModel<FeaturesType,M>
* * org.apache.spark.ml.classification.ClassificationModel<FeaturesType,M>
- org.apache.spark.ml.Model
- org.apache.spark.ml.Transformer
Type Parameters:
FeaturesType
- Type of input features. E.g.,Vector
M
- Concrete Model type
All Implemented Interfaces:
java.io.Serializable, org.apache.spark.internal.Logging, ClassifierParams, Params, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasRawPredictionCol, PredictorParams, Identifiable
Direct Known Subclasses:
LinearSVCModel, ProbabilisticClassificationModel
public abstract class ClassificationModel<FeaturesType,M extends ClassificationModel<FeaturesType,M>>
extends PredictionModel<FeaturesType,M>
implements ClassifierParams
Model produced by a Classifier. Classes are indexed {0, 1, ..., numClasses - 1}.
See Also:
Serialized Form
Nested Class Summary
* ### Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging `org.apache.spark.internal.Logging.SparkShellLoggingFilter`
Constructor Summary
Constructors
Constructor and Description ClassificationModel() Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods
Modifier and Type Method and Description abstract int numClasses() Number of classes (values which the label can take). double predict(FeaturesType features) Predict label for the given features. abstract Vector predictRaw(FeaturesType features) Raw prediction for each possible label. Param rawPredictionCol() Param for raw prediction (a.k.a. M setRawPredictionCol(String value) Dataset<Row> transform(Dataset<?> dataset) Transforms dataset by reading from featuresCol, and appending new columns as specified by parameters: - predicted labels as predictionCol of type Double - raw predictions (confidences) as rawPredictionCol of type Vector. Dataset<Row> transformImpl(Dataset<?> dataset) StructType transformSchema(StructType schema) Check transform validity and derive the output schema from the input schema. * ### Methods inherited from class org.apache.spark.ml.[PredictionModel](../../../../../org/apache/spark/ml/PredictionModel.html "class in org.apache.spark.ml") `[featuresCol](../../../../../org/apache/spark/ml/PredictionModel.html#featuresCol--), [labelCol](../../../../../org/apache/spark/ml/PredictionModel.html#labelCol--), [numFeatures](../../../../../org/apache/spark/ml/PredictionModel.html#numFeatures--), [predictionCol](../../../../../org/apache/spark/ml/PredictionModel.html#predictionCol--), [setFeaturesCol](../../../../../org/apache/spark/ml/PredictionModel.html#setFeaturesCol-java.lang.String-), [setPredictionCol](../../../../../org/apache/spark/ml/PredictionModel.html#setPredictionCol-java.lang.String-)` * ### Methods inherited from class org.apache.spark.ml.[Model](../../../../../org/apache/spark/ml/Model.html "class in org.apache.spark.ml") `[copy](../../../../../org/apache/spark/ml/Model.html#copy-org.apache.spark.ml.param.ParamMap-), [hasParent](../../../../../org/apache/spark/ml/Model.html#hasParent--), [parent](../../../../../org/apache/spark/ml/Model.html#parent--), [setParent](../../../../../org/apache/spark/ml/Model.html#setParent-org.apache.spark.ml.Estimator-)` * ### Methods inherited from class org.apache.spark.ml.[Transformer](../../../../../org/apache/spark/ml/Transformer.html "class in org.apache.spark.ml") `[transform](../../../../../org/apache/spark/ml/Transformer.html#transform-org.apache.spark.sql.Dataset-org.apache.spark.ml.param.ParamMap-), [transform](../../../../../org/apache/spark/ml/Transformer.html#transform-org.apache.spark.sql.Dataset-org.apache.spark.ml.param.ParamPair-org.apache.spark.ml.param.ParamPair...-), [transform](../../../../../org/apache/spark/ml/Transformer.html#transform-org.apache.spark.sql.Dataset-org.apache.spark.ml.param.ParamPair-scala.collection.Seq-)` * ### Methods inherited from class org.apache.spark.ml.[PipelineStage](../../../../../org/apache/spark/ml/PipelineStage.html "class in org.apache.spark.ml") `[params](../../../../../org/apache/spark/ml/PipelineStage.html#params--)` * ### Methods inherited from class Object `equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait` * ### Methods inherited from interface org.apache.spark.ml.classification.[ClassifierParams](../../../../../org/apache/spark/ml/classification/ClassifierParams.html "interface in org.apache.spark.ml.classification") `[validateAndTransformSchema](../../../../../org/apache/spark/ml/classification/ClassifierParams.html#validateAndTransformSchema-org.apache.spark.sql.types.StructType-boolean-org.apache.spark.sql.types.DataType-)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasLabelCol](../../../../../org/apache/spark/ml/param/shared/HasLabelCol.html "interface in org.apache.spark.ml.param.shared") `[getLabelCol](../../../../../org/apache/spark/ml/param/shared/HasLabelCol.html#getLabelCol--), [labelCol](../../../../../org/apache/spark/ml/param/shared/HasLabelCol.html#labelCol--)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasFeaturesCol](../../../../../org/apache/spark/ml/param/shared/HasFeaturesCol.html "interface in org.apache.spark.ml.param.shared") `[featuresCol](../../../../../org/apache/spark/ml/param/shared/HasFeaturesCol.html#featuresCol--), [getFeaturesCol](../../../../../org/apache/spark/ml/param/shared/HasFeaturesCol.html#getFeaturesCol--)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasPredictionCol](../../../../../org/apache/spark/ml/param/shared/HasPredictionCol.html "interface in org.apache.spark.ml.param.shared") `[getPredictionCol](../../../../../org/apache/spark/ml/param/shared/HasPredictionCol.html#getPredictionCol--), [predictionCol](../../../../../org/apache/spark/ml/param/shared/HasPredictionCol.html#predictionCol--)` * ### Methods inherited from interface org.apache.spark.ml.param.[Params](../../../../../org/apache/spark/ml/param/Params.html "interface in org.apache.spark.ml.param") `[clear](../../../../../org/apache/spark/ml/param/Params.html#clear-org.apache.spark.ml.param.Param-), [copy](../../../../../org/apache/spark/ml/param/Params.html#copy-org.apache.spark.ml.param.ParamMap-), [copyValues](../../../../../org/apache/spark/ml/param/Params.html#copyValues-T-org.apache.spark.ml.param.ParamMap-), [defaultCopy](../../../../../org/apache/spark/ml/param/Params.html#defaultCopy-org.apache.spark.ml.param.ParamMap-), [defaultParamMap](../../../../../org/apache/spark/ml/param/Params.html#defaultParamMap--), [explainParam](../../../../../org/apache/spark/ml/param/Params.html#explainParam-org.apache.spark.ml.param.Param-), [explainParams](../../../../../org/apache/spark/ml/param/Params.html#explainParams--), [extractParamMap](../../../../../org/apache/spark/ml/param/Params.html#extractParamMap--), [extractParamMap](../../../../../org/apache/spark/ml/param/Params.html#extractParamMap-org.apache.spark.ml.param.ParamMap-), [get](../../../../../org/apache/spark/ml/param/Params.html#get-org.apache.spark.ml.param.Param-), [getDefault](../../../../../org/apache/spark/ml/param/Params.html#getDefault-org.apache.spark.ml.param.Param-), [getOrDefault](../../../../../org/apache/spark/ml/param/Params.html#getOrDefault-org.apache.spark.ml.param.Param-), [getParam](../../../../../org/apache/spark/ml/param/Params.html#getParam-java.lang.String-), [hasDefault](../../../../../org/apache/spark/ml/param/Params.html#hasDefault-org.apache.spark.ml.param.Param-), [hasParam](../../../../../org/apache/spark/ml/param/Params.html#hasParam-java.lang.String-), [isDefined](../../../../../org/apache/spark/ml/param/Params.html#isDefined-org.apache.spark.ml.param.Param-), [isSet](../../../../../org/apache/spark/ml/param/Params.html#isSet-org.apache.spark.ml.param.Param-), [onParamChange](../../../../../org/apache/spark/ml/param/Params.html#onParamChange-org.apache.spark.ml.param.Param-), [paramMap](../../../../../org/apache/spark/ml/param/Params.html#paramMap--), [params](../../../../../org/apache/spark/ml/param/Params.html#params--), [set](../../../../../org/apache/spark/ml/param/Params.html#set-org.apache.spark.ml.param.Param-T-), [set](../../../../../org/apache/spark/ml/param/Params.html#set-org.apache.spark.ml.param.ParamPair-), [set](../../../../../org/apache/spark/ml/param/Params.html#set-java.lang.String-java.lang.Object-), [setDefault](../../../../../org/apache/spark/ml/param/Params.html#setDefault-org.apache.spark.ml.param.Param-T-), [setDefault](../../../../../org/apache/spark/ml/param/Params.html#setDefault-scala.collection.Seq-), [shouldOwn](../../../../../org/apache/spark/ml/param/Params.html#shouldOwn-org.apache.spark.ml.param.Param-)` * ### Methods inherited from interface org.apache.spark.ml.util.[Identifiable](../../../../../org/apache/spark/ml/util/Identifiable.html "interface in org.apache.spark.ml.util") `[toString](../../../../../org/apache/spark/ml/util/Identifiable.html#toString--), [uid](../../../../../org/apache/spark/ml/util/Identifiable.html#uid--)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasRawPredictionCol](../../../../../org/apache/spark/ml/param/shared/HasRawPredictionCol.html "interface in org.apache.spark.ml.param.shared") `[getRawPredictionCol](../../../../../org/apache/spark/ml/param/shared/HasRawPredictionCol.html#getRawPredictionCol--)` * ### Methods inherited from interface org.apache.spark.internal.Logging `$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize`
Constructor Detail
* #### ClassificationModel public ClassificationModel()
Method Detail
* #### numClasses public abstract int numClasses() Number of classes (values which the label can take). * #### predict public double predict([FeaturesType](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel") features) Predict label for the given features. This method is used to implement `transform()` and output `predictionCol`. This default implementation for classification predicts the index of the maximum value from `predictRaw()`. Specified by: `[predict](../../../../../org/apache/spark/ml/PredictionModel.html#predict-FeaturesType-)` in class `[PredictionModel](../../../../../org/apache/spark/ml/PredictionModel.html "class in org.apache.spark.ml")<[FeaturesType](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel"),[M](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel") extends [ClassificationModel](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "class in org.apache.spark.ml.classification")<[FeaturesType](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel"),[M](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel")>>` Parameters: `features` \- (undocumented) Returns: (undocumented) * #### predictRaw public abstract [Vector](../../../../../org/apache/spark/ml/linalg/Vector.html "interface in org.apache.spark.ml.linalg") predictRaw([FeaturesType](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel") features) Raw prediction for each possible label. The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives a measure of confidence in each possible label (where larger = more confident). This internal method is used to implement `transform()` and output `rawPredictionCol`. Parameters: `features` \- (undocumented) Returns: vector where element i is the raw prediction for label i. This raw prediction may be any real number, where a larger value indicates greater confidence for that label. * #### rawPredictionCol public final [Param](../../../../../org/apache/spark/ml/param/Param.html "class in org.apache.spark.ml.param")<String> rawPredictionCol() Param for raw prediction (a.k.a. confidence) column name. Specified by: `[rawPredictionCol](../../../../../org/apache/spark/ml/param/shared/HasRawPredictionCol.html#rawPredictionCol--)` in interface `[HasRawPredictionCol](../../../../../org/apache/spark/ml/param/shared/HasRawPredictionCol.html "interface in org.apache.spark.ml.param.shared")` Returns: (undocumented) * #### setRawPredictionCol public [M](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel") setRawPredictionCol(String value) * #### transform public [Dataset](../../../../../org/apache/spark/sql/Dataset.html "class in org.apache.spark.sql")<[Row](../../../../../org/apache/spark/sql/Row.html "interface in org.apache.spark.sql")> transform([Dataset](../../../../../org/apache/spark/sql/Dataset.html "class in org.apache.spark.sql")<?> dataset) Transforms dataset by reading from `featuresCol`, and appending new columns as specified by parameters: - predicted labels as `predictionCol` of type `Double` \- raw predictions (confidences) as `rawPredictionCol` of type `Vector`. Overrides: `[transform](../../../../../org/apache/spark/ml/PredictionModel.html#transform-org.apache.spark.sql.Dataset-)` in class `[PredictionModel](../../../../../org/apache/spark/ml/PredictionModel.html "class in org.apache.spark.ml")<[FeaturesType](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel"),[M](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel") extends [ClassificationModel](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "class in org.apache.spark.ml.classification")<[FeaturesType](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel"),[M](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel")>>` Parameters: `dataset` \- input dataset Returns: transformed dataset * #### transformImpl public final [Dataset](../../../../../org/apache/spark/sql/Dataset.html "class in org.apache.spark.sql")<[Row](../../../../../org/apache/spark/sql/Row.html "interface in org.apache.spark.sql")> transformImpl([Dataset](../../../../../org/apache/spark/sql/Dataset.html "class in org.apache.spark.sql")<?> dataset) * #### transformSchema public [StructType](../../../../../org/apache/spark/sql/types/StructType.html "class in org.apache.spark.sql.types") transformSchema([StructType](../../../../../org/apache/spark/sql/types/StructType.html "class in org.apache.spark.sql.types") schema) Check transform validity and derive the output schema from the input schema. We check validity for interactions between parameters during `transformSchema` and raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled by `Param.validate()`. Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks. Overrides: `[transformSchema](../../../../../org/apache/spark/ml/PredictionModel.html#transformSchema-org.apache.spark.sql.types.StructType-)` in class `[PredictionModel](../../../../../org/apache/spark/ml/PredictionModel.html "class in org.apache.spark.ml")<[FeaturesType](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel"),[M](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel") extends [ClassificationModel](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "class in org.apache.spark.ml.classification")<[FeaturesType](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel"),[M](../../../../../org/apache/spark/ml/classification/ClassificationModel.html "type parameter in ClassificationModel")>>` Parameters: `schema` \- (undocumented) Returns: (undocumented)