Skip to content

Commit 585e7da

Browse files
authored
colorbar: ax instead of ax1 & ax2 (#107)
* colorbar: ax instead of ax1 & ax2 * changelog * update example * Apply suggestions from code review * add docs for pad * fix for 2D numpy arrays
1 parent e500db6 commit 585e7da

File tree

6 files changed

+243
-87
lines changed

6 files changed

+243
-87
lines changed

CHANGELOG.md

+21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,26 @@
11
# Changelog
22

3+
4+
## v0.5.0 (unreleased)
5+
6+
### Breaking changes
7+
8+
- The `ax1` and `ax2` arguments of `mpu.colorbar` have been combined into `ax` ([#107](https://github.com/mathause/mplotutils/pull/107))
9+
To update
10+
11+
```diff
12+
- mpu.colorbar(h, axs[0], axs[1])
13+
+ mpu.colorbar(h, [axs[0], axs[1]])
14+
```
15+
or
16+
```diff
17+
- mpu.colorbar(h, axs[0], axs[1])
18+
+ mpu.colorbar(h, axs)
19+
```
20+
- When passing `size` to `mpu.colorbar` it now uses the height/ width of _all_ passed axes to scale the colorbar. This is consistent with `plt.colorbar` but may lead to differences compared to the previous version ([#107](https://github.com/mathause/mplotutils/pull/107)).
21+
- Similarly for `pad`, which is also scaled by the height/ width of _all_ passed axes. This is consistent with `plt.colorbar` but may change the padding of the colorbar compared to the previous version ([#107](https://github.com/mathause/mplotutils/pull/107)).
22+
23+
324
## v0.4.0 (23.02.2024)
425

526
Version 0.4.0 simplifies the way figures with a `mpu.colorbar` have to be saved and

docs/example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def plot_map_mpu():
4141
ax.coastlines()
4242

4343
# add colorbar using mplotutils
44-
mpu.colorbar(mappable=mappable, ax1=axs[1], ax2=axs[3])
44+
mpu.colorbar(mappable=mappable, ax=axs)
4545

4646
# adjust the margins manually
4747
f.subplots_adjust(left=0.025, right=0.85, top=0.9, bottom=0.05)

docs/example_mpu.png

67 Bytes
Loading

docs/example_no_mpu.png

40 Bytes
Loading

mplotutils/_colorbar.py

+90-70
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,59 @@
11
import warnings
22

3+
import matplotlib as mpl
34
import matplotlib.pyplot as plt
5+
import matplotlib.transforms as mtransforms
46
import numpy as np
57

68
from mplotutils._deprecate import _deprecate_positional_args
79

810

11+
def _deprecate_ax1_ax2(ax, ax2, ax1):
12+
if ax is None:
13+
if ax1 is None:
14+
raise TypeError("colorbar() missing 1 required positional argument: 'ax'")
15+
16+
if ax2 is None:
17+
ax = ax1
18+
warnings.warn(
19+
"`ax1` has been deprecated in favor of `ax`",
20+
FutureWarning,
21+
stacklevel=4,
22+
)
23+
24+
else:
25+
ax = [ax1, ax2]
26+
warnings.warn(
27+
"`ax1` and `ax2` has been deprecated in favor of `ax`, i.e. pass ax=[ax1, ax2]",
28+
FutureWarning,
29+
stacklevel=4,
30+
)
31+
32+
else:
33+
34+
if ax1 is not None and ax2 is not None:
35+
raise TypeError("Cannot pass `ax`, and `ax2`")
36+
37+
if ax2 is not None and np.ndim(ax) != 0:
38+
raise TypeError("Cannot pass ax2 in addition to a list of axes")
39+
40+
if ax2 is not None:
41+
ax = [ax, ax2]
42+
43+
warnings.warn(
44+
"Passing axes individually has been deprecated in favor of passing them"
45+
" as list, i.e. pass ``ax=[ax1, ax2]``, or ``ax=axs``",
46+
FutureWarning,
47+
stacklevel=4,
48+
)
49+
50+
return ax
51+
52+
953
@_deprecate_positional_args("0.3")
1054
def colorbar(
1155
mappable,
12-
ax1,
56+
ax=None,
1357
ax2=None,
1458
*,
1559
orientation="vertical",
@@ -20,32 +64,29 @@ def colorbar(
2064
shrink=None,
2165
**kwargs,
2266
):
23-
"""
24-
automatically resize colorbars on draw
67+
"""colorbar that adjusts to the axes height (and automatically resizes)
2568
2669
See below for Example
2770
2871
Parameters
2972
----------
3073
mappable : handle
3174
The `matplotlib.cm.ScalarMappable` described by this colorbar.
32-
ax1 : `matplotlib.axes.Axes`
33-
The axes to adjust the colorbar to.
34-
ax2 : `~matplotlib.axes.Axes`, default: None.
35-
If the colorbar should span more than one axes.
75+
ax : `~matplotlib.axes.Axes` or iterable or `numpy.ndarray` of Axes
76+
One or more parent Axes of the colorbar.
3677
orientation : 'vertical' | 'horizontal'. Default: 'vertical'.
3778
Orientation of the colorbar.
3879
aspect : float, default 20.
39-
The ratio of long to short dimensions of the colorbar Mutually exclusive with
80+
The ratio of long to short dimensions of the colorbar. Mutually exclusive with
4081
`size`.
4182
size : float, default: None
42-
Width of the colorbar as fraction of the axes width (vertical) or
83+
Width of the colorbar as fraction of all parent axes width (vertical) or
4384
height (horizontal). Mutually exclusive with `aspect`.
4485
pad : float, default: None.
45-
Distance of the colorbar to the axes in Figure coordinates.
46-
Default: 0.05 (vertical) or 0.15 (horizontal).
86+
Distance between axes and colorbar. In fraction of parent axes.
87+
Default: 0.05 (vertical) or 0.15 (horizontal).
4788
shift : 'symmetric' or float in 0..1, default: 'symmetric'
48-
Fraction of the total height that the colorbar is shifted upwards. See Note.
89+
Fraction of the total height that the colorbar is shifted up/ right. See Note.
4990
shrink : None or float in 0..1, default: None.
5091
Fraction of the total height that the colorbar is shrunk. See Note.
5192
**kwargs : keyword arguments
@@ -119,7 +160,7 @@ def colorbar(
119160
ax.set_global()
120161
h = ax.pcolormesh([[0, 1]])
121162
122-
cbar = mpu.colorbar(h, axs[0], axs[1])
163+
cbar = mpu.colorbar(h, axs)
123164
124165
# =========================
125166
# example with 3 axes & 2 colorbars
@@ -134,7 +175,7 @@ def colorbar(
134175
h1 = ax.pcolormesh([[0, 1]])
135176
h2 = ax.pcolormesh([[0, 1]], cmap='Blues')
136177
137-
cbar = mpu.colorbar(h, axs[0], axs[1], size=0.05)
178+
cbar = mpu.colorbar(h, [axs[0], axs[1]], size=0.05)
138179
cbar = mpu.colorbar(h, axs[2], size=0.05)
139180
140181
plt.draw()
@@ -151,6 +192,9 @@ def colorbar(
151192
plt.colorbar
152193
"""
153194

195+
ax = _deprecate_ax1_ax2(ax, ax2, kwargs.pop("ax1", None))
196+
axs = np.asarray(ax).flatten()
197+
154198
if orientation not in ("vertical", "horizontal"):
155199
raise ValueError("orientation must be 'vertical' or 'horizontal'")
156200

@@ -159,17 +203,13 @@ def colorbar(
159203
msg = "'anchor' and 'panchor' keywords not supported, use 'shrink' and 'shift'"
160204
raise ValueError(msg)
161205

162-
# ensure 'ax' does not end up in plt.colorbar(**kwargs)
163-
if "ax" in k:
164-
if ax2 is not None:
165-
raise ValueError("Cannot pass `ax`, and `ax2`")
166-
# assume it is ax2 (it can't be ax1)
167-
ax2 = kwargs.pop("ax")
206+
if not all(isinstance(ax, mpl.axes.Axes) for ax in axs):
207+
raise TypeError("ax must be of Type mpl.axes.Axes")
168208

169-
f = ax1.get_figure()
209+
f = axs[0].get_figure()
170210

171-
if ax2 is not None and f != ax2.get_figure():
172-
raise ValueError("'ax1' and 'ax2' must belong to the same figure")
211+
if not all(f == ax.get_figure() for ax in axs):
212+
raise TypeError("All passed axes must belong to the same figure")
173213

174214
cbax = _get_cbax(f)
175215

@@ -178,8 +218,8 @@ def colorbar(
178218
if orientation == "vertical":
179219
func = _resize_colorbar_vert(
180220
cbax,
181-
ax1,
182-
ax2=ax2,
221+
f,
222+
axs,
183223
aspect=aspect,
184224
size=size,
185225
pad=pad,
@@ -189,8 +229,8 @@ def colorbar(
189229
else:
190230
func = _resize_colorbar_horz(
191231
cbax,
192-
ax1,
193-
ax2=ax2,
232+
f,
233+
axs,
194234
aspect=aspect,
195235
size=size,
196236
pad=pad,
@@ -219,8 +259,8 @@ def _get_cbax(f):
219259

220260
def _resize_colorbar_vert(
221261
cbax,
222-
ax1,
223-
ax2=None,
262+
f,
263+
axs,
224264
aspect=None,
225265
size=None,
226266
pad=None,
@@ -263,40 +303,31 @@ def _resize_colorbar_vert(
263303

264304
size, aspect, pad = _parse_size_aspect_pad(size, aspect, pad, "vertical")
265305

266-
f = ax1.get_figure()
267-
268-
# swap axes if ax1 is above ax2
269-
if ax2 is not None:
270-
posn1 = ax1.get_position()
271-
posn2 = ax2.get_position()
272-
273-
ax1, ax2 = (ax1, ax2) if posn1.y0 < posn2.y0 else (ax2, ax1)
274-
275306
if aspect is not None:
276307
anchor = (0, 0.5)
277308
cbax.set_anchor(anchor)
278309
cbax.set_box_aspect(aspect)
279310

280311
# inner function is called by event handler
281312
def inner(event=None):
282-
pos1 = ax1.get_position()
313+
314+
# from mpl.colorbar (but not using ax.get_position(original=True).frozen())
315+
parents_bbox = mtransforms.Bbox.union([ax.get_position() for ax in axs])
283316

284317
# determine total height of all axes
285-
if ax2 is None:
286-
full_height = pos1.height
287-
else:
288-
pos2 = ax2.get_position()
289-
full_height = pos2.y0 - pos1.y0 + pos2.height
318+
full_height = parents_bbox.height
290319

291-
pad_scaled = pad * pos1.width
320+
pad_scaled = pad * parents_bbox.width
292321

293322
# calculate position of cbax
294-
left = pos1.x0 + pos1.width + pad_scaled
295-
bottom = pos1.y0 + shift * full_height
323+
left = parents_bbox.x1 + pad_scaled
324+
325+
bottom = parents_bbox.y0 + shift * full_height
326+
296327
height = (1 - shrink) * full_height
297328

298329
if aspect is None:
299-
size_scaled = size * pos1.width
330+
size_scaled = size * parents_bbox.width
300331
width = size_scaled
301332
else:
302333
figure_aspect = np.divide(*f.get_size_inches())
@@ -314,8 +345,8 @@ def inner(event=None):
314345

315346
def _resize_colorbar_horz(
316347
cbax,
317-
ax1,
318-
ax2=None,
348+
f,
349+
axs,
319350
aspect=None,
320351
size=None,
321352
pad=None,
@@ -357,43 +388,32 @@ def _resize_colorbar_horz(
357388

358389
size, aspect, pad = _parse_size_aspect_pad(size, aspect, pad, "horizontal")
359390

360-
f = ax1.get_figure()
361-
362-
if ax2 is not None:
363-
posn1 = ax1.get_position()
364-
posn2 = ax2.get_position()
365-
366-
# swap axes if ax1 is right of ax2
367-
ax1, ax2 = (ax1, ax2) if posn1.x0 < posn2.x0 else (ax2, ax1)
368-
369391
if aspect is not None:
370392
aspect = 1 / aspect
371393
anchor = (0.5, 1.0)
372394
cbax.set_anchor(anchor)
373395
cbax.set_box_aspect(aspect)
374396

375397
def inner(event=None):
376-
posn1 = ax1.get_position()
377398

378-
if ax2 is None:
379-
full_width = posn1.width
380-
else:
381-
posn2 = ax2.get_position()
382-
full_width = posn2.x0 - posn1.x0 + posn2.width
399+
# from mpl.colorbar (but not using ax.get_position(original=True).frozen())
400+
parents_bbox = mtransforms.Bbox.union([ax.get_position() for ax in axs])
401+
402+
full_width = parents_bbox.width
383403

384-
pad_scaled = pad * posn1.height
404+
pad_scaled = pad * parents_bbox.height
385405

386-
width = full_width - shrink * full_width
406+
width = (1 - shrink) * full_width
387407

388408
if aspect is None:
389-
size_scaled = size * posn1.height
409+
size_scaled = size * parents_bbox.height
390410
height = size_scaled
391411
else:
392412
figure_aspect = np.divide(*f.get_size_inches())
393413
height = width * (aspect * figure_aspect)
394414

395-
left = posn1.x0 + shift * full_width
396-
bottom = posn1.y0 - (pad_scaled + height)
415+
left = parents_bbox.x0 + shift * full_width
416+
bottom = parents_bbox.y0 - (pad_scaled + height)
397417

398418
pos = [left, bottom, width, height]
399419

0 commit comments

Comments
 (0)