From 46828e6c9bfc56bcf2334991d2f975475b695b97 Mon Sep 17 00:00:00 2001 From: "Oriol (ProDesk)" Date: Wed, 13 Mar 2024 03:15:26 +0100 Subject: [PATCH] fix bokeh legend generation --- src/arviz_plots/backend/bokeh/legend.py | 12 +++++++----- src/arviz_plots/plot_collection.py | 3 ++- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/arviz_plots/backend/bokeh/legend.py b/src/arviz_plots/backend/bokeh/legend.py index 5948459..4248c6a 100644 --- a/src/arviz_plots/backend/bokeh/legend.py +++ b/src/arviz_plots/backend/bokeh/legend.py @@ -37,14 +37,16 @@ def legend( artist_kwargs = {} if legend_target is None: legend_target = (0, -1) + # TODO: improve selection of Figure object from what is stored as "chart" children = target.children + if not isinstance(children[0], tuple): + children = children[1].children plots = [child[0] for child in children] row_id = np.array([child[1] for child in children], dtype=int) - col_id = np.array([child[1] for child in children], dtype=int) + col_id = np.array([child[2] for child in children], dtype=int) legend_id = np.argmax( - row_id - == np.unique(row_id)[legend_target[0]] & col_id - == np.unique(col_id)[legend_target[1]] + (row_id == np.unique(row_id)[legend_target[0]]) + & (col_id == np.unique(col_id)[legend_target[1]]) ) target_plot = plots[legend_id] if target_plot.legend: @@ -60,7 +62,7 @@ def legend( glyph = artist_fun(**{**artist_kwargs, **kws}) glyph_list.append(glyph) leg = Legend( - items=[(label, [glyph]) for label, glyph in zip(label_list, glyph_list)], + items=[(str(label), [glyph]) for label, glyph in zip(label_list, glyph_list)], title=title, **kwargs, ) diff --git a/src/arviz_plots/plot_collection.py b/src/arviz_plots/plot_collection.py index b333d3e..d33bec3 100644 --- a/src/arviz_plots/plot_collection.py +++ b/src/arviz_plots/plot_collection.py @@ -681,7 +681,8 @@ def add_legend(self, dim, var_name=None, aes=None, artist_kwargs=None, title=Non ) aes_ds = aes_ds.drop_dims([d for d in aes_ds.dims if d != dim]) if aes is None: - aes_ds = aes_ds.drop_vars(("y", "x"), errors="ignore") + dropped_vars = ["x", "y"] + [name for name, da in aes_ds.items() if dim not in da.dims] + aes_ds = aes_ds.drop_vars(dropped_vars, errors="ignore") else: if isinstance(aes, str): aes = [aes]