Performance issue while groupby.shift if fill_value specified · Issue #26615 · pandas-dev/pandas (original) (raw)
Test code:
SIZE_MULT = 5
data = np.random.randint(0, 255, size=10**SIZE_MULT, dtype='uint8')
index = pd.MultiIndex.from_product(
[list(range(10**(SIZE_MULT-1))), list('ABCDEFGHIJ')],
names = ['d', 'l'])
test = pd.DataFrame(data, index, columns = ['data'])
test.head()
test['data'].dtype
data
d l
0 A 137
B 156
C 48
D 186
E 170
dtype('uint8')
And suppose we want group by 0-level of index and shift each group (shift step = 2, for example).
%%time shifted = test.groupby(axis=0, level=[0]).shift(2) print(shifted['data'].dtype)
float64
CPU times: user 9.43 ms, sys: 56 µs, total: 9.49 ms
Wall time: 8.29 ms
Now to the problem: if we want to preserve our dtype 'uint8', we have to get rid of None
s, and set our fill value with 0, for example. But we will get HUGE time of code execution now:
%%time shifted = test.groupby(axis=0, level=[0]).shift(2, fill_value = 0) shifted.head() print(shifted['data'].dtype)
uint8
CPU times: user 5.9 s, sys: 38.4 ms, total: 5.94 s
Wall time: 5.89 s
Expected time of execution should be comparable with following:
If we take 1st shifted dataframe without fill_value
, and add few code lines to achieve same result:
%%time shifted = test.groupby(axis=0, level=[0]).shift(2) shifted.fillna(0, inplace=True) shifted = shifted.astype(np.uint8) print(shifted['data'].dtype)
Output:
uint8
CPU times: user 9.64 ms, sys: 3.68 ms, total: 13.3 ms
Wall time: 11.3 ms
It will add only few ms, not 5 seconds.
INSTALLED VERSIONS ------------------ commit: None python: 3.7.3.final.0 python-bits: 64 OS: Linux OS-release: 5.0.0-16-generic machine: x86_64 processor: x86_64 byteorder: little LC_ALL: None LANG: en_US.UTF-8 LOCALE: en_US.UTF-8
pandas: 0.24.2
pytest: 4.5.0
pip: 19.1.1
setuptools: 41.0.1
Cython: 0.29.7
numpy: 1.16.3
scipy: 1.2.1
pyarrow: None
xarray: 0.12.1
IPython: 7.2.0
sphinx: None
patsy: None
dateutil: 2.8.0
pytz: 2019.1
blosc: None
bottleneck: None
tables: 3.5.1
numexpr: 2.6.9
feather: None
matplotlib: 3.0.3
openpyxl: None
xlrd: None
xlwt: None
xlsxwriter: None
lxml.etree: None
bs4: None
html5lib: None
sqlalchemy: None
pymysql: None
psycopg2: None
jinja2: 2.10.1
s3fs: None
fastparquet: None
pandas_gbq: None
pandas_datareader: None
gcsfs: None