From e8eabea8c066037eed96d6eddb74937db1a7c930 Mon Sep 17 00:00:00 2001 From: Afonso Antunes Date: Fri, 28 Mar 2025 11:53:59 +0100 Subject: [PATCH 1/2] Fix #58421: Index[timestamp[pyarrow]].union with itself return object type --- pandas/core/reshape/concat.py | 45 ++++++++++++++++-- .../frame/methods/test_concat_arrow_index.py | 46 +++++++++++++++++++ 2 files changed, 86 insertions(+), 5 deletions(-) create mode 100644 pandas/tests/frame/methods/test_concat_arrow_index.py diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index e7cb7069bbc26..d5079fdb1523f 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -47,6 +47,10 @@ ) from pandas.core.internals import concatenate_managers +from pandas.core.dtypes.common import is_extension_array_dtype + +from pandas.core.construction import array + if TYPE_CHECKING: from collections.abc import ( Callable, @@ -819,7 +823,20 @@ def _get_sample_object( def _concat_indexes(indexes) -> Index: - return indexes[0].append(indexes[1:]) + # try to preserve extension types such as timestamp[pyarrow] + values = [] + for idx in indexes: + values.extend(idx._values if hasattr(idx, "_values") else idx) + + # use the first index as a sample to infer the desired dtype + sample = indexes[0] + try: + # this helps preserve extension types like timestamp[pyarrow] + arr = array(values, dtype=sample.dtype) + except Exception: + arr = array(values) # fallback + + return Index(arr) def validate_unique_levels(levels: list[Index]) -> None: @@ -876,14 +893,32 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde concat_index = _concat_indexes(indexes) - # these go at the end if isinstance(concat_index, MultiIndex): levels.extend(concat_index.levels) codes_list.extend(concat_index.codes) else: - codes, categories = factorize_from_iterable(concat_index) - levels.append(categories) - codes_list.append(codes) + # handle the case where the resulting index is a flat Index + # but contains tuples (i.e., a collapsed MultiIndex) + if isinstance(concat_index[0], tuple): + # retrieve the original dtypes + original_dtypes = [lvl.dtype for lvl in indexes[0].levels] + + unzipped = list(zip(*concat_index)) + for i, level_values in enumerate(unzipped): + # reconstruct each level using original dtype + arr = array(level_values, dtype=original_dtypes[i]) + level_codes, _ = factorize_from_iterable(arr) + levels.append(ensure_index(arr)) + codes_list.append(level_codes) + else: + # simple indexes factorize directly + codes, categories = factorize_from_iterable(concat_index) + values = getattr(concat_index, "_values", concat_index) + if is_extension_array_dtype(values): + levels.append(values) + else: + levels.append(categories) + codes_list.append(codes) if len(names) == len(levels): names = list(names) diff --git a/pandas/tests/frame/methods/test_concat_arrow_index.py b/pandas/tests/frame/methods/test_concat_arrow_index.py new file mode 100644 index 0000000000000..a341f1477326a --- /dev/null +++ b/pandas/tests/frame/methods/test_concat_arrow_index.py @@ -0,0 +1,46 @@ +import pandas as pd +import pytest + + +schema = { + "id": "int64[pyarrow]", + "time": "timestamp[s][pyarrow]", + "value": "float[pyarrow]", +} + +@pytest.mark.parametrize("dtype", ["timestamp[s][pyarrow]"]) +def test_concat_preserves_pyarrow_timestamp(dtype): + dfA = ( + pd.DataFrame( + [ + (0, "2021-01-01 00:00:00", 5.3), + (1, "2021-01-01 00:01:00", 5.4), + (2, "2021-01-01 00:01:00", 5.4), + (3, "2021-01-01 00:02:00", 5.5), + ], + columns=schema, + ) + .astype(schema) + .set_index(["id", "time"]) + ) + + dfB = ( + pd.DataFrame( + [ + (1, "2022-01-01 08:00:00", 6.3), + (2, "2022-01-01 08:01:00", 6.4), + (3, "2022-01-01 08:02:00", 6.5), + ], + columns=schema, + ) + .astype(schema) + .set_index(["id", "time"]) + ) + + df = pd.concat([dfA, dfB], keys=[0, 1], names=["run"]) + + # chech whether df.index is multiIndex + assert isinstance(df.index, pd.MultiIndex), f"Expected MultiIndex, but received {type(df.index)}" + + # Verifying special dtype timestamp[s][pyarrow] stays intact after concat + assert df.index.levels[2].dtype == dtype, f"Expected {dtype}, but received {df.index.levels[2].dtype}" From d4d095f305d86e0fcd30af7f1c8513703d63dd0e Mon Sep 17 00:00:00 2001 From: Afonso Antunes Date: Wed, 2 Apr 2025 20:06:09 +0100 Subject: [PATCH 2/2] Now complies with pre-commit requirements and added an entry in v3.0.0.rst --- doc/source/whatsnew/v3.0.0.rst | 1 + pandas/core/reshape/concat.py | 12 +++++------- .../frame/methods/test_concat_arrow_index.py | 17 +++++++++++------ 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/doc/source/whatsnew/v3.0.0.rst b/doc/source/whatsnew/v3.0.0.rst index 4a6cf117fd196..d875b7cc8dd76 100644 --- a/doc/source/whatsnew/v3.0.0.rst +++ b/doc/source/whatsnew/v3.0.0.rst @@ -712,6 +712,7 @@ MultiIndex - :func:`MultiIndex.get_level_values` accessing a :class:`DatetimeIndex` does not carry the frequency attribute along (:issue:`58327`, :issue:`57949`) - Bug in :class:`DataFrame` arithmetic operations in case of unaligned MultiIndex columns (:issue:`60498`) - Bug in :class:`DataFrame` arithmetic operations with :class:`Series` in case of unaligned MultiIndex (:issue:`61009`) +- Bug in :class:`MultiIndex.concat` where extension dtypes such as ``timestamp[pyarrow]`` were silently coerced to ``object`` instead of preserving their original dtype (:issue:`58421`) - I/O diff --git a/pandas/core/reshape/concat.py b/pandas/core/reshape/concat.py index d5079fdb1523f..776e7a482d7fe 100644 --- a/pandas/core/reshape/concat.py +++ b/pandas/core/reshape/concat.py @@ -22,6 +22,7 @@ from pandas.core.dtypes.common import ( is_bool, + is_extension_array_dtype, is_scalar, ) from pandas.core.dtypes.concat import concat_compat @@ -36,6 +37,7 @@ factorize_from_iterables, ) import pandas.core.common as com +from pandas.core.construction import array as pd_array from pandas.core.indexes.api import ( Index, MultiIndex, @@ -47,10 +49,6 @@ ) from pandas.core.internals import concatenate_managers -from pandas.core.dtypes.common import is_extension_array_dtype - -from pandas.core.construction import array - if TYPE_CHECKING: from collections.abc import ( Callable, @@ -832,9 +830,9 @@ def _concat_indexes(indexes) -> Index: sample = indexes[0] try: # this helps preserve extension types like timestamp[pyarrow] - arr = array(values, dtype=sample.dtype) + arr = pd_array(values, dtype=sample.dtype) except Exception: - arr = array(values) # fallback + arr = pd_array(values) # fallback return Index(arr) @@ -906,7 +904,7 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde unzipped = list(zip(*concat_index)) for i, level_values in enumerate(unzipped): # reconstruct each level using original dtype - arr = array(level_values, dtype=original_dtypes[i]) + arr = pd_array(level_values, dtype=original_dtypes[i]) level_codes, _ = factorize_from_iterable(arr) levels.append(ensure_index(arr)) codes_list.append(level_codes) diff --git a/pandas/tests/frame/methods/test_concat_arrow_index.py b/pandas/tests/frame/methods/test_concat_arrow_index.py index a341f1477326a..6fcc5ee5119a6 100644 --- a/pandas/tests/frame/methods/test_concat_arrow_index.py +++ b/pandas/tests/frame/methods/test_concat_arrow_index.py @@ -1,6 +1,6 @@ -import pandas as pd import pytest +import pandas as pd schema = { "id": "int64[pyarrow]", @@ -8,6 +8,7 @@ "value": "float[pyarrow]", } + @pytest.mark.parametrize("dtype", ["timestamp[s][pyarrow]"]) def test_concat_preserves_pyarrow_timestamp(dtype): dfA = ( @@ -38,9 +39,13 @@ def test_concat_preserves_pyarrow_timestamp(dtype): ) df = pd.concat([dfA, dfB], keys=[0, 1], names=["run"]) - - # chech whether df.index is multiIndex - assert isinstance(df.index, pd.MultiIndex), f"Expected MultiIndex, but received {type(df.index)}" - + + # check whether df.index is multiIndex + assert isinstance(df.index, pd.MultiIndex), ( + f"Expected MultiIndex, but received {type(df.index)}" + ) + # Verifying special dtype timestamp[s][pyarrow] stays intact after concat - assert df.index.levels[2].dtype == dtype, f"Expected {dtype}, but received {df.index.levels[2].dtype}" + assert df.index.levels[2].dtype == dtype, ( + f"Expected {dtype}, but received {df.index.levels[2].dtype}" + )