ENH Enables array_api for LinearDiscriminantAnalysis by thomasjpfan · Pull Request #102 · thomasjpfan/scikit-learn (original) (raw)
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh I see, you moved the aggregation from at
into the mean calculation.
I think you could get rid of the loop with something like
np.sum(np.where(y[None] == np.arange(classes.shape[0])[:, None], X, np.asarray(0.)), axis=1)/cnt
except None
indexing isn't in the spec yet (data-apis/array-api#360), so you'd have to use expand_dims
. That's still not as "efficient" as the original because you are adding a lot of redundant 0s in the sum, so depending on how many classes there typically are the loop version might be better anyway (at least in the sense that it's more readable). Another thing to note is that not all array API modules are guaranteed to have boolean indexing (https://data-apis.org/array-api/latest/API_specification/indexing.html#boolean-array-indexing).
Also, I think cnt
can be gotten from unique_counts
or unique_all
in the array API (include_counts
in NumPy).