PredictionModel (Spark 3.5.5 JavaDoc) (original) (raw)
Object
- org.apache.spark.ml.PipelineStage
- org.apache.spark.ml.Transformer
- org.apache.spark.ml.Model
* * org.apache.spark.ml.PredictionModel<FeaturesType,M>
- org.apache.spark.ml.Model
- org.apache.spark.ml.Transformer
Type Parameters:
FeaturesType
- Type of features. E.g.,VectorUDT
for vector features.M
- Specialization of PredictionModel. If you subclass this type, use this type parameter to specify the concrete type for the corresponding model.
All Implemented Interfaces:
java.io.Serializable, org.apache.spark.internal.Logging, Params, HasFeaturesCol, HasLabelCol, HasPredictionCol, PredictorParams, Identifiable
Direct Known Subclasses:
ClassificationModel, RegressionModel
public abstract class PredictionModel<FeaturesType,M extends PredictionModel<FeaturesType,M>>
extends Model
implements PredictorParams
Abstraction for a model for prediction tasks (regression and classification).
See Also:
Serialized Form
Nested Class Summary
* ### Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging `org.apache.spark.internal.Logging.SparkShellLoggingFilter`
Constructor Summary
Constructors
Constructor and Description PredictionModel() Method Summary
All Methods Instance Methods Abstract Methods Concrete Methods
Modifier and Type Method and Description Param featuresCol() Param for features column name. Param labelCol() Param for label column name. int numFeatures() Returns the number of features the model was trained on. abstract double predict(FeaturesType features) Predict label for the given features. Param predictionCol() Param for prediction column name. M setFeaturesCol(String value) M setPredictionCol(String value) Dataset<Row> transform(Dataset<?> dataset) Transforms dataset by reading from featuresCol, calling predict, and storing the predictions as a new column predictionCol. StructType transformSchema(StructType schema) Check transform validity and derive the output schema from the input schema. * ### Methods inherited from class org.apache.spark.ml.[Model](../../../../org/apache/spark/ml/Model.html "class in org.apache.spark.ml") `[copy](../../../../org/apache/spark/ml/Model.html#copy-org.apache.spark.ml.param.ParamMap-), [hasParent](../../../../org/apache/spark/ml/Model.html#hasParent--), [parent](../../../../org/apache/spark/ml/Model.html#parent--), [setParent](../../../../org/apache/spark/ml/Model.html#setParent-org.apache.spark.ml.Estimator-)` * ### Methods inherited from class org.apache.spark.ml.[Transformer](../../../../org/apache/spark/ml/Transformer.html "class in org.apache.spark.ml") `[transform](../../../../org/apache/spark/ml/Transformer.html#transform-org.apache.spark.sql.Dataset-org.apache.spark.ml.param.ParamMap-), [transform](../../../../org/apache/spark/ml/Transformer.html#transform-org.apache.spark.sql.Dataset-org.apache.spark.ml.param.ParamPair-org.apache.spark.ml.param.ParamPair...-), [transform](../../../../org/apache/spark/ml/Transformer.html#transform-org.apache.spark.sql.Dataset-org.apache.spark.ml.param.ParamPair-scala.collection.Seq-)` * ### Methods inherited from class org.apache.spark.ml.[PipelineStage](../../../../org/apache/spark/ml/PipelineStage.html "class in org.apache.spark.ml") `[params](../../../../org/apache/spark/ml/PipelineStage.html#params--)` * ### Methods inherited from class Object `equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait` * ### Methods inherited from interface org.apache.spark.ml.[PredictorParams](../../../../org/apache/spark/ml/PredictorParams.html "interface in org.apache.spark.ml") `[validateAndTransformSchema](../../../../org/apache/spark/ml/PredictorParams.html#validateAndTransformSchema-org.apache.spark.sql.types.StructType-boolean-org.apache.spark.sql.types.DataType-)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasLabelCol](../../../../org/apache/spark/ml/param/shared/HasLabelCol.html "interface in org.apache.spark.ml.param.shared") `[getLabelCol](../../../../org/apache/spark/ml/param/shared/HasLabelCol.html#getLabelCol--)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasFeaturesCol](../../../../org/apache/spark/ml/param/shared/HasFeaturesCol.html "interface in org.apache.spark.ml.param.shared") `[getFeaturesCol](../../../../org/apache/spark/ml/param/shared/HasFeaturesCol.html#getFeaturesCol--)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasPredictionCol](../../../../org/apache/spark/ml/param/shared/HasPredictionCol.html "interface in org.apache.spark.ml.param.shared") `[getPredictionCol](../../../../org/apache/spark/ml/param/shared/HasPredictionCol.html#getPredictionCol--)` * ### Methods inherited from interface org.apache.spark.ml.param.[Params](../../../../org/apache/spark/ml/param/Params.html "interface in org.apache.spark.ml.param") `[clear](../../../../org/apache/spark/ml/param/Params.html#clear-org.apache.spark.ml.param.Param-), [copy](../../../../org/apache/spark/ml/param/Params.html#copy-org.apache.spark.ml.param.ParamMap-), [copyValues](../../../../org/apache/spark/ml/param/Params.html#copyValues-T-org.apache.spark.ml.param.ParamMap-), [defaultCopy](../../../../org/apache/spark/ml/param/Params.html#defaultCopy-org.apache.spark.ml.param.ParamMap-), [defaultParamMap](../../../../org/apache/spark/ml/param/Params.html#defaultParamMap--), [explainParam](../../../../org/apache/spark/ml/param/Params.html#explainParam-org.apache.spark.ml.param.Param-), [explainParams](../../../../org/apache/spark/ml/param/Params.html#explainParams--), [extractParamMap](../../../../org/apache/spark/ml/param/Params.html#extractParamMap--), [extractParamMap](../../../../org/apache/spark/ml/param/Params.html#extractParamMap-org.apache.spark.ml.param.ParamMap-), [get](../../../../org/apache/spark/ml/param/Params.html#get-org.apache.spark.ml.param.Param-), [getDefault](../../../../org/apache/spark/ml/param/Params.html#getDefault-org.apache.spark.ml.param.Param-), [getOrDefault](../../../../org/apache/spark/ml/param/Params.html#getOrDefault-org.apache.spark.ml.param.Param-), [getParam](../../../../org/apache/spark/ml/param/Params.html#getParam-java.lang.String-), [hasDefault](../../../../org/apache/spark/ml/param/Params.html#hasDefault-org.apache.spark.ml.param.Param-), [hasParam](../../../../org/apache/spark/ml/param/Params.html#hasParam-java.lang.String-), [isDefined](../../../../org/apache/spark/ml/param/Params.html#isDefined-org.apache.spark.ml.param.Param-), [isSet](../../../../org/apache/spark/ml/param/Params.html#isSet-org.apache.spark.ml.param.Param-), [onParamChange](../../../../org/apache/spark/ml/param/Params.html#onParamChange-org.apache.spark.ml.param.Param-), [paramMap](../../../../org/apache/spark/ml/param/Params.html#paramMap--), [params](../../../../org/apache/spark/ml/param/Params.html#params--), [set](../../../../org/apache/spark/ml/param/Params.html#set-org.apache.spark.ml.param.Param-T-), [set](../../../../org/apache/spark/ml/param/Params.html#set-org.apache.spark.ml.param.ParamPair-), [set](../../../../org/apache/spark/ml/param/Params.html#set-java.lang.String-java.lang.Object-), [setDefault](../../../../org/apache/spark/ml/param/Params.html#setDefault-org.apache.spark.ml.param.Param-T-), [setDefault](../../../../org/apache/spark/ml/param/Params.html#setDefault-scala.collection.Seq-), [shouldOwn](../../../../org/apache/spark/ml/param/Params.html#shouldOwn-org.apache.spark.ml.param.Param-)` * ### Methods inherited from interface org.apache.spark.ml.util.[Identifiable](../../../../org/apache/spark/ml/util/Identifiable.html "interface in org.apache.spark.ml.util") `[toString](../../../../org/apache/spark/ml/util/Identifiable.html#toString--), [uid](../../../../org/apache/spark/ml/util/Identifiable.html#uid--)` * ### Methods inherited from interface org.apache.spark.internal.Logging `$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize`
Constructor Detail
* #### PredictionModel public PredictionModel()
Method Detail
* #### featuresCol public final [Param](../../../../org/apache/spark/ml/param/Param.html "class in org.apache.spark.ml.param")<String> featuresCol() Param for features column name. Specified by: `[featuresCol](../../../../org/apache/spark/ml/param/shared/HasFeaturesCol.html#featuresCol--)` in interface `[HasFeaturesCol](../../../../org/apache/spark/ml/param/shared/HasFeaturesCol.html "interface in org.apache.spark.ml.param.shared")` Returns: (undocumented) * #### labelCol public final [Param](../../../../org/apache/spark/ml/param/Param.html "class in org.apache.spark.ml.param")<String> labelCol() Description copied from interface: `[HasLabelCol](../../../../org/apache/spark/ml/param/shared/HasLabelCol.html#labelCol--)` Param for label column name. Specified by: `[labelCol](../../../../org/apache/spark/ml/param/shared/HasLabelCol.html#labelCol--)` in interface `[HasLabelCol](../../../../org/apache/spark/ml/param/shared/HasLabelCol.html "interface in org.apache.spark.ml.param.shared")` Returns: (undocumented) * #### numFeatures public int numFeatures() Returns the number of features the model was trained on. If unknown, returns -1 * #### predict public abstract double predict([FeaturesType](../../../../org/apache/spark/ml/PredictionModel.html "type parameter in PredictionModel") features) Predict label for the given features. This method is used to implement `transform()` and output `predictionCol`. Parameters: `features` \- (undocumented) Returns: (undocumented) * #### predictionCol public final [Param](../../../../org/apache/spark/ml/param/Param.html "class in org.apache.spark.ml.param")<String> predictionCol() Param for prediction column name. Specified by: `[predictionCol](../../../../org/apache/spark/ml/param/shared/HasPredictionCol.html#predictionCol--)` in interface `[HasPredictionCol](../../../../org/apache/spark/ml/param/shared/HasPredictionCol.html "interface in org.apache.spark.ml.param.shared")` Returns: (undocumented) * #### setFeaturesCol public [M](../../../../org/apache/spark/ml/PredictionModel.html "type parameter in PredictionModel") setFeaturesCol(String value) * #### setPredictionCol public [M](../../../../org/apache/spark/ml/PredictionModel.html "type parameter in PredictionModel") setPredictionCol(String value) * #### transform public [Dataset](../../../../org/apache/spark/sql/Dataset.html "class in org.apache.spark.sql")<[Row](../../../../org/apache/spark/sql/Row.html "interface in org.apache.spark.sql")> transform([Dataset](../../../../org/apache/spark/sql/Dataset.html "class in org.apache.spark.sql")<?> dataset) Transforms dataset by reading from `featuresCol`, calling `predict`, and storing the predictions as a new column `predictionCol`. Specified by: `[transform](../../../../org/apache/spark/ml/Transformer.html#transform-org.apache.spark.sql.Dataset-)` in class `[Transformer](../../../../org/apache/spark/ml/Transformer.html "class in org.apache.spark.ml")` Parameters: `dataset` \- input dataset Returns: transformed dataset with `predictionCol` of type `Double` * #### transformSchema public [StructType](../../../../org/apache/spark/sql/types/StructType.html "class in org.apache.spark.sql.types") transformSchema([StructType](../../../../org/apache/spark/sql/types/StructType.html "class in org.apache.spark.sql.types") schema) Check transform validity and derive the output schema from the input schema. We check validity for interactions between parameters during `transformSchema` and raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled by `Param.validate()`. Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks. Specified by: `[transformSchema](../../../../org/apache/spark/ml/PipelineStage.html#transformSchema-org.apache.spark.sql.types.StructType-)` in class `[PipelineStage](../../../../org/apache/spark/ml/PipelineStage.html "class in org.apache.spark.ml")` Parameters: `schema` \- (undocumented) Returns: (undocumented)