BUG: numba function being cached with arguments · Issue #41647 · pandas-dev/pandas (original) (raw)


Code Sample

import pandas as pd

def sum_last(values, index, n): return values[-n:].sum()

df = pd.DataFrame({'id': [0, 0, 1, 1], 'x': [1, 1, 1, 1]}) grouped_x = df.groupby('id')['x'] grouped_x.transform(sum_last, 1, engine='numba')

0 1.0

1 1.0

2 1.0

3 1.0

Name: x, dtype: float64

grouped_x.transform(sum_last, 2, engine='numba')

0 1.0

1 1.0

2 1.0

3 1.0

Name: x, dtype: float64

Problem description

When using groupby.transform with engine='numba' calling the same function with different arguments gives the same results as the first time it was called. I believe this is because the function is being cached with its arguments.

numba_transform_func = numba_.generate_numba_transform_func(
tuple(args), kwargs, func, engine_kwargs
)
cache_key = (func, "groupby_transform")
if cache_key not in NUMBA_FUNC_CACHE:
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func

Expected Output

grouped_x.transform(sum_last, 2, engine='numba')

0 2.0

1 2.0

2 2.0

3 2.0

Name: x, dtype: float64

Output of pd.show_versions()

INSTALLED VERSIONS

commit : 2cb9652
python : 3.8.8.final.0
python-bits : 64
OS : Linux
OS-release : 5.11.0-7614-generic
Version : #15161862669320.10~ecb25cd-Ubuntu SMP Thu Apr 22 16:00:45 UTC
machine : x86_64
processor : x86_64
byteorder : little
LC_ALL : None
LANG : en_US.UTF-8
LOCALE : en_US.UTF-8

pandas : 1.2.4
numpy : 1.20.2
pytz : 2021.1
dateutil : 2.8.1
pip : 21.1
setuptools : 49.6.0.post20210108
Cython : None
pytest : None
hypothesis : None
sphinx : 4.0.0
blosc : None
feather : None
xlsxwriter : None
lxml.etree : None
html5lib : None
pymysql : None
psycopg2 : 2.8.6 (dt dec pq3 ext lo64)
jinja2 : 2.11.3
IPython : 7.23.1
pandas_datareader: None
bs4 : 4.9.3
bottleneck : None
fsspec : 2021.04.0
fastparquet : None
gcsfs : None
matplotlib : 3.4.1
numexpr : None
odfpy : None
openpyxl : None
pandas_gbq : None
pyarrow : 3.0.0
pyxlsb : None
s3fs : 2021.04.0
scipy : 1.6.3
sqlalchemy : None
tables : None
tabulate : 0.8.9
xarray : None
xlrd : None
xlwt : None
numba : 0.53.1