Skip to content

Commit 51746a4

Browse files
Fix explode to preserve datetime unit without type loss
1 parent 401a42a commit 51746a4

File tree

1 file changed

+64
-25
lines changed

1 file changed

+64
-25
lines changed

pandas/core/frame.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010
"""
1111

1212
from __future__ import annotations
13-
13+
from pandas.core.dtypes.common import (
14+
is_list_like,
15+
is_scalar,
16+
is_datetime64_dtype,
17+
isna,
18+
)
1419
import collections
1520
from collections import abc
1621
from collections.abc import (
@@ -9881,52 +9886,86 @@ def explode(
98819886
3 3 1 d
98829887
3 4 1 e
98839888
"""
9884-
if not self.columns.is_unique:
9885-
duplicate_cols = self.columns[self.columns.duplicated()].tolist()
9886-
raise ValueError(
9887-
f"DataFrame columns must be unique. Duplicate columns: {duplicate_cols}"
9888-
)
9889+
df = self.reset_index(drop=True)
98899890

98909891
columns: list[Hashable]
98919892
if is_scalar(column) or isinstance(column, tuple):
98929893
columns = [column]
9893-
elif isinstance(column, list) and all(
9894-
is_scalar(c) or isinstance(c, tuple) for c in column
9895-
):
9894+
elif isinstance(column, list) and all(is_scalar(c) or isinstance(c, tuple) for c in column):
98969895
if not column:
98979896
raise ValueError("column must be nonempty")
98989897
if len(column) > len(set(column)):
98999898
raise ValueError("column must be unique")
99009899
columns = column
99019900
else:
99029901
raise ValueError("column must be a scalar, tuple, or list thereof")
9903-
9904-
df = self.reset_index(drop=True)
99059902
if len(columns) == 1:
9906-
result = df[columns[0]].explode()
9907-
orig_dtype = df[columns[0]].dtype
9908-
if pd.api.types.is_datetime64_dtype(orig_dtype):
9909-
result = result.astype(orig_dtype)
9903+
col = columns[0]
9904+
orig_dtype = df[col].dtype
9905+
9906+
exploded_values = []
9907+
exploded_index = []
9908+
9909+
for i, val in enumerate(df[col]):
9910+
if is_list_like(val) and not isinstance(val, (str, bytes)):
9911+
for item in val:
9912+
exploded_values.append(item)
9913+
exploded_index.append(i)
9914+
elif isna(val):
9915+
exploded_values.append(np.datetime64("NaT") if is_datetime64_dtype(orig_dtype) else np.nan)
9916+
exploded_index.append(i)
9917+
else:
9918+
exploded_values.append(val)
9919+
exploded_index.append(i)
9920+
9921+
exploded_series = pd.Series(
9922+
np.array(exploded_values, dtype=orig_dtype if is_datetime64_dtype(orig_dtype) else None),
9923+
index=exploded_index,
9924+
name=col
9925+
)
9926+
9927+
result = df.drop(columns, axis=1).iloc[exploded_series.index]
9928+
result[col] = exploded_series.values
99109929
else:
99119930
mylen = lambda x: len(x) if (is_list_like(x) and len(x) > 0) else 1
99129931
counts0 = self[columns[0]].apply(mylen)
99139932
for c in columns[1:]:
99149933
if not all(counts0 == self[c].apply(mylen)):
99159934
raise ValueError("columns must have matching element counts")
9916-
result_data = {}
9917-
for c in columns:
9918-
exploded_series = df[c].explode()
9919-
orig_dtype = df[c].dtype
9920-
if pd.api.types.is_datetime64_dtype(orig_dtype):
9921-
exploded_series = exploded_series.astype(orig_dtype)
9922-
result_data[c] = exploded_series
9923-
result = DataFrame(result_data)
9924-
9925-
result = df.drop(columns, axis=1).join(result)
9935+
9936+
exploded_columns = {}
9937+
exploded_index = []
9938+
9939+
for i in range(len(df)):
9940+
row_counts = mylen(df[columns[0]].iloc[i])
9941+
for j in range(row_counts):
9942+
exploded_index.append(i)
9943+
9944+
for col in columns:
9945+
orig_dtype = df[col].dtype
9946+
values = []
9947+
for val in df[col]:
9948+
if is_list_like(val) and not isinstance(val, (str, bytes)):
9949+
values.extend(val)
9950+
elif isna(val):
9951+
values.append(np.datetime64("NaT") if is_datetime64_dtype(orig_dtype) else np.nan)
9952+
else:
9953+
values.append(val)
9954+
exploded_columns[col] = pd.Series(
9955+
np.array(values, dtype=orig_dtype if is_datetime64_dtype(orig_dtype) else None),
9956+
index=exploded_index
9957+
)
9958+
9959+
result = df.drop(columns, axis=1).iloc[exploded_index].copy()
9960+
for col in columns:
9961+
result[col] = exploded_columns[col].values
9962+
9963+
# Handle index
99269964
if ignore_index:
99279965
result.index = default_index(len(result))
99289966
else:
99299967
result.index = self.index.take(result.index)
9968+
99309969
result = result.reindex(columns=self.columns)
99319970

99329971
return result.__finalize__(self, method="explode")

0 commit comments

Comments
 (0)