ENH: Add numba engine to groupby.transform (#32854) · pandas-dev/pandas@b8b6471 (original) (raw)

`@@ -75,6 +75,13 @@

`

75

75

`import pandas.core.indexes.base as ibase

`

76

76

`from pandas.core.internals import BlockManager, make_block

`

77

77

`from pandas.core.series import Series

`

``

78

`+

from pandas.core.util.numba_ import (

`

``

79

`+

check_kwargs_and_nopython,

`

``

80

`+

get_jit_arguments,

`

``

81

`+

jit_user_function,

`

``

82

`+

split_for_numba,

`

``

83

`+

validate_udf,

`

``

84

`+

)

`

78

85

``

79

86

`from pandas.plotting import boxplot_frame_groupby

`

80

87

``

`@@ -154,6 +161,8 @@ def pinner(cls):

`

154

161

`class SeriesGroupBy(GroupBy[Series]):

`

155

162

`_apply_whitelist = base.series_apply_whitelist

`

156

163

``

``

164

`+

_numba_func_cache: Dict[Callable, Callable] = {}

`

``

165

+

157

166

`def _iterate_slices(self) -> Iterable[Series]:

`

158

167

`yield self._selected_obj

`

159

168

``

`@@ -463,11 +472,13 @@ def _aggregate_named(self, func, *args, **kwargs):

`

463

472

``

464

473

`@Substitution(klass="Series", selected="A.")

`

465

474

`@Appender(_transform_template)

`

466

``

`-

def transform(self, func, *args, **kwargs):

`

``

475

`+

def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):

`

467

476

`func = self._get_cython_func(func) or func

`

468

477

``

469

478

`if not isinstance(func, str):

`

470

``

`-

return self._transform_general(func, *args, **kwargs)

`

``

479

`+

return self._transform_general(

`

``

480

`+

func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs

`

``

481

`+

)

`

471

482

``

472

483

`elif func not in base.transform_kernel_whitelist:

`

473

484

`msg = f"'{func}' is not a valid function name for transform(name)"

`

`@@ -482,16 +493,33 @@ def transform(self, func, *args, **kwargs):

`

482

493

`result = getattr(self, func)(*args, **kwargs)

`

483

494

`return self._transform_fast(result, func)

`

484

495

``

485

``

`-

def _transform_general(self, func, *args, **kwargs):

`

``

496

`+

def _transform_general(

`

``

497

`+

self, func, *args, engine="cython", engine_kwargs=None, **kwargs

`

``

498

`+

):

`

486

499

`"""

`

487

500

`` Transform with a non-str func.

``

488

501

` """

`

``

502

+

``

503

`+

if engine == "numba":

`

``

504

`+

nopython, nogil, parallel = get_jit_arguments(engine_kwargs)

`

``

505

`+

check_kwargs_and_nopython(kwargs, nopython)

`

``

506

`+

validate_udf(func)

`

``

507

`+

numba_func = self._numba_func_cache.get(

`

``

508

`+

func, jit_user_function(func, nopython, nogil, parallel)

`

``

509

`+

)

`

``

510

+

489

511

`klass = type(self._selected_obj)

`

490

512

``

491

513

`results = []

`

492

514

`for name, group in self:

`

493

515

`object.setattr(group, "name", name)

`

494

``

`-

res = func(group, *args, **kwargs)

`

``

516

`+

if engine == "numba":

`

``

517

`+

values, index = split_for_numba(group)

`

``

518

`+

res = numba_func(values, index, *args)

`

``

519

`+

if func not in self._numba_func_cache:

`

``

520

`+

self._numba_func_cache[func] = numba_func

`

``

521

`+

else:

`

``

522

`+

res = func(group, *args, **kwargs)

`

495

523

``

496

524

`if isinstance(res, (ABCDataFrame, ABCSeries)):

`

497

525

`res = res._values

`

`@@ -819,6 +847,8 @@ class DataFrameGroupBy(GroupBy[DataFrame]):

`

819

847

``

820

848

`_apply_whitelist = base.dataframe_apply_whitelist

`

821

849

``

``

850

`+

_numba_func_cache: Dict[Callable, Callable] = {}

`

``

851

+

822

852

`_agg_see_also_doc = dedent(

`

823

853

`"""

`

824

854

` See Also

`

`@@ -1355,19 +1385,35 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):

`

1355

1385

`# Handle cases like BinGrouper

`

1356

1386

`return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)

`

1357

1387

``

1358

``

`-

def _transform_general(self, func, *args, **kwargs):

`

``

1388

`+

def _transform_general(

`

``

1389

`+

self, func, *args, engine="cython", engine_kwargs=None, **kwargs

`

``

1390

`+

):

`

1359

1391

`from pandas.core.reshape.concat import concat

`

1360

1392

``

1361

1393

`applied = []

`

1362

1394

`obj = self._obj_with_exclusions

`

1363

1395

`gen = self.grouper.get_iterator(obj, axis=self.axis)

`

1364

``

`-

fast_path, slow_path = self._define_paths(func, *args, **kwargs)

`

``

1396

`+

if engine == "numba":

`

``

1397

`+

nopython, nogil, parallel = get_jit_arguments(engine_kwargs)

`

``

1398

`+

check_kwargs_and_nopython(kwargs, nopython)

`

``

1399

`+

validate_udf(func)

`

``

1400

`+

numba_func = self._numba_func_cache.get(

`

``

1401

`+

func, jit_user_function(func, nopython, nogil, parallel)

`

``

1402

`+

)

`

``

1403

`+

else:

`

``

1404

`+

fast_path, slow_path = self._define_paths(func, *args, **kwargs)

`

1365

1405

``

1366

``

`-

path = None

`

1367

1406

`for name, group in gen:

`

1368

1407

`object.setattr(group, "name", name)

`

1369

1408

``

1370

``

`-

if path is None:

`

``

1409

`+

if engine == "numba":

`

``

1410

`+

values, index = split_for_numba(group)

`

``

1411

`+

res = numba_func(values, index, *args)

`

``

1412

`+

if func not in self._numba_func_cache:

`

``

1413

`+

self._numba_func_cache[func] = numba_func

`

``

1414

`+

Return the result as a DataFrame for concatenation later

`

``

1415

`+

res = DataFrame(res, index=group.index, columns=group.columns)

`

``

1416

`+

else:

`

1371

1417

`# Try slow path and fast path.

`

1372

1418

`try:

`

1373

1419

`path, res = self._choose_path(fast_path, slow_path, group)

`

`@@ -1376,8 +1422,6 @@ def _transform_general(self, func, *args, **kwargs):

`

1376

1422

`except ValueError as err:

`

1377

1423

`msg = "transform must return a scalar value for each group"

`

1378

1424

`raise ValueError(msg) from err

`

1379

``

`-

else:

`

1380

``

`-

res = path(group)

`

1381

1425

``

1382

1426

`if isinstance(res, Series):

`

1383

1427

``

`@@ -1411,13 +1455,15 @@ def _transform_general(self, func, *args, **kwargs):

`

1411

1455

``

1412

1456

`@Substitution(klass="DataFrame", selected="")

`

1413

1457

`@Appender(_transform_template)

`

1414

``

`-

def transform(self, func, *args, **kwargs):

`

``

1458

`+

def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):

`

1415

1459

``

1416

1460

`# optimized transforms

`

1417

1461

`func = self._get_cython_func(func) or func

`

1418

1462

``

1419

1463

`if not isinstance(func, str):

`

1420

``

`-

return self._transform_general(func, *args, **kwargs)

`

``

1464

`+

return self._transform_general(

`

``

1465

`+

func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs

`

``

1466

`+

)

`

1421

1467

``

1422

1468

`elif func not in base.transform_kernel_whitelist:

`

1423

1469

`msg = f"'{func}' is not a valid function name for transform(name)"

`

`@@ -1439,7 +1485,9 @@ def transform(self, func, *args, **kwargs):

`

1439

1485

` ):

`

1440

1486

`return self._transform_fast(result, func)

`

1441

1487

``

1442

``

`-

return self._transform_general(func, *args, **kwargs)

`

``

1488

`+

return self._transform_general(

`

``

1489

`+

func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs

`

``

1490

`+

)

`

1443

1491

``

1444

1492

`def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:

`

1445

1493

`"""

`