diff --git a/neural_tangents/_src/stax/requirements.py b/neural_tangents/_src/stax/requirements.py index 0858cc5f..d6e3e2c0 100644 --- a/neural_tangents/_src/stax/requirements.py +++ b/neural_tangents/_src/stax/requirements.py @@ -764,6 +764,11 @@ def get_x_cov_mask(x): x = _get_masked_array(x, mask_constant) x, mask = x.masked_value, x.mask + # reduce mask + if mask_constant and mask.shape[channel_axis] > 1: + warnings.warn("Assuming consistent masks (all zero or one) for features of dimension > 1, which is not verified.") + mask = np.any(mask, axis=channel_axis, keepdims=True) + # TODO(schsam): Think more about dtype automatic vs manual dtype promotion. x = x.astype(jax.dtypes.canonicalize_dtype(np.float64))