PERF: Allow jitting of groupby agg loop (#35759) · pandas-dev/pandas@068e654 (original) (raw)

`@@ -34,7 +34,7 @@ class providing the base-class of operations.

`

34

34

``

35

35

`from pandas._config.config import option_context

`

36

36

``

37

``

`-

from pandas._libs import Timestamp

`

``

37

`+

from pandas._libs import Timestamp, lib

`

38

38

`import pandas._libs.groupby as libgroupby

`

39

39

`from pandas._typing import F, FrameOrSeries, FrameOrSeriesUnion, Scalar

`

40

40

`from pandas.compat.numpy import function as nv

`

`@@ -61,11 +61,11 @@ class providing the base-class of operations.

`

61

61

`import pandas.core.common as com

`

62

62

`from pandas.core.frame import DataFrame

`

63

63

`from pandas.core.generic import NDFrame

`

64

``

`-

from pandas.core.groupby import base, ops

`

``

64

`+

from pandas.core.groupby import base, numba_, ops

`

65

65

`from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex

`

66

66

`from pandas.core.series import Series

`

67

67

`from pandas.core.sorting import get_group_index_sorter

`

68

``

`-

from pandas.core.util.numba_ import maybe_use_numba

`

``

68

`+

from pandas.core.util.numba_ import NUMBA_FUNC_CACHE

`

69

69

``

70

70

`_common_see_also = """

`

71

71

` See Also

`

`@@ -384,7 +384,8 @@ class providing the base-class of operations.

`

384

384

` - dict of axis labels -> functions, function names or list of such.

`

385

385

``

386

386

` Can also accept a Numba JIT function with

`

387

``


 ``engine='numba'`` specified.

``

387


 ``engine='numba'`` specified. Only passing a single function is supported

``

388

`+

with this engine.

`

388

389

``

389

390

``` If the 'numba' engine is chosen, the function must be


`390`

`391`

```  a user defined function with ``values`` and ``index`` as the

`@@ -1053,12 +1054,43 @@ def _cython_agg_general(

`

1053

1054

``

1054

1055

`return self._wrap_aggregated_output(output, index=self.grouper.result_index)

`

1055

1056

``

1056

``

`-

def _python_agg_general(

`

1057

``

`-

self, func, *args, engine="cython", engine_kwargs=None, **kwargs

`

1058

``

`-

):

`

``

1057

`+

def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):

`

``

1058

`+

"""

`

``

1059

`+

Perform groupby aggregation routine with the numba engine.

`

``

1060

+

``

1061

`+

This routine mimics the data splitting routine of the DataSplitter class

`

``

1062

`+

to generate the indices of each group in the sorted data and then passes the

`

``

1063

`+

data and indices into a Numba jitted function.

`

``

1064

`+

"""

`

``

1065

`+

group_keys = self.grouper._get_group_keys()

`

``

1066

`+

labels, _, n_groups = self.grouper.group_info

`

``

1067

`+

sorted_index = get_group_index_sorter(labels, n_groups)

`

``

1068

`+

sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)

`

``

1069

`+

sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()

`

``

1070

`+

starts, ends = lib.generate_slices(sorted_labels, n_groups)

`

``

1071

`+

cache_key = (func, "groupby_agg")

`

``

1072

`+

if cache_key in NUMBA_FUNC_CACHE:

`

``

1073

`+

Return an already compiled version of roll_apply if available

`

``

1074

`+

numba_agg_func = NUMBA_FUNC_CACHE[cache_key]

`

``

1075

`+

else:

`

``

1076

`+

numba_agg_func = numba_.generate_numba_agg_func(

`

``

1077

`+

tuple(args), kwargs, func, engine_kwargs

`

``

1078

`+

)

`

``

1079

`+

result = numba_agg_func(

`

``

1080

`+

sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns),

`

``

1081

`+

)

`

``

1082

`+

if cache_key not in NUMBA_FUNC_CACHE:

`

``

1083

`+

NUMBA_FUNC_CACHE[cache_key] = numba_agg_func

`

``

1084

+

``

1085

`+

if self.grouper.nkeys > 1:

`

``

1086

`+

index = MultiIndex.from_tuples(group_keys, names=self.grouper.names)

`

``

1087

`+

else:

`

``

1088

`+

index = Index(group_keys, name=self.grouper.names[0])

`

``

1089

`+

return result, index

`

``

1090

+

``

1091

`+

def _python_agg_general(self, func, *args, **kwargs):

`

1059

1092

`func = self._is_builtin_func(func)

`

1060

``

`-

if engine != "numba":

`

1061

``

`-

f = lambda x: func(x, *args, **kwargs)

`

``

1093

`+

f = lambda x: func(x, *args, **kwargs)

`

1062

1094

``

1063

1095

`# iterate through "columns" ex exclusions to populate output dict

`

1064

1096

`output: Dict[base.OutputKey, np.ndarray] = {}

`

`@@ -1069,21 +1101,11 @@ def _python_agg_general(

`

1069

1101

`# agg_series below assumes ngroups > 0

`

1070

1102

`continue

`

1071

1103

``

1072

``

`-

if maybe_use_numba(engine):

`

1073

``

`-

result, counts = self.grouper.agg_series(

`

1074

``

`-

obj,

`

1075

``

`-

func,

`

1076

``

`-

*args,

`

1077

``

`-

engine=engine,

`

1078

``

`-

engine_kwargs=engine_kwargs,

`

1079

``

`-

**kwargs,

`

1080

``

`-

)

`

1081

``

`-

else:

`

1082

``

`-

try:

`

1083

``

`-

if this function is invalid for this dtype, we will ignore it.

`

1084

``

`-

result, counts = self.grouper.agg_series(obj, f)

`

1085

``

`-

except TypeError:

`

1086

``

`-

continue

`

``

1104

`+

try:

`

``

1105

`+

if this function is invalid for this dtype, we will ignore it.

`

``

1106

`+

result, counts = self.grouper.agg_series(obj, f)

`

``

1107

`+

except TypeError:

`

``

1108

`+

continue

`

1087

1109

``

1088

1110

`assert result is not None

`

1089

1111

`key = base.OutputKey(label=name, position=idx)

`