Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,15 @@ Per-lane variable shifts (slow if SSSE3/SSE4, or 16-bit, or Shr i64 on AVX2):
neither NaN nor infinity, i.e. normal, subnormal or zero. Equivalent to
`Not(Or(IsNaN(v), IsInf(v)))`.

#### Masked floating-point classification

All ops in this section return `false` for `mask=false` lanes. These are
equivalent to, and potentially more efficient than, `And(m, IsNaN(v));` etc.

* `V`: `{f}` \
<code>M **MaskedIsNaN**(M m, V v)</code>: returns mask indicating whether
`v[i]` is "not a number" (unordered) or `false` if `m[i]` is false.

### Logical

* `V`: `{u,i}` \
Expand Down Expand Up @@ -1532,6 +1541,29 @@ These return a mask (see above) indicating whether the condition is true.
for comparing 64-bit keys alongside 64-bit values. Only available if
`HWY_TARGET != HWY_SCALAR`.

#### Masked comparison

All ops in this section return `false` for `mask=false` lanes. These are
equivalent to, and potentially more efficient than, `And(m, Eq(a, b));` etc.

* <code>M **MaskedEq**(M m, V a, V b)</code>: returns `a[i] == b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedNe**(M m, V a, V b)</code>: returns `a[i] != b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedLt**(M m, V a, V b)</code>: returns `a[i] < b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedGt**(M m, V a, V b)</code>: returns `a[i] > b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedLe**(M m, V a, V b)</code>: returns `a[i] <= b[i]` or
`false` if `m[i]` is false.

* <code>M **MaskedGe**(M m, V a, V b)</code>: returns `a[i] >= b[i]` or
`false` if `m[i]` is false.

### Memory

Memory operands are little-endian, otherwise their order would depend on the
Expand Down
39 changes: 39 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,45 @@ HWY_SVE_FOREACH_F(HWY_SVE_MUL_BY_POW2, MulByPow2, scale)

#undef HWY_SVE_MUL_BY_POW2

// ------------------------------ MaskedEq etc.
#ifdef HWY_NATIVE_MASKED_COMP
#undef HWY_NATIVE_MASKED_COMP
#else
#define HWY_NATIVE_MASKED_COMP
#endif

// mask = f(mask, vector, vector)
#define HWY_SVE_COMPARE_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API svbool_t NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a, \
HWY_SVE_V(BASE, BITS) b) { \
return sv##OP##_##CHAR##BITS(m, a, b); \
}

HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedEq, cmpeq)
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedNe, cmpne)
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLt, cmplt)
HWY_SVE_FOREACH(HWY_SVE_COMPARE_Z, MaskedLe, cmple)

#undef HWY_SVE_COMPARE_Z


template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedGt(M m, V a, V b) {
// Swap args to reverse comparison
return MaskedLt(m, b, a);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedGe(M m, V a, V b) {
// Swap args to reverse comparison
return MaskedLe(m, b, a);
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) {
return MaskedNe(m, v, v);
}

// ================================================== MEMORY

// ------------------------------ LoadU/MaskedLoad/LoadDup128/StoreU/Stream
Expand Down
44 changes: 44 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,50 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif // HWY_NATIVE_MASKED_ARITH

// ------------------------------ MaskedEq etc.
#if (defined(HWY_NATIVE_MASKED_COMP) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_MASKED_COMP
#undef HWY_NATIVE_MASKED_COMP
#else
#define HWY_NATIVE_MASKED_COMP
#endif

template <class V, class M>
HWY_API auto MaskedEq(M m, V a, V b) -> decltype(a == b) {
return And(m, Eq(a, b));
}

template <class V, class M>
HWY_API auto MaskedNe(M m, V a, V b) -> decltype(a == b) {
return And(m, Ne(a, b));
}

template <class V, class M>
HWY_API auto MaskedLt(M m, V a, V b) -> decltype(a == b) {
return And(m, Lt(a, b));
}

template <class V, class M>
HWY_API auto MaskedGt(M m, V a, V b) -> decltype(a == b) {
return And(m, Gt(a, b));
}

template <class V, class M>
HWY_API auto MaskedLe(M m, V a, V b) -> decltype(a == b) {
return And(m, Le(a, b));
}

template <class V, class M>
HWY_API auto MaskedGe(M m, V a, V b) -> decltype(a == b) {
return And(m, Ge(a, b));
}

template <class V, class M, class D = DFromV<V>>
HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) {
return And(m, IsNaN(v));
}
#endif // HWY_NATIVE_MASKED_COMP

