|
10 | 10 | """
|
11 | 11 |
|
12 | 12 | from __future__ import annotations
|
13 |
| - |
| 13 | +from pandas.core.dtypes.common import ( |
| 14 | + is_list_like, |
| 15 | + is_scalar, |
| 16 | + is_datetime64_dtype, |
| 17 | + isna, |
| 18 | +) |
14 | 19 | import collections
|
15 | 20 | from collections import abc
|
16 | 21 | from collections.abc import (
|
@@ -9881,52 +9886,86 @@ def explode(
|
9881 | 9886 | 3 3 1 d
|
9882 | 9887 | 3 4 1 e
|
9883 | 9888 | """
|
9884 |
| - if not self.columns.is_unique: |
9885 |
| - duplicate_cols = self.columns[self.columns.duplicated()].tolist() |
9886 |
| - raise ValueError( |
9887 |
| - f"DataFrame columns must be unique. Duplicate columns: {duplicate_cols}" |
9888 |
| - ) |
| 9889 | + df = self.reset_index(drop=True) |
9889 | 9890 |
|
9890 | 9891 | columns: list[Hashable]
|
9891 | 9892 | if is_scalar(column) or isinstance(column, tuple):
|
9892 | 9893 | columns = [column]
|
9893 |
| - elif isinstance(column, list) and all( |
9894 |
| - is_scalar(c) or isinstance(c, tuple) for c in column |
9895 |
| - ): |
| 9894 | + elif isinstance(column, list) and all(is_scalar(c) or isinstance(c, tuple) for c in column): |
9896 | 9895 | if not column:
|
9897 | 9896 | raise ValueError("column must be nonempty")
|
9898 | 9897 | if len(column) > len(set(column)):
|
9899 | 9898 | raise ValueError("column must be unique")
|
9900 | 9899 | columns = column
|
9901 | 9900 | else:
|
9902 | 9901 | raise ValueError("column must be a scalar, tuple, or list thereof")
|
9903 |
| - |
9904 |
| - df = self.reset_index(drop=True) |
9905 | 9902 | if len(columns) == 1:
|
9906 |
| - result = df[columns[0]].explode() |
9907 |
| - orig_dtype = df[columns[0]].dtype |
9908 |
| - if pd.api.types.is_datetime64_dtype(orig_dtype): |
9909 |
| - result = result.astype(orig_dtype) |
| 9903 | + col = columns[0] |
| 9904 | + orig_dtype = df[col].dtype |
| 9905 | + |
| 9906 | + exploded_values = [] |
| 9907 | + exploded_index = [] |
| 9908 | + |
| 9909 | + for i, val in enumerate(df[col]): |
| 9910 | + if is_list_like(val) and not isinstance(val, (str, bytes)): |
| 9911 | + for item in val: |
| 9912 | + exploded_values.append(item) |
| 9913 | + exploded_index.append(i) |
| 9914 | + elif isna(val): |
| 9915 | + exploded_values.append(np.datetime64("NaT") if is_datetime64_dtype(orig_dtype) else np.nan) |
| 9916 | + exploded_index.append(i) |
| 9917 | + else: |
| 9918 | + exploded_values.append(val) |
| 9919 | + exploded_index.append(i) |
| 9920 | + |
| 9921 | + exploded_series = pd.Series( |
| 9922 | + np.array(exploded_values, dtype=orig_dtype if is_datetime64_dtype(orig_dtype) else None), |
| 9923 | + index=exploded_index, |
| 9924 | + name=col |
| 9925 | + ) |
| 9926 | + |
| 9927 | + result = df.drop(columns, axis=1).iloc[exploded_series.index] |
| 9928 | + result[col] = exploded_series.values |
9910 | 9929 | else:
|
9911 | 9930 | mylen = lambda x: len(x) if (is_list_like(x) and len(x) > 0) else 1
|
9912 | 9931 | counts0 = self[columns[0]].apply(mylen)
|
9913 | 9932 | for c in columns[1:]:
|
9914 | 9933 | if not all(counts0 == self[c].apply(mylen)):
|
9915 | 9934 | raise ValueError("columns must have matching element counts")
|
9916 |
| - result_data = {} |
9917 |
| - for c in columns: |
9918 |
| - exploded_series = df[c].explode() |
9919 |
| - orig_dtype = df[c].dtype |
9920 |
| - if pd.api.types.is_datetime64_dtype(orig_dtype): |
9921 |
| - exploded_series = exploded_series.astype(orig_dtype) |
9922 |
| - result_data[c] = exploded_series |
9923 |
| - result = DataFrame(result_data) |
9924 |
| - |
9925 |
| - result = df.drop(columns, axis=1).join(result) |
| 9935 | + |
| 9936 | + exploded_columns = {} |
| 9937 | + exploded_index = [] |
| 9938 | + |
| 9939 | + for i in range(len(df)): |
| 9940 | + row_counts = mylen(df[columns[0]].iloc[i]) |
| 9941 | + for j in range(row_counts): |
| 9942 | + exploded_index.append(i) |
| 9943 | + |
| 9944 | + for col in columns: |
| 9945 | + orig_dtype = df[col].dtype |
| 9946 | + values = [] |
| 9947 | + for val in df[col]: |
| 9948 | + if is_list_like(val) and not isinstance(val, (str, bytes)): |
| 9949 | + values.extend(val) |
| 9950 | + elif isna(val): |
| 9951 | + values.append(np.datetime64("NaT") if is_datetime64_dtype(orig_dtype) else np.nan) |
| 9952 | + else: |
| 9953 | + values.append(val) |
| 9954 | + exploded_columns[col] = pd.Series( |
| 9955 | + np.array(values, dtype=orig_dtype if is_datetime64_dtype(orig_dtype) else None), |
| 9956 | + index=exploded_index |
| 9957 | + ) |
| 9958 | + |
| 9959 | + result = df.drop(columns, axis=1).iloc[exploded_index].copy() |
| 9960 | + for col in columns: |
| 9961 | + result[col] = exploded_columns[col].values |
| 9962 | + |
| 9963 | + # Handle index |
9926 | 9964 | if ignore_index:
|
9927 | 9965 | result.index = default_index(len(result))
|
9928 | 9966 | else:
|
9929 | 9967 | result.index = self.index.take(result.index)
|
| 9968 | + |
9930 | 9969 | result = result.reindex(columns=self.columns)
|
9931 | 9970 |
|
9932 | 9971 | return result.__finalize__(self, method="explode")
|
|
0 commit comments