Support ExtensionArray types in where · Issue #24077 · pandas-dev/pandas (original) (raw)
This is blocking DatetimeArray. It's also a slight regression from 0.24, since things like .where
on a DataFrame with period objects would work (via object dtype).
I think the easiest place for this is by defining ExtensionBlock.where
, and restricting it to cases where the dtype of self
and other
match (so that the result dtype is the same).
We can do this pretty easily for our EAs by performing the .where
on _ndarray_values
. But _ndarray_values
isn't part of the EA interface yet. I'm not sure if we'll have time to properly design and implement a generic .where
for any ExtensionArray since there are a couple subtlies.
Here's a start
diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 1b67c2053..ce5c01359 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1955,6 +1955,37 @@ class ExtensionBlock(NonConsolidatableMixIn, Block): placement=self.mgr_locs, ndim=self.ndim)]
- def where(self, other, cond, align=True, errors='raise',
try_cast=False, axis=0, transpose=False):
import pandas.core.computation.expressions as expressions
values = self.values._ndarray_values
if cond.ndim == 2:
assert cond.shape[-1] == 1
cond = cond._data.blocks[0].values.ravel()
if hasattr(other, 'ndim') and other.ndim == 2:
# TODO: this hasn't been normalized
assert other.shape[-1] == 1
other = other._data.blocks[0].values
elif (lib.is_scalar(other) and isna(other)) or other is None:
# TODO: we need the storage NA value (e.g. iNaT)
other = self.values.dtype.na_value
# other = tslibs.iNaT
# TODO: cond.ravel().all() short-circut
if cond.ndim > 1:
cond = cond.ravel()
result = expressions.where(cond, values, other)
if not isinstance(result, self._holder):
# Need a kind of _from_ndarray_values()
# this is different from _from_sequence
result = self._holder.(result, dtype=self.dtype)
return self.make_block_same_class(result)
@property def _ftype(self): return getattr(self.values, '_pandas_ftype', Block._ftype)
There are a couple TODOs there, plus tests, and I'm sure plenty of edge cases.
In [7]: df = pd.DataFrame({"A": pd.period_range("2000", periods=12)})
In [8]: df.where(df.A.dt.day == 2) Out[8]: A 0 NaT 1 2000-01-02 2 NaT 3 NaT 4 NaT 5 NaT 6 NaT 7 NaT 8 NaT 9 NaT 10 NaT 11 NaT