// ------------------------------ IfNegativeThenNegOrUndefIfZero

#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \
Expand Down
158 changes: 157 additions & 1 deletion hwy/tests/compare_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,160 @@ HWY_NOINLINE void TestAllEq128Upper() {
ForGEVectors<128, TestEq128Upper>()(uint64_t());
}

} // namespace
struct TestMaskedCompare {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

const Vec<D> v0 = Zero(d);
const Vec<D> v2 = Iota(d, 2);
const Vec<D> v2b = Iota(d, 2);
const Vec<D> v3 = Iota(d, 3);
const size_t N = Lanes(d);

const Mask<D> mask_false = MaskFalse(d);
const Mask<D> mask_true = MaskTrue(d);

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask_true, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask_true, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedEq(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedEq(mask_true, v2, v2b));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedNe(mask_true, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedNe(mask_true, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask_true, v2, v2b));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGt(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLt(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask_true, v0, v2));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLe(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGe(mask_true, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLe(mask_true, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGe(mask_true, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask_true, MaskedGe(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedLe(mask_true, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask_true, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask_true, v0, v2));

auto bool_lanes = AllocateAligned<T>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
}

const Vec<D> mask_i = Load(d, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(d)));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedEq(mask, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedEq(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedEq(mask, v2, v2b));
HWY_ASSERT_MASK_EQ(d, mask, MaskedNe(mask, v2, v3));
HWY_ASSERT_MASK_EQ(d, mask, MaskedNe(mask, v3, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedNe(mask, v2, v2b));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask, MaskedGt(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedLt(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLt(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGt(mask, v0, v2));

HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedLe(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedGe(mask, v0, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedLe(mask, v2, v2));
HWY_ASSERT_MASK_EQ(d, mask, MaskedGe(mask, v2, v2));

HWY_ASSERT_MASK_EQ(d, mask, MaskedGe(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask, MaskedLe(mask, v0, v2));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedLe(mask, v2, v0));
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedGe(mask, v0, v2));
}
}
};

HWY_NOINLINE void TestAllMaskedCompare() {
ForAllTypes(ForPartialVectors<TestMaskedCompare>());
}

struct TestMaskedFloatClassification {
template <typename T, class D>
HWY_NOINLINE void operator()(T /*unused*/, D d) {
RandomState rng;

const Vec<D> v0 = Zero(d);
const Vec<D> v1 = Iota(d, 2);
const Vec<D> v2 = Inf(d);
const Vec<D> v3 = NaN(d);
const size_t N = Lanes(d);

const Mask<D> mask_false = MaskFalse(d);
const Mask<D> mask_true = MaskTrue(d);

// Test against all zeros
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v0));

// Test against finite values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v1));

// Test against infinite values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask_true, v2));

// Test against NaN values
HWY_ASSERT_MASK_EQ(d, mask_true, MaskedIsNaN(mask_true, v3));

auto bool_lanes = AllocateAligned<T>(N);
HWY_ASSERT(bool_lanes);

for (size_t rep = 0; rep < AdjustedReps(200); ++rep) {
for (size_t i = 0; i < N; ++i) {
bool_lanes[i] = (Random32(&rng) & 1024) ? T(1) : T(0);
}

const Vec<D> mask_i = Load(d, bool_lanes.get());
const Mask<D> mask = RebindMask(d, Gt(mask_i, Zero(d)));

// Test against all zeros
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v0));

// Test against finite values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v1));

// Test against infinite values
HWY_ASSERT_MASK_EQ(d, mask_false, MaskedIsNaN(mask, v2));

// Test against NaN values
HWY_ASSERT_MASK_EQ(d, mask, MaskedIsNaN(mask, v3));
}
}
};

HWY_NOINLINE void TestAllMaskedFloatClassification() {
ForFloatTypes(ForPartialVectors<TestMaskedFloatClassification>());
}
} // namespace
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace hwy
Expand All @@ -695,6 +848,9 @@ HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllLt128Upper);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllEq128Upper);

HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedCompare);
HWY_EXPORT_AND_TEST_P(HwyCompareTest, TestAllMaskedFloatClassification);
HWY_AFTER_TEST();
} // namespace
} // namespace hwy
Expand Down
Loading