Path for Adopting the Array API spec · Issue #22352 · scikit-learn/scikit-learn (original) (raw)
I have been experimenting with adopting the Array API spec into scikit-learn. The Array API is one way for scikit-learn to run on other hardware such as GPUs.
I have some POCs on my fork for LinearDiscriminantAnalysis and GaussianMixture. Overall, there is runtime performance benefit when running on CuPy compared to NumPy, as shown in notesbooks for LDA (14x improvement) and GMM (7x improvement).
Official acceptance of the Array API in numpy is tracked as NEP 47.
Proposed User API
Here is the proposed API for dispatching. We require the array to adopt the Array API standard and we have a configuration option to turn on Array API dispatching:
Create Array API arrays following spec
import cupy.array_api as xp X_cu = xp.asarray(X_np) y_cu = xp.asarray(y_np)
Configure scikit-learn to dispatch
from sklearn import set_config set_config(array_api_dispatch=True)
Dispatches using array_api
lda_cu = LinearDiscriminantAnalysis() lda_cu.fit(X_cu, y_cu)
This way the user can decide between the old behavior of potentially casting to NumPy and the new behavior of using the array api if available.
Developer Experience
The Array API spec and the NumPy API overlaps in many cases, but there is API we use in NumPy and not in Array API. There are a few ways to bridge this gap while trying to keep a maintainable code base:
- Wrap the Array-API namespace object to make it look "more like NumPy"
- Wrap the NumPy module to make it look "more like ArrayAPI"
- Helper functions everyone
1 and 2 are not mutually exclusive. To demonstrate these options, I'll do a case study on unique
. The Array API spec does not define a unique
function, but a unique_values
instead.
Wrap the Array-API namespace object to make it look "more like NumPy"
def check_y(y): np, _ = get_namespace(y) # Returns _ArrayAPIWrapper or NumPy classes = np.unique(y)
class _ArrayAPIWrapper: def unique(self, x): return self._array_namespace.unique_values(x)
Existing scikit-learn code does not need to change as much because the Array API "looks like NumPy"
Make NumPy object "look more like Array-API"
def check_y(y): xp, _ = get_namespace(y) # Returns Array API namespace or _NumPyApiWrapper classes = xp.unique_values(y)
class _NumPyApiWrapper: def unique_values(self, x): return np.unique(x)
We need to update scikit-learn to use these new functions from the Array API spec.
Helper functions everyone
def check_y(y): classes = _unique_values(y)
def _unique_values(x): xp, is_array_api = get_namespace(x) if is_array_api: return xp.unique_values(x) return np.unique(x)
We need to update scikit-learn to use these helper functions when API diverges. Some notable functions that needs some wrapper or helper functions include concat
, astype
, asarray
, unique
, errstate
, may_share_memory
, etc.
For my POCs, I went with a mostly option 1 where I wrapped Array API to look like NumPy. (I did wrap NumPy once to get np.dtype
, which is the same as array.astype
).
CC @scikit-learn/core-devs
Other API considerations
- Type promotion is more strict with Array API
import numpy.array_api as xp X = xp.asarray([1]) y = xp.asarray([1.0])
fails
X + y
- No method chaining. (Array API arrays do not have methods on them)
(X.mean(axis=1) > 1.0).any()
becomes
xp.any(xp.mean(X, axis=1) > 1.0)
- Array API has no concept of order
- Array API does not have integer indexing with
__getitem__
, alternative istake
which is going into the Array API spec. - No views into arrays in the Array API spec
- Can not support
Dask
orJAX
at first because of they do not support methods that have Data-dependent output shapes such asunique
.