[pandas.tools.plotting._subplots] squeeze option is not respected if ax
is provided · Issue #16253 · pandas-dev/pandas (original) (raw)
pandas.tools.plotting._subplots
returns a flatten array for the axes if ax
parameter is provided regardless the value of the parameter squeeze
. If ax
is a 2d list, ax
is not flattened but then the test len(ax) == naxes
fails. The test should be ax.size == naxes
.
This issue forbid the use of the ax
parameter in plotting function like scatter_matrix
which expect ax
to be a NxN array.
import pandas as pd from pandas.tools.plotting import scatter_matrix import matplotlib.pyplot as plt
df = pd.DataFrame(dict(a=[0, 1, 2, 3, 4, 5], b=5, 6, 7, 8, 9])
f, axes = plt.subplots(2, 2) scatter_matrix(df, ax=axes)
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-10-8860714df14f> in <module>()
----> 1 scatter_matrix(df, ax=axes)
/Users/bolay/.envs/scantrust_tools/lib/python2.7/site-packages/pandas/tools/plotting.pyc in scatter_matrix(frame, alpha, figsize, ax, grid, diagonal, marker, density_kwds, hist_kwds, range_padding, **kwds)
371 for i, a in zip(lrange(n), df.columns):
372 for j, b in zip(lrange(n), df.columns):
--> 373 ax = axes[i, j]
374
375 if i == j:
IndexError: too many indices for array
and if axes is a 2d list
scatter_matrix(df, ax=axes.tolist())
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-11-aeac24929f47> in <module>()
----> 1 scatter_matrix(df, ax=axes.tolist())
/Users/bolay/.envs/scantrust_tools/lib/python2.7/site-packages/pandas/tools/plotting.pyc in scatter_matrix(frame, alpha, figsize, ax, grid, diagonal, marker, density_kwds, hist_kwds, range_padding, **kwds)
347 naxes = n * n
348 fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax,
--> 349 squeeze=False)
350
351 # no gaps between subplots
/Users/bolay/.envs/scantrust_tools/lib/python2.7/site-packages/pandas/tools/plotting.pyc in _subplots(naxes, sharex, sharey, squeeze, subplot_kw, ax, layout, layout_type, **fig_kw)
3389 else:
3390 raise ValueError("The number of passed axes must be {0}, the "
-> 3391 "same as the output plot".format(naxes))
3392
3393 fig = ax.get_figure()
ValueError: The number of passed axes must be 4, the same as the output plot