Skip to content

Commit

Permalink
fix bokeh legend generation
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Mar 13, 2024
1 parent f624885 commit 46828e6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
12 changes: 7 additions & 5 deletions src/arviz_plots/backend/bokeh/legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,16 @@ def legend(
artist_kwargs = {}
if legend_target is None:
legend_target = (0, -1)
# TODO: improve selection of Figure object from what is stored as "chart"
children = target.children
if not isinstance(children[0], tuple):
children = children[1].children
plots = [child[0] for child in children]
row_id = np.array([child[1] for child in children], dtype=int)
col_id = np.array([child[1] for child in children], dtype=int)
col_id = np.array([child[2] for child in children], dtype=int)
legend_id = np.argmax(
row_id
== np.unique(row_id)[legend_target[0]] & col_id
== np.unique(col_id)[legend_target[1]]
(row_id == np.unique(row_id)[legend_target[0]])
& (col_id == np.unique(col_id)[legend_target[1]])
)
target_plot = plots[legend_id]
if target_plot.legend:
Expand All @@ -60,7 +62,7 @@ def legend(
glyph = artist_fun(**{**artist_kwargs, **kws})
glyph_list.append(glyph)
leg = Legend(
items=[(label, [glyph]) for label, glyph in zip(label_list, glyph_list)],
items=[(str(label), [glyph]) for label, glyph in zip(label_list, glyph_list)],
title=title,
**kwargs,
)
Expand Down
3 changes: 2 additions & 1 deletion src/arviz_plots/plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,8 @@ def add_legend(self, dim, var_name=None, aes=None, artist_kwargs=None, title=Non
)
aes_ds = aes_ds.drop_dims([d for d in aes_ds.dims if d != dim])
if aes is None:
aes_ds = aes_ds.drop_vars(("y", "x"), errors="ignore")
dropped_vars = ["x", "y"] + [name for name, da in aes_ds.items() if dim not in da.dims]
aes_ds = aes_ds.drop_vars(dropped_vars, errors="ignore")
else:
if isinstance(aes, str):
aes = [aes]
Expand Down

0 comments on commit 46828e6

Please sign in to comment.