ENH Enables array_api for LinearDiscriminantAnalysis by thomasjpfan · Pull Request #102 · thomasjpfan/scikit-learn (original) (raw)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, you moved the aggregation from at into the mean calculation.

I think you could get rid of the loop with something like

np.sum(np.where(y[None] == np.arange(classes.shape[0])[:, None], X, np.asarray(0.)), axis=1)/cnt

except None indexing isn't in the spec yet (data-apis/array-api#360), so you'd have to use expand_dims. That's still not as "efficient" as the original because you are adding a lot of redundant 0s in the sum, so depending on how many classes there typically are the loop version might be better anyway (at least in the sense that it's more readable). Another thing to note is that not all array API modules are guaranteed to have boolean indexing (https://data-apis.org/array-api/latest/API_specification/indexing.html#boolean-array-indexing).

Also, I think cnt can be gotten from unique_counts or unique_all in the array API (include_counts in NumPy).