1
1
import warnings
2
2
3
+ import matplotlib as mpl
3
4
import matplotlib .pyplot as plt
5
+ import matplotlib .transforms as mtransforms
4
6
import numpy as np
5
7
6
8
from mplotutils ._deprecate import _deprecate_positional_args
7
9
8
10
11
+ def _deprecate_ax1_ax2 (ax , ax2 , ax1 ):
12
+ if ax is None :
13
+ if ax1 is None :
14
+ raise TypeError ("colorbar() missing 1 required positional argument: 'ax'" )
15
+
16
+ if ax2 is None :
17
+ ax = ax1
18
+ warnings .warn (
19
+ "`ax1` has been deprecated in favor of `ax`" ,
20
+ FutureWarning ,
21
+ stacklevel = 4 ,
22
+ )
23
+
24
+ else :
25
+ ax = [ax1 , ax2 ]
26
+ warnings .warn (
27
+ "`ax1` and `ax2` has been deprecated in favor of `ax`, i.e. pass ax=[ax1, ax2]" ,
28
+ FutureWarning ,
29
+ stacklevel = 4 ,
30
+ )
31
+
32
+ else :
33
+
34
+ if ax1 is not None and ax2 is not None :
35
+ raise TypeError ("Cannot pass `ax`, and `ax2`" )
36
+
37
+ if ax2 is not None and np .ndim (ax ) != 0 :
38
+ raise TypeError ("Cannot pass ax2 in addition to a list of axes" )
39
+
40
+ if ax2 is not None :
41
+ ax = [ax , ax2 ]
42
+
43
+ warnings .warn (
44
+ "Passing axes individually has been deprecated in favor of passing them"
45
+ " as list, i.e. pass ``ax=[ax1, ax2]``, or ``ax=axs``" ,
46
+ FutureWarning ,
47
+ stacklevel = 4 ,
48
+ )
49
+
50
+ return ax
51
+
52
+
9
53
@_deprecate_positional_args ("0.3" )
10
54
def colorbar (
11
55
mappable ,
12
- ax1 ,
56
+ ax = None ,
13
57
ax2 = None ,
14
58
* ,
15
59
orientation = "vertical" ,
@@ -20,32 +64,29 @@ def colorbar(
20
64
shrink = None ,
21
65
** kwargs ,
22
66
):
23
- """
24
- automatically resize colorbars on draw
67
+ """colorbar that adjusts to the axes height (and automatically resizes)
25
68
26
69
See below for Example
27
70
28
71
Parameters
29
72
----------
30
73
mappable : handle
31
74
The `matplotlib.cm.ScalarMappable` described by this colorbar.
32
- ax1 : `matplotlib.axes.Axes`
33
- The axes to adjust the colorbar to.
34
- ax2 : `~matplotlib.axes.Axes`, default: None.
35
- If the colorbar should span more than one axes.
75
+ ax : `~matplotlib.axes.Axes` or iterable or `numpy.ndarray` of Axes
76
+ One or more parent Axes of the colorbar.
36
77
orientation : 'vertical' | 'horizontal'. Default: 'vertical'.
37
78
Orientation of the colorbar.
38
79
aspect : float, default 20.
39
- The ratio of long to short dimensions of the colorbar Mutually exclusive with
80
+ The ratio of long to short dimensions of the colorbar. Mutually exclusive with
40
81
`size`.
41
82
size : float, default: None
42
- Width of the colorbar as fraction of the axes width (vertical) or
83
+ Width of the colorbar as fraction of all parent axes width (vertical) or
43
84
height (horizontal). Mutually exclusive with `aspect`.
44
85
pad : float, default: None.
45
- Distance of the colorbar to the axes in Figure coordinates .
46
- Default: 0.05 (vertical) or 0.15 (horizontal).
86
+ Distance between axes and colorbar. In fraction of parent axes .
87
+ Default: 0.05 (vertical) or 0.15 (horizontal).
47
88
shift : 'symmetric' or float in 0..1, default: 'symmetric'
48
- Fraction of the total height that the colorbar is shifted upwards . See Note.
89
+ Fraction of the total height that the colorbar is shifted up/ right . See Note.
49
90
shrink : None or float in 0..1, default: None.
50
91
Fraction of the total height that the colorbar is shrunk. See Note.
51
92
**kwargs : keyword arguments
@@ -119,7 +160,7 @@ def colorbar(
119
160
ax.set_global()
120
161
h = ax.pcolormesh([[0, 1]])
121
162
122
- cbar = mpu.colorbar(h, axs[0], axs[1] )
163
+ cbar = mpu.colorbar(h, axs)
123
164
124
165
# =========================
125
166
# example with 3 axes & 2 colorbars
@@ -134,7 +175,7 @@ def colorbar(
134
175
h1 = ax.pcolormesh([[0, 1]])
135
176
h2 = ax.pcolormesh([[0, 1]], cmap='Blues')
136
177
137
- cbar = mpu.colorbar(h, axs[0], axs[1], size=0.05)
178
+ cbar = mpu.colorbar(h, [ axs[0], axs[1] ], size=0.05)
138
179
cbar = mpu.colorbar(h, axs[2], size=0.05)
139
180
140
181
plt.draw()
@@ -151,6 +192,9 @@ def colorbar(
151
192
plt.colorbar
152
193
"""
153
194
195
+ ax = _deprecate_ax1_ax2 (ax , ax2 , kwargs .pop ("ax1" , None ))
196
+ axs = np .asarray (ax ).flatten ()
197
+
154
198
if orientation not in ("vertical" , "horizontal" ):
155
199
raise ValueError ("orientation must be 'vertical' or 'horizontal'" )
156
200
@@ -159,17 +203,13 @@ def colorbar(
159
203
msg = "'anchor' and 'panchor' keywords not supported, use 'shrink' and 'shift'"
160
204
raise ValueError (msg )
161
205
162
- # ensure 'ax' does not end up in plt.colorbar(**kwargs)
163
- if "ax" in k :
164
- if ax2 is not None :
165
- raise ValueError ("Cannot pass `ax`, and `ax2`" )
166
- # assume it is ax2 (it can't be ax1)
167
- ax2 = kwargs .pop ("ax" )
206
+ if not all (isinstance (ax , mpl .axes .Axes ) for ax in axs ):
207
+ raise TypeError ("ax must be of Type mpl.axes.Axes" )
168
208
169
- f = ax1 .get_figure ()
209
+ f = axs [ 0 ] .get_figure ()
170
210
171
- if ax2 is not None and f != ax2 .get_figure ():
172
- raise ValueError ( "'ax1' and 'ax2' must belong to the same figure" )
211
+ if not all ( f == ax .get_figure () for ax in axs ):
212
+ raise TypeError ( "All passed axes must belong to the same figure" )
173
213
174
214
cbax = _get_cbax (f )
175
215
@@ -178,8 +218,8 @@ def colorbar(
178
218
if orientation == "vertical" :
179
219
func = _resize_colorbar_vert (
180
220
cbax ,
181
- ax1 ,
182
- ax2 = ax2 ,
221
+ f ,
222
+ axs ,
183
223
aspect = aspect ,
184
224
size = size ,
185
225
pad = pad ,
@@ -189,8 +229,8 @@ def colorbar(
189
229
else :
190
230
func = _resize_colorbar_horz (
191
231
cbax ,
192
- ax1 ,
193
- ax2 = ax2 ,
232
+ f ,
233
+ axs ,
194
234
aspect = aspect ,
195
235
size = size ,
196
236
pad = pad ,
@@ -219,8 +259,8 @@ def _get_cbax(f):
219
259
220
260
def _resize_colorbar_vert (
221
261
cbax ,
222
- ax1 ,
223
- ax2 = None ,
262
+ f ,
263
+ axs ,
224
264
aspect = None ,
225
265
size = None ,
226
266
pad = None ,
@@ -263,40 +303,31 @@ def _resize_colorbar_vert(
263
303
264
304
size , aspect , pad = _parse_size_aspect_pad (size , aspect , pad , "vertical" )
265
305
266
- f = ax1 .get_figure ()
267
-
268
- # swap axes if ax1 is above ax2
269
- if ax2 is not None :
270
- posn1 = ax1 .get_position ()
271
- posn2 = ax2 .get_position ()
272
-
273
- ax1 , ax2 = (ax1 , ax2 ) if posn1 .y0 < posn2 .y0 else (ax2 , ax1 )
274
-
275
306
if aspect is not None :
276
307
anchor = (0 , 0.5 )
277
308
cbax .set_anchor (anchor )
278
309
cbax .set_box_aspect (aspect )
279
310
280
311
# inner function is called by event handler
281
312
def inner (event = None ):
282
- pos1 = ax1 .get_position ()
313
+
314
+ # from mpl.colorbar (but not using ax.get_position(original=True).frozen())
315
+ parents_bbox = mtransforms .Bbox .union ([ax .get_position () for ax in axs ])
283
316
284
317
# determine total height of all axes
285
- if ax2 is None :
286
- full_height = pos1 .height
287
- else :
288
- pos2 = ax2 .get_position ()
289
- full_height = pos2 .y0 - pos1 .y0 + pos2 .height
318
+ full_height = parents_bbox .height
290
319
291
- pad_scaled = pad * pos1 .width
320
+ pad_scaled = pad * parents_bbox .width
292
321
293
322
# calculate position of cbax
294
- left = pos1 .x0 + pos1 .width + pad_scaled
295
- bottom = pos1 .y0 + shift * full_height
323
+ left = parents_bbox .x1 + pad_scaled
324
+
325
+ bottom = parents_bbox .y0 + shift * full_height
326
+
296
327
height = (1 - shrink ) * full_height
297
328
298
329
if aspect is None :
299
- size_scaled = size * pos1 .width
330
+ size_scaled = size * parents_bbox .width
300
331
width = size_scaled
301
332
else :
302
333
figure_aspect = np .divide (* f .get_size_inches ())
@@ -314,8 +345,8 @@ def inner(event=None):
314
345
315
346
def _resize_colorbar_horz (
316
347
cbax ,
317
- ax1 ,
318
- ax2 = None ,
348
+ f ,
349
+ axs ,
319
350
aspect = None ,
320
351
size = None ,
321
352
pad = None ,
@@ -357,43 +388,32 @@ def _resize_colorbar_horz(
357
388
358
389
size , aspect , pad = _parse_size_aspect_pad (size , aspect , pad , "horizontal" )
359
390
360
- f = ax1 .get_figure ()
361
-
362
- if ax2 is not None :
363
- posn1 = ax1 .get_position ()
364
- posn2 = ax2 .get_position ()
365
-
366
- # swap axes if ax1 is right of ax2
367
- ax1 , ax2 = (ax1 , ax2 ) if posn1 .x0 < posn2 .x0 else (ax2 , ax1 )
368
-
369
391
if aspect is not None :
370
392
aspect = 1 / aspect
371
393
anchor = (0.5 , 1.0 )
372
394
cbax .set_anchor (anchor )
373
395
cbax .set_box_aspect (aspect )
374
396
375
397
def inner (event = None ):
376
- posn1 = ax1 .get_position ()
377
398
378
- if ax2 is None :
379
- full_width = posn1 .width
380
- else :
381
- posn2 = ax2 .get_position ()
382
- full_width = posn2 .x0 - posn1 .x0 + posn2 .width
399
+ # from mpl.colorbar (but not using ax.get_position(original=True).frozen())
400
+ parents_bbox = mtransforms .Bbox .union ([ax .get_position () for ax in axs ])
401
+
402
+ full_width = parents_bbox .width
383
403
384
- pad_scaled = pad * posn1 .height
404
+ pad_scaled = pad * parents_bbox .height
385
405
386
- width = full_width - shrink * full_width
406
+ width = ( 1 - shrink ) * full_width
387
407
388
408
if aspect is None :
389
- size_scaled = size * posn1 .height
409
+ size_scaled = size * parents_bbox .height
390
410
height = size_scaled
391
411
else :
392
412
figure_aspect = np .divide (* f .get_size_inches ())
393
413
height = width * (aspect * figure_aspect )
394
414
395
- left = posn1 .x0 + shift * full_width
396
- bottom = posn1 .y0 - (pad_scaled + height )
415
+ left = parents_bbox .x0 + shift * full_width
416
+ bottom = parents_bbox .y0 - (pad_scaled + height )
397
417
398
418
pos = [left , bottom , width , height ]
399
419
0 commit comments