BUG: numba function being cached with arguments · Issue #41647 · pandas-dev/pandas (original) (raw)
- I have checked that this issue has not already been reported.
- I have confirmed this bug exists on the latest version of pandas.
- (optional) I have confirmed this bug exists on the master branch of pandas.
Code Sample
import pandas as pd
def sum_last(values, index, n): return values[-n:].sum()
df = pd.DataFrame({'id': [0, 0, 1, 1], 'x': [1, 1, 1, 1]}) grouped_x = df.groupby('id')['x'] grouped_x.transform(sum_last, 1, engine='numba')
0 1.0
1 1.0
2 1.0
3 1.0
Name: x, dtype: float64
grouped_x.transform(sum_last, 2, engine='numba')
0 1.0
1 1.0
2 1.0
3 1.0
Name: x, dtype: float64
Problem description
When using groupby.transform
with engine='numba'
calling the same function with different arguments gives the same results as the first time it was called. I believe this is because the function is being cached with its arguments.
numba_transform_func = numba_.generate_numba_transform_func( |
---|
tuple(args), kwargs, func, engine_kwargs |
) |
cache_key = (func, "groupby_transform") |
---|
if cache_key not in NUMBA_FUNC_CACHE: |
NUMBA_FUNC_CACHE[cache_key] = numba_transform_func |
Expected Output
grouped_x.transform(sum_last, 2, engine='numba')
0 2.0
1 2.0
2 2.0
3 2.0
Name: x, dtype: float64
Output of pd.show_versions()
INSTALLED VERSIONS
commit : 2cb9652
python : 3.8.8.final.0
python-bits : 64
OS : Linux
OS-release : 5.11.0-7614-generic
Version : #15161862669320.10~ecb25cd-Ubuntu SMP Thu Apr 22 16:00:45 UTC
machine : x86_64
processor : x86_64
byteorder : little
LC_ALL : None
LANG : en_US.UTF-8
LOCALE : en_US.UTF-8
pandas : 1.2.4
numpy : 1.20.2
pytz : 2021.1
dateutil : 2.8.1
pip : 21.1
setuptools : 49.6.0.post20210108
Cython : None
pytest : None
hypothesis : None
sphinx : 4.0.0
blosc : None
feather : None
xlsxwriter : None
lxml.etree : None
html5lib : None
pymysql : None
psycopg2 : 2.8.6 (dt dec pq3 ext lo64)
jinja2 : 2.11.3
IPython : 7.23.1
pandas_datareader: None
bs4 : 4.9.3
bottleneck : None
fsspec : 2021.04.0
fastparquet : None
gcsfs : None
matplotlib : 3.4.1
numexpr : None
odfpy : None
openpyxl : None
pandas_gbq : None
pyarrow : 3.0.0
pyxlsb : None
s3fs : 2021.04.0
scipy : 1.6.3
sqlalchemy : None
tables : None
tabulate : 0.8.9
xarray : None
xlrd : None
xlwt : None
numba : 0.53.1