PERF: avoid cast in algos.rank (#46175) · pandas-dev/pandas@d801a4b (original) (raw)
`@@ -46,7 +46,6 @@ cnp.import_array()
`
46
46
``
47
47
`cimport pandas._libs.util as util
`
48
48
`from pandas._libs.dtypes cimport (
`
49
``
`-
iu_64_floating_obj_t,
`
50
49
` numeric_object_t,
`
51
50
` numeric_t,
`
52
51
`)
`
`@@ -821,30 +820,54 @@ def is_monotonic(ndarray[numeric_object_t, ndim=1] arr, bint timelike):
`
821
820
`# rank_1d, rank_2d
`
822
821
`# ----------------------------------------------------------------------
`
823
822
``
824
``
`-
cdef iu_64_floating_obj_t get_rank_nan_fill_val(
`
825
``
`-
bint rank_nans_highest,
`
826
``
`-
iu_64_floating_obj_t[:] _=None
`
``
823
`+
cdef numeric_object_t get_rank_nan_fill_val(
`
``
824
`+
bint rank_nans_highest,
`
``
825
`+
numeric_object_t[:] _=None
`
827
826
`):
`
828
827
`"""
`
829
828
` Return the value we'll use to represent missing values when sorting depending
`
830
829
` on if we'd like missing values to end up at the top/bottom. (The second parameter
`
831
830
` is unused, but needed for fused type specialization)
`
832
831
`"""
`
833
832
`if rank_nans_highest:
`
834
``
`-
if iu_64_floating_obj_t is object:
`
``
833
`+
if numeric_object_t is object:
`
835
834
`return Infinity()
`
836
``
`-
elif iu_64_floating_obj_t is int64_t:
`
``
835
`+
elif numeric_object_t is int64_t:
`
837
836
`return util.INT64_MAX
`
838
``
`-
elif iu_64_floating_obj_t is uint64_t:
`
``
837
`+
elif numeric_object_t is int32_t:
`
``
838
`+
return util.INT32_MAX
`
``
839
`+
elif numeric_object_t is int16_t:
`
``
840
`+
return util.INT16_MAX
`
``
841
`+
elif numeric_object_t is int8_t:
`
``
842
`+
return util.INT8_MAX
`
``
843
`+
elif numeric_object_t is uint64_t:
`
839
844
`return util.UINT64_MAX
`
``
845
`+
elif numeric_object_t is uint32_t:
`
``
846
`+
return util.UINT32_MAX
`
``
847
`+
elif numeric_object_t is uint16_t:
`
``
848
`+
return util.UINT16_MAX
`
``
849
`+
elif numeric_object_t is uint8_t:
`
``
850
`+
return util.UINT8_MAX
`
840
851
`else:
`
841
852
`return np.inf
`
842
853
`else:
`
843
``
`-
if iu_64_floating_obj_t is object:
`
``
854
`+
if numeric_object_t is object:
`
844
855
`return NegInfinity()
`
845
``
`-
elif iu_64_floating_obj_t is int64_t:
`
``
856
`+
elif numeric_object_t is int64_t:
`
846
857
`return NPY_NAT
`
847
``
`-
elif iu_64_floating_obj_t is uint64_t:
`
``
858
`+
elif numeric_object_t is int32_t:
`
``
859
`+
return util.INT32_MIN
`
``
860
`+
elif numeric_object_t is int16_t:
`
``
861
`+
return util.INT16_MIN
`
``
862
`+
elif numeric_object_t is int8_t:
`
``
863
`+
return util.INT8_MIN
`
``
864
`+
elif numeric_object_t is uint64_t:
`
``
865
`+
return 0
`
``
866
`+
elif numeric_object_t is uint32_t:
`
``
867
`+
return 0
`
``
868
`+
elif numeric_object_t is uint16_t:
`
``
869
`+
return 0
`
``
870
`+
elif numeric_object_t is uint8_t:
`
848
871
`return 0
`
849
872
`else:
`
850
873
`return -np.inf
`
`@@ -853,7 +876,7 @@ cdef iu_64_floating_obj_t get_rank_nan_fill_val(
`
853
876
`@cython.wraparound(False)
`
854
877
`@cython.boundscheck(False)
`
855
878
`def rank_1d(
`
856
``
`-
ndarray[iu_64_floating_obj_t, ndim=1] values,
`
``
879
`+
ndarray[numeric_object_t, ndim=1] values,
`
857
880
` const intp_t[:] labels=None,
`
858
881
` bint is_datetimelike=False,
`
859
882
`ties_method="average",
`
`@@ -866,7 +889,7 @@ def rank_1d(
`
866
889
``
867
890
` Parameters
`
868
891
` ----------
`
869
``
`-
values : array of iu_64_floating_obj_t values to be ranked
`
``
892
`+
values : array of numeric_object_t values to be ranked
`
870
893
` labels : np.ndarray[np.intp] or None
`
871
894
` Array containing unique label for each group, with its ordering
`
872
895
`` matching up to the corresponding record in values. If not called
``
`@@ -896,11 +919,11 @@ def rank_1d(
`
896
919
` int64_t[::1] grp_sizes
`
897
920
` intp_t[:] lexsort_indexer
`
898
921
` float64_t[::1] out
`
899
``
`-
ndarray[iu_64_floating_obj_t, ndim=1] masked_vals
`
900
``
`-
iu_64_floating_obj_t[:] masked_vals_memview
`
``
922
`+
ndarray[numeric_object_t, ndim=1] masked_vals
`
``
923
`+
numeric_object_t[:] masked_vals_memview
`
901
924
` uint8_t[:] mask
`
902
925
` bint keep_na, nans_rank_highest, check_labels, check_mask
`
903
``
`-
iu_64_floating_obj_t nan_fill_val
`
``
926
`+
numeric_object_t nan_fill_val
`
904
927
``
905
928
` tiebreak = tiebreakers[ties_method]
`
906
929
`if tiebreak == TIEBREAK_FIRST:
`
`@@ -921,22 +944,26 @@ def rank_1d(
`
921
944
` check_labels = labels is not None
`
922
945
``
923
946
`# For cases where a mask is not possible, we can avoid mask checks
`
924
``
`-
check_mask = not (iu_64_floating_obj_t is uint64_t or
`
925
``
`-
(iu_64_floating_obj_t is int64_t and not is_datetimelike))
`
``
947
`+
check_mask = (
`
``
948
`+
numeric_object_t is float32_t
`
``
949
`+
or numeric_object_t is float64_t
`
``
950
`+
or numeric_object_t is object
`
``
951
`+
or (numeric_object_t is int64_t and is_datetimelike)
`
``
952
`+
)
`
926
953
``
927
954
`# Copy values into new array in order to fill missing data
`
928
955
`# with mask, without obfuscating location of missing data
`
929
956
`# in values array
`
930
``
`-
if iu_64_floating_obj_t is object and values.dtype != np.object_:
`
``
957
`+
if numeric_object_t is object and values.dtype != np.object_:
`
931
958
` masked_vals = values.astype('O')
`
932
959
`else:
`
933
960
` masked_vals = values.copy()
`
934
961
``
935
``
`-
if iu_64_floating_obj_t is object:
`
``
962
`+
if numeric_object_t is object:
`
936
963
` mask = missing.isnaobj(masked_vals)
`
937
``
`-
elif iu_64_floating_obj_t is int64_t and is_datetimelike:
`
``
964
`+
elif numeric_object_t is int64_t and is_datetimelike:
`
938
965
` mask = (masked_vals == NPY_NAT).astype(np.uint8)
`
939
``
`-
elif iu_64_floating_obj_t is float64_t:
`
``
966
`+
elif numeric_object_t is float64_t or numeric_object_t is float32_t:
`
940
967
` mask = np.isnan(masked_vals).astype(np.uint8)
`
941
968
`else:
`
942
969
` mask = np.zeros(shape=len(masked_vals), dtype=np.uint8)
`
`@@ -948,7 +975,7 @@ def rank_1d(
`
948
975
`# will flip the ordering to still end up with lowest rank.
`
949
976
`` # Symmetric logic applies to na_option == 'bottom'
``
950
977
` nans_rank_highest = ascending ^ (na_option == 'top')
`
951
``
`-
nan_fill_val = get_rank_nan_fill_valiu_64_floating_obj_t
`
``
978
`+
nan_fill_val = get_rank_nan_fill_valnumeric_object_t
`
952
979
`if nans_rank_highest:
`
953
980
` order = [masked_vals, mask]
`
954
981
`else:
`
`@@ -994,8 +1021,8 @@ cdef void rank_sorted_1d(
`
994
1021
` float64_t[::1] out,
`
995
1022
` int64_t[::1] grp_sizes,
`
996
1023
` const intp_t[:] sort_indexer,
`
997
``
`-
Can make const with cython3 (https://github.com/cython/cython/issues/3222)
`
998
``
`-
iu_64_floating_obj_t[:] masked_vals,
`
``
1024
`+
TODO(cython3): make const (https://github.com/cython/cython/issues/3222)
`
``
1025
`+
numeric_object_t[:] masked_vals,
`
999
1026
` const uint8_t[:] mask,
`
1000
1027
` bint check_mask,
`
1001
1028
` Py_ssize_t N,
`
`@@ -1019,7 +1046,7 @@ cdef void rank_sorted_1d(
`
1019
1046
` if labels is None.
`
1020
1047
` sort_indexer : intp_t[:]
`
1021
1048
` Array of indices which sorts masked_vals
`
1022
``
`-
masked_vals : iu_64_floating_obj_t[:]
`
``
1049
`+
masked_vals : numeric_object_t[:]
`
1023
1050
` The values input to rank_1d, with missing values replaced by fill values
`
1024
1051
` mask : uint8_t[:]
`
1025
1052
` Array where entries are True if the value is missing, False otherwise.
`
`@@ -1051,7 +1078,7 @@ cdef void rank_sorted_1d(
`
1051
1078
`# that sorted value for retrieval back from the original
`
1052
1079
`# values / masked_vals arrays
`
1053
1080
`# TODO(cython3): de-duplicate once cython supports conditional nogil
`
1054
``
`-
if iu_64_floating_obj_t is object:
`
``
1081
`+
if numeric_object_t is object:
`
1055
1082
`with gil:
`
1056
1083
`for i in range(N):
`
1057
1084
` at_end = i == N - 1
`
`@@ -1259,7 +1286,7 @@ cdef void rank_sorted_1d(
`
1259
1286
``
1260
1287
``
1261
1288
`def rank_2d(
`
1262
``
`-
ndarray[iu_64_floating_obj_t, ndim=2] in_arr,
`
``
1289
`+
ndarray[numeric_object_t, ndim=2] in_arr,
`
1263
1290
` int axis=0,
`
1264
1291
` bint is_datetimelike=False,
`
1265
1292
`ties_method="average",
`
`@@ -1274,13 +1301,13 @@ def rank_2d(
`
1274
1301
` Py_ssize_t k, n, col
`
1275
1302
` float64_t[::1, :] out # Column-major so columns are contiguous
`
1276
1303
` int64_t[::1] grp_sizes
`
1277
``
`-
ndarray[iu_64_floating_obj_t, ndim=2] values
`
1278
``
`-
iu_64_floating_obj_t[:, :] masked_vals
`
``
1304
`+
ndarray[numeric_object_t, ndim=2] values
`
``
1305
`+
numeric_object_t[:, :] masked_vals
`
1279
1306
` intp_t[:, :] sort_indexer
`
1280
1307
` uint8_t[:, :] mask
`
1281
1308
` TiebreakEnumType tiebreak
`
1282
1309
` bint check_mask, keep_na, nans_rank_highest
`
1283
``
`-
iu_64_floating_obj_t nan_fill_val
`
``
1310
`+
numeric_object_t nan_fill_val
`
1284
1311
``
1285
1312
` tiebreak = tiebreakers[ties_method]
`
1286
1313
`if tiebreak == TIEBREAK_FIRST:
`
`@@ -1290,29 +1317,32 @@ def rank_2d(
`
1290
1317
` keep_na = na_option == 'keep'
`
1291
1318
``
1292
1319
`# For cases where a mask is not possible, we can avoid mask checks
`
1293
``
`-
check_mask = not (iu_64_floating_obj_t is uint64_t or
`
1294
``
`-
(iu_64_floating_obj_t is int64_t and not is_datetimelike))
`
``
1320
`+
check_mask = (
`
``
1321
`+
numeric_object_t is float32_t
`
``
1322
`+
or numeric_object_t is float64_t
`
``
1323
`+
or numeric_object_t is object
`
``
1324
`+
or (numeric_object_t is int64_t and is_datetimelike)
`
``
1325
`+
)
`
1295
1326
``
1296
1327
`if axis == 1:
`
1297
1328
` values = np.asarray(in_arr).T.copy()
`
1298
1329
`else:
`
1299
1330
` values = np.asarray(in_arr).copy()
`
1300
1331
``
1301
``
`-
if iu_64_floating_obj_t is object:
`
``
1332
`+
if numeric_object_t is object:
`
1302
1333
`if values.dtype != np.object_:
`
1303
1334
` values = values.astype('O')
`
1304
1335
``
1305
1336
` nans_rank_highest = ascending ^ (na_option == 'top')
`
1306
1337
`if check_mask:
`
1307
``
`-
nan_fill_val = get_rank_nan_fill_valiu_64_floating_obj_t
`
``
1338
`+
nan_fill_val = get_rank_nan_fill_valnumeric_object_t
`
1308
1339
``
1309
``
`-
if iu_64_floating_obj_t is object:
`
``
1340
`+
if numeric_object_t is object:
`
1310
1341
` mask = missing.isnaobj2d(values).view(np.uint8)
`
1311
``
`-
elif iu_64_floating_obj_t is float64_t:
`
``
1342
`+
elif numeric_object_t is float64_t or numeric_object_t is float32_t:
`
1312
1343
` mask = np.isnan(values).view(np.uint8)
`
1313
``
-
1314
``
`-
int64 and datetimelike
`
1315
1344
`else:
`
``
1345
`+
i.e. int64 and datetimelike
`
1316
1346
` mask = (values == NPY_NAT).view(np.uint8)
`
1317
1347
` np.putmask(values, mask, nan_fill_val)
`
1318
1348
`else:
`