66from magicgui import magicgui
77
88from .base import NapariMPLWidget
9+ from .util import Interval
910
1011__all__ = ["ScatterWidget" , "FeaturesScatterWidget" ]
1112
@@ -84,7 +85,8 @@ class ScatterWidget(ScatterBaseWidget):
8485 of a scatter plot, to avoid too many scatter points.
8586 """
8687
87- n_layers_input = 2
88+ n_layers_input = Interval (2 , 2 )
89+ input_layer_types = (napari .layers .Image ,)
8890
8991 def _get_data (self ) -> Tuple [List [np .ndarray ], str , str ]:
9092 """Get the plot data.
@@ -106,7 +108,15 @@ def _get_data(self) -> Tuple[List[np.ndarray], str, str]:
106108
107109
108110class FeaturesScatterWidget (ScatterBaseWidget ):
109- n_layers_input = 1
111+ n_layers_input = Interval (1 , 1 )
112+ # All layers that have a .features attributes
113+ input_layer_types = (
114+ napari .layers .Labels ,
115+ napari .layers .Points ,
116+ napari .layers .Shapes ,
117+ napari .layers .Tracks ,
118+ napari .layers .Vectors ,
119+ )
110120
111121 def __init__ (self , napari_viewer : napari .viewer .Viewer ):
112122 super ().__init__ (napari_viewer )
@@ -146,7 +156,8 @@ def _set_axis_keys(self, x_axis_key: str, y_axis_key: str):
146156 self ._draw ()
147157
148158 def _get_valid_axis_keys (self , combo_widget = None ) -> List [str ]:
149- """Get the valid axis keys from the layer FeatureTable.
159+ """
160+ Get the valid axis keys from the layer FeatureTable.
150161
151162 Returns
152163 -------
0 commit comments