check_is_fitted (original) (raw)
sklearn.utils.validation.check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=)[source]#
Perform is_fitted validation for estimator.
Checks if the estimator is fitted by verifying the presence of fitted attributes (ending with a trailing underscore) and otherwise raises a NotFittedError with the given message.
If an estimator does not set any attributes with a trailing underscore, it can define a __sklearn_is_fitted__
method returning a boolean to specify if the estimator is fitted or not. See__sklearn_is_fitted__ as Developer APIfor an example on how to use the API.
If no attributes
are passed, this fuction will pass if an estimator is stateless. An estimator can indicate it’s stateless by setting the requires_fit
tag. SeeEstimator Tags for more information. Note that the requires_fit
tag is ignored if attributes
are passed.
Parameters:
estimatorestimator instance
Estimator instance for which the check is performed.
attributesstr, list or tuple of str, default=None
Attribute name(s) given as string or a list/tuple of strings Eg.: ["coef_", "estimator_", ...], "coef_"
If None
, estimator
is considered fitted if there exist an attribute that ends with a underscore and does not start with double underscore.
msgstr, default=None
The default error message is, “This %(name)s instance is not fitted yet. Call ‘fit’ with appropriate arguments before using this estimator.”
For custom messages if “%(name)s” is present in the message string, it is substituted for the estimator name.
Eg. : “Estimator, %(name)s, must be fitted before sparsifying”.
all_or_anycallable, {all, any}, default=all
Specify whether all or any of the given attributes must exist.
Raises:
TypeError
If the estimator is a class or not an estimator instance
NotFittedError
If the attributes are not found.
Examples
from sklearn.linear_model import LogisticRegression from sklearn.utils.validation import check_is_fitted from sklearn.exceptions import NotFittedError lr = LogisticRegression() try: ... check_is_fitted(lr) ... except NotFittedError as exc: ... print(f"Model is not fitted yet.") Model is not fitted yet. lr.fit([[1, 2], [1, 3]], [1, 0]) LogisticRegression() check_is_fitted(lr)