diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md
index 8213f4ac0b..4a11b42e37 100644
--- a/g3doc/quick_reference.md
+++ b/g3doc/quick_reference.md
@@ -1054,6 +1054,15 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
neither NaN nor infinity, i.e. normal, subnormal or zero. Equivalent to
`Not(Or(IsNaN(v), IsInf(v)))`.
+#### Masked floating-point classification
+
+All ops in this section return `false` for `mask=false` lanes. These are
+equivalent to, and potentially more efficient than, `And(m, IsNaN(v));` etc.
+
+* `V`: `{f}` \
+ M **MaskedIsNaN**(M m, V v): returns mask indicating whether
+ `v[i]` is "not a number" (unordered) or `false` if `m[i]` is false.
+
### Logical
* `V`: `{u,i}` \
@@ -1532,6 +1541,29 @@ These return a mask (see above) indicating whether the condition is true.
for comparing 64-bit keys alongside 64-bit values. Only available if
`HWY_TARGET != HWY_SCALAR`.
+#### Masked comparison
+
+All ops in this section return `false` for `mask=false` lanes. These are
+equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc.
+
+* M **MaskedEq**(M m, V a, V b): returns `a[i] == b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedNe**(M m, V a, V b): returns `a[i] != b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedLt**(M m, V a, V b): returns `a[i] < b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedGt**(M m, V a, V b): returns `a[i] > b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedLe**(M m, V a, V b): returns `a[i] <= b[i]` or
+ `false` if `m[i]` is false.
+
+* M **MaskedGe**(M m, V a, V b): returns `a[i] >= b[i]` or
+ `false` if `m[i]` is false.
+
### Memory
Memory operands are little-endian, otherwise their order would depend on the
diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h
index 4c4a37e5d4..e6566ab5eb 100644
--- a/hwy/ops/arm_sve-inl.h
+++ b/hwy/ops/arm_sve-inl.h
@@ -2138,6 +2138,45 @@ HWY_SVE_FOREACH_F(HWY_SVE_MUL_BY_POW2, MulByPow2, scale)
#undef HWY_SVE_MUL_BY_POW2
+// ------------------------------ MaskedEq etc.
+#ifdef HWY_NATIVE_MASKED_COMP
+#undef HWY_NATIVE_MASKED_COMP
+#else
+#define HWY_NATIVE_MASKED_COMP
+#endif
+
+// mask = f(mask, vector, vector)
+#define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
+ HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \
+ HWY_SVE_V(BASE, BITS) b) { \
+ return sv##OP##_##CHAR##BITS(m, a, b); \
+ }
+
+HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq)
+HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne)
+HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt)
+HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple)
+
+#undef HWY_SVE_COMPARE_Z
+
+
+template >
+HWY_API MFromD MaskedGt(M m, V a, V b) {
+ // Swap args to reverse comparison
+ return MaskedLt(m, b, a);
+}
+
+template >
+HWY_API MFromD MaskedGe(M m, V a, V b) {
+ // Swap args to reverse comparison
+ return MaskedLe(m, b, a);
+}
+
+template >
+HWY_API MFromD MaskedIsNaN(const M m, const V v) {
+ return MaskedNe(m, v, v);
+}
+
// ================================================== MEMORY
// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream
diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h
index 05ff70e0fb..fd2e7bdd26 100644
--- a/hwy/ops/generic_ops-inl.h
+++ b/hwy/ops/generic_ops-inl.h
@@ -631,6 +631,50 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif // HWY_NATIVE_MASKED_ARITH
+// ------------------------------ MaskedEq etc.
+#if (defined(HWY_NATIVE_MASKED_COMP) == defined(HWY_TARGET_TOGGLE))
+#ifdef HWY_NATIVE_MASKED_COMP
+#undef HWY_NATIVE_MASKED_COMP
+#else
+#define HWY_NATIVE_MASKED_COMP
+#endif
+
+template
+HWY_API auto MaskedEq(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Eq(a, b));
+}
+
+template
+HWY_API auto MaskedNe(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Ne(a, b));
+}
+
+template
+HWY_API auto MaskedLt(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Lt(a, b));
+}
+
+template
+HWY_API auto MaskedGt(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Gt(a, b));
+}
+
+template
+HWY_API auto MaskedLe(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Le(a, b));
+}
+
+template
+HWY_API auto MaskedGe(M m, V a, V b) -> decltype(a == b) {
+ return And(m, Ge(a, b));
+}
+
+template >
+HWY_API MFromD MaskedIsNaN(const M m, const V v) {
+ return And(m, IsNaN(v));
+}
+#endif // HWY_NATIVE_MASKED_COMP
+
// ------------------------------ IfNegativeThenNegOrUndefIfZero
#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \
diff --git a/hwy/tests/compare_test.cc b/hwy/tests/compare_test.cc
index 728b58c3dc..7f1827f35e 100644
--- a/hwy/tests/compare_test.cc
+++ b/hwy/tests/compare_test.cc
@@ -673,7 +673,160 @@ HWY_NOINLINE void TestAllEq128Upper() {
ForGEVectors<128, TestEq128Upper>()(uint64_t());
}
-} // namespace
+struct TestMaskedCompare {
+ template
+ HWY_NOINLINE void operator()(T /*unused*/, D d) {
+ RandomState rng;
+
+ const Vec v0 = Zero(d);
+ const Vec v2 = Iota(d, 2);
+ const Vec v2b = Iota(d, 2);
+ const Vec v3 = Iota(d, 3);
+ const size_t N = Lanes(d);
+
+ const Mask mask_false = MaskFalse(d);
+ const Mask mask_true = MaskTrue(d);
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask_true, v2, v3));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask_true, v3, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedEq(mask_true, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedEq(mask_true, v2, v2b));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedNe(mask_true, v2, v3));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedNe(mask_true, v3, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask_true, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask_true, v2, v2b));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v2, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGt(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLt(mask_true, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v0, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask_true, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLe(mask_true, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGe(mask_true, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLe(mask_true, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGe(mask_true, v2, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGe(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLe(mask_true, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask_true, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask_true, v0, v2));
+
+ auto bool_lanes = AllocateAligned(N);
+ HWY_ASSERT(bool_lanes);
+
+ for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
+ for (size_t i = 0; i < N; ++i) {
+ bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
+ }
+
+ const Vec mask_i = Load(d, bool_lanes.get());
+ const Mask mask = RebindMask(d, Gt(mask_i, Zero(d)));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask, v2, v3));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask, v3, v2));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedEq(mask, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedEq(mask, v2, v2b));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedNe(mask, v2, v3));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedNe(mask, v3, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask, v2, v2b));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v2, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedGt(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedLt(mask, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v0, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedLe(mask, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedGe(mask, v0, v0));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedLe(mask, v2, v2));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedGe(mask, v2, v2));
+
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedGe(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedLe(mask, v0, v2));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask, v2, v0));
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask, v0, v2));
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllMaskedCompare() {
+ ForAllTypes(ForPartialVectors());
+}
+
+struct TestMaskedFloatClassification {
+ template
+ HWY_NOINLINE void operator()(T /*unused*/, D d) {
+ RandomState rng;
+
+ const Vec v0 = Zero(d);
+ const Vec v1 = Iota(d, 2);
+ const Vec v2 = Inf(d);
+ const Vec v3 = NaN(d);
+ const size_t N = Lanes(d);
+
+ const Mask mask_false = MaskFalse(d);
+ const Mask mask_true = MaskTrue(d);
+
+ // Test against all zeros
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v0));
+
+ // Test against finite values
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v1));
+
+ // Test against infinite values
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v2));
+
+ // Test against NaN values
+ HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsNaN(mask_true, v3));
+
+ auto bool_lanes = AllocateAligned(N);
+ HWY_ASSERT(bool_lanes);
+
+ for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
+ for (size_t i = 0; i < N; ++i) {
+ bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
+ }
+
+ const Vec mask_i = Load(d, bool_lanes.get());
+ const Mask mask = RebindMask(d, Gt(mask_i, Zero(d)));
+
+ // Test against all zeros
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v0));
+
+ // Test against finite values
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v1));
+
+ // Test against infinite values
+ HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v2));
+
+ // Test against NaN values
+ HWY_ASSERT_MASK_EQ(d, mask, MaskedIsNaN(mask, v3));
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllMaskedFloatClassification() {
+ ForFloatTypes(ForPartialVectors());
+}
+} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
@@ -695,6 +848,9 @@ HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128Upper);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128Upper);
+
+HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedCompare);
+HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedFloatClassification);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy