@@ -1920,40 +1920,66 @@ def infer_config(args, constructor, trace_patch, layout_patch):
1920
1920
return trace_specs , grouped_mappings , sizeref , show_colorbar
1921
1921
1922
1922
1923
- def get_orderings (args , grouper , grouped ):
1923
+ def get_groups_and_orders (args , grouper ):
1924
1924
"""
1925
1925
`orders` is the user-supplied ordering with the remaining data-frame-supplied
1926
1926
ordering appended if the column is used for grouping. It includes anything the user
1927
1927
gave, for any variable, including values not present in the dataset. It's a dict
1928
1928
where the keys are e.g. "x" or "color"
1929
1929
1930
- `sorted_group_names ` is the set of groups, ordered by the order above. It's a list
1931
- of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
1930
+ `groups ` is the dicts of groups, ordered by the order above. Its keys are
1931
+ tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
1932
1932
of a single dimension-group
1933
1933
"""
1934
-
1935
1934
orders = {} if "category_orders" not in args else args ["category_orders" ].copy ()
1935
+
1936
+ # figure out orders and what the single group name would be if there were one
1937
+ single_group_name = []
1938
+ unique_cache = dict ()
1936
1939
for col in grouper :
1937
- if col != one_group :
1938
- uniques = list (args ["data_frame" ][col ].unique ())
1940
+ if col == one_group :
1941
+ single_group_name .append ("" )
1942
+ else :
1943
+ if col not in unique_cache :
1944
+ unique_cache [col ] = list (args ["data_frame" ][col ].unique ())
1945
+ uniques = unique_cache [col ]
1946
+ if len (uniques ) == 1 :
1947
+ single_group_name .append (uniques [0 ])
1939
1948
if col not in orders :
1940
1949
orders [col ] = uniques
1941
1950
else :
1942
1951
orders [col ] = list (OrderedDict .fromkeys (list (orders [col ]) + uniques ))
1952
+ df = args ["data_frame" ]
1953
+ if len (single_group_name ) == len (grouper ):
1954
+ # we have a single group, so we can skip all group-by operations!
1955
+ groups = {tuple (single_group_name ): df }
1956
+ else :
1957
+ required_grouper = [g for g in grouper if g != one_group ]
1958
+ grouped = df .groupby (required_grouper , sort = False ) # skip one_group groupers
1959
+ group_indices = grouped .indices
1960
+ sorted_group_names = [
1961
+ g if len (required_grouper ) != 1 else (g ,) for g in group_indices
1962
+ ]
1943
1963
1944
- sorted_group_names = []
1945
- for group_name in grouped .groups :
1946
- if len (grouper ) == 1 :
1947
- group_name = (group_name ,)
1948
- sorted_group_names .append (group_name )
1949
-
1950
- for i , col in reversed (list (enumerate (grouper ))):
1951
- if col != one_group :
1964
+ for i , col in reversed (list (enumerate (required_grouper ))):
1952
1965
sorted_group_names = sorted (
1953
1966
sorted_group_names ,
1954
1967
key = lambda g : orders [col ].index (g [i ]) if g [i ] in orders [col ] else - 1 ,
1955
1968
)
1956
- return orders , sorted_group_names
1969
+
1970
+ # calculate the full group_names by inserting "" in the tuple index for one_group groups
1971
+ full_sorted_group_names = [list (t ) for t in sorted_group_names ]
1972
+ for i , col in enumerate (grouper ):
1973
+ if col == one_group :
1974
+ for g in full_sorted_group_names :
1975
+ g .insert (i , "" )
1976
+ full_sorted_group_names = [tuple (g ) for g in full_sorted_group_names ]
1977
+
1978
+ groups = {
1979
+ sf : grouped .get_group (s if len (s ) > 1 else s [0 ])
1980
+ for sf , s in zip (full_sorted_group_names , sorted_group_names )
1981
+ }
1982
+ return groups , orders
1957
1983
1958
1984
1959
1985
def make_figure (args , constructor , trace_patch = None , layout_patch = None ):
@@ -1974,9 +2000,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1974
2000
args , constructor , trace_patch , layout_patch
1975
2001
)
1976
2002
grouper = [x .grouper or one_group for x in grouped_mappings ] or [one_group ]
1977
- grouped = args ["data_frame" ].groupby (grouper , sort = False )
1978
-
1979
- orders , sorted_group_names = get_orderings (args , grouper , grouped )
2003
+ groups , orders = get_groups_and_orders (args , grouper )
1980
2004
1981
2005
col_labels = []
1982
2006
row_labels = []
@@ -2005,8 +2029,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
2005
2029
trendline_rows = []
2006
2030
trace_name_labels = None
2007
2031
facet_col_wrap = args .get ("facet_col_wrap" , 0 )
2008
- for group_name in sorted_group_names :
2009
- group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
2032
+ for group_name , group in groups .items ():
2010
2033
mapping_labels = OrderedDict ()
2011
2034
trace_name_labels = OrderedDict ()
2012
2035
frame_name = ""
@@ -2224,6 +2247,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
2224
2247
fig .update_layout (layout_patch )
2225
2248
if "template" in args and args ["template" ] is not None :
2226
2249
fig .update_layout (template = args ["template" ], overwrite = True )
2250
+ for f in frame_list :
2251
+ f ["name" ] = str (f ["name" ])
2227
2252
fig .frames = frame_list if len (frames ) > 1 else []
2228
2253
2229
2254
if args .get ("trendline" ) and args .get ("trendline_scope" , "trace" ) == "overall" :
0 commit comments