Skip to content

Commit c332cc9

Browse files
committed
ntt/{kernels.cu,kernels/*}: switch to shfl_bfly() method and clean up.
1 parent e5add3f commit c332cc9

File tree

5 files changed

+4
-29
lines changed

5 files changed

+4
-29
lines changed

ntt/kernels.cu

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,6 @@
77

88
#include <cooperative_groups.h>
99

10-
#ifdef __CUDA_ARCH__
11-
__device__ __forceinline__
12-
void shfl_bfly(fr_t& r, int laneMask)
13-
{
14-
#pragma unroll
15-
for (int iter = 0; iter < r.len(); iter++)
16-
r[iter] = __shfl_xor_sync(0xFFFFFFFF, r[iter], laneMask);
17-
}
18-
#endif
19-
20-
__device__ __forceinline__
21-
void shfl_bfly(index_t& index, int laneMask)
22-
{
23-
index = __shfl_xor_sync(0xFFFFFFFF, index, laneMask);
24-
}
25-
26-
template<typename T>
27-
__device__ __forceinline__
28-
void swap(T& u1, T& u2)
29-
{
30-
T temp = u1;
31-
u1 = u2;
32-
u2 = temp;
33-
}
34-
3510
template<typename T>
3611
__device__ __forceinline__
3712
T bit_rev(T i, unsigned int nbits)

ntt/kernels/ct_mixed_radix_narrow.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size,
100100
for (int z = 0; z < z_count; z++) {
101101
fr_t t = fr_t::csel(r[1][z], r[0][z], pos);
102102

103-
shfl_bfly(t, laneMask);
103+
t.shfl_bfly(laneMask);
104104

105105
r[0][z] = fr_t::csel(t, r[0][z], !pos);
106106
r[1][z] = fr_t::csel(t, r[1][z], pos);

ntt/kernels/ct_mixed_radix_wide.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void _CT_NTT(const unsigned int radix, const unsigned int lg_domain_size,
8080

8181
#ifdef __CUDA_ARCH__
8282
fr_t x = fr_t::csel(r1, r0, pos);
83-
shfl_bfly(x, laneMask);
83+
x.shfl_bfly(laneMask);
8484
r0 = fr_t::csel(x, r0, !pos);
8585
r1 = fr_t::csel(x, r1, pos);
8686
#endif

ntt/kernels/gs_mixed_radix_narrow.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void _GS_NTT(const unsigned int radix, const unsigned int lg_domain_size,
104104
#ifdef __CUDA_ARCH__
105105
t = fr_t::csel(r[1][z], r[0][z], pos);
106106

107-
shfl_bfly(t, laneMask);
107+
t.shfl_bfly(laneMask);
108108

109109
r[0][z] = fr_t::csel(t, r[0][z], !pos);
110110
r[1][z] = fr_t::csel(t, r[1][z], pos);

ntt/kernels/gs_mixed_radix_wide.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void _GS_NTT(const unsigned int radix, const unsigned int lg_domain_size,
7070
bool pos = rank < laneMask;
7171
#ifdef __CUDA_ARCH__
7272
t = fr_t::csel(r1, r0, pos);
73-
shfl_bfly(t, laneMask);
73+
t.shfl_bfly(laneMask);
7474
r0 = fr_t::csel(t, r0, !pos);
7575
r1 = fr_t::csel(t, r1, pos);
7676
#endif

0 commit comments

Comments
 (0)