PERF: GroupBy.quantile by jbrockmendel · Pull Request #51722 · pandas-dev/pandas (original) (raw)

The existing implementation does a lexsort over (values, ids) for each column which gets very expensive. By using grouper._get_splitter()._sorted_data, we only sort by ids once, then we cheaply iterate over groups and do group-by-group argsorts. This is roughly equivalent to

obj = self._get_data_to_aggregate(...)

def quantile_func(df):
    returnn df.quantile(qs, interpolation=interpolation)

result = self._python_apply_general(df, obj)
[wrapping stuff specific to quantile]

I tried an implementation that used _python_apply_general and ripped out the cython group_quantile entirely and found it had some rough edges with a) axis=1, b) dtypes where GroupBy.quantile behaves different from DataFrame.quantile (xref #51424), and it performed poorly as ngroups becomes big. With a large number of rows though, that performs better than this. I speculate that np.percentile is doing something more efficiently than our cython group_quantile, but haven't figured it out yet.

Some timings!

def time_quantile(nrows, ncols, ngroups):
    qs = [0.5, 0.75]
    np.random.seed(342464)
    arr = np.random.randn(nrows, ncols)
    df = pd.DataFrame(arr)
    df["A"] = np.random.randint(ngroups, size=nrows)

    gb = df.groupby("A")

    res = %timeit -o gb.quantile(qs)
    return res.average

timings = {}

for nrows in [10**5, 10**6, 10**7]:
    for ncols in [1, 5, 10]:
        for ngroups in [10, 100, 1000, 10000]:
            key = (nrows, ncols, ngroups)
            timings[key] = time_quantile(nrows, ncols, ngroups)

Did this for each of main, this PR, and the pure-python implementation described above. The results:

                         pure_py         PR       main
nrows    ncols ngroups                                
100000   1     10       0.006144   0.008039   0.020776
               100      0.028917   0.006472   0.023184
               1000     0.275030   0.006711   0.022179
               10000    2.558759   0.023987   0.024582
         5     10       0.015371   0.038649   0.110015
               100      0.038711   0.027942   0.111417
               1000     0.266454   0.029911   0.119993
               10000    2.418917   0.101383   0.121122
         10    10       0.026670   0.072206   0.213674
               100      0.047769   0.060826   0.224247
               1000     0.271641   0.061457   0.234644
               10000    2.441758   0.200785   0.255878
1000000  1     10       0.044218   0.116293   0.331845
               100      0.069265   0.098826   0.358672
               1000     0.287011   0.084252   0.366118
               10000    2.360640   0.094245   0.410894
         5     10       0.149162   0.484149   1.610431
               100      0.169841   0.387538   1.724159
               1000     0.408468   0.291384   1.726341
               10000    2.581606   0.309029   1.741293
         10    10       0.261970   0.879218   3.167597
               100      0.259414   0.687485   4.399621
               1000     0.499755   0.640692   3.856934
               10000    2.778011   0.630412   4.160573
10000000 1     10       0.493514   1.398108  10.061577
               100      0.780788   1.416905  11.622942
               1000     1.105673   1.319962  11.244395
               10000    3.552967   1.357090  11.206071
         5     10       1.895618   6.074704  50.622437
               100      1.956891   4.885986  56.239948
               1000     2.101247   3.985950  58.742726
               10000    4.756156   3.486512  59.193984
         10    10       3.561774  12.117515        NaN
               100      3.392814   9.700422        NaN
               1000     3.302072   7.369062        NaN
               10000    5.719315   6.224435        NaN

on main I had to interrupt (with ctl-z, not ctl-c!) the largest cases.

This out-performs main in all cases. The pure-python version out-performs this in small-ngroups cases and large-nrows cases, but suffers pretty dramatically in the opposite cases.

Potential caveats:

  1. I did this on my laptop, tried to keep other heavy processes to a minimum, but still had browser tabs etc.
  2. Only did float64s
  3. There are no NaNs in this data. I'm pretty sure the pure-python version goes through a less performant path in cases with nans.
  4. Profiling one of the cases where this does poorly vs pure_py shows most of the time being in blk_func but libgroupby.group_quantile doesn't show up and certainly nothing inside it. My best guess is most of the time is in the argsort inside group_quantile.

Potential downsides vs main:

  1. Accessing self.grouper._get_splitter()._sorted_data creates a copy. The np.lexsort in the status quo also allocates an array, but since that is done block-by-block its plausible we get less total memory allocation.
  2. Reaching into grouper internals (ish) is not super-pretty.

We could plausibly keep multiple implementations and dispatch based on sizes, but im not sure groupby.quantile is important enough to really merit that. So which version to use really comes down to what sized cases we think are the most common.