diff --git a/g3doc/quick_reference.md b/g3doc/quick_reference.md index 8213f4ac0b..4a11b42e37 100644 --- a/g3doc/quick_reference.md +++ b/g3doc/quick_reference.md @@ -1054,6 +1054,15 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2): neither NaN nor infinity, i.e. normal, subnormal or zero. Equivalent to `Not(Or(IsNaN(v), IsInf(v)))`. +#### Masked floating-point classification + +All ops in this section return `false` for `mask=false` lanes. These are +equivalent to, and potentially more efficient than, `And(m, IsNaN(v));` etc. + +* `V`: `{f}` \ + M **MaskedIsNaN**(M m, V v): returns mask indicating whether + `v[i]` is "not a number" (unordered) or `false` if `m[i]` is false. + ### Logical * `V`: `{u,i}` \ @@ -1532,6 +1541,29 @@ These return a mask (see above) indicating whether the condition is true. for comparing 64-bit keys alongside 64-bit values. Only available if `HWY_TARGET != HWY_SCALAR`. +#### Masked comparison + +All ops in this section return `false` for `mask=false` lanes. These are +equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc. + +* M **MaskedEq**(M m, V a, V b): returns `a[i] == b[i]` or + `false` if `m[i]` is false. + +* M **MaskedNe**(M m, V a, V b): returns `a[i] != b[i]` or + `false` if `m[i]` is false. + +* M **MaskedLt**(M m, V a, V b): returns `a[i] < b[i]` or + `false` if `m[i]` is false. + +* M **MaskedGt**(M m, V a, V b): returns `a[i] > b[i]` or + `false` if `m[i]` is false. + +* M **MaskedLe**(M m, V a, V b): returns `a[i] <= b[i]` or + `false` if `m[i]` is false. + +* M **MaskedGe**(M m, V a, V b): returns `a[i] >= b[i]` or + `false` if `m[i]` is false. + ### Memory Memory operands are little-endian, otherwise their order would depend on the diff --git a/hwy/ops/arm_sve-inl.h b/hwy/ops/arm_sve-inl.h index 4c4a37e5d4..e6566ab5eb 100644 --- a/hwy/ops/arm_sve-inl.h +++ b/hwy/ops/arm_sve-inl.h @@ -2138,6 +2138,45 @@ HWY_SVE_FOREACH_F(HWY_SVE_MUL_BY_POW2, MulByPow2, scale) #undef HWY_SVE_MUL_BY_POW2 +// ------------------------------ MaskedEq etc. +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +// mask = f(mask, vector, vector) +#define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \ + HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \ + HWY_SVE_V(BASE, BITS) b) { \ + return sv##OP##_##CHAR##BITS(m, a, b); \ + } + +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt) +HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple) + +#undef HWY_SVE_COMPARE_Z + + +template > +HWY_API MFromD MaskedGt(M m, V a, V b) { + // Swap args to reverse comparison + return MaskedLt(m, b, a); +} + +template > +HWY_API MFromD MaskedGe(M m, V a, V b) { + // Swap args to reverse comparison + return MaskedLe(m, b, a); +} + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return MaskedNe(m, v, v); +} + // ================================================== MEMORY // ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream diff --git a/hwy/ops/generic_ops-inl.h b/hwy/ops/generic_ops-inl.h index 05ff70e0fb..fd2e7bdd26 100644 --- a/hwy/ops/generic_ops-inl.h +++ b/hwy/ops/generic_ops-inl.h @@ -631,6 +631,50 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) { } #endif // HWY_NATIVE_MASKED_ARITH +// ------------------------------ MaskedEq etc. +#if (defined(HWY_NATIVE_MASKED_COMP) == defined(HWY_TARGET_TOGGLE)) +#ifdef HWY_NATIVE_MASKED_COMP +#undef HWY_NATIVE_MASKED_COMP +#else +#define HWY_NATIVE_MASKED_COMP +#endif + +template +HWY_API auto MaskedEq(M m, V a, V b) -> decltype(a == b) { + return And(m, Eq(a, b)); +} + +template +HWY_API auto MaskedNe(M m, V a, V b) -> decltype(a == b) { + return And(m, Ne(a, b)); +} + +template +HWY_API auto MaskedLt(M m, V a, V b) -> decltype(a == b) { + return And(m, Lt(a, b)); +} + +template +HWY_API auto MaskedGt(M m, V a, V b) -> decltype(a == b) { + return And(m, Gt(a, b)); +} + +template +HWY_API auto MaskedLe(M m, V a, V b) -> decltype(a == b) { + return And(m, Le(a, b)); +} + +template +HWY_API auto MaskedGe(M m, V a, V b) -> decltype(a == b) { + return And(m, Ge(a, b)); +} + +template > +HWY_API MFromD MaskedIsNaN(const M m, const V v) { + return And(m, IsNaN(v)); +} +#endif // HWY_NATIVE_MASKED_COMP + // ------------------------------ IfNegativeThenNegOrUndefIfZero #if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \ diff --git a/hwy/tests/compare_test.cc b/hwy/tests/compare_test.cc index 728b58c3dc..7f1827f35e 100644 --- a/hwy/tests/compare_test.cc +++ b/hwy/tests/compare_test.cc @@ -673,7 +673,160 @@ HWY_NOINLINE void TestAllEq128Upper() { ForGEVectors<128, TestEq128Upper>()(uint64_t()); } -} // namespace +struct TestMaskedCompare { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const Vec v0 = Zero(d); + const Vec v2 = Iota(d, 2); + const Vec v2b = Iota(d, 2); + const Vec v3 = Iota(d, 3); + const size_t N = Lanes(d); + + const Mask mask_false = MaskFalse(d); + const Mask mask_true = MaskTrue(d); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask_true, v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask_true, v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedEq(mask_true, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedEq(mask_true, v2, v2b)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedNe(mask_true, v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedNe(mask_true, v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask_true, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask_true, v2, v2b)); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v2, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGt(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLt(mask_true, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v0, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask_true, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLe(mask_true, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGe(mask_true, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLe(mask_true, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGe(mask_true, v2, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGe(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLe(mask_true, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask_true, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask_true, v0, v2)); + + auto bool_lanes = AllocateAligned(N); + HWY_ASSERT(bool_lanes); + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0); + } + + const Vec mask_i = Load(d, bool_lanes.get()); + const Mask mask = RebindMask(d, Gt(mask_i, Zero(d))); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask, v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask, v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedEq(mask, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedEq(mask, v2, v2b)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedNe(mask, v2, v3)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedNe(mask, v3, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask, v2, v2b)); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v2, v2)); + + HWY_ASSERT_MASK_EQ(d, mask, MaskedGt(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedLt(mask, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v0, v2)); + + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedLe(mask, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedGe(mask, v0, v0)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedLe(mask, v2, v2)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedGe(mask, v2, v2)); + + HWY_ASSERT_MASK_EQ(d, mask, MaskedGe(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask, MaskedLe(mask, v0, v2)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask, v2, v0)); + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask, v0, v2)); + } + } +}; + +HWY_NOINLINE void TestAllMaskedCompare() { + ForAllTypes(ForPartialVectors()); +} + +struct TestMaskedFloatClassification { + template + HWY_NOINLINE void operator()(T /*unused*/, D d) { + RandomState rng; + + const Vec v0 = Zero(d); + const Vec v1 = Iota(d, 2); + const Vec v2 = Inf(d); + const Vec v3 = NaN(d); + const size_t N = Lanes(d); + + const Mask mask_false = MaskFalse(d); + const Mask mask_true = MaskTrue(d); + + // Test against all zeros + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v0)); + + // Test against finite values + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v1)); + + // Test against infinite values + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v2)); + + // Test against NaN values + HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsNaN(mask_true, v3)); + + auto bool_lanes = AllocateAligned(N); + HWY_ASSERT(bool_lanes); + + for (size_t rep = 0; rep < AdjustedReps(200); ++rep) { + for (size_t i = 0; i < N; ++i) { + bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0); + } + + const Vec mask_i = Load(d, bool_lanes.get()); + const Mask mask = RebindMask(d, Gt(mask_i, Zero(d))); + + // Test against all zeros + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v0)); + + // Test against finite values + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v1)); + + // Test against infinite values + HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v2)); + + // Test against NaN values + HWY_ASSERT_MASK_EQ(d, mask, MaskedIsNaN(mask, v3)); + } + } +}; + +HWY_NOINLINE void TestAllMaskedFloatClassification() { + ForFloatTypes(ForPartialVectors()); +} +} // namespace // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace hwy @@ -695,6 +848,9 @@ HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128); HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128Upper); HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128); HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128Upper); + +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedCompare); +HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedFloatClassification); HWY_AFTER_TEST(); } // namespace } // namespace hwy