Skip to content

Commit f83921f

Browse files
Merge pull request #3765 from plotly/one_group_short_circuit
PX: Avoid `groupby` when possible and access groups more efficiently
2 parents a4b9887 + 4b22199 commit f83921f

File tree

2 files changed

+49
-20
lines changed

2 files changed

+49
-20
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ This project adheres to [Semantic Versioning](http://semver.org/).
1010
- `pattern_shape` options now available in `px.timeline()` [#3774](https://github.com/plotly/plotly.py/pull/3774)
1111
- `facet_*` and `category_orders` now available in `px.pie()` [#3775](https://github.com/plotly/plotly.py/pull/3775)
1212

13+
### Performance
14+
15+
- `px` methods no longer call `groupby` on the input dataframe when the result would be a single group, and no longer groups by a lambda, for significant speedups [#3765](https://github.com/plotly/plotly.py/pull/3765)
16+
1317
### Updated
1418

1519
- Allow non-string extras in `flaglist` attributes, to support upcoming changes to `ax.automargin` in plotly.js [plotly.js#6193](https://github.com/plotly/plotly.js/pull/6193), [#3749](https://github.com/plotly/plotly.py/pull/3749)

packages/python/plotly/plotly/express/_core.py

+45-20
Original file line numberDiff line numberDiff line change
@@ -1920,40 +1920,66 @@ def infer_config(args, constructor, trace_patch, layout_patch):
19201920
return trace_specs, grouped_mappings, sizeref, show_colorbar
19211921

19221922

1923-
def get_orderings(args, grouper, grouped):
1923+
def get_groups_and_orders(args, grouper):
19241924
"""
19251925
`orders` is the user-supplied ordering with the remaining data-frame-supplied
19261926
ordering appended if the column is used for grouping. It includes anything the user
19271927
gave, for any variable, including values not present in the dataset. It's a dict
19281928
where the keys are e.g. "x" or "color"
19291929
1930-
`sorted_group_names` is the set of groups, ordered by the order above. It's a list
1931-
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
1930+
`groups` is the dicts of groups, ordered by the order above. Its keys are
1931+
tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
19321932
of a single dimension-group
19331933
"""
1934-
19351934
orders = {} if "category_orders" not in args else args["category_orders"].copy()
1935+
1936+
# figure out orders and what the single group name would be if there were one
1937+
single_group_name = []
1938+
unique_cache = dict()
19361939
for col in grouper:
1937-
if col != one_group:
1938-
uniques = list(args["data_frame"][col].unique())
1940+
if col == one_group:
1941+
single_group_name.append("")
1942+
else:
1943+
if col not in unique_cache:
1944+
unique_cache[col] = list(args["data_frame"][col].unique())
1945+
uniques = unique_cache[col]
1946+
if len(uniques) == 1:
1947+
single_group_name.append(uniques[0])
19391948
if col not in orders:
19401949
orders[col] = uniques
19411950
else:
19421951
orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques))
1952+
df = args["data_frame"]
1953+
if len(single_group_name) == len(grouper):
1954+
# we have a single group, so we can skip all group-by operations!
1955+
groups = {tuple(single_group_name): df}
1956+
else:
1957+
required_grouper = [g for g in grouper if g != one_group]
1958+
grouped = df.groupby(required_grouper, sort=False) # skip one_group groupers
1959+
group_indices = grouped.indices
1960+
sorted_group_names = [
1961+
g if len(required_grouper) != 1 else (g,) for g in group_indices
1962+
]
19431963

1944-
sorted_group_names = []
1945-
for group_name in grouped.groups:
1946-
if len(grouper) == 1:
1947-
group_name = (group_name,)
1948-
sorted_group_names.append(group_name)
1949-
1950-
for i, col in reversed(list(enumerate(grouper))):
1951-
if col != one_group:
1964+
for i, col in reversed(list(enumerate(required_grouper))):
19521965
sorted_group_names = sorted(
19531966
sorted_group_names,
19541967
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
19551968
)
1956-
return orders, sorted_group_names
1969+
1970+
# calculate the full group_names by inserting "" in the tuple index for one_group groups
1971+
full_sorted_group_names = [list(t) for t in sorted_group_names]
1972+
for i, col in enumerate(grouper):
1973+
if col == one_group:
1974+
for g in full_sorted_group_names:
1975+
g.insert(i, "")
1976+
full_sorted_group_names = [tuple(g) for g in full_sorted_group_names]
1977+
1978+
groups = {
1979+
sf: grouped.get_group(s if len(s) > 1 else s[0])
1980+
for sf, s in zip(full_sorted_group_names, sorted_group_names)
1981+
}
1982+
return groups, orders
19571983

19581984

19591985
def make_figure(args, constructor, trace_patch=None, layout_patch=None):
@@ -1974,9 +2000,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19742000
args, constructor, trace_patch, layout_patch
19752001
)
19762002
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
1977-
grouped = args["data_frame"].groupby(grouper, sort=False)
1978-
1979-
orders, sorted_group_names = get_orderings(args, grouper, grouped)
2003+
groups, orders = get_groups_and_orders(args, grouper)
19802004

19812005
col_labels = []
19822006
row_labels = []
@@ -2005,8 +2029,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
20052029
trendline_rows = []
20062030
trace_name_labels = None
20072031
facet_col_wrap = args.get("facet_col_wrap", 0)
2008-
for group_name in sorted_group_names:
2009-
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
2032+
for group_name, group in groups.items():
20102033
mapping_labels = OrderedDict()
20112034
trace_name_labels = OrderedDict()
20122035
frame_name = ""
@@ -2224,6 +2247,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
22242247
fig.update_layout(layout_patch)
22252248
if "template" in args and args["template"] is not None:
22262249
fig.update_layout(template=args["template"], overwrite=True)
2250+
for f in frame_list:
2251+
f["name"] = str(f["name"])
22272252
fig.frames = frame_list if len(frames) > 1 else []
22282253

22292254
if args.get("trendline") and args.get("trendline_scope", "trace") == "overall":

0 commit comments

Comments
 (0)