REF: melt by mroeschke · Pull Request #55948 · pandas-dev/pandas (original) (raw)
@@ -13,11 +13,7 @@
import pandas.core.algorithms as algos
from pandas.core.arrays import Categorical
import pandas.core.common as com
from pandas.core.indexes.api import (
Index,
MultiIndex,
)
from pandas.core.indexes.api import MultiIndex
from pandas.core.reshape.concat import concat
from pandas.core.reshape.util import tile_compat
from pandas.core.shared_docs import _shared_docs
@@ -31,6 +27,20 @@
from pandas import DataFrame
def ensure_list_vars(arg_vars, variable: str, columns) -> list:
if arg_vars is not None:
if not is_list_like(arg_vars):
return [arg_vars]
elif isinstance(columns, MultiIndex) and not isinstance(arg_vars, list):
raise ValueError(
f"{variable} must be a list of tuples when columns are a MultiIndex"
)
else:
return list(arg_vars)
else:
return []
@Appender(_shared_docs["melt"] % {"caller": "pd.melt(df, ", "other": "DataFrame.melt"})
def melt(
frame: DataFrame,
@@ -41,61 +51,35 @@ def melt(
col_level=None,
ignore_index: bool = True,
) -> DataFrame:
# If multiindex, gather names of columns on all level for checking presence
# of `id_vars` and `value_vars`
if isinstance(frame.columns, MultiIndex):
cols = [x for c in frame.columns for x in c]
else:
cols = list(frame.columns)
if value_name in frame.columns:
raise ValueError(
f"value_name ({value_name}) cannot match an element in "
"the DataFrame columns."
)
id_vars = ensure_list_vars(id_vars, "id_vars", frame.columns)
value_vars_was_not_none = value_vars is not None
value_vars = ensure_list_vars(value_vars, "value_vars", frame.columns)
if id_vars is not None:
if not is_list_like(id_vars):
id_vars = [id_vars]
elif isinstance(frame.columns, MultiIndex) and not isinstance(id_vars, list):
raise ValueError(
"id_vars must be a list of tuples when columns are a MultiIndex"
)
else:
# Check that `id_vars` are in frame
id_vars = list(id_vars)
missing = Index(com.flatten(id_vars)).difference(cols)
if not missing.empty:
raise KeyError(
"The following 'id_vars' are not present "
f"in the DataFrame: {list(missing)}"
)
else:
id_vars = []
if value_vars is not None:
if not is_list_like(value_vars):
value_vars = [value_vars]
elif isinstance(frame.columns, MultiIndex) and not isinstance(value_vars, list):
raise ValueError(
"value_vars must be a list of tuples when columns are a MultiIndex"
)
else:
value_vars = list(value_vars)
# Check that `value_vars` are in frame
missing = Index(com.flatten(value_vars)).difference(cols)
if not missing.empty:
raise KeyError(
"The following 'value_vars' are not present in "
f"the DataFrame: {list(missing)}"
)
if id_vars or value_vars:
if col_level is not None:
idx = frame.columns.get_level_values(col_level).get_indexer(
id_vars + value_vars
level = frame.columns.get_level_values(col_level)
else:
level = frame.columns
labels = id_vars + value_vars
idx = level.get_indexer_for(labels)
missing = idx == -1
if missing.any():
missing_labels = [
lab for lab, not_found in zip(labels, missing) if not_found
]
raise KeyError(
"The following id_vars or value_vars are not present in "
f"the DataFrame: {missing_labels}"
)
if value_vars_was_not_none:
frame = frame.iloc[:, algos.unique(idx)]
else:
idx = algos.unique(frame.columns.get_indexer_for(id_vars + value_vars))
frame = frame.iloc[:, idx]
frame = frame.copy()
else:
frame = frame.copy()
@@ -113,24 +97,26 @@ def melt(
var_name = [
frame.columns.name if frame.columns.name is not None else "variable"
]
if isinstance(var_name, str):
elif is_list_like(var_name):
raise ValueError(f"{var_name=} must be a scalar.")
else:
var_name = [var_name]
N, K = frame.shape
K -= len(id_vars)
num_rows, K = frame.shape
num_cols_adjusted = K - len(id_vars)
mdata: dict[Hashable, AnyArrayLike] = {}
for col in id_vars:
id_data = frame.pop(col)
if not isinstance(id_data.dtype, np.dtype):
# i.e. ExtensionDtype
if K > 0:
mdata[col] = concat([id_data] * K, ignore_index=True)
if num_cols_adjusted > 0:
mdata[col] = concat([id_data] * num_cols_adjusted, ignore_index=True)
else:
# We can't concat empty list. (GH 46044)
mdata[col] = type(id_data)([], name=id_data.name, dtype=id_data.dtype)
else:
mdata[col] = np.tile(id_data._values, K)
mdata[col] = np.tile(id_data._values, num_cols_adjusted)
mcolumns = id_vars + var_name + [value_name]
@@ -143,12 +129,12 @@ def melt(
else:
mdata[value_name] = frame._values.ravel("F")
for i, col in enumerate(var_name):
mdata[col] = frame.columns._get_level_values(i).repeat(N)
mdata[col] = frame.columns._get_level_values(i).repeat(num_rows)
result = frame._constructor(mdata, columns=mcolumns)
if not ignore_index:
result.index = tile_compat(frame.index, K)
result.index = tile_compat(frame.index, num_cols_adjusted)
return result