Skip to content

Commit e670c4b

Browse files
authored
Merge pull request #4246 from legendof-selda/fix/pd_perf_issue
Fix/pandas Performance Warning Issue due to multiple `frame.insert`
2 parents e430257 + f1ed3d7 commit e670c4b

File tree

3 files changed

+60
-20
lines changed

3 files changed

+60
-20
lines changed

.flake8

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[flake8]
2+
max-line-length = 88

packages/python/plotly/plotly/express/_core.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,6 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
322322
and args["y"]
323323
and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
324324
):
325-
326325
# sorting is bad but trace_specs with "trendline" have no other attrs
327326
sorted_trace_data = trace_data.sort_values(by=args["x"])
328327
y = sorted_trace_data[args["y"]].values
@@ -563,7 +562,6 @@ def set_cartesian_axis_opts(args, axis, letter, orders):
563562

564563

565564
def configure_cartesian_marginal_axes(args, fig, orders):
566-
567565
if "histogram" in [args["marginal_x"], args["marginal_y"]]:
568566
fig.layout["barmode"] = "overlay"
569567

@@ -1065,14 +1063,14 @@ def _escape_col_name(columns, col_name, extra):
10651063
return col_name
10661064

10671065

1068-
def to_unindexed_series(x):
1066+
def to_unindexed_series(x, name=None):
10691067
"""
10701068
assuming x is list-like or even an existing pd.Series, return a new pd.Series with
10711069
no index, without extracting the data from an existing Series via numpy, which
10721070
seems to mangle datetime columns. Stripping the index from existing pd.Series is
10731071
required to get things to match up right in the new DataFrame we're building
10741072
"""
1075-
return pd.Series(x).reset_index(drop=True)
1073+
return pd.Series(x, name=name).reset_index(drop=True)
10761074

10771075

10781076
def process_args_into_dataframe(args, wide_mode, var_name, value_name):
@@ -1087,9 +1085,12 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
10871085
df_input = args["data_frame"]
10881086
df_provided = df_input is not None
10891087

1090-
df_output = pd.DataFrame()
1091-
constants = dict()
1092-
ranges = list()
1088+
# we use a dict instead of a dataframe directly so that it doesn't cause
1089+
# PerformanceWarning by pandas by repeatedly setting the columns.
1090+
# a dict is used instead of a list as the columns needs to be overwritten.
1091+
df_output = {}
1092+
constants = {}
1093+
ranges = []
10931094
wide_id_vars = set()
10941095
reserved_names = _get_reserved_col_names(args) if df_provided else set()
10951096

@@ -1100,7 +1101,7 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
11001101
"No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument."
11011102
)
11021103
else:
1103-
df_output[df_input.columns] = df_input[df_input.columns]
1104+
df_output = {col: series for col, series in df_input.items()}
11041105

11051106
# hover_data is a dict
11061107
hover_data_is_dict = (
@@ -1141,7 +1142,7 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
11411142
# argument_list and field_list ready, iterate over them
11421143
# Core of the loop starts here
11431144
for i, (argument, field) in enumerate(zip(argument_list, field_list)):
1144-
length = len(df_output)
1145+
length = len(df_output[next(iter(df_output))]) if len(df_output) else 0
11451146
if argument is None:
11461147
continue
11471148
col_name = None
@@ -1182,11 +1183,11 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
11821183
% (
11831184
argument,
11841185
len(real_argument),
1185-
str(list(df_output.columns)),
1186+
str(list(df_output.keys())),
11861187
length,
11871188
)
11881189
)
1189-
df_output[col_name] = to_unindexed_series(real_argument)
1190+
df_output[col_name] = to_unindexed_series(real_argument, col_name)
11901191
elif not df_provided:
11911192
raise ValueError(
11921193
"String or int arguments are only possible when a "
@@ -1215,13 +1216,15 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
12151216
% (
12161217
field,
12171218
len(df_input[argument]),
1218-
str(list(df_output.columns)),
1219+
str(list(df_output.keys())),
12191220
length,
12201221
)
12211222
)
12221223
else:
12231224
col_name = str(argument)
1224-
df_output[col_name] = to_unindexed_series(df_input[argument])
1225+
df_output[col_name] = to_unindexed_series(
1226+
df_input[argument], col_name
1227+
)
12251228
# ----------------- argument is likely a column / array / list.... -------
12261229
else:
12271230
if df_provided and hasattr(argument, "name"):
@@ -1248,9 +1251,9 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
12481251
"All arguments should have the same length. "
12491252
"The length of argument `%s` is %d, whereas the "
12501253
"length of previously-processed arguments %s is %d"
1251-
% (field, len(argument), str(list(df_output.columns)), length)
1254+
% (field, len(argument), str(list(df_output.keys())), length)
12521255
)
1253-
df_output[str(col_name)] = to_unindexed_series(argument)
1256+
df_output[str(col_name)] = to_unindexed_series(argument, str(col_name))
12541257

12551258
# Finally, update argument with column name now that column exists
12561259
assert col_name is not None, (
@@ -1268,12 +1271,19 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name):
12681271
if field_name != "wide_variable":
12691272
wide_id_vars.add(str(col_name))
12701273

1271-
for col_name in ranges:
1272-
df_output[col_name] = range(len(df_output))
1273-
1274-
for col_name in constants:
1275-
df_output[col_name] = constants[col_name]
1274+
length = len(df_output[next(iter(df_output))]) if len(df_output) else 0
1275+
df_output.update(
1276+
{col_name: to_unindexed_series(range(length), col_name) for col_name in ranges}
1277+
)
1278+
df_output.update(
1279+
{
1280+
# constant is single value. repeat by len to avoid creating NaN on concating
1281+
col_name: to_unindexed_series([constants[col_name]] * length, col_name)
1282+
for col_name in constants
1283+
}
1284+
)
12761285

1286+
df_output = pd.DataFrame(df_output)
12771287
return df_output, wide_id_vars
12781288

12791289

packages/python/plotly/plotly/tests/test_optional/test_px/test_px_wide.py

+28
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import plotly.express as px
22
import plotly.graph_objects as go
33
import pandas as pd
4+
import numpy as np
45
from plotly.express._core import build_dataframe, _is_col_list
56
from pandas.testing import assert_frame_equal
67
import pytest
8+
import warnings
79

810

911
def test_is_col_list():
@@ -847,3 +849,29 @@ def test_line_group():
847849
assert len(fig.data) == 4
848850
fig = px.scatter(df, x="x", y=["miss", "score"], color="who")
849851
assert len(fig.data) == 2
852+
853+
854+
def test_no_pd_perf_warning():
855+
n_cols = 1000
856+
n_rows = 1000
857+
858+
columns = list(f"col_{c}" for c in range(n_cols))
859+
index = list(f"i_{r}" for r in range(n_rows))
860+
861+
df = pd.DataFrame(
862+
np.random.uniform(size=(n_rows, n_cols)), index=index, columns=columns
863+
)
864+
865+
with warnings.catch_warnings(record=True) as warn_list:
866+
_ = px.bar(
867+
df,
868+
x=df.index,
869+
y=df.columns[:-2],
870+
labels=df.columns[:-2],
871+
)
872+
performance_warnings = [
873+
warn
874+
for warn in warn_list
875+
if issubclass(warn.category, pd.errors.PerformanceWarning)
876+
]
877+
assert len(performance_warnings) == 0, "PerformanceWarning(s) raised!"

0 commit comments

Comments
 (0)