TrainValidationSplit (Spark 3.5.5 JavaDoc) (original) (raw)
Object
- org.apache.spark.ml.PipelineStage
- org.apache.spark.ml.Estimator<TrainValidationSplitModel>
- org.apache.spark.ml.tuning.TrainValidationSplit
- org.apache.spark.ml.Estimator<TrainValidationSplitModel>
All Implemented Interfaces:
java.io.Serializable, org.apache.spark.internal.Logging, Params, HasCollectSubModels, HasParallelism, HasSeed, TrainValidationSplitParams, ValidatorParams, Identifiable, MLWritable
public class TrainValidationSplit
extends Estimator<TrainValidationSplitModel>
implements TrainValidationSplitParams, HasParallelism, HasCollectSubModels, MLWritable, org.apache.spark.internal.Logging
Validation for hyper-parameter tuning. Randomly splits the input dataset into train and validation sets, and uses evaluation metric on the validation set to select the best model. Similar to CrossValidator, but only splits the set once.
See Also:
Serialized Form
Nested Class Summary
* ### Nested classes/interfaces inherited from interface org.apache.spark.internal.Logging `org.apache.spark.internal.Logging.SparkShellLoggingFilter`
Constructor Summary
Constructors
Constructor and Description TrainValidationSplit() TrainValidationSplit(String uid) Method Summary
All Methods Static Methods Instance Methods Concrete Methods
Modifier and Type Method and Description BooleanParam collectSubModels() Param for whether to collect a list of sub-models trained during tuning. TrainValidationSplit copy(ParamMap extra) Creates a copy of this instance with the same UID and some extra params. Param<Estimator<?>> estimator() param for the estimator to be validated Param<ParamMap[]> estimatorParamMaps() param for estimator param maps Param<Evaluator> evaluator() param for the evaluator used to select hyper-parameters that maximize the validated metric TrainValidationSplitModel fit(Dataset<?> dataset) Fits a model to the input data. static TrainValidationSplit load(String path) IntParam parallelism() The number of threads to use when running parallel algorithms. static MLReader<TrainValidationSplit> read() LongParam seed() Param for random seed. TrainValidationSplit setCollectSubModels(boolean value) Whether to collect submodels when fitting. TrainValidationSplit setEstimator(Estimator<?> value) TrainValidationSplit setEstimatorParamMaps(ParamMap[] value) TrainValidationSplit setEvaluator(Evaluator value) TrainValidationSplit setParallelism(int value) Set the maximum level of parallelism to evaluate models in parallel. TrainValidationSplit setSeed(long value) TrainValidationSplit setTrainRatio(double value) DoubleParam trainRatio() Param for ratio between train and validation data. StructType transformSchema(StructType schema) Check transform validity and derive the output schema from the input schema. String uid() An immutable unique ID for the object and its derivatives. MLWriter write() Returns an MLWriter instance for this ML instance. * ### Methods inherited from class org.apache.spark.ml.[Estimator](../../../../../org/apache/spark/ml/Estimator.html "class in org.apache.spark.ml") `[fit](../../../../../org/apache/spark/ml/Estimator.html#fit-org.apache.spark.sql.Dataset-org.apache.spark.ml.param.ParamMap-), [fit](../../../../../org/apache/spark/ml/Estimator.html#fit-org.apache.spark.sql.Dataset-org.apache.spark.ml.param.ParamPair-org.apache.spark.ml.param.ParamPair...-), [fit](../../../../../org/apache/spark/ml/Estimator.html#fit-org.apache.spark.sql.Dataset-org.apache.spark.ml.param.ParamPair-scala.collection.Seq-), [fit](../../../../../org/apache/spark/ml/Estimator.html#fit-org.apache.spark.sql.Dataset-scala.collection.Seq-)` * ### Methods inherited from class org.apache.spark.ml.[PipelineStage](../../../../../org/apache/spark/ml/PipelineStage.html "class in org.apache.spark.ml") `[params](../../../../../org/apache/spark/ml/PipelineStage.html#params--)` * ### Methods inherited from class Object `equals, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait` * ### Methods inherited from interface org.apache.spark.ml.tuning.[TrainValidationSplitParams](../../../../../org/apache/spark/ml/tuning/TrainValidationSplitParams.html "interface in org.apache.spark.ml.tuning") `[getTrainRatio](../../../../../org/apache/spark/ml/tuning/TrainValidationSplitParams.html#getTrainRatio--)` * ### Methods inherited from interface org.apache.spark.ml.tuning.[ValidatorParams](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html "interface in org.apache.spark.ml.tuning") `[getEstimator](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html#getEstimator--), [getEstimatorParamMaps](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html#getEstimatorParamMaps--), [getEvaluator](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html#getEvaluator--), [logTuningParams](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html#logTuningParams-org.apache.spark.ml.util.Instrumentation-), [transformSchemaImpl](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html#transformSchemaImpl-org.apache.spark.sql.types.StructType-)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasSeed](../../../../../org/apache/spark/ml/param/shared/HasSeed.html "interface in org.apache.spark.ml.param.shared") `[getSeed](../../../../../org/apache/spark/ml/param/shared/HasSeed.html#getSeed--)` * ### Methods inherited from interface org.apache.spark.ml.param.[Params](../../../../../org/apache/spark/ml/param/Params.html "interface in org.apache.spark.ml.param") `[clear](../../../../../org/apache/spark/ml/param/Params.html#clear-org.apache.spark.ml.param.Param-), [copyValues](../../../../../org/apache/spark/ml/param/Params.html#copyValues-T-org.apache.spark.ml.param.ParamMap-), [defaultCopy](../../../../../org/apache/spark/ml/param/Params.html#defaultCopy-org.apache.spark.ml.param.ParamMap-), [defaultParamMap](../../../../../org/apache/spark/ml/param/Params.html#defaultParamMap--), [explainParam](../../../../../org/apache/spark/ml/param/Params.html#explainParam-org.apache.spark.ml.param.Param-), [explainParams](../../../../../org/apache/spark/ml/param/Params.html#explainParams--), [extractParamMap](../../../../../org/apache/spark/ml/param/Params.html#extractParamMap--), [extractParamMap](../../../../../org/apache/spark/ml/param/Params.html#extractParamMap-org.apache.spark.ml.param.ParamMap-), [get](../../../../../org/apache/spark/ml/param/Params.html#get-org.apache.spark.ml.param.Param-), [getDefault](../../../../../org/apache/spark/ml/param/Params.html#getDefault-org.apache.spark.ml.param.Param-), [getOrDefault](../../../../../org/apache/spark/ml/param/Params.html#getOrDefault-org.apache.spark.ml.param.Param-), [getParam](../../../../../org/apache/spark/ml/param/Params.html#getParam-java.lang.String-), [hasDefault](../../../../../org/apache/spark/ml/param/Params.html#hasDefault-org.apache.spark.ml.param.Param-), [hasParam](../../../../../org/apache/spark/ml/param/Params.html#hasParam-java.lang.String-), [isDefined](../../../../../org/apache/spark/ml/param/Params.html#isDefined-org.apache.spark.ml.param.Param-), [isSet](../../../../../org/apache/spark/ml/param/Params.html#isSet-org.apache.spark.ml.param.Param-), [onParamChange](../../../../../org/apache/spark/ml/param/Params.html#onParamChange-org.apache.spark.ml.param.Param-), [paramMap](../../../../../org/apache/spark/ml/param/Params.html#paramMap--), [params](../../../../../org/apache/spark/ml/param/Params.html#params--), [set](../../../../../org/apache/spark/ml/param/Params.html#set-org.apache.spark.ml.param.Param-T-), [set](../../../../../org/apache/spark/ml/param/Params.html#set-org.apache.spark.ml.param.ParamPair-), [set](../../../../../org/apache/spark/ml/param/Params.html#set-java.lang.String-java.lang.Object-), [setDefault](../../../../../org/apache/spark/ml/param/Params.html#setDefault-org.apache.spark.ml.param.Param-T-), [setDefault](../../../../../org/apache/spark/ml/param/Params.html#setDefault-scala.collection.Seq-), [shouldOwn](../../../../../org/apache/spark/ml/param/Params.html#shouldOwn-org.apache.spark.ml.param.Param-)` * ### Methods inherited from interface org.apache.spark.ml.util.[Identifiable](../../../../../org/apache/spark/ml/util/Identifiable.html "interface in org.apache.spark.ml.util") `[toString](../../../../../org/apache/spark/ml/util/Identifiable.html#toString--)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasParallelism](../../../../../org/apache/spark/ml/param/shared/HasParallelism.html "interface in org.apache.spark.ml.param.shared") `[getExecutionContext](../../../../../org/apache/spark/ml/param/shared/HasParallelism.html#getExecutionContext--), [getParallelism](../../../../../org/apache/spark/ml/param/shared/HasParallelism.html#getParallelism--)` * ### Methods inherited from interface org.apache.spark.ml.param.shared.[HasCollectSubModels](../../../../../org/apache/spark/ml/param/shared/HasCollectSubModels.html "interface in org.apache.spark.ml.param.shared") `[getCollectSubModels](../../../../../org/apache/spark/ml/param/shared/HasCollectSubModels.html#getCollectSubModels--)` * ### Methods inherited from interface org.apache.spark.ml.util.[MLWritable](../../../../../org/apache/spark/ml/util/MLWritable.html "interface in org.apache.spark.ml.util") `[save](../../../../../org/apache/spark/ml/util/MLWritable.html#save-java.lang.String-)` * ### Methods inherited from interface org.apache.spark.internal.Logging `$init$, initializeForcefully, initializeLogIfNecessary, initializeLogIfNecessary, initializeLogIfNecessary$default$2, initLock, isTraceEnabled, log, logDebug, logDebug, logError, logError, logInfo, logInfo, logName, logTrace, logTrace, logWarning, logWarning, org$apache$spark$internal$Logging$$log__$eq, org$apache$spark$internal$Logging$$log_, uninitialize`
Constructor Detail
* #### TrainValidationSplit public TrainValidationSplit(String uid) * #### TrainValidationSplit public TrainValidationSplit()
Method Detail
* #### read public static [MLReader](../../../../../org/apache/spark/ml/util/MLReader.html "class in org.apache.spark.ml.util")<[TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning")> read() * #### load public static [TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning") load(String path) * #### collectSubModels public final [BooleanParam](../../../../../org/apache/spark/ml/param/BooleanParam.html "class in org.apache.spark.ml.param") collectSubModels() Param for whether to collect a list of sub-models trained during tuning. If set to false, then only the single best sub-model will be available after fitting. If set to true, then all sub-models will be available. Warning: For large models, collecting all sub-models can cause OOMs on the Spark driver. Specified by: `[collectSubModels](../../../../../org/apache/spark/ml/param/shared/HasCollectSubModels.html#collectSubModels--)` in interface `[HasCollectSubModels](../../../../../org/apache/spark/ml/param/shared/HasCollectSubModels.html "interface in org.apache.spark.ml.param.shared")` Returns: (undocumented) * #### parallelism public [IntParam](../../../../../org/apache/spark/ml/param/IntParam.html "class in org.apache.spark.ml.param") parallelism() The number of threads to use when running parallel algorithms. Default is 1 for serial execution Specified by: `[parallelism](../../../../../org/apache/spark/ml/param/shared/HasParallelism.html#parallelism--)` in interface `[HasParallelism](../../../../../org/apache/spark/ml/param/shared/HasParallelism.html "interface in org.apache.spark.ml.param.shared")` Returns: (undocumented) * #### trainRatio public [DoubleParam](../../../../../org/apache/spark/ml/param/DoubleParam.html "class in org.apache.spark.ml.param") trainRatio() Param for ratio between train and validation data. Must be between 0 and 1\. Default: 0.75 Specified by: `[trainRatio](../../../../../org/apache/spark/ml/tuning/TrainValidationSplitParams.html#trainRatio--)` in interface `[TrainValidationSplitParams](../../../../../org/apache/spark/ml/tuning/TrainValidationSplitParams.html "interface in org.apache.spark.ml.tuning")` Returns: (undocumented) * #### estimator public [Param](../../../../../org/apache/spark/ml/param/Param.html "class in org.apache.spark.ml.param")<[Estimator](../../../../../org/apache/spark/ml/Estimator.html "class in org.apache.spark.ml")<?>> estimator() param for the estimator to be validated Specified by: `[estimator](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html#estimator--)` in interface `[ValidatorParams](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html "interface in org.apache.spark.ml.tuning")` Returns: (undocumented) * #### estimatorParamMaps public [Param](../../../../../org/apache/spark/ml/param/Param.html "class in org.apache.spark.ml.param")<[ParamMap](../../../../../org/apache/spark/ml/param/ParamMap.html "class in org.apache.spark.ml.param")[]> estimatorParamMaps() param for estimator param maps Specified by: `[estimatorParamMaps](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html#estimatorParamMaps--)` in interface `[ValidatorParams](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html "interface in org.apache.spark.ml.tuning")` Returns: (undocumented) * #### evaluator public [Param](../../../../../org/apache/spark/ml/param/Param.html "class in org.apache.spark.ml.param")<[Evaluator](../../../../../org/apache/spark/ml/evaluation/Evaluator.html "class in org.apache.spark.ml.evaluation")> evaluator() param for the evaluator used to select hyper-parameters that maximize the validated metric Specified by: `[evaluator](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html#evaluator--)` in interface `[ValidatorParams](../../../../../org/apache/spark/ml/tuning/ValidatorParams.html "interface in org.apache.spark.ml.tuning")` Returns: (undocumented) * #### seed public final [LongParam](../../../../../org/apache/spark/ml/param/LongParam.html "class in org.apache.spark.ml.param") seed() Description copied from interface: `[HasSeed](../../../../../org/apache/spark/ml/param/shared/HasSeed.html#seed--)` Param for random seed. Specified by: `[seed](../../../../../org/apache/spark/ml/param/shared/HasSeed.html#seed--)` in interface `[HasSeed](../../../../../org/apache/spark/ml/param/shared/HasSeed.html "interface in org.apache.spark.ml.param.shared")` Returns: (undocumented) * #### uid public String uid() An immutable unique ID for the object and its derivatives. Specified by: `[uid](../../../../../org/apache/spark/ml/util/Identifiable.html#uid--)` in interface `[Identifiable](../../../../../org/apache/spark/ml/util/Identifiable.html "interface in org.apache.spark.ml.util")` Returns: (undocumented) * #### setEstimator public [TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning") setEstimator([Estimator](../../../../../org/apache/spark/ml/Estimator.html "class in org.apache.spark.ml")<?> value) * #### setEstimatorParamMaps public [TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning") setEstimatorParamMaps([ParamMap](../../../../../org/apache/spark/ml/param/ParamMap.html "class in org.apache.spark.ml.param")[] value) * #### setEvaluator public [TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning") setEvaluator([Evaluator](../../../../../org/apache/spark/ml/evaluation/Evaluator.html "class in org.apache.spark.ml.evaluation") value) * #### setTrainRatio public [TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning") setTrainRatio(double value) * #### setSeed public [TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning") setSeed(long value) * #### setParallelism public [TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning") setParallelism(int value) Set the maximum level of parallelism to evaluate models in parallel. Default is 1 for serial evaluation Parameters: `value` \- (undocumented) Returns: (undocumented) * #### setCollectSubModels public [TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning") setCollectSubModels(boolean value) Whether to collect submodels when fitting. If set, we can get submodels from the returned model. Note: If set this param, when you save the returned model, you can set an option "persistSubModels" to be "true" before saving, in order to save these submodels. You can check documents of[TrainValidationSplitModel.TrainValidationSplitModelWriter](../../../../../org/apache/spark/ml/tuning/TrainValidationSplitModel.TrainValidationSplitModelWriter.html "class in org.apache.spark.ml.tuning") for more information. Parameters: `value` \- (undocumented) Returns: (undocumented) * #### fit public [TrainValidationSplitModel](../../../../../org/apache/spark/ml/tuning/TrainValidationSplitModel.html "class in org.apache.spark.ml.tuning") fit([Dataset](../../../../../org/apache/spark/sql/Dataset.html "class in org.apache.spark.sql")<?> dataset) Description copied from class: `[Estimator](../../../../../org/apache/spark/ml/Estimator.html#fit-org.apache.spark.sql.Dataset-)` Fits a model to the input data. Specified by: `[fit](../../../../../org/apache/spark/ml/Estimator.html#fit-org.apache.spark.sql.Dataset-)` in class `[Estimator](../../../../../org/apache/spark/ml/Estimator.html "class in org.apache.spark.ml")<[TrainValidationSplitModel](../../../../../org/apache/spark/ml/tuning/TrainValidationSplitModel.html "class in org.apache.spark.ml.tuning")>` Parameters: `dataset` \- (undocumented) Returns: (undocumented) * #### transformSchema public [StructType](../../../../../org/apache/spark/sql/types/StructType.html "class in org.apache.spark.sql.types") transformSchema([StructType](../../../../../org/apache/spark/sql/types/StructType.html "class in org.apache.spark.sql.types") schema) Check transform validity and derive the output schema from the input schema. We check validity for interactions between parameters during `transformSchema` and raise an exception if any parameter value is invalid. Parameter value checks which do not depend on other parameters are handled by `Param.validate()`. Typical implementation should first conduct verification on schema change and parameter validity, including complex parameter interaction checks. Specified by: `[transformSchema](../../../../../org/apache/spark/ml/PipelineStage.html#transformSchema-org.apache.spark.sql.types.StructType-)` in class `[PipelineStage](../../../../../org/apache/spark/ml/PipelineStage.html "class in org.apache.spark.ml")` Parameters: `schema` \- (undocumented) Returns: (undocumented) * #### copy public [TrainValidationSplit](../../../../../org/apache/spark/ml/tuning/TrainValidationSplit.html "class in org.apache.spark.ml.tuning") copy([ParamMap](../../../../../org/apache/spark/ml/param/ParamMap.html "class in org.apache.spark.ml.param") extra) Description copied from interface: `[Params](../../../../../org/apache/spark/ml/param/Params.html#copy-org.apache.spark.ml.param.ParamMap-)` Creates a copy of this instance with the same UID and some extra params. Subclasses should implement this method and set the return type properly. See `defaultCopy()`. Specified by: `[copy](../../../../../org/apache/spark/ml/param/Params.html#copy-org.apache.spark.ml.param.ParamMap-)` in interface `[Params](../../../../../org/apache/spark/ml/param/Params.html "interface in org.apache.spark.ml.param")` Specified by: `[copy](../../../../../org/apache/spark/ml/Estimator.html#copy-org.apache.spark.ml.param.ParamMap-)` in class `[Estimator](../../../../../org/apache/spark/ml/Estimator.html "class in org.apache.spark.ml")<[TrainValidationSplitModel](../../../../../org/apache/spark/ml/tuning/TrainValidationSplitModel.html "class in org.apache.spark.ml.tuning")>` Parameters: `extra` \- (undocumented) Returns: (undocumented) * #### write public [MLWriter](../../../../../org/apache/spark/ml/util/MLWriter.html "class in org.apache.spark.ml.util") write() Description copied from interface: `[MLWritable](../../../../../org/apache/spark/ml/util/MLWritable.html#write--)` Returns an `MLWriter` instance for this ML instance. Specified by: `[write](../../../../../org/apache/spark/ml/util/MLWritable.html#write--)` in interface `[MLWritable](../../../../../org/apache/spark/ml/util/MLWritable.html "interface in org.apache.spark.ml.util")` Returns: (undocumented)