Skip to content

Commit

Permalink
Add ff/batch_inversion.hpp.
Browse files Browse the repository at this point in the history
  • Loading branch information
dot-asm committed Jul 24, 2024
1 parent 74643bc commit a176292
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions ff/batch_inversion.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// Copyright Supranational LLC
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

#ifndef __SPPARK_FF_BATCH_INVERSION_HPP__
#define __SPPARK_FF_BATCH_INVERSION_HPP__

/*
* Since the batch inversion requires twice the storage, on GPU there
* is incentive to use the shared memory. If deemed beneficial, the
* suggestion is to have the caller wrap T[] in S with custom operator[]
* that would address the shared memory and offload the input.
*/
template<class T, size_t N, typename S = T[N]>
#ifdef __CUDACC__
__device__ __host__ __forceinline__
#endif
static void batch_inversion(T out[N], const S inp, bool preloaded = false)
{
static_assert(N <= 32, "too large N");

if (!preloaded)
out[0] = inp[0];

bool zero = out[0].is_zero();
out[0] = T::csel(T::one(), out[0], zero);
unsigned int map = zero;

for (size_t i = 1; i < N; i++) {
if (!preloaded)
out[i] = inp[i];
zero = out[i].is_zero();
out[i] *= out[i-1];
out[i] = T::csel(out[i-1], out[i], zero);
map = (map << 1) + zero;
}

T tmp, inv = 1/out[N-1];

for (size_t i = N; --i; map >>= 1) {
out[i] = inv*out[i-1];
tmp = inp[i];
tmp *= inv;
inv = T::csel(inv, tmp, map&1);
out[i] = czero(out[i], map&1);
}

out[0] = czero(inv, map);
}
#endif

0 comments on commit a176292

Please sign in to comment.