ENH: Add numba engine to groupby.transform (#32854) · pandas-dev/pandas@b8b6471 (original) (raw)
`@@ -75,6 +75,13 @@
`
75
75
`import pandas.core.indexes.base as ibase
`
76
76
`from pandas.core.internals import BlockManager, make_block
`
77
77
`from pandas.core.series import Series
`
``
78
`+
from pandas.core.util.numba_ import (
`
``
79
`+
check_kwargs_and_nopython,
`
``
80
`+
get_jit_arguments,
`
``
81
`+
jit_user_function,
`
``
82
`+
split_for_numba,
`
``
83
`+
validate_udf,
`
``
84
`+
)
`
78
85
``
79
86
`from pandas.plotting import boxplot_frame_groupby
`
80
87
``
`@@ -154,6 +161,8 @@ def pinner(cls):
`
154
161
`class SeriesGroupBy(GroupBy[Series]):
`
155
162
`_apply_whitelist = base.series_apply_whitelist
`
156
163
``
``
164
`+
_numba_func_cache: Dict[Callable, Callable] = {}
`
``
165
+
157
166
`def _iterate_slices(self) -> Iterable[Series]:
`
158
167
`yield self._selected_obj
`
159
168
``
`@@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs):
`
463
472
``
464
473
`@Substitution(klass="Series", selected="A.")
`
465
474
`@Appender(_transform_template)
`
466
``
`-
def transform(self, func, *args, **kwargs):
`
``
475
`+
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
`
467
476
`func = self._get_cython_func(func) or func
`
468
477
``
469
478
`if not isinstance(func, str):
`
470
``
`-
return self._transform_general(func, *args, **kwargs)
`
``
479
`+
return self._transform_general(
`
``
480
`+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
`
``
481
`+
)
`
471
482
``
472
483
`elif func not in base.transform_kernel_whitelist:
`
473
484
`msg = f"'{func}' is not a valid function name for transform(name)"
`
`@@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs):
`
482
493
`result = getattr(self, func)(*args, **kwargs)
`
483
494
`return self._transform_fast(result, func)
`
484
495
``
485
``
`-
def _transform_general(self, func, *args, **kwargs):
`
``
496
`+
def _transform_general(
`
``
497
`+
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
`
``
498
`+
):
`
486
499
`"""
`
487
500
`` Transform with a non-str func
.
``
488
501
` """
`
``
502
+
``
503
`+
if engine == "numba":
`
``
504
`+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
`
``
505
`+
check_kwargs_and_nopython(kwargs, nopython)
`
``
506
`+
validate_udf(func)
`
``
507
`+
numba_func = self._numba_func_cache.get(
`
``
508
`+
func, jit_user_function(func, nopython, nogil, parallel)
`
``
509
`+
)
`
``
510
+
489
511
`klass = type(self._selected_obj)
`
490
512
``
491
513
`results = []
`
492
514
`for name, group in self:
`
493
515
`object.setattr(group, "name", name)
`
494
``
`-
res = func(group, *args, **kwargs)
`
``
516
`+
if engine == "numba":
`
``
517
`+
values, index = split_for_numba(group)
`
``
518
`+
res = numba_func(values, index, *args)
`
``
519
`+
if func not in self._numba_func_cache:
`
``
520
`+
self._numba_func_cache[func] = numba_func
`
``
521
`+
else:
`
``
522
`+
res = func(group, *args, **kwargs)
`
495
523
``
496
524
`if isinstance(res, (ABCDataFrame, ABCSeries)):
`
497
525
`res = res._values
`
`@@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
`
819
847
``
820
848
`_apply_whitelist = base.dataframe_apply_whitelist
`
821
849
``
``
850
`+
_numba_func_cache: Dict[Callable, Callable] = {}
`
``
851
+
822
852
`_agg_see_also_doc = dedent(
`
823
853
`"""
`
824
854
` See Also
`
`@@ -1355,19 +1385,35 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
`
1355
1385
`# Handle cases like BinGrouper
`
1356
1386
`return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
`
1357
1387
``
1358
``
`-
def _transform_general(self, func, *args, **kwargs):
`
``
1388
`+
def _transform_general(
`
``
1389
`+
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
`
``
1390
`+
):
`
1359
1391
`from pandas.core.reshape.concat import concat
`
1360
1392
``
1361
1393
`applied = []
`
1362
1394
`obj = self._obj_with_exclusions
`
1363
1395
`gen = self.grouper.get_iterator(obj, axis=self.axis)
`
1364
``
`-
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
`
``
1396
`+
if engine == "numba":
`
``
1397
`+
nopython, nogil, parallel = get_jit_arguments(engine_kwargs)
`
``
1398
`+
check_kwargs_and_nopython(kwargs, nopython)
`
``
1399
`+
validate_udf(func)
`
``
1400
`+
numba_func = self._numba_func_cache.get(
`
``
1401
`+
func, jit_user_function(func, nopython, nogil, parallel)
`
``
1402
`+
)
`
``
1403
`+
else:
`
``
1404
`+
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
`
1365
1405
``
1366
``
`-
path = None
`
1367
1406
`for name, group in gen:
`
1368
1407
`object.setattr(group, "name", name)
`
1369
1408
``
1370
``
`-
if path is None:
`
``
1409
`+
if engine == "numba":
`
``
1410
`+
values, index = split_for_numba(group)
`
``
1411
`+
res = numba_func(values, index, *args)
`
``
1412
`+
if func not in self._numba_func_cache:
`
``
1413
`+
self._numba_func_cache[func] = numba_func
`
``
1414
`+
Return the result as a DataFrame for concatenation later
`
``
1415
`+
res = DataFrame(res, index=group.index, columns=group.columns)
`
``
1416
`+
else:
`
1371
1417
`# Try slow path and fast path.
`
1372
1418
`try:
`
1373
1419
`path, res = self._choose_path(fast_path, slow_path, group)
`
`@@ -1376,8 +1422,6 @@ def _transform_general(self, func, *args, **kwargs):
`
1376
1422
`except ValueError as err:
`
1377
1423
`msg = "transform must return a scalar value for each group"
`
1378
1424
`raise ValueError(msg) from err
`
1379
``
`-
else:
`
1380
``
`-
res = path(group)
`
1381
1425
``
1382
1426
`if isinstance(res, Series):
`
1383
1427
``
`@@ -1411,13 +1455,15 @@ def _transform_general(self, func, *args, **kwargs):
`
1411
1455
``
1412
1456
`@Substitution(klass="DataFrame", selected="")
`
1413
1457
`@Appender(_transform_template)
`
1414
``
`-
def transform(self, func, *args, **kwargs):
`
``
1458
`+
def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
`
1415
1459
``
1416
1460
`# optimized transforms
`
1417
1461
`func = self._get_cython_func(func) or func
`
1418
1462
``
1419
1463
`if not isinstance(func, str):
`
1420
``
`-
return self._transform_general(func, *args, **kwargs)
`
``
1464
`+
return self._transform_general(
`
``
1465
`+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
`
``
1466
`+
)
`
1421
1467
``
1422
1468
`elif func not in base.transform_kernel_whitelist:
`
1423
1469
`msg = f"'{func}' is not a valid function name for transform(name)"
`
`@@ -1439,7 +1485,9 @@ def transform(self, func, *args, **kwargs):
`
1439
1485
` ):
`
1440
1486
`return self._transform_fast(result, func)
`
1441
1487
``
1442
``
`-
return self._transform_general(func, *args, **kwargs)
`
``
1488
`+
return self._transform_general(
`
``
1489
`+
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs
`
``
1490
`+
)
`
1443
1491
``
1444
1492
`def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
`
1445
1493
`"""
`