Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float operations SqrtLower, MulSubAdd, GetExponent etc. #2425

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions g3doc/quick_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,10 @@ from left to right, of the arguments passed to `Create{2-4}`.
<code>V **ApproximateReciprocal**(V a)</code>: returns an approximation of
`1.0 / a[i]`.

* `V`: `{f}` \
<code>V **GetExponent**(V v)</code>: returns the exponent of `v[i]` as a floating point value.
Essentially calculates `floor(log2(x))`.

#### Min/Max

**Note**: Min/Max corner cases are target-specific and may change. If either
Expand Down Expand Up @@ -864,6 +868,10 @@ variants are somewhat slower on Arm, and unavailable for integer inputs; if the
c))` or `MulAddSub(a, b, OddEven(c, Neg(c))`, but `MulSub(a, b, c)` is more
efficient on some targets (including AVX2/AVX3).

* <code>V **MulSubAdd**(V a, V b, V c)</code>: returns `a[i] * b[i] + c[i]` in
the even lanes and `a[i] * b[i] - c[i]` in the odd lanes. Essentially,
MulAddSub with `c[i]` negated.

* `V`: `bf16`, `D`: `RepartitionToWide<DFromV<V>>`, `VW`: `Vec<D>` \
<code>VW **MulEvenAdd**(D d, V a, V b, VW c)</code>: equivalent to and
potentially more efficient than `MulAdd(PromoteEvenTo(d, a),
Expand All @@ -881,6 +889,9 @@ exceptions for those lanes if that is supported by the ISA. When exceptions are
not a concern, these are equivalent to, and potentially more efficient than,
`IfThenElse(m, Add(a, b), no);` etc.

* `V`: `{f}` \
<code>V **MaskedSqrtOr**(V no, M m, V a)</code>: returns `sqrt(a[i])` or
`no[i]` if `m[i]` is false.
* <code>V **MaskedMinOr**(V no, M m, V a, V b)</code>: returns `Min(a, b)[i]`
or `no[i]` if `m[i]` is false.
* <code>V **MaskedMaxOr**(V no, M m, V a, V b)</code>: returns `Max(a, b)[i]`
Expand All @@ -905,6 +916,21 @@ not a concern, these are equivalent to, and potentially more efficient than,
b[i]` saturated to the minimum/maximum representable value, or `no[i]` if
`m[i]` is false.

#### Zero masked arithmetic

All ops in this section return `0` for `mask=false` lanes. These are equivalent
to, and potentially more efficient than, `IfThenElseZero(m, Add(a, b));` etc.

* `V`: `{f}` \
<code>V **MaskedSqrt**(M m, V a)</code>: returns `sqrt(a[i])` where
m is true, and zero otherwise.
* `V`: `{f}` \
<code>V **MaskedApproximateReciprocalSqrt**(M m, V a)</code>: returns
the result of ApproximateReciprocalSqrt where m is true and zero otherwise.
* `V`: `{f}` \
<code>V **MaskedApproximateReciprocal**(M m, V a)</code>: returns the
result of ApproximateReciprocal where m is true and zero otherwise.

#### Shifts

**Note**: Counts not in `[0, sizeof(T)*8)` yield implementation-defined results.
Expand Down
53 changes: 53 additions & 0 deletions hwy/ops/arm_sve-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,14 @@ HWY_SVE_FOREACH_BF16_UNCONDITIONAL(HWY_SPECIALIZE, _, _)
HWY_API HWY_SVE_V(BASE, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS(v); \
}
#define HWY_SVE_RETV_ARGMV(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_x(m, v); \
}
#define HWY_SVE_RETV_ARGMV_Z(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(BASE, BITS) NAME(svbool_t m, HWY_SVE_V(BASE, BITS) a) { \
return sv##OP##_##CHAR##BITS##_z(m, a); \
}

// vector = f(vector, scalar), e.g. detail::AddN
#define HWY_SVE_RETV_ARGPVN(BASE, CHAR, BITS, HALF, NAME, OP) \
Expand Down Expand Up @@ -1234,6 +1242,15 @@ HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGV, ApproximateReciprocal, recpe)
// ------------------------------ Sqrt
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPV, Sqrt, sqrt)

// ------------------------------ MaskedSqrt
#ifdef HWY_NATIVE_MASKED_SQRT
#undef HWY_NATIVE_MASKED_SQRT
#else
#define HWY_NATIVE_MASKED_SQRT
#endif

HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV_Z, MaskedSqrt, sqrt)

// ------------------------------ ApproximateReciprocalSqrt
#ifdef HWY_NATIVE_F64_APPROX_RSQRT
#undef HWY_NATIVE_F64_APPROX_RSQRT
Expand Down Expand Up @@ -1521,6 +1538,7 @@ HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_UI32(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_UI64(HWY_SVE_RETV_ARGMVV, MaskedDiv, div)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMV, MaskedSqrt, sqrt)
#if HWY_SVE_HAVE_2
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatAdd, qadd)
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedSatSub, qsub)
Expand Down Expand Up @@ -1584,6 +1602,11 @@ HWY_API V MaskedSatSubOr(V no, M m, V a, V b) {
}
#endif

template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrtOr(V no, M m, V v) {
return IfThenElse(m, detail::MaskedSqrt(m, v), no);
}

// ================================================== REDUCE

#ifdef HWY_NATIVE_REDUCE_SCALAR
Expand Down Expand Up @@ -3094,6 +3117,34 @@ HWY_API VFromD<D> Iota(const D d, T2 first) {
ConvertScalarTo<TFromD<D>>(first));
}

// ------------------------------ GetExponent

#if HWY_SVE_HAVE_2 || HWY_IDE
#ifdef HWY_NATIVE_GET_EXPONENT
#undef HWY_NATIVE_GET_EXPONENT
#else
#define HWY_NATIVE_GET_EXPONENT
#endif

namespace detail {
#define HWY_SVE_GET_EXP(BASE, CHAR, BITS, HALF, NAME, OP) \
HWY_API HWY_SVE_V(int, BITS) NAME(HWY_SVE_V(BASE, BITS) v) { \
return sv##OP##_##CHAR##BITS##_x(HWY_SVE_PTRUE(BITS), v); \
}
HWY_SVE_FOREACH_F(HWY_SVE_GET_EXP, GetExponent, logb)
#undef HWY_SVE_GET_EXP
} // namespace detail

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V GetExponent(V v) {
const DFromV<V> d;
const RebindToSigned<decltype(d)> di;
const VFromD<decltype(di)> exponent_int = detail::GetExponent(v);
// convert integer to original type
return ConvertTo(d, exponent_int);
}
#endif // HWY_SVE_HAVE_2

// ------------------------------ InterleaveLower

template <class D, class V>
Expand Down Expand Up @@ -6352,6 +6403,8 @@ HWY_API V HighestSetBitIndex(V v) {
#undef HWY_SVE_IF_NOT_EMULATED_D
#undef HWY_SVE_PTRUE
#undef HWY_SVE_RETV_ARGMVV
#undef HWY_SVE_RETV_ARGMV_Z
#undef HWY_SVE_RETV_ARGMV
#undef HWY_SVE_RETV_ARGPV
#undef HWY_SVE_RETV_ARGPVN
#undef HWY_SVE_RETV_ARGPVV
Expand Down
73 changes: 73 additions & 0 deletions hwy/ops/generic_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,34 @@ HWY_API V MulByFloorPow2(V v, V exp) {

#endif // HWY_NATIVE_MUL_BY_POW2

// ------------------------------ GetExponent

#if (defined(HWY_NATIVE_GET_EXPONENT) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_GET_EXPONENT
#undef HWY_NATIVE_GET_EXPONENT
#else
#define HWY_NATIVE_GET_EXPONENT
#endif

template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V GetExponent(V v) {
const DFromV<V> d;
using T = TFromV<V>;
const RebindToUnsigned<decltype(d)> du;
const RebindToSigned<decltype(d)> di;

constexpr uint8_t mantissa_bits = MantissaBits<T>();
const auto exponent_offset = Set(di, MaxExponentField<T>() >> 1);

// extract exponent bits as integer
const auto encoded_exponent = ShiftRight<mantissa_bits>(BitCast(du, Abs(v)));
const auto exponent_int = Sub(BitCast(di, encoded_exponent), exponent_offset);

// convert integer to original type
return ConvertTo(d, exponent_int);
}

#endif // HWY_NATIVE_GET_EXPONENT
// ------------------------------ LoadInterleaved2

#if HWY_IDE || \
Expand Down Expand Up @@ -4409,6 +4437,19 @@ HWY_API V MulAddSub(V mul, V x, V sub_or_add) {
OddEven(sub_or_add, BitCast(d, Neg(BitCast(d_negate, sub_or_add))));
return MulAdd(mul, x, add);
}
// ------------------------------ MulSubAdd

template <class V>
HWY_API V MulSubAdd(V mul, V x, V sub_or_add) {
using D = DFromV<V>;
using T = TFromD<D>;
using TNegate = If<!IsSigned<T>(), MakeSigned<T>, T>;

const D d;
const Rebind<TNegate, D> d_negate;

return MulAddSub(mul, x, BitCast(d, Neg(BitCast(d_negate, sub_or_add))));
}

// ------------------------------ Integer division
#if (defined(HWY_NATIVE_INT_DIV) == defined(HWY_TARGET_TOGGLE))
Expand Down Expand Up @@ -5234,6 +5275,26 @@ HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 di32,

#endif // HWY_NATIVE_I16_SATWIDENMULACCUMFIXEDPOINT

// ------------------------------ MaskedSqrt

#if (defined(HWY_NATIVE_MASKED_SQRT) == defined(HWY_TARGET_TOGGLE))

#ifdef HWY_NATIVE_MASKED_SQRT
#undef HWY_NATIVE_MASKED_SQRT
#else
#define HWY_NATIVE_MASKED_SQRT
#endif
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrt(M m, V v) {
return IfThenElseZero(m, Sqrt(v));
}

template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedSqrtOr(V no, M m, V v) {
return IfThenElse(m, Sqrt(v), no);
}
#endif

// ------------------------------ SumOfMulQuadAccumulate

#if (defined(HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE) == \
Expand Down Expand Up @@ -5418,6 +5479,12 @@ HWY_API V ApproximateReciprocal(V v) {

#endif // HWY_NATIVE_F64_APPROX_RECIP

// ------------------------------ MaskedApproximateReciprocal
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedApproximateReciprocal(M m, V v) {
return IfThenElseZero(m, ApproximateReciprocal(v));
}

// ------------------------------ F64 ApproximateReciprocalSqrt

#if (defined(HWY_NATIVE_F64_APPROX_RSQRT) == defined(HWY_TARGET_TOGGLE))
Expand All @@ -5443,6 +5510,12 @@ HWY_API V ApproximateReciprocalSqrt(V v) {

#endif // HWY_NATIVE_F64_APPROX_RSQRT

// ------------------------------ MaskedApproximateReciprocalSqrt
template <class V, HWY_IF_FLOAT_V(V), class M>
HWY_API V MaskedApproximateReciprocalSqrt(M m, V v) {
return IfThenElseZero(m, ApproximateReciprocalSqrt(v));
}

// ------------------------------ Compress*

#if (defined(HWY_NATIVE_COMPRESS8) == defined(HWY_TARGET_TOGGLE))
Expand Down
3 changes: 3 additions & 0 deletions hwy/ops/ppc_vsx-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,9 @@ HWY_API Vec128<T, N> ApproximateReciprocal(Vec128<T, N> v) {
#endif
}

// TODO: Implement GetExponent using vec_extract_exp (which returns the biased
// exponent) followed by a subtraction by MaxExponentField<T>() >> 1

// ------------------------------ Floating-point square root

#if HWY_S390X_HAVE_Z14
Expand Down
2 changes: 2 additions & 0 deletions hwy/ops/x86_512-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,8 @@ HWY_API Vec512<double> ApproximateReciprocal(Vec512<double> v) {
return Vec512<double>{_mm512_rcp14_pd(v.raw)};
}

// TODO: Implement GetExponent using _mm_getexp_ps/_mm_getexp_pd/_mm_getexp_ph

// ------------------------------ MaskedMinOr

template <typename T, HWY_IF_U8(T)>
Expand Down
Loading
Loading