Skip to content

Commit 235bfbc

Browse files
committed
Updated to super plotnine 0.13.x (in middle)
1 parent a8da1d5 commit 235bfbc

File tree

2 files changed

+186
-26
lines changed

2 files changed

+186
-26
lines changed

patchworklib/modified_plotnine.py

+112
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,117 @@
1+
import numpy as np
2+
import itertools
3+
from plotnine.facets.strips import Strips
14
from copy import deepcopy
25

6+
def make_figure(self, figure=None):
7+
"""
8+
Create and return Matplotlib figure and subplot axes
9+
"""
10+
num_panels = len(self.layout.layout)
11+
axsarr = np.empty((self.nrow, self.ncol), dtype=object)
12+
13+
# Create figure & gridspec
14+
if figure is None:
15+
figure, gs = self._make_figure()
16+
else:
17+
_, gs = self._make_figure()
18+
self.grid_spec = gs
19+
20+
# Create axes
21+
it = itertools.product(range(self.nrow), range(self.ncol))
22+
for i, (row, col) in enumerate(it):
23+
axsarr[row, col] = figure.add_subplot(gs[i])
24+
25+
# Rearrange axes
26+
# They are ordered to match the positions in the layout table
27+
if self.dir == "h":
28+
order: Literal["C", "F"] = "C"
29+
if not self.as_table:
30+
axsarr = axsarr[::-1]
31+
elif self.dir == "v":
32+
order = "F"
33+
if not self.as_table:
34+
axsarr = np.array([row[::-1] for row in axsarr])
35+
else:
36+
raise ValueError(f'Bad value `dir="{self.dir}"` for direction')
37+
38+
axs = axsarr.ravel(order)
39+
40+
# Delete unused axes
41+
for ax in axs[num_panels:]:
42+
figure.delaxes(ax)
43+
axs = axs[:num_panels]
44+
return figure, list(axs)
45+
46+
def setup(self, figure, plot):
47+
self.plot = plot
48+
self.layout = plot.layout
49+
50+
if hasattr(plot, "figure"):
51+
self.figure, self.axs = plot.figure, plot.axs
52+
else:
53+
self.figure, self.axs = self.make_figure(self, figure=figure)
54+
55+
self.coordinates = plot.coordinates
56+
self.theme = plot.theme
57+
self.layout.axs = self.axs
58+
self.strips = Strips.from_facet(self)
59+
return self.figure, self.axs
60+
61+
def newdraw(self, return_ggplot=False, show: bool = False):
62+
"""
63+
Render the complete plot
64+
65+
Parameters
66+
----------
67+
show :
68+
Whether to show the plot.
69+
70+
Returns
71+
-------
72+
:
73+
Matplotlib figure
74+
"""
75+
import matplotlib as mpl
76+
from plotnine._mpl.layout_engine import PlotnineLayoutEngine
77+
from plotnine.ggplot import plot_context
78+
79+
# Do not draw if drawn already.
80+
# This prevents a needless error when reusing
81+
# figure & axes in the jupyter notebook.
82+
if hasattr(self, "figure"):
83+
return self.figure
84+
85+
# Prevent against any modifications to the users
86+
# ggplot object. Do the copy here as we may/may not
87+
# assign a default theme
88+
self = deepcopy(self)
89+
with plot_context(self, show=show):
90+
self._build()
91+
92+
# setup
93+
self.figure, self.axs = self.facet.setup(self)
94+
self.guides._setup(self)
95+
self.theme.setup(self)
96+
97+
# Drawing
98+
self._draw_layers()
99+
self._draw_panel_borders()
100+
self._draw_breaks_and_labels()
101+
self.guides.draw()
102+
self._draw_figure_texts()
103+
self._draw_watermarks()
104+
105+
# Artist object theming
106+
self.theme.apply()
107+
self.figure.set_layout_engine(PlotnineLayoutEngine(self))
108+
if return_ggplot == True:
109+
return self.figure, self
110+
else:
111+
return self.figure
112+
return self.figure
113+
114+
3115
def draw(self, return_ggplot=False, show: bool = False):
4116
"""
5117
Render the complete plot

patchworklib/patchworklib.py

+74-26
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,12 @@ def _reset_ggplot_legend(bricks):
210210
pass
211211

212212
def overwrite_plotnine():
213-
plotnine.ggplot.draw = mp9.draw
213+
if StrictVersion(plotnine.__version__) >= StrictVersion("0.13"):
214+
plotnine.ggplot.draw = mp9.newdraw
215+
plotnine.facets.facet.setup = mp9.setup
216+
plotnine.facets.facet.make_figure = mp9.make_figure
217+
else:
218+
plotnine.ggplot.draw = mp9.draw
214219

215220
def load_ggplot(ggplot=None, figsize=None):
216221
"""
@@ -277,8 +282,28 @@ def draw_labels(bricks, gori, gcp, figsize):
277282
else:
278283
xlabel = bricks.set_xlabel(labels.x, labelpad=pad_x, va="top")
279284
ylabel = bricks.set_ylabel(labels.y, labelpad=pad_y)
285+
286+
if StrictVersion(plotnine_version) >= StrictVersion("0.13"):
287+
gori.theme.targets.axis_title_x = xlabel
288+
gori.theme.targets.axis_title_y = ylabel
289+
if 'axis_title_x' in gori.theme.themeables:
290+
gori.theme.themeables['axis_title_x'].apply_figure(gori.figure, gori.theme.targets)
291+
for ax in gori.axs:
292+
gori.theme.themeables['axis_title_x'].apply_ax(ax)
280293

281-
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
294+
if 'axis_title_y' in gori.theme.themeables:
295+
gori.theme.themeables['axis_title_y'].apply_figure(gori.figure, gori.theme.targets)
296+
for ax in gori.axs:
297+
gori.theme.themeables['axis_title_y'].apply_ax(ax)
298+
299+
for key in gori.theme.themeables:
300+
if "legend" in key:
301+
gori.theme.themeables[key].apply_figure(gori.figure, gori.theme.targets)
302+
for ax in gori.axs:
303+
gori.theme.themeables[key].apply_ax(ax)
304+
305+
306+
elif StrictVersion(plotnine_version) >= StrictVersion("0.12"):
282307
gori.theme._targets['axis_title_x'] = xlabel
283308
gori.theme._targets['axis_title_y'] = ylabel
284309
if 'axis_title_x' in gori.theme.themeables:
@@ -415,8 +440,9 @@ def draw_title(bricks, gori, gcp, figsize):
415440
x = ha
416441
ha = "center"
417442

418-
except KeyError:
419-
ha = 0.5
443+
except Exception as e:
444+
x = 0.5
445+
ha = "center"
420446

421447
try:
422448
va = get_property('plot_title', 'va')
@@ -440,7 +466,13 @@ def draw_title(bricks, gori, gcp, figsize):
440466
else:
441467
text = bricks._case.set_title(title, pad=pad, fontsize=fontsize, x=x, ha=ha, va=va)
442468

443-
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
469+
if StrictVersion(plotnine_version) >= StrictVersion("0.13"):
470+
gori.theme.targets.plot_title = text
471+
gori.theme.themeables['plot_title'].apply_figure(gori.figure, gori.theme.targets)
472+
for ax in gori.axs:
473+
gori.theme.themeables['plot_title'].apply_ax(ax)
474+
475+
elif StrictVersion(plotnine_version) >= StrictVersion("0.12"):
444476
gori.theme._targets['plot_title'] = text
445477
gori.theme.themeables['plot_title'].apply_figure(gori.figure, gori.theme._targets)
446478
for ax in gori.axs:
@@ -487,15 +519,23 @@ def draw_title(bricks, gori, gcp, figsize):
487519
strips = []
488520

489521
ggplot._build()
490-
axs = ggplot.facet.make_axes(
491-
_basefigure,
492-
ggplot.layout.layout,
493-
ggplot.coordinates)
522+
523+
if StrictVersion(plotnine_version) >= StrictVersion("0.13"):
524+
ggplot.facet.make_figure = mp9.make_figure
525+
fig, axs = plotnine.facets.facet.setup(ggplot.facet, _basefigure, ggplot)
526+
else:
527+
axs = ggplot.facet.make_axes(
528+
_basefigure,
529+
ggplot.layout.layout,
530+
ggplot.coordinates)
494531

495532
ggplot.figure = _basefigure
496533
ggplot.axs = axs
497534

498-
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
535+
if StrictVersion(plotnine_version) >= StrictVersion("0.13"):
536+
ggplot.theme = gcp.theme
537+
538+
elif StrictVersion(plotnine_version) >= StrictVersion("0.12"):
499539
ggplot.theme = gcp.theme
500540
ggplot.theme._targets = gcp.theme._targets
501541

@@ -517,8 +557,12 @@ def draw_title(bricks, gori, gcp, figsize):
517557
ggplot.axs[i].spines[bar].set_ec(gcp.axs[i].spines[bar].get_ec())
518558
ggplot.axs[i].spines[bar].set_visible(gcp.axs[i].spines[bar].get_visible())
519559

520-
ggplot._setup_parameters()
521-
ggplot.facet.strips.generate()
560+
if StrictVersion(plotnine_version) >= StrictVersion("0.13"):
561+
ggplot.theme.setup(ggplot)
562+
else:
563+
ggplot._setup_parameters()
564+
ggplot.facet.strips.generate()
565+
522566
for i in range(len(ggplot.facet.strips)):
523567
if StrictVersion(plotnine_version) >= StrictVersion("0.12"):
524568
ggplot.facet.strips[i].position = strips[i].draw_info.position
@@ -550,6 +594,10 @@ def draw_title(bricks, gori, gcp, figsize):
550594
for i, l in enumerate(ggplot.layers, start=1):
551595
l.zorder = i + 10
552596
l.draw(ggplot.layout, ggplot.coordinates)
597+
598+
if StrictVersion(plotnine_version) >= StrictVersion("0.13"):
599+
ggplot._draw_panel_borders()
600+
ggplot.facet.theme = ggplot.theme
553601
ggplot._draw_breaks_and_labels()
554602
ggplot._draw_watermarks()
555603
new = themeable.from_class_name
@@ -583,7 +631,7 @@ def draw_title(bricks, gori, gcp, figsize):
583631

584632
if StrictVersion(plotnine_version) >= StrictVersion("0.9"):
585633
xl, yl = draw_labels(ax, ggplot, gcp, figsize)
586-
draw_legend(ax, ggplot, gcp, figsize)
634+
draw_legend(ax, ggplot, gcp, figsize) #0.13 makes Erros here.
587635
draw_title(ax, ggplot, gcp, figsize)
588636

589637
elif StrictVersion("0.8") <= StrictVersion(plotnine_version) < StrictVersion("0.9"):
@@ -2449,7 +2497,7 @@ def get_outer_corner(self):
24492497

24502498
return min(x0_list), max(x1_list), min(y0_list), max(y1_list)
24512499

2452-
def savefig(self, fname=None, transparent=None, quick=True, _ggplot=False, **kwargs):
2500+
def savefig(self, fname=None, transparent=False, quick=True, _ggplot=False, **kwargs):
24532501
"""
24542502
24552503
Save figure.
@@ -2586,13 +2634,13 @@ def __sub__(self, other):
25862634

25872635
def _repr_png_(self):
25882636
buf = io.BytesIO()
2589-
self.savefig(buf, "png")
2637+
self.savefig(buf, format="png", dpi=300, transparent=False)
25902638
return buf.getvalue()
25912639

2592-
def _repr_pdf_(self):
2593-
buf = io.BytesIO()
2594-
self.savefig(buf, "pdf")
2595-
return buf.getvalue()
2640+
#def _repr_pdf_(self):
2641+
# buf = io.BytesIO()
2642+
# self.savefig(buf, format="pdf", transparent=False)
2643+
# return buf.getvalue()
25962644

25972645
#class Brick(axes.Axes):
25982646
class pBrick:
@@ -2892,7 +2940,7 @@ def get_outer_corner(self, labes=None):
28922940
self._outer_flag = True
28932941
return self._outer_corner
28942942

2895-
def savefig(self, fname=None, transparent=None, quick=True, _ggplot=False, **kwargs):
2943+
def savefig(self, fname=None, transparent=False, quick=True, _ggplot=False, **kwargs):
28962944
"""
28972945
28982946
Save figure.
@@ -3078,13 +3126,13 @@ def __sub__(self, other):
30783126

30793127
def _repr_png_(self):
30803128
buf = io.BytesIO()
3081-
self.savefig(buf, "png")
3129+
self.savefig(buf, format="png", dpi=300, transparent=False)
30823130
return buf.getvalue()
30833131

3084-
def _repr_pdf_(self):
3085-
buf = io.BytesIO()
3086-
self.savefig(buf, "pdf")
3087-
return buf.getvalue()
3132+
#def _repr_pdf_(self):
3133+
# buf = io.BytesIO()
3134+
# self.savefig(buf, format="pdf", transparent=False)
3135+
# return buf.getvalue()
30883136

30893137
class Brick(pBrick, axes.Axes):
30903138
def __getattribute__(self, name):
@@ -3519,4 +3567,4 @@ def resize(self, direction):
35193567

35203568
if StrictVersion(plotnine.__version__) >= StrictVersion("0.12"):
35213569
overwrite_plotnine()
3522-
3570+

0 commit comments

Comments
 (0)