11.1. Array API support (experimental) (original) (raw)
The Array API specification defines a standard API for all array manipulation libraries with a NumPy-like API. Scikit-learn’s Array API support requiresarray-api-compat to be installed, and the environment variable SCIPY_ARRAY_API
must be set to 1
before importingscipy
and scikit-learn
:
Please note that this environment variable is intended for temporary use. For more details, refer to SciPy’s Array API documentation.
Some scikit-learn estimators that primarily rely on NumPy (as opposed to using Cython) to implement the algorithmic logic of their fit
, predict
ortransform
methods can be configured to accept any Array API compatible input datastructures and automatically dispatch operations to the underlying namespace instead of relying on NumPy.
At this stage, this support is considered experimental and must be enabled explicitly as explained in the following.
Note
Currently, only array-api-strict
, cupy
, and PyTorch
are known to work with scikit-learn’s estimators.
The following video provides an overview of the standard’s design principles and how it facilitates interoperability between array libraries:
- Scikit-learn on GPUs with Array APIby Thomas Fan at PyData NYC 2023.
11.1.1. Example usage#
Here is an example code snippet to demonstrate how to use CuPy to runLinearDiscriminantAnalysis on a GPU:
from sklearn.datasets import make_classification from sklearn import config_context from sklearn.discriminant_analysis import LinearDiscriminantAnalysis import cupy
X_np, y_np = make_classification(random_state=0) X_cu = cupy.asarray(X_np) y_cu = cupy.asarray(y_np) X_cu.device <CUDA Device 0>
with config_context(array_api_dispatch=True): ... lda = LinearDiscriminantAnalysis() ... X_trans = lda.fit_transform(X_cu, y_cu) X_trans.device <CUDA Device 0>
After the model is trained, fitted attributes that are arrays will also be from the same Array API namespace as the training data. For example, if CuPy’s Array API namespace was used for training, then fitted attributes will be on the GPU. We provide a experimental _estimator_with_converted_arrays
utility that transfers an estimator attributes from Array API to a ndarray:
from sklearn.utils._array_api import _estimator_with_converted_arrays cupy_to_ndarray = lambda array : array.get() lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray) X_trans = lda_np.transform(X_np) type(X_trans) <class 'numpy.ndarray'>
11.1.1.1. PyTorch Support#
PyTorch Tensors are supported by setting array_api_dispatch=True
and passing in the tensors directly:
import torch X_torch = torch.asarray(X_np, device="cuda", dtype=torch.float32) y_torch = torch.asarray(y_np, device="cuda", dtype=torch.float32)
with config_context(array_api_dispatch=True): ... lda = LinearDiscriminantAnalysis() ... X_trans = lda.fit_transform(X_torch, y_torch) type(X_trans) <class 'torch.Tensor'> X_trans.device.type 'cuda'
11.1.2. Support for Array API
-compatible inputs#
Estimators and other tools in scikit-learn that support Array API compatible inputs.
11.1.2.1. Estimators#
- decomposition.PCA (with
svd_solver="full"
,svd_solver="randomized"
andpower_iteration_normalizer="QR"
) - linear_model.Ridge (with
solver="svd"
) - discriminant_analysis.LinearDiscriminantAnalysis (with
solver="svd"
) - preprocessing.KernelCenterer
- preprocessing.LabelEncoder
- preprocessing.MaxAbsScaler
- preprocessing.MinMaxScaler
- preprocessing.Normalizer
11.1.2.2. Meta-estimators#
Meta-estimators that accept Array API inputs conditioned on the fact that the base estimator also does:
- model_selection.GridSearchCV
- model_selection.RandomizedSearchCV
- model_selection.HalvingGridSearchCV
- model_selection.HalvingRandomSearchCV
11.1.2.3. Metrics#
sklearn.metrics.cluster.entropy
- sklearn.metrics.accuracy_score
- sklearn.metrics.d2_tweedie_score
- sklearn.metrics.f1_score
- sklearn.metrics.max_error
- sklearn.metrics.mean_absolute_error
- sklearn.metrics.mean_absolute_percentage_error
- sklearn.metrics.mean_gamma_deviance
- sklearn.metrics.mean_poisson_deviance (requires enabling array API support for SciPy)
- sklearn.metrics.mean_squared_error
- sklearn.metrics.mean_squared_log_error
- sklearn.metrics.mean_tweedie_deviance
- sklearn.metrics.multilabel_confusion_matrix
- sklearn.metrics.pairwise.additive_chi2_kernel
- sklearn.metrics.pairwise.chi2_kernel
- sklearn.metrics.pairwise.cosine_similarity
- sklearn.metrics.pairwise.cosine_distances
- sklearn.metrics.pairwise.euclidean_distances (see Note on device support for float64)
- sklearn.metrics.pairwise.linear_kernel
- sklearn.metrics.pairwise.paired_cosine_distances
- sklearn.metrics.pairwise.paired_euclidean_distances
- sklearn.metrics.pairwise.polynomial_kernel
- sklearn.metrics.pairwise.rbf_kernel (see Note on device support for float64)
- sklearn.metrics.pairwise.sigmoid_kernel
- sklearn.metrics.precision_recall_fscore_support
- sklearn.metrics.r2_score
- sklearn.metrics.root_mean_squared_error
- sklearn.metrics.root_mean_squared_log_error
- sklearn.metrics.zero_one_loss
11.1.2.4. Tools#
Coverage is expected to grow over time. Please follow the dedicated meta-issue on GitHub to track progress.
11.1.2.5. Type of return values and fitted attributes#
When calling functions or methods with Array API compatible inputs, the convention is to return array values of the same array container type and device as the input data.
Similarly, when an estimator is fitted with Array API compatible inputs, the fitted attributes will be arrays from the same library as the input and stored on the same device. The predict
and transform
method subsequently expect inputs from the same array library and device as the data passed to the fit
method.
Note however that scoring functions that return scalar values return Python scalars (typically a float
instance) instead of an array scalar value.
11.1.3. Common estimator checks#
Add the array_api_support
tag to an estimator’s set of tags to indicate that it supports the Array API. This will enable dedicated checks as part of the common tests to verify that the estimators result’s are the same when using vanilla NumPy and Array API inputs.
To run these checks you need to installarray_api_compat in your test environment. To run the full set of checks you need to install bothPyTorch and CuPy and have a GPU. Checks that can not be executed or have missing dependencies will be automatically skipped. Therefore it’s important to run the tests with the-v
flag to see which checks are skipped:
pip install array-api-compat # and other libraries as needed pytest -k "array_api" -v
11.1.3.1. Note on MPS device support#
On macOS, PyTorch can use the Metal Performance Shaders (MPS) to access hardware accelerators (e.g. the internal GPU component of the M1 or M2 chips). However, the MPS device support for PyTorch is incomplete at the time of writing. See the following github issue for more details:
To enable the MPS support in PyTorch, set the environment variablePYTORCH_ENABLE_MPS_FALLBACK=1
before running the tests:
PYTORCH_ENABLE_MPS_FALLBACK=1 pytest -k "array_api" -v
At the time of writing all scikit-learn tests should pass, however, the computational speed is not necessarily better than with the CPU device.
11.1.3.2. Note on device support for float64
#
Certain operations within scikit-learn will automatically perform operations on floating-point values with float64
precision to prevent overflows and ensure correctness (e.g., metrics.pairwise.euclidean_distances). However, certain combinations of array namespaces and devices, such as PyTorch on MPS
(see Note on MPS device support) do not support the float64
data type. In these cases, scikit-learn will revert to using the float32
data type instead. This can result in different behavior (typically numerically unstable results) compared to not using array API dispatching or using a device with float64
support.