Skip to content

Commit 7f6afcd

Browse files
BUG: concat not preserving extension dtype in MultiIndex levels (pandas-dev#58421)
1 parent 56847c5 commit 7f6afcd

File tree

2 files changed

+88
-6
lines changed

2 files changed

+88
-6
lines changed

pandas/core/reshape/concat.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import numpy as np
1818

19+
import pandas as pd
1920
from pandas._libs import lib
2021
from pandas.util._decorators import set_module
2122
from pandas.util._exceptions import find_stack_level
@@ -47,6 +48,9 @@
4748
)
4849
from pandas.core.internals import concatenate_managers
4950

51+
from pandas.core.dtypes.common import is_extension_array_dtype
52+
53+
5054
if TYPE_CHECKING:
5155
from collections.abc import (
5256
Callable,
@@ -817,9 +821,21 @@ def _get_sample_object(
817821

818822
return objs[0], objs
819823

820-
821824
def _concat_indexes(indexes) -> Index:
822-
return indexes[0].append(indexes[1:])
825+
# try to preserve extension types such as timestamp[pyarrow]
826+
values = []
827+
for idx in indexes:
828+
values.extend(idx._values if hasattr(idx, "_values") else idx)
829+
830+
# use the first index as a sample to infer the desired dtype
831+
sample = indexes[0]
832+
try:
833+
# this helps preserve extension types like timestamp[pyarrow]
834+
arr = pd.array(values, dtype=sample.dtype)
835+
except Exception:
836+
arr = pd.array(values) # fallback
837+
838+
return pd.Index(arr)
823839

824840

825841
def validate_unique_levels(levels: list[Index]) -> None:
@@ -876,14 +892,33 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde
876892

877893
concat_index = _concat_indexes(indexes)
878894

879-
# these go at the end
880895
if isinstance(concat_index, MultiIndex):
881896
levels.extend(concat_index.levels)
882897
codes_list.extend(concat_index.codes)
883898
else:
884-
codes, categories = factorize_from_iterable(concat_index)
885-
levels.append(categories)
886-
codes_list.append(codes)
899+
# handle the case where the resulting index is a flat Index
900+
# but contains tuples (i.e., a collapsed MultiIndex)
901+
if isinstance(concat_index[0], tuple):
902+
# retrieve the original dtypes
903+
original_dtypes = [lvl.dtype for lvl in indexes[0].levels]
904+
905+
unzipped = list(zip(*concat_index))
906+
for i, level_values in enumerate(unzipped):
907+
# reconstruct each level using original dtype
908+
arr = pd.array(level_values, dtype=original_dtypes[i])
909+
level_codes, _ = factorize_from_iterable(arr)
910+
levels.append(ensure_index(arr))
911+
codes_list.append(level_codes)
912+
else:
913+
# simple indexes factorize directly
914+
codes, categories = factorize_from_iterable(concat_index)
915+
values = getattr(concat_index, "_values", concat_index)
916+
if is_extension_array_dtype(values):
917+
levels.append(values)
918+
else:
919+
levels.append(categories)
920+
codes_list.append(codes)
921+
887922

888923
if len(names) == len(levels):
889924
names = list(names)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pandas as pd
2+
import pytest
3+
4+
5+
schema = {
6+
"id": "int64[pyarrow]",
7+
"time": "timestamp[s][pyarrow]",
8+
"value": "float[pyarrow]",
9+
}
10+
11+
@pytest.mark.parametrize("dtype", ["timestamp[s][pyarrow]"])
12+
def test_concat_preserves_pyarrow_timestamp(dtype):
13+
dfA = (
14+
pd.DataFrame(
15+
[
16+
(0, "2021-01-01 00:00:00", 5.3),
17+
(1, "2021-01-01 00:01:00", 5.4),
18+
(2, "2021-01-01 00:01:00", 5.4),
19+
(3, "2021-01-01 00:02:00", 5.5),
20+
],
21+
columns=schema,
22+
)
23+
.astype(schema)
24+
.set_index(["id", "time"])
25+
)
26+
27+
dfB = (
28+
pd.DataFrame(
29+
[
30+
(1, "2022-01-01 08:00:00", 6.3),
31+
(2, "2022-01-01 08:01:00", 6.4),
32+
(3, "2022-01-01 08:02:00", 6.5),
33+
],
34+
columns=schema,
35+
)
36+
.astype(schema)
37+
.set_index(["id", "time"])
38+
)
39+
40+
df = pd.concat([dfA, dfB], keys=[0, 1], names=["run"])
41+
42+
# chech whether df.index is multiIndex
43+
assert isinstance(df.index, pd.MultiIndex), f"Expected MultiIndex, but received {type(df.index)}"
44+
45+
# Verifying special dtype timestamp[s][pyarrow] stays intact after concat
46+
assert df.index.levels[2].dtype == dtype, f"Expected {dtype}, but received {df.index.levels[2].dtype}"
47+

0 commit comments

Comments
 (0)