[pandas.tools.plotting._subplots] squeeze option is not respected if ax is provided · Issue #16253 · pandas-dev/pandas (original) (raw)

pandas.tools.plotting._subplots returns a flatten array for the axes if ax parameter is provided regardless the value of the parameter squeeze. If ax is a 2d list, ax is not flattened but then the test len(ax) == naxes fails. The test should be ax.size == naxes.

This issue forbid the use of the ax parameter in plotting function like scatter_matrix which expect ax to be a NxN array.

import pandas as pd from pandas.tools.plotting import scatter_matrix import matplotlib.pyplot as plt

df = pd.DataFrame(dict(a=[0, 1, 2, 3, 4, 5], b=5, 6, 7, 8, 9])

f, axes = plt.subplots(2, 2) scatter_matrix(df, ax=axes)

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-10-8860714df14f> in <module>()
----> 1 scatter_matrix(df, ax=axes)

/Users/bolay/.envs/scantrust_tools/lib/python2.7/site-packages/pandas/tools/plotting.pyc in scatter_matrix(frame, alpha, figsize, ax, grid, diagonal, marker, density_kwds, hist_kwds, range_padding, **kwds)
    371     for i, a in zip(lrange(n), df.columns):
    372         for j, b in zip(lrange(n), df.columns):
--> 373             ax = axes[i, j]
    374 
    375             if i == j:

IndexError: too many indices for array

and if axes is a 2d list

scatter_matrix(df, ax=axes.tolist())

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-11-aeac24929f47> in <module>()
----> 1 scatter_matrix(df, ax=axes.tolist())

/Users/bolay/.envs/scantrust_tools/lib/python2.7/site-packages/pandas/tools/plotting.pyc in scatter_matrix(frame, alpha, figsize, ax, grid, diagonal, marker, density_kwds, hist_kwds, range_padding, **kwds)
    347     naxes = n * n
    348     fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax,
--> 349                           squeeze=False)
    350 
    351     # no gaps between subplots

/Users/bolay/.envs/scantrust_tools/lib/python2.7/site-packages/pandas/tools/plotting.pyc in _subplots(naxes, sharex, sharey, squeeze, subplot_kw, ax, layout, layout_type, **fig_kw)
   3389             else:
   3390                 raise ValueError("The number of passed axes must be {0}, the "
-> 3391                                  "same as the output plot".format(naxes))
   3392 
   3393         fig = ax.get_figure()

ValueError: The number of passed axes must be 4, the same as the output plot