|
16 | 16 |
|
17 | 17 | import numpy as np
|
18 | 18 |
|
| 19 | +import pandas as pd |
19 | 20 | from pandas._libs import lib
|
20 | 21 | from pandas.util._decorators import set_module
|
21 | 22 | from pandas.util._exceptions import find_stack_level
|
|
47 | 48 | )
|
48 | 49 | from pandas.core.internals import concatenate_managers
|
49 | 50 |
|
| 51 | +from pandas.core.dtypes.common import is_extension_array_dtype |
| 52 | + |
| 53 | + |
50 | 54 | if TYPE_CHECKING:
|
51 | 55 | from collections.abc import (
|
52 | 56 | Callable,
|
@@ -817,9 +821,21 @@ def _get_sample_object(
|
817 | 821 |
|
818 | 822 | return objs[0], objs
|
819 | 823 |
|
820 |
| - |
821 | 824 | def _concat_indexes(indexes) -> Index:
|
822 |
| - return indexes[0].append(indexes[1:]) |
| 825 | + # try to preserve extension types such as timestamp[pyarrow] |
| 826 | + values = [] |
| 827 | + for idx in indexes: |
| 828 | + values.extend(idx._values if hasattr(idx, "_values") else idx) |
| 829 | + |
| 830 | + # use the first index as a sample to infer the desired dtype |
| 831 | + sample = indexes[0] |
| 832 | + try: |
| 833 | + # this helps preserve extension types like timestamp[pyarrow] |
| 834 | + arr = pd.array(values, dtype=sample.dtype) |
| 835 | + except Exception: |
| 836 | + arr = pd.array(values) # fallback |
| 837 | + |
| 838 | + return pd.Index(arr) |
823 | 839 |
|
824 | 840 |
|
825 | 841 | def validate_unique_levels(levels: list[Index]) -> None:
|
@@ -876,14 +892,33 @@ def _make_concat_multiindex(indexes, keys, levels=None, names=None) -> MultiInde
|
876 | 892 |
|
877 | 893 | concat_index = _concat_indexes(indexes)
|
878 | 894 |
|
879 |
| - # these go at the end |
880 | 895 | if isinstance(concat_index, MultiIndex):
|
881 | 896 | levels.extend(concat_index.levels)
|
882 | 897 | codes_list.extend(concat_index.codes)
|
883 | 898 | else:
|
884 |
| - codes, categories = factorize_from_iterable(concat_index) |
885 |
| - levels.append(categories) |
886 |
| - codes_list.append(codes) |
| 899 | + # handle the case where the resulting index is a flat Index |
| 900 | + # but contains tuples (i.e., a collapsed MultiIndex) |
| 901 | + if isinstance(concat_index[0], tuple): |
| 902 | + # retrieve the original dtypes |
| 903 | + original_dtypes = [lvl.dtype for lvl in indexes[0].levels] |
| 904 | + |
| 905 | + unzipped = list(zip(*concat_index)) |
| 906 | + for i, level_values in enumerate(unzipped): |
| 907 | + # reconstruct each level using original dtype |
| 908 | + arr = pd.array(level_values, dtype=original_dtypes[i]) |
| 909 | + level_codes, _ = factorize_from_iterable(arr) |
| 910 | + levels.append(ensure_index(arr)) |
| 911 | + codes_list.append(level_codes) |
| 912 | + else: |
| 913 | + # simple indexes factorize directly |
| 914 | + codes, categories = factorize_from_iterable(concat_index) |
| 915 | + values = getattr(concat_index, "_values", concat_index) |
| 916 | + if is_extension_array_dtype(values): |
| 917 | + levels.append(values) |
| 918 | + else: |
| 919 | + levels.append(categories) |
| 920 | + codes_list.append(codes) |
| 921 | + |
887 | 922 |
|
888 | 923 | if len(names) == len(levels):
|
889 | 924 | names = list(names)
|
|
0 commit comments