Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add templates for recip_to_real in the pw_basis #6023

Merged
merged 10 commits into from
Mar 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions source/module_basis/module_pw/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ list(APPEND objects
pw_distributer.cpp
pw_init.cpp
pw_transform.cpp
pw_transform_gpu.cpp
pw_transform_k.cpp
module_fft/fft_bundle.cpp
module_fft/fft_cpu.cpp
Expand Down
87 changes: 81 additions & 6 deletions source/module_basis/module_pw/kernels/cuda/pw_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,24 @@ __global__ void set_recip_to_real_output(
}
}

template<class FPTYPE>
__global__ void set_recip_to_real_output(
const int nrxx,
const bool add,
const FPTYPE factor,
const thrust::complex<FPTYPE>* in,
FPTYPE* out)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx >= nrxx) {return;}
if(add) {
out[idx] += factor * in[idx].real();
}
else {
out[idx] = in[idx].real();
}
}

template<class FPTYPE>
__global__ void set_real_to_recip_output(
const int npwk,
Expand All @@ -61,9 +79,28 @@ __global__ void set_real_to_recip_output(
}
}

template<class FPTYPE>
__global__ void set_real_to_recip_output(
const int npwk,
const int nxyz,
const bool add,
const FPTYPE factor,
const int* box_index,
const thrust::complex<FPTYPE>* in,
FPTYPE* out)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if(idx >= npwk) {return;}
if(add) {
out[idx] += factor / nxyz * in[box_index[idx]].real();
}
else {
out[idx] = in[box_index[idx]].real() / nxyz;
}
}

template <typename FPTYPE>
void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
const int npwk,
void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
const int* box_index,
const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out)
Expand All @@ -79,8 +116,7 @@ void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_d
}

template <typename FPTYPE>
void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
const int nrxx,
void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
const bool add,
const FPTYPE factor,
const std::complex<FPTYPE>* in,
Expand All @@ -98,8 +134,25 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
}

template <typename FPTYPE>
void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
const int npwk,
void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
const bool add,
const FPTYPE factor,
const std::complex<FPTYPE>* in,
FPTYPE* out)
{
const int block = (nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
set_recip_to_real_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
nrxx,
add,
factor,
reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
reinterpret_cast<FPTYPE*>(out));

cudaCheckOnDebug();
}

template <typename FPTYPE>
void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
const int nxyz,
const bool add,
const FPTYPE factor,
Expand All @@ -120,6 +173,28 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
cudaCheckOnDebug();
}

template <typename FPTYPE>
void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
const int nxyz,
const bool add,
const FPTYPE factor,
const int* box_index,
const std::complex<FPTYPE>* in,
FPTYPE* out)
{
const int block = (npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
set_real_to_recip_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
npwk,
nxyz,
add,
factor,
box_index,
reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
reinterpret_cast<FPTYPE*>(out));

cudaCheckOnDebug();
}

template struct set_3d_fft_box_op<float, base_device::DEVICE_GPU>;
template struct set_recip_to_real_output_op<float, base_device::DEVICE_GPU>;
template struct set_real_to_recip_output_op<float, base_device::DEVICE_GPU>;
Expand Down
31 changes: 25 additions & 6 deletions source/module_basis/module_pw/kernels/pw_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@ namespace ModulePW {
template <typename FPTYPE>
struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* /*dev*/,
const int npwk,
void operator()(const int npwk,
const int* box_index,
const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out)
Expand All @@ -21,8 +20,7 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU>
template <typename FPTYPE>
struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* /*dev*/,
const int nrxx,
void operator()(const int nrxx,
const bool add,
const FPTYPE factor,
const std::complex<FPTYPE>* in,
Expand All @@ -39,13 +37,34 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
}
}
}

void operator()(const int nrxx,
const bool add,
const FPTYPE factor,
const std::complex<FPTYPE>* in,
FPTYPE* out)
{
if (add)
{
for (int ir = 0; ir < nrxx; ++ir)
{
out[ir] += factor * in[ir].real();
}
}
else
{
for (int ir = 0; ir < nrxx; ++ir)
{
out[ir] = in[ir].real();
}
}
}
};

template <typename FPTYPE>
struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>
{
void operator()(const base_device::DEVICE_CPU* /*dev*/,
const int npw_k,
void operator()(const int npw_k,
const int nxyz,
const bool add,
const FPTYPE factor,
Expand Down
41 changes: 32 additions & 9 deletions source/module_basis/module_pw/kernels/pw_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ struct set_3d_fft_box_op {
/// Output Parameters
/// @param out - output psi within the 3D box(in recip space)
void operator() (
const Device* dev,
const int npwk,
const int* box_index,
const std::complex<FPTYPE>* in,
Expand All @@ -39,12 +38,18 @@ struct set_recip_to_real_output_op {
/// Output Parameters
/// @param out - output psi within the 3D box(in real space)
void operator() (
const Device* dev,
const int nrxx,
const bool add,
const FPTYPE factor,
const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out);

void operator() (
const int nrxx,
const bool add,
const FPTYPE factor,
const std::complex<FPTYPE>* in,
FPTYPE* out);
};

template <typename FPTYPE, typename Device>
Expand All @@ -62,23 +67,30 @@ struct set_real_to_recip_output_op {
/// Output Parameters
/// @param out - output psi within the 3D box(in recip space)
void operator() (
const Device* dev,
const int npw_k,
const int nxyz,
const bool add,
const FPTYPE factor,
const int* box_index,
const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out);

void operator() (
const int npw_k,
const int nxyz,
const bool add,
const FPTYPE factor,
const int* box_index,
const std::complex<FPTYPE>* in,
FPTYPE* out);
};

#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
// Partially specialize functor for base_device::GpuDevice.
template <typename FPTYPE>
struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>
{
void operator()(const base_device::DEVICE_GPU* dev,
const int npwk,
void operator()(const int npwk,
const int* box_index,
const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out);
Expand All @@ -87,25 +99,36 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>
template <typename FPTYPE>
struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>
{
void operator()(const base_device::DEVICE_GPU* dev,
const int nrxx,
void operator()(const int nrxx,
const bool add,
const FPTYPE factor,
const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out);

void operator()(const int nrxx,
const bool add,
const FPTYPE factor,
const std::complex<FPTYPE>* in,
FPTYPE* out);
};

template <typename FPTYPE>
struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>
{
void operator()(const base_device::DEVICE_GPU* dev,
const int npw_k,
void operator()(const int npw_k,
const int nxyz,
const bool add,
const FPTYPE factor,
const int* box_index,
const std::complex<FPTYPE>* in,
std::complex<FPTYPE>* out);
void operator()(const int npw_k,
const int nxyz,
const bool add,
const FPTYPE factor,
const int* box_index,
const std::complex<FPTYPE>* in,
FPTYPE* out);
};

#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
Expand Down
Loading
Loading