PERF: Allow jitting of groupby agg loop (#35759) · pandas-dev/pandas@068e654 (original) (raw)
`@@ -34,7 +34,7 @@ class providing the base-class of operations.
`
34
34
``
35
35
`from pandas._config.config import option_context
`
36
36
``
37
``
`-
from pandas._libs import Timestamp
`
``
37
`+
from pandas._libs import Timestamp, lib
`
38
38
`import pandas._libs.groupby as libgroupby
`
39
39
`from pandas._typing import F, FrameOrSeries, FrameOrSeriesUnion, Scalar
`
40
40
`from pandas.compat.numpy import function as nv
`
`@@ -61,11 +61,11 @@ class providing the base-class of operations.
`
61
61
`import pandas.core.common as com
`
62
62
`from pandas.core.frame import DataFrame
`
63
63
`from pandas.core.generic import NDFrame
`
64
``
`-
from pandas.core.groupby import base, ops
`
``
64
`+
from pandas.core.groupby import base, numba_, ops
`
65
65
`from pandas.core.indexes.api import CategoricalIndex, Index, MultiIndex
`
66
66
`from pandas.core.series import Series
`
67
67
`from pandas.core.sorting import get_group_index_sorter
`
68
``
`-
from pandas.core.util.numba_ import maybe_use_numba
`
``
68
`+
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE
`
69
69
``
70
70
`_common_see_also = """
`
71
71
` See Also
`
`@@ -384,7 +384,8 @@ class providing the base-class of operations.
`
384
384
` - dict of axis labels -> functions, function names or list of such.
`
385
385
``
386
386
` Can also accept a Numba JIT function with
`
387
``
``engine='numba'`` specified.
``
387
``engine='numba'`` specified. Only passing a single function is supported
``
388
`+
with this engine.
`
388
389
``
389
390
``` If the 'numba'
engine is chosen, the function must be
`390`
`391`
``` a user defined function with ``values`` and ``index`` as the
`@@ -1053,12 +1054,43 @@ def _cython_agg_general(
`
1053
1054
``
1054
1055
`return self._wrap_aggregated_output(output, index=self.grouper.result_index)
`
1055
1056
``
1056
``
`-
def _python_agg_general(
`
1057
``
`-
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
`
1058
``
`-
):
`
``
1057
`+
def _aggregate_with_numba(self, data, func, *args, engine_kwargs=None, **kwargs):
`
``
1058
`+
"""
`
``
1059
`+
Perform groupby aggregation routine with the numba engine.
`
``
1060
+
``
1061
`+
This routine mimics the data splitting routine of the DataSplitter class
`
``
1062
`+
to generate the indices of each group in the sorted data and then passes the
`
``
1063
`+
data and indices into a Numba jitted function.
`
``
1064
`+
"""
`
``
1065
`+
group_keys = self.grouper._get_group_keys()
`
``
1066
`+
labels, _, n_groups = self.grouper.group_info
`
``
1067
`+
sorted_index = get_group_index_sorter(labels, n_groups)
`
``
1068
`+
sorted_labels = algorithms.take_nd(labels, sorted_index, allow_fill=False)
`
``
1069
`+
sorted_data = data.take(sorted_index, axis=self.axis).to_numpy()
`
``
1070
`+
starts, ends = lib.generate_slices(sorted_labels, n_groups)
`
``
1071
`+
cache_key = (func, "groupby_agg")
`
``
1072
`+
if cache_key in NUMBA_FUNC_CACHE:
`
``
1073
`+
Return an already compiled version of roll_apply if available
`
``
1074
`+
numba_agg_func = NUMBA_FUNC_CACHE[cache_key]
`
``
1075
`+
else:
`
``
1076
`+
numba_agg_func = numba_.generate_numba_agg_func(
`
``
1077
`+
tuple(args), kwargs, func, engine_kwargs
`
``
1078
`+
)
`
``
1079
`+
result = numba_agg_func(
`
``
1080
`+
sorted_data, sorted_index, starts, ends, len(group_keys), len(data.columns),
`
``
1081
`+
)
`
``
1082
`+
if cache_key not in NUMBA_FUNC_CACHE:
`
``
1083
`+
NUMBA_FUNC_CACHE[cache_key] = numba_agg_func
`
``
1084
+
``
1085
`+
if self.grouper.nkeys > 1:
`
``
1086
`+
index = MultiIndex.from_tuples(group_keys, names=self.grouper.names)
`
``
1087
`+
else:
`
``
1088
`+
index = Index(group_keys, name=self.grouper.names[0])
`
``
1089
`+
return result, index
`
``
1090
+
``
1091
`+
def _python_agg_general(self, func, *args, **kwargs):
`
1059
1092
`func = self._is_builtin_func(func)
`
1060
``
`-
if engine != "numba":
`
1061
``
`-
f = lambda x: func(x, *args, **kwargs)
`
``
1093
`+
f = lambda x: func(x, *args, **kwargs)
`
1062
1094
``
1063
1095
`# iterate through "columns" ex exclusions to populate output dict
`
1064
1096
`output: Dict[base.OutputKey, np.ndarray] = {}
`
`@@ -1069,21 +1101,11 @@ def _python_agg_general(
`
1069
1101
`# agg_series below assumes ngroups > 0
`
1070
1102
`continue
`
1071
1103
``
1072
``
`-
if maybe_use_numba(engine):
`
1073
``
`-
result, counts = self.grouper.agg_series(
`
1074
``
`-
obj,
`
1075
``
`-
func,
`
1076
``
`-
*args,
`
1077
``
`-
engine=engine,
`
1078
``
`-
engine_kwargs=engine_kwargs,
`
1079
``
`-
**kwargs,
`
1080
``
`-
)
`
1081
``
`-
else:
`
1082
``
`-
try:
`
1083
``
`-
if this function is invalid for this dtype, we will ignore it.
`
1084
``
`-
result, counts = self.grouper.agg_series(obj, f)
`
1085
``
`-
except TypeError:
`
1086
``
`-
continue
`
``
1104
`+
try:
`
``
1105
`+
if this function is invalid for this dtype, we will ignore it.
`
``
1106
`+
result, counts = self.grouper.agg_series(obj, f)
`
``
1107
`+
except TypeError:
`
``
1108
`+
continue
`
1087
1109
``
1088
1110
`assert result is not None
`
1089
1111
`key = base.OutputKey(label=name, position=idx)
`