PERF: avoid cast in algos.rank (#46175) · pandas-dev/pandas@d801a4b (original) (raw)

`@@ -46,7 +46,6 @@ cnp.import_array()

`

46

46

``

47

47

`cimport pandas._libs.util as util

`

48

48

`from pandas._libs.dtypes cimport (

`

49

``

`-

iu_64_floating_obj_t,

`

50

49

` numeric_object_t,

`

51

50

` numeric_t,

`

52

51

`)

`

`@@ -821,30 +820,54 @@ def is_monotonic(ndarray[numeric_object_t, ndim=1] arr, bint timelike):

`

821

820

`# rank_1d, rank_2d

`

822

821

`# ----------------------------------------------------------------------

`

823

822

``

824

``

`-

cdef iu_64_floating_obj_t get_rank_nan_fill_val(

`

825

``

`-

bint rank_nans_highest,

`

826

``

`-

iu_64_floating_obj_t[:] _=None

`

``

823

`+

cdef numeric_object_t get_rank_nan_fill_val(

`

``

824

`+

bint rank_nans_highest,

`

``

825

`+

numeric_object_t[:] _=None

`

827

826

`):

`

828

827

`"""

`

829

828

` Return the value we'll use to represent missing values when sorting depending

`

830

829

` on if we'd like missing values to end up at the top/bottom. (The second parameter

`

831

830

` is unused, but needed for fused type specialization)

`

832

831

`"""

`

833

832

`if rank_nans_highest:

`

834

``

`-

if iu_64_floating_obj_t is object:

`

``

833

`+

if numeric_object_t is object:

`

835

834

`return Infinity()

`

836

``

`-

elif iu_64_floating_obj_t is int64_t:

`

``

835

`+

elif numeric_object_t is int64_t:

`

837

836

`return util.INT64_MAX

`

838

``

`-

elif iu_64_floating_obj_t is uint64_t:

`

``

837

`+

elif numeric_object_t is int32_t:

`

``

838

`+

return util.INT32_MAX

`

``

839

`+

elif numeric_object_t is int16_t:

`

``

840

`+

return util.INT16_MAX

`

``

841

`+

elif numeric_object_t is int8_t:

`

``

842

`+

return util.INT8_MAX

`

``

843

`+

elif numeric_object_t is uint64_t:

`

839

844

`return util.UINT64_MAX

`

``

845

`+

elif numeric_object_t is uint32_t:

`

``

846

`+

return util.UINT32_MAX

`

``

847

`+

elif numeric_object_t is uint16_t:

`

``

848

`+

return util.UINT16_MAX

`

``

849

`+

elif numeric_object_t is uint8_t:

`

``

850

`+

return util.UINT8_MAX

`

840

851

`else:

`

841

852

`return np.inf

`

842

853

`else:

`

843

``

`-

if iu_64_floating_obj_t is object:

`

``

854

`+

if numeric_object_t is object:

`

844

855

`return NegInfinity()

`

845

``

`-

elif iu_64_floating_obj_t is int64_t:

`

``

856

`+

elif numeric_object_t is int64_t:

`

846

857

`return NPY_NAT

`

847

``

`-

elif iu_64_floating_obj_t is uint64_t:

`

``

858

`+

elif numeric_object_t is int32_t:

`

``

859

`+

return util.INT32_MIN

`

``

860

`+

elif numeric_object_t is int16_t:

`

``

861

`+

return util.INT16_MIN

`

``

862

`+

elif numeric_object_t is int8_t:

`

``

863

`+

return util.INT8_MIN

`

``

864

`+

elif numeric_object_t is uint64_t:

`

``

865

`+

return 0

`

``

866

`+

elif numeric_object_t is uint32_t:

`

``

867

`+

return 0

`

``

868

`+

elif numeric_object_t is uint16_t:

`

``

869

`+

return 0

`

``

870

`+

elif numeric_object_t is uint8_t:

`

848

871

`return 0

`

849

872

`else:

`

850

873

`return -np.inf

`

`@@ -853,7 +876,7 @@ cdef iu_64_floating_obj_t get_rank_nan_fill_val(

`

853

876

`@cython.wraparound(False)

`

854

877

`@cython.boundscheck(False)

`

855

878

`def rank_1d(

`

856

``

`-

ndarray[iu_64_floating_obj_t, ndim=1] values,

`

``

879

`+

ndarray[numeric_object_t, ndim=1] values,

`

857

880

` const intp_t[:] labels=None,

`

858

881

` bint is_datetimelike=False,

`

859

882

`ties_method="average",

`

`@@ -866,7 +889,7 @@ def rank_1d(

`

866

889

``

867

890

` Parameters

`

868

891

` ----------

`

869

``

`-

values : array of iu_64_floating_obj_t values to be ranked

`

``

892

`+

values : array of numeric_object_t values to be ranked

`

870

893

` labels : np.ndarray[np.intp] or None

`

871

894

` Array containing unique label for each group, with its ordering

`

872

895

`` matching up to the corresponding record in values. If not called

``

`@@ -896,11 +919,11 @@ def rank_1d(

`

896

919

` int64_t[::1] grp_sizes

`

897

920

` intp_t[:] lexsort_indexer

`

898

921

` float64_t[::1] out

`

899

``

`-

ndarray[iu_64_floating_obj_t, ndim=1] masked_vals

`

900

``

`-

iu_64_floating_obj_t[:] masked_vals_memview

`

``

922

`+

ndarray[numeric_object_t, ndim=1] masked_vals

`

``

923

`+

numeric_object_t[:] masked_vals_memview

`

901

924

` uint8_t[:] mask

`

902

925

` bint keep_na, nans_rank_highest, check_labels, check_mask

`

903

``

`-

iu_64_floating_obj_t nan_fill_val

`

``

926

`+

numeric_object_t nan_fill_val

`

904

927

``

905

928

` tiebreak = tiebreakers[ties_method]

`

906

929

`if tiebreak == TIEBREAK_FIRST:

`

`@@ -921,22 +944,26 @@ def rank_1d(

`

921

944

` check_labels = labels is not None

`

922

945

``

923

946

`# For cases where a mask is not possible, we can avoid mask checks

`

924

``

`-

check_mask = not (iu_64_floating_obj_t is uint64_t or

`

925

``

`-

(iu_64_floating_obj_t is int64_t and not is_datetimelike))

`

``

947

`+

check_mask = (

`

``

948

`+

numeric_object_t is float32_t

`

``

949

`+

or numeric_object_t is float64_t

`

``

950

`+

or numeric_object_t is object

`

``

951

`+

or (numeric_object_t is int64_t and is_datetimelike)

`

``

952

`+

)

`

926

953

``

927

954

`# Copy values into new array in order to fill missing data

`

928

955

`# with mask, without obfuscating location of missing data

`

929

956

`# in values array

`

930

``

`-

if iu_64_floating_obj_t is object and values.dtype != np.object_:

`

``

957

`+

if numeric_object_t is object and values.dtype != np.object_:

`

931

958

` masked_vals = values.astype('O')

`

932

959

`else:

`

933

960

` masked_vals = values.copy()

`

934

961

``

935

``

`-

if iu_64_floating_obj_t is object:

`

``

962

`+

if numeric_object_t is object:

`

936

963

` mask = missing.isnaobj(masked_vals)

`

937

``

`-

elif iu_64_floating_obj_t is int64_t and is_datetimelike:

`

``

964

`+

elif numeric_object_t is int64_t and is_datetimelike:

`

938

965

` mask = (masked_vals == NPY_NAT).astype(np.uint8)

`

939

``

`-

elif iu_64_floating_obj_t is float64_t:

`

``

966

`+

elif numeric_object_t is float64_t or numeric_object_t is float32_t:

`

940

967

` mask = np.isnan(masked_vals).astype(np.uint8)

`

941

968

`else:

`

942

969

` mask = np.zeros(shape=len(masked_vals), dtype=np.uint8)

`

`@@ -948,7 +975,7 @@ def rank_1d(

`

948

975

`# will flip the ordering to still end up with lowest rank.

`

949

976

`` # Symmetric logic applies to na_option == 'bottom'

``

950

977

` nans_rank_highest = ascending ^ (na_option == 'top')

`

951

``

`-

nan_fill_val = get_rank_nan_fill_valiu_64_floating_obj_t

`

``

978

`+

nan_fill_val = get_rank_nan_fill_valnumeric_object_t

`

952

979

`if nans_rank_highest:

`

953

980

` order = [masked_vals, mask]

`

954

981

`else:

`

`@@ -994,8 +1021,8 @@ cdef void rank_sorted_1d(

`

994

1021

` float64_t[::1] out,

`

995

1022

` int64_t[::1] grp_sizes,

`

996

1023

` const intp_t[:] sort_indexer,

`

997

``

`-

Can make const with cython3 (https://github.com/cython/cython/issues/3222)

`

998

``

`-

iu_64_floating_obj_t[:] masked_vals,

`

``

1024

`+

TODO(cython3): make const (https://github.com/cython/cython/issues/3222)

`

``

1025

`+

numeric_object_t[:] masked_vals,

`

999

1026

` const uint8_t[:] mask,

`

1000

1027

` bint check_mask,

`

1001

1028

` Py_ssize_t N,

`

`@@ -1019,7 +1046,7 @@ cdef void rank_sorted_1d(

`

1019

1046

` if labels is None.

`

1020

1047

` sort_indexer : intp_t[:]

`

1021

1048

` Array of indices which sorts masked_vals

`

1022

``

`-

masked_vals : iu_64_floating_obj_t[:]

`

``

1049

`+

masked_vals : numeric_object_t[:]

`

1023

1050

` The values input to rank_1d, with missing values replaced by fill values

`

1024

1051

` mask : uint8_t[:]

`

1025

1052

` Array where entries are True if the value is missing, False otherwise.

`

`@@ -1051,7 +1078,7 @@ cdef void rank_sorted_1d(

`

1051

1078

`# that sorted value for retrieval back from the original

`

1052

1079

`# values / masked_vals arrays

`

1053

1080

`# TODO(cython3): de-duplicate once cython supports conditional nogil

`

1054

``

`-

if iu_64_floating_obj_t is object:

`

``

1081

`+

if numeric_object_t is object:

`

1055

1082

`with gil:

`

1056

1083

`for i in range(N):

`

1057

1084

` at_end = i == N - 1

`

`@@ -1259,7 +1286,7 @@ cdef void rank_sorted_1d(

`

1259

1286

``

1260

1287

``

1261

1288

`def rank_2d(

`

1262

``

`-

ndarray[iu_64_floating_obj_t, ndim=2] in_arr,

`

``

1289

`+

ndarray[numeric_object_t, ndim=2] in_arr,

`

1263

1290

` int axis=0,

`

1264

1291

` bint is_datetimelike=False,

`

1265

1292

`ties_method="average",

`

`@@ -1274,13 +1301,13 @@ def rank_2d(

`

1274

1301

` Py_ssize_t k, n, col

`

1275

1302

` float64_t[::1, :] out # Column-major so columns are contiguous

`

1276

1303

` int64_t[::1] grp_sizes

`

1277

``

`-

ndarray[iu_64_floating_obj_t, ndim=2] values

`

1278

``

`-

iu_64_floating_obj_t[:, :] masked_vals

`

``

1304

`+

ndarray[numeric_object_t, ndim=2] values

`

``

1305

`+

numeric_object_t[:, :] masked_vals

`

1279

1306

` intp_t[:, :] sort_indexer

`

1280

1307

` uint8_t[:, :] mask

`

1281

1308

` TiebreakEnumType tiebreak

`

1282

1309

` bint check_mask, keep_na, nans_rank_highest

`

1283

``

`-

iu_64_floating_obj_t nan_fill_val

`

``

1310

`+

numeric_object_t nan_fill_val

`

1284

1311

``

1285

1312

` tiebreak = tiebreakers[ties_method]

`

1286

1313

`if tiebreak == TIEBREAK_FIRST:

`

`@@ -1290,29 +1317,32 @@ def rank_2d(

`

1290

1317

` keep_na = na_option == 'keep'

`

1291

1318

``

1292

1319

`# For cases where a mask is not possible, we can avoid mask checks

`

1293

``

`-

check_mask = not (iu_64_floating_obj_t is uint64_t or

`

1294

``

`-

(iu_64_floating_obj_t is int64_t and not is_datetimelike))

`

``

1320

`+

check_mask = (

`

``

1321

`+

numeric_object_t is float32_t

`

``

1322

`+

or numeric_object_t is float64_t

`

``

1323

`+

or numeric_object_t is object

`

``

1324

`+

or (numeric_object_t is int64_t and is_datetimelike)

`

``

1325

`+

)

`

1295

1326

``

1296

1327

`if axis == 1:

`

1297

1328

` values = np.asarray(in_arr).T.copy()

`

1298

1329

`else:

`

1299

1330

` values = np.asarray(in_arr).copy()

`

1300

1331

``

1301

``

`-

if iu_64_floating_obj_t is object:

`

``

1332

`+

if numeric_object_t is object:

`

1302

1333

`if values.dtype != np.object_:

`

1303

1334

` values = values.astype('O')

`

1304

1335

``

1305

1336

` nans_rank_highest = ascending ^ (na_option == 'top')

`

1306

1337

`if check_mask:

`

1307

``

`-

nan_fill_val = get_rank_nan_fill_valiu_64_floating_obj_t

`

``

1338

`+

nan_fill_val = get_rank_nan_fill_valnumeric_object_t

`

1308

1339

``

1309

``

`-

if iu_64_floating_obj_t is object:

`

``

1340

`+

if numeric_object_t is object:

`

1310

1341

` mask = missing.isnaobj2d(values).view(np.uint8)

`

1311

``

`-

elif iu_64_floating_obj_t is float64_t:

`

``

1342

`+

elif numeric_object_t is float64_t or numeric_object_t is float32_t:

`

1312

1343

` mask = np.isnan(values).view(np.uint8)

`

1313

``

-

1314

``

`-

int64 and datetimelike

`

1315

1344

`else:

`

``

1345

`+

i.e. int64 and datetimelike

`

1316

1346

` mask = (values == NPY_NAT).view(np.uint8)

`

1317

1347

` np.putmask(values, mask, nan_fill_val)

`

1318

1348

`else:

`