Skip to content

Commit

Permalink
ff/baby_bear.hpp: add dedicated bb31_4_t::batch_inversion.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jul 24, 2024
1 parent a176292 commit e5c62ff
Showing 1 changed file with 56 additions and 11 deletions.
67 changes: 56 additions & 11 deletions ff/baby_bear.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,27 +484,39 @@ class __align__(16) bb31_4_t {
inline bb31_4_t& operator-=(bb31_t b)
{ c[0] -= b; return *this; }

private:
// don't bother with breaking these down, 1/x dominates.
inline bb31_t recip_b0(bb31_t beta) const
{ return c[0]*c[0] - beta*(c[1]*bb31_t{u[3]<<1} - c[2]*c[2]); }
inline bb31_t recip_b2(bb31_t beta) const
{ return c[0]*bb31_t{u[2]<<1} - c[1]*c[1] - beta*(c[3]*c[3]); }
inline bb31_4_t recip_ret(bb31_t b0, bb31_t b2, bb31_t beta) const
{
bb31_4_t ret;
bb31_t beta_b2 = beta*b2;

ret[0] = c[0]*b0 - c[2]*beta_b2;
ret[1] = c[3]*beta_b2 - c[1]*b0;
ret[2] = c[2]*b0 - c[0]*b2;
ret[3] = c[1]*b2 - c[3]*b0;

return ret;
}

public:
inline bb31_4_t reciprocal() const
{
const bb31_t beta{BETA};

// don't bother with breaking this down, 1/x dominates.
bb31_t b0 = c[0]*c[0] - beta*(c[1]*bb31_t{u[3]<<1} - c[2]*c[2]);
bb31_t b2 = c[0]*bb31_t{u[2]<<1} - c[1]*c[1] - beta*(c[3]*c[3]);
bb31_t b0 = recip_b0(beta);
bb31_t b2 = recip_b2(beta);

bb31_t inv = 1/(b0*b0 - beta*b2*b2);

b0 *= inv;
b2 *= inv;

bb31_4_t ret;
bb31_t beta_b2 = beta*b2;
ret[0] = c[0]*b0 - c[2]*beta_b2;
ret[1] = c[3]*beta_b2 - c[1]*b0;
ret[2] = c[2]*b0 - c[0]*b2;
ret[3] = c[1]*b2 - c[3]*b0;

return ret;
return recip_ret(b0, b2, beta);
}
friend inline bb31_4_t operator/(int one, const bb31_4_t& a)
{ assert(one == 1); return a.reciprocal(); }
Expand All @@ -519,6 +531,39 @@ class __align__(16) bb31_4_t {
inline bb31_4_t& operator/=(bb31_t a)
{ return *this *= a.reciprocal(); }

# ifdef __SPPARK_FF_BATCH_INVERSION_HPP__
template<size_t N, typename S = bb31_4_t[N]>
friend inline void batch_inversion(bb31_4_t out[N], const S inp)
{
const bb31_t beta{BETA};
bb31_t b0[N], b2[N], bx[N];

for (size_t i = 0; i < N; i++) {
bb31_4_t tmp = inp[i];
b0[i] = tmp.recip_b0(beta);
b2[i] = tmp.recip_b2(beta);
bx[i] = b0[i]*b0[i] - beta*b2[i]*b2[i];
}

bb31_t inv[N];

batch_inversion<bb31_t, N>(inv, bx);

for (size_t i = N; i--;) {
b0[i] *= inv[i];
b2[i] *= inv[i];
bb31_4_t tmp = inp[i];
out[i] = tmp.recip_ret(b0[i], b2[i], beta);
}
}

// Unlike the generic batch_inversion<T, N> bb31_4_t procedure
// can perform the inversion in-place.
template<size_t N>
friend inline void batch_inversion(bb31_4_t inout[N])
{ batch_inversion<N>(inout, inout); }
# endif

inline bool is_one() const
{ return c[0].is_one() & u[1]==0 & u[2]==0 & u[3]==0; }
inline bool is_zero() const
Expand Down

0 comments on commit e5c62ff

Please sign in to comment.