1- from typing import List , Tuple , Union
1+ from typing import List , Optional , Tuple
22
33import matplotlib .colors as mcolor
44import napari
@@ -22,21 +22,21 @@ class ScatterBaseWidget(NapariMPLWidget):
2222 # the scatter is plotted as a 2dhist
2323 _threshold_to_switch_to_histogram = 500
2424
25- def __init__ (
26- self ,
27- napari_viewer : napari .viewer .Viewer ,
28- ):
25+ def __init__ (self , napari_viewer : napari .viewer .Viewer ):
2926 super ().__init__ (napari_viewer )
3027
3128 self .axes = self .canvas .figure .subplots ()
3229 self .update_layers (None )
3330
3431 def clear (self ) -> None :
32+ """
33+ Clear the axes.
34+ """
3535 self .axes .clear ()
3636
3737 def draw (self ) -> None :
3838 """
39- Clear the axes and scatter the currently selected layers.
39+ Scatter the currently selected layers.
4040 """
4141 data , x_axis_name , y_axis_name = self ._get_data ()
4242
@@ -86,14 +86,6 @@ class ScatterWidget(ScatterBaseWidget):
8686
8787 n_layers_input = 2
8888
89- def __init__ (
90- self ,
91- napari_viewer : napari .viewer .Viewer ,
92- ):
93- super ().__init__ (
94- napari_viewer ,
95- )
96-
9789 def _get_data (self ) -> Tuple [List [np .ndarray ], str , str ]:
9890 """Get the plot data.
9991
@@ -116,42 +108,34 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
116108class FeaturesScatterWidget (ScatterBaseWidget ):
117109 n_layers_input = 1
118110
119- def __init__ (
120- self ,
121- napari_viewer : napari .viewer .Viewer ,
122- key_selection_gui : bool = True ,
123- ):
124- self ._key_selection_widget = None
125- super ().__init__ (
126- napari_viewer ,
111+ def __init__ (self , napari_viewer : napari .viewer .Viewer ):
112+ super ().__init__ (napari_viewer )
113+ self ._key_selection_widget = magicgui (
114+ self ._set_axis_keys ,
115+ x_axis_key = {"choices" : self ._get_valid_axis_keys },
116+ y_axis_key = {"choices" : self ._get_valid_axis_keys },
117+ call_button = "plot" ,
127118 )
128119
129- if key_selection_gui is True :
130- self ._key_selection_widget = magicgui (
131- self ._set_axis_keys ,
132- x_axis_key = {"choices" : self ._get_valid_axis_keys },
133- y_axis_key = {"choices" : self ._get_valid_axis_keys },
134- call_button = "plot" ,
135- )
136- self .layout ().addWidget (self ._key_selection_widget .native )
120+ self .layout ().addWidget (self ._key_selection_widget .native )
137121
138122 @property
139- def x_axis_key (self ) -> Union [ None , str ]:
123+ def x_axis_key (self ) -> Optional [ str ]:
140124 """Key to access x axis data from the FeaturesTable"""
141125 return self ._x_axis_key
142126
143127 @x_axis_key .setter
144- def x_axis_key (self , key : Union [ None , str ]):
128+ def x_axis_key (self , key : Optional [ str ]):
145129 self ._x_axis_key = key
146130 self ._draw ()
147131
148132 @property
149- def y_axis_key (self ) -> Union [ None , str ]:
133+ def y_axis_key (self ) -> Optional [ str ]:
150134 """Key to access y axis data from the FeaturesTable"""
151135 return self ._y_axis_key
152136
153137 @y_axis_key .setter
154- def y_axis_key (self , key : Union [ None , str ]):
138+ def y_axis_key (self , key : Optional [ str ]):
155139 self ._y_axis_key = key
156140 self ._draw ()
157141
@@ -214,10 +198,11 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
214198 return data , x_axis_name , y_axis_name
215199
216200 def _on_update_layers (self ) -> None :
217- """This is called when the layer selection changes
218- by self.update_layers().
219201 """
220- if self ._key_selection_widget is not None :
202+ This is called when the layer selection changes by
203+ ``self.update_layers()``.
204+ """
205+ if hasattr (self , "_key_selection_widget" ):
221206 self ._key_selection_widget .reset_choices ()
222207
223208 # reset the axis keys
0 commit comments