From e5c62ff0b232be6e5935b5b73c6b5ea97575907d Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Tue, 16 Jul 2024 10:41:42 +0200 Subject: [PATCH] ff/baby_bear.hpp: add dedicated bb31_4_t::batch_inversion. --- ff/baby_bear.hpp | 67 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 56 insertions(+), 11 deletions(-) diff --git a/ff/baby_bear.hpp b/ff/baby_bear.hpp index 9950380..3e866c0 100644 --- a/ff/baby_bear.hpp +++ b/ff/baby_bear.hpp @@ -484,27 +484,39 @@ class __align__(16) bb31_4_t { inline bb31_4_t& operator-=(bb31_t b) { c[0] -= b; return *this; } +private: + // don't bother with breaking these down, 1/x dominates. + inline bb31_t recip_b0(bb31_t beta) const + { return c[0]*c[0] - beta*(c[1]*bb31_t{u[3]<<1} - c[2]*c[2]); } + inline bb31_t recip_b2(bb31_t beta) const + { return c[0]*bb31_t{u[2]<<1} - c[1]*c[1] - beta*(c[3]*c[3]); } + inline bb31_4_t recip_ret(bb31_t b0, bb31_t b2, bb31_t beta) const + { + bb31_4_t ret; + bb31_t beta_b2 = beta*b2; + + ret[0] = c[0]*b0 - c[2]*beta_b2; + ret[1] = c[3]*beta_b2 - c[1]*b0; + ret[2] = c[2]*b0 - c[0]*b2; + ret[3] = c[1]*b2 - c[3]*b0; + + return ret; + } + +public: inline bb31_4_t reciprocal() const { const bb31_t beta{BETA}; - // don't bother with breaking this down, 1/x dominates. - bb31_t b0 = c[0]*c[0] - beta*(c[1]*bb31_t{u[3]<<1} - c[2]*c[2]); - bb31_t b2 = c[0]*bb31_t{u[2]<<1} - c[1]*c[1] - beta*(c[3]*c[3]); + bb31_t b0 = recip_b0(beta); + bb31_t b2 = recip_b2(beta); bb31_t inv = 1/(b0*b0 - beta*b2*b2); b0 *= inv; b2 *= inv; - bb31_4_t ret; - bb31_t beta_b2 = beta*b2; - ret[0] = c[0]*b0 - c[2]*beta_b2; - ret[1] = c[3]*beta_b2 - c[1]*b0; - ret[2] = c[2]*b0 - c[0]*b2; - ret[3] = c[1]*b2 - c[3]*b0; - - return ret; + return recip_ret(b0, b2, beta); } friend inline bb31_4_t operator/(int one, const bb31_4_t& a) { assert(one == 1); return a.reciprocal(); } @@ -519,6 +531,39 @@ class __align__(16) bb31_4_t { inline bb31_4_t& operator/=(bb31_t a) { return *this *= a.reciprocal(); } +# ifdef __SPPARK_FF_BATCH_INVERSION_HPP__ + template + friend inline void batch_inversion(bb31_4_t out[N], const S inp) + { + const bb31_t beta{BETA}; + bb31_t b0[N], b2[N], bx[N]; + + for (size_t i = 0; i < N; i++) { + bb31_4_t tmp = inp[i]; + b0[i] = tmp.recip_b0(beta); + b2[i] = tmp.recip_b2(beta); + bx[i] = b0[i]*b0[i] - beta*b2[i]*b2[i]; + } + + bb31_t inv[N]; + + batch_inversion(inv, bx); + + for (size_t i = N; i--;) { + b0[i] *= inv[i]; + b2[i] *= inv[i]; + bb31_4_t tmp = inp[i]; + out[i] = tmp.recip_ret(b0[i], b2[i], beta); + } + } + + // Unlike the generic batch_inversion bb31_4_t procedure + // can perform the inversion in-place. + template + friend inline void batch_inversion(bb31_4_t inout[N]) + { batch_inversion(inout, inout); } +# endif + inline bool is_one() const { return c[0].is_one() & u[1]==0 & u[2]==0 & u[3]==0; } inline bool is_zero() const