Skip to content

Commit

Permalink
Fix bug when setting some plot_kwargs to false (#134)
Browse files Browse the repository at this point in the history
* fix bug setting when setting some plot_kwargs to false

* remove references
  • Loading branch information
aloctavodia authored Feb 12, 2025
1 parent eafb6e0 commit e99f3f1
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 47 deletions.
4 changes: 2 additions & 2 deletions src/arviz_plots/plots/essplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,7 +548,7 @@ def plot_ess(
# plot x and y axis labels
# Add varnames as x and y labels
_, labels_aes, labels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims)
xlabel_kwargs = plot_kwargs.get("xlabel", {}).copy()
xlabel_kwargs = copy(plot_kwargs.get("xlabel", {}))
if xlabel_kwargs is not False:
if "color" not in labels_aes:
xlabel_kwargs.setdefault("color", "black")
Expand All @@ -565,7 +565,7 @@ def plot_ess(
)

_, labels_aes, labels_ignore = filter_aes(plot_collection, aes_map, "ylabel", sample_dims)
ylabel_kwargs = plot_kwargs.get("ylabel", {}).copy()
ylabel_kwargs = copy(plot_kwargs.get("ylabel", {}))
if ylabel_kwargs is not False:
if "color" not in labels_aes:
ylabel_kwargs.setdefault("color", "black")
Expand Down
4 changes: 2 additions & 2 deletions src/arviz_plots/plots/evolutionplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def compute_ess_dataset(
# plot x and y axis labels
# Add varnames as x and y labels
_, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims)
xlabel_kwargs = plot_kwargs.get("xlabel", {}).copy()
xlabel_kwargs = copy(plot_kwargs.get("xlabel", {}))
if xlabel_kwargs is not False:
if "color" not in xlabels_aes:
xlabel_kwargs.setdefault("color", "black")
Expand All @@ -618,7 +618,7 @@ def compute_ess_dataset(
)

_, ylabels_aes, ylabels_ignore = filter_aes(plot_collection, aes_map, "ylabel", sample_dims)
ylabel_kwargs = plot_kwargs.get("ylabel", {}).copy()
ylabel_kwargs = copy(plot_kwargs.get("ylabel", {}))
if ylabel_kwargs is not False:
if "color" not in ylabels_aes:
ylabel_kwargs.setdefault("color", "black")
Expand Down
4 changes: 2 additions & 2 deletions src/arviz_plots/plots/pavacalibrationplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def plot_pava_calibration(

# set xlabel
_, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims)
xlabel_kwargs = plot_kwargs.get("xlabel", {}).copy()
xlabel_kwargs = copy(plot_kwargs.get("xlabel", {}))
if xlabel_kwargs is not False:
if "color" not in xlabels_aes:
xlabel_kwargs.setdefault("color", "black")
Expand All @@ -271,7 +271,7 @@ def plot_pava_calibration(

# set ylabel
_, ylabels_aes, ylabels_ignore = filter_aes(plot_collection, aes_map, "ylabel", sample_dims)
ylabel_kwargs = plot_kwargs.get("ylabel", {}).copy()
ylabel_kwargs = copy(plot_kwargs.get("ylabel", {}))
if ylabel_kwargs is not False:
if "color" not in ylabels_aes:
ylabel_kwargs.setdefault("color", "black")
Expand Down
83 changes: 42 additions & 41 deletions src/arviz_plots/plots/psensequantitiesplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,29 +284,29 @@ def plot_psense_quantities(
prior_ms_kwargs.setdefault("marker", markers[0])
prior_ms_kwargs.setdefault("color", colors[3])

plot_collection.map(
scatter_xy,
"prior_markers",
data=quantities_ds.sel(component_group="prior"),
x=quantities_ds.alpha,
ignore_aes=prior_ms_ignore,
**prior_ms_kwargs,
)
plot_collection.map(
scatter_xy,
"prior_markers",
data=quantities_ds.sel(component_group="prior"),
x=quantities_ds.alpha,
ignore_aes=prior_ms_ignore,
**prior_ms_kwargs,
)
## lines
prior_ls_kwargs = copy(plot_kwargs.get("prior_lines", {}))

if prior_ls_kwargs is not False:
_, _, prior_ms_ignore = filter_aes(plot_collection, aes_map, "prior_lines", sample_dims)
prior_ls_kwargs.setdefault("color", colors[3])

plot_collection.map(
line_xy,
"prior_lines",
data=quantities_ds.sel(component_group="prior"),
x=quantities_ds.alpha,
ignore_aes=prior_ms_ignore,
**prior_ls_kwargs,
)
plot_collection.map(
line_xy,
"prior_lines",
data=quantities_ds.sel(component_group="prior"),
x=quantities_ds.alpha,
ignore_aes=prior_ms_ignore,
**prior_ls_kwargs,
)

# plot quantities for likelihood-perturbations
## markers
Expand All @@ -320,14 +320,14 @@ def plot_psense_quantities(
likelihood_ms_kwargs.setdefault("marker", markers[5])
likelihood_ms_kwargs.setdefault("color", colors[4])

plot_collection.map(
scatter_xy,
"likelihood_markers",
data=quantities_ds.sel(component_group="likelihood"),
x=quantities_ds.alpha,
ignore_aes=likelihood_ms_ignore,
**likelihood_ms_kwargs,
)
plot_collection.map(
scatter_xy,
"likelihood_markers",
data=quantities_ds.sel(component_group="likelihood"),
x=quantities_ds.alpha,
ignore_aes=likelihood_ms_ignore,
**likelihood_ms_kwargs,
)
## lines
likelihood_ls_kwargs = copy(plot_kwargs.get("likelihood_lines", {}))

Expand All @@ -338,14 +338,14 @@ def plot_psense_quantities(

likelihood_ls_kwargs.setdefault("color", colors[4])

plot_collection.map(
line_xy,
"prior_lines",
data=quantities_ds.sel(component_group="likelihood"),
x=quantities_ds.alpha,
ignore_aes=likelihood_ls_ignore,
**likelihood_ls_kwargs,
)
plot_collection.map(
line_xy,
"prior_lines",
data=quantities_ds.sel(component_group="likelihood"),
x=quantities_ds.alpha,
ignore_aes=likelihood_ls_ignore,
**likelihood_ls_kwargs,
)

# plot mcse
if mcse:
Expand Down Expand Up @@ -374,7 +374,7 @@ def plot_psense_quantities(

# set xlabel
_, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims)
xlabel_kwargs = plot_kwargs.get("xlabel", {}).copy()
xlabel_kwargs = copy(plot_kwargs.get("xlabel", {}))
if xlabel_kwargs is not False:
if "color" not in xlabels_aes:
xlabel_kwargs.setdefault("color", "black")
Expand All @@ -393,13 +393,14 @@ def plot_psense_quantities(
title_kwargs = copy(plot_kwargs.get("title", {}))
_, _, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims)

plot_collection.map(
labelled_title,
"title",
ignore_aes=title_ignore,
subset_info=True,
labeller=labeller,
**title_kwargs,
)
if title_kwargs is not False:
plot_collection.map(
labelled_title,
"title",
ignore_aes=title_ignore,
subset_info=True,
labeller=labeller,
**title_kwargs,
)

return plot_collection

0 comments on commit e99f3f1

Please sign in to comment.