From bdbb70b0b9ef15f0a4f38a074c078d0cd2e9c102 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Wed, 19 Mar 2025 09:40:14 +0800 Subject: [PATCH 1/8] update recip_to_real in the rhpw --- source/module_basis/module_pw/pw_basis.h | 36 +- .../module_basis/module_pw/pw_transform.cpp | 361 ++++++++++++++---- source/module_elecstate/elecstate_pw.cpp | 3 +- 3 files changed, 329 insertions(+), 71 deletions(-) diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index 47827000eb..4f1ad9a1ee 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -267,6 +267,35 @@ class PW_Basis const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + template <typename FPTYPE, typename Device> + void real_to_recip(const std::complex<FPTYPE>* in, + std::complex<FPTYPE>* out, + const Device* ctx=get_default_device_ctx(), + const int ik=0, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) + template <typename FPTYPE, typename Device> + void recip_to_real(const std::complex<FPTYPE>* in, + std::complex<FPTYPE>* out, + const Device* ctx=get_default_device_ctx(), + const int ik=0, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + + template <typename FPTYPE, typename Device = base_device::DEVICE_CPU> + void real_to_recip(FPTYPE* in, + std::complex<FPTYPE>* out, + const Device* ctx=get_default_device_ctx(), + const int ik=0, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) + template <typename FPTYPE, typename Device = base_device::DEVICE_CPU> + void recip_to_real(const std::complex<FPTYPE>* in, + FPTYPE* out, + const Device* ctx=get_default_device_ctx() , + const int ik=0, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) protected: //gather planes and scatter sticks of all processors template <typename T> @@ -282,15 +311,16 @@ class PW_Basis using resmem_int_op = base_device::memory::resize_memory_op<int, base_device::DEVICE_GPU>; using delmem_int_op = base_device::memory::delete_memory_op<int, base_device::DEVICE_GPU>; - using syncmem_int_h2d_op - = base_device::memory::synchronize_memory_op<int, base_device::DEVICE_GPU, base_device::DEVICE_CPU>; - + using syncmem_int_h2d_op = base_device::memory::synchronize_memory_op<int, base_device::DEVICE_GPU, base_device::DEVICE_CPU>; + // using default_device_cpu = base_device::DEVICE_CPU; + void set_device(std::string device_); void set_precision(std::string precision_); protected: std::string device = "cpu"; std::string precision = "double"; + static const base_device::DEVICE_CPU* get_default_device_ctx(); }; } diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index 8e458b2561..38d4bfc49d 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -1,12 +1,19 @@ -#include "module_fft/fft_bundle.h" -#include <complex> -#include "pw_basis.h" -#include <cassert> #include "module_base/global_function.h" #include "module_base/timer.h" +#include "module_basis/module_pw/kernels/pw_op.h" +#include "module_fft/fft_bundle.h" +#include "pw_basis.h" #include "pw_gatherscatter.h" -namespace ModulePW { +#include <cassert> +#include <complex> + +namespace ModulePW +{ + const base_device::DEVICE_CPU* PW_Basis::get_default_device_ctx() { + static const base_device::DEVICE_CPU* default_device_cpu; + return default_device_cpu; +} /** * @brief transform real space to reciprocal space * @details c(g)=\int dr*f(r)*exp(-ig*r) @@ -24,25 +31,25 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in, assert(this->gamma_only == false); #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ir = 0 ; ir < this->nrxx ; ++ir) + for (int ir = 0; ir < this->nrxx; ++ir) { this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = in[ir]; } - this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>()); + this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>()); this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>()); - - this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(),fft_bundle.get_auxg_data<FPTYPE>()); - if(add) + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>()); + + if (add) { FPTYPE tmpfac = factor / FPTYPE(this->nxyz); #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ig = 0 ; ig < this->npw ; ++ig) + for (int ig = 0; ig < this->npw; ++ig) { out[ig] += tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]]; } @@ -51,9 +58,9 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in, { FPTYPE tmpfac = 1.0 / FPTYPE(this->nxyz); #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ig = 0 ; ig < this->npw ; ++ig) + for (int ig = 0; ig < this->npw; ++ig) { out[ig] = tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]]; } @@ -72,44 +79,44 @@ template <typename FPTYPE> void PW_Basis::real2recip(const FPTYPE* in, std::complex<FPTYPE>* out, const bool add, const FPTYPE factor) const { ModuleBase::timer::tick(this->classname, "real2recip"); - if(this->gamma_only) + if (this->gamma_only) { const int npy = this->ny * this->nplane; #ifdef _OPENMP -#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ix = 0 ; ix < this->nx ; ++ix) + for (int ix = 0; ix < this->nx; ++ix) { - for(int ipy = 0 ; ipy < npy ; ++ipy) + for (int ipy = 0; ipy < npy; ++ipy) { - this->fft_bundle.get_rspace_data<FPTYPE>()[ix*npy + ipy] = in[ix*npy + ipy]; + this->fft_bundle.get_rspace_data<FPTYPE>()[ix * npy + ipy] = in[ix * npy + ipy]; } } - this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>()); + this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>()); } else { #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ir = 0 ; ir < this->nrxx ; ++ir) + for (int ir = 0; ir < this->nrxx; ++ir) { - this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = std::complex<FPTYPE>(in[ir],0); + this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = std::complex<FPTYPE>(in[ir], 0); } - this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>()); + this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>()); } this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>()); - - this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(),fft_bundle.get_auxg_data<FPTYPE>()); - if(add) + this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>()); + + if (add) { FPTYPE tmpfac = factor / FPTYPE(this->nxyz); #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ig = 0 ; ig < this->npw ; ++ig) + for (int ig = 0; ig < this->npw; ++ig) { out[ig] += tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]]; } @@ -118,9 +125,9 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex<FPTYPE>* out, const boo { FPTYPE tmpfac = 1.0 / FPTYPE(this->nxyz); #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ig = 0 ; ig < this->npw ; ++ig) + for (int ig = 0; ig < this->npw; ++ig) { out[ig] = tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]]; } @@ -144,32 +151,32 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, ModuleBase::timer::tick(this->classname, "recip2real"); assert(this->gamma_only == false); #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int i = 0 ; i < this->nst * this->nz ; ++i) + for (int i = 0; i < this->nst * this->nz; ++i) { fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<FPTYPE>(0, 0); } #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ig = 0 ; ig < this->npw ; ++ig) + for (int ig = 0; ig < this->npw; ++ig) { this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]] = in[ig]; } this->fft_bundle.fftzbac(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>()); - this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(),this->fft_bundle.get_auxr_data<FPTYPE>()); + this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(), this->fft_bundle.get_auxr_data<FPTYPE>()); + + this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>()); - this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>()); - - if(add) + if (add) { #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ir = 0 ; ir < this->nrxx ; ++ir) + for (int ir = 0; ir < this->nrxx; ++ir) { out[ir] += factor * this->fft_bundle.get_auxr_data<FPTYPE>()[ir]; } @@ -177,9 +184,9 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, else { #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ir = 0 ; ir < this->nrxx ; ++ir) + for (int ir = 0; ir < this->nrxx; ++ir) { out[ir] = this->fft_bundle.get_auxr_data<FPTYPE>()[ir]; } @@ -199,17 +206,17 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo { ModuleBase::timer::tick(this->classname, "recip2real"); #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int i = 0 ; i < this->nst * this->nz ; ++i) + for (int i = 0; i < this->nst * this->nz; ++i) { fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<double>(0, 0); } #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ig = 0 ; ig < this->npw ; ++ig) + for (int ig = 0; ig < this->npw; ++ig) { this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]] = in[ig]; } @@ -217,49 +224,49 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(), this->fft_bundle.get_auxr_data<FPTYPE>()); - if(this->gamma_only) + if (this->gamma_only) { - this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_rspace_data<FPTYPE>()); + this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_rspace_data<FPTYPE>()); // r2c in place const int npy = this->ny * this->nplane; - if(add) + if (add) { #ifdef _OPENMP -#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ix = 0 ; ix < this->nx ; ++ix) + for (int ix = 0; ix < this->nx; ++ix) { - for(int ipy = 0 ; ipy < npy ; ++ipy) + for (int ipy = 0; ipy < npy; ++ipy) { - out[ix*npy + ipy] += factor * this->fft_bundle.get_rspace_data<FPTYPE>()[ix*npy + ipy]; + out[ix * npy + ipy] += factor * this->fft_bundle.get_rspace_data<FPTYPE>()[ix * npy + ipy]; } } } else { #ifdef _OPENMP -#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ix = 0 ; ix < this->nx ; ++ix) + for (int ix = 0; ix < this->nx; ++ix) { - for(int ipy = 0 ; ipy < npy ; ++ipy) + for (int ipy = 0; ipy < npy; ++ipy) { - out[ix*npy + ipy] = this->fft_bundle.get_rspace_data<FPTYPE>()[ix*npy + ipy]; + out[ix * npy + ipy] = this->fft_bundle.get_rspace_data<FPTYPE>()[ix * npy + ipy]; } } } } else { - this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>()); - if(add) + this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>()); + if (add) { #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ir = 0 ; ir < this->nrxx ; ++ir) + for (int ir = 0; ir < this->nrxx; ++ir) { out[ir] += factor * this->fft_bundle.get_auxr_data<FPTYPE>()[ir].real(); } @@ -267,9 +274,9 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo else { #ifdef _OPENMP -#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE)) +#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE)) #endif - for(int ir = 0 ; ir < this->nrxx ; ++ir) + for (int ir = 0; ir < this->nrxx; ++ir) { out[ir] = this->fft_bundle.get_auxr_data<FPTYPE>()[ir].real(); } @@ -278,6 +285,228 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo ModuleBase::timer::tick(this->classname, "recip2real"); } +template <> +void PW_Basis::real_to_recip(const std::complex<float>* in, + std::complex<float>* out, + const base_device::DEVICE_CPU* ctx, + const int ik, + const bool add, + const float factor) const +{ + this->real2recip(in, out, add, factor); +} +template <> +void PW_Basis::real_to_recip(const std::complex<double>* in, + std::complex<double>* out, + const base_device::DEVICE_CPU* ctx, + const int ik, + const bool add, + const double factor) const +{ + this->real2recip(in, out, add, factor); +} + +template <> +void PW_Basis::recip_to_real(const std::complex<float>* in, + std::complex<float>* out, + const base_device::DEVICE_CPU* ctx, + const int ik, + const bool add, + const float factor) const +{ + this->recip2real(in, out, add, factor); +} +template <> +void PW_Basis::recip_to_real(const std::complex<double>* in, + std::complex<double>* out, + const base_device::DEVICE_CPU* ctx, + const int ik, + const bool add, + const double factor) const +{ + this->recip2real(in, out, add, factor); +} + +template <> +void PW_Basis::real_to_recip(float* in, + std::complex<float>* out, + const base_device::DEVICE_CPU* ctx, + const int ik, + const bool add, + const float factor) const +{ + this->real2recip(in, out, add, factor); +} +template <> +void PW_Basis::real_to_recip(double* in, + std::complex<double>* out, + const base_device::DEVICE_CPU* ctx, + const int ik, + const bool add, + const double factor) const +{ + this->real2recip(in, out, add, factor); +} + +template <> +void PW_Basis::recip_to_real(const std::complex<float>* in, + float* out, + const base_device::DEVICE_CPU* ctx, + const int ik, + const bool add, + const float factor) const +{ + this->recip2real(in, out, add, factor); +} + +template <> +void PW_Basis::recip_to_real(const std::complex<double>* in, + double* out, + const base_device::DEVICE_CPU* ctx, + const int ik, + const bool add, + const double factor) const +{ + this->recip2real(in, out, add, factor); +} + +#if (defined(__CUDA) || defined(__ROCM)) +template <> +void PW_Basis::real_to_recip(const std::complex<float>* in, + std::complex<float>* out, + const base_device::DEVICE_GPU* ctx, + const int ik, + const bool add, + const float factor) const +{ + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + + base_device::memory::synchronize_memory_op<std::complex<float>, base_device::DEVICE_GPU, base_device::DEVICE_GPU>()( + this->fft_bundle.get_auxr_3d_data<float>(), + in, + this->nrxx); + + this->fft_bundle.fft3D_forward(ctx, + this->fft_bundle.get_auxr_3d_data<float>(), + this->fft_bundle.get_auxr_3d_data<float>()); + + set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(ctx, + npw, + this->nxyz, + add, + factor, + this->ig2isz, + this->fft_bundle.get_auxr_3d_data<float>(), + out); + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); +} +template <> +void PW_Basis::real_to_recip(const std::complex<double>* in, + std::complex<double>* out, + const base_device::DEVICE_GPU* ctx, + const int ik, + const bool add, + const double factor) const +{ + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + + base_device::memory::synchronize_memory_op<std::complex<double>, + base_device::DEVICE_GPU, + base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<double>(), + in, + this->nrxx); + + this->fft_bundle.fft3D_forward(ctx, + this->fft_bundle.get_auxr_3d_data<double>(), + this->fft_bundle.get_auxr_3d_data<double>()); + + set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(ctx, + npw, + this->nxyz, + add, + factor, + this->ig2isz, + this->fft_bundle.get_auxr_3d_data<double>(), + out); + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); +} + +template <> +void PW_Basis::recip_to_real(const std::complex<float>* in, + std::complex<float>* out, + const base_device::DEVICE_GPU* ctx, + const int ik, + const bool add, + const float factor) const +{ + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<float>(), this->nxyz); + base_device::memory::set_memory_op<std::complex<float>, base_device::DEVICE_GPU>()( + this->fft_bundle.get_auxr_3d_data<float>(), + 0, + this->nxyz); + + set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(ctx, + npw, + this->ig2isz, + in, + this->fft_bundle.get_auxr_3d_data<float>()); + this->fft_bundle.fft3D_backward(ctx, + this->fft_bundle.get_auxr_3d_data<float>(), + this->fft_bundle.get_auxr_3d_data<float>()); + + set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(ctx, + this->nrxx, + add, + factor, + this->fft_bundle.get_auxr_3d_data<float>(), + out); + + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); +} +template <> +void PW_Basis::recip_to_real(const std::complex<double>* in, + std::complex<double>* out, + const base_device::DEVICE_GPU* ctx, + const int ik, + const bool add, + const double factor) const +{ + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<double>(), this->nxyz); + base_device::memory::set_memory_op<std::complex<double>, base_device::DEVICE_GPU>()( + this->fft_bundle.get_auxr_3d_data<double>(), + 0, + this->nxyz); + + set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(ctx, + npw, + this->ig2isz, + in, + this->fft_bundle.get_auxr_3d_data<double>()); + this->fft_bundle.fft3D_backward(ctx, + this->fft_bundle.get_auxr_3d_data<double>(), + this->fft_bundle.get_auxr_3d_data<double>()); + + set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(ctx, + this->nrxx, + add, + factor, + this->fft_bundle.get_auxr_3d_data<double>(), + out); + + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); +} +#endif + template void PW_Basis::real2recip<float>(const float* in, std::complex<float>* out, const bool add, @@ -311,4 +540,4 @@ template void PW_Basis::recip2real<double>(const std::complex<double>* in, std::complex<double>* out, const bool add, const double factor) const; -} \ No newline at end of file +} // namespace ModulePW \ No newline at end of file diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index 739e497515..e09a1a3dfa 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -443,13 +443,12 @@ void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi) { this->addusdens_g(becsum, rhog); } - // transform back to real space using dense grids if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp) { for (int is = 0; is < PARAM.inp.nspin; is++) { - this->charge->rhopw->recip2real(this->rhog[is], this->rho[is]); + this->charge->rhopw->recip_to_real(this->rhog[is], this->rho[is]); } } } From 926f608371e5bb6b9ae93a577e97300440261e15 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Fri, 21 Mar 2025 16:07:36 +0800 Subject: [PATCH 2/8] add the operator in pw_basis --- source/module_basis/module_pw/CMakeLists.txt | 1 + .../module_pw/kernels/cuda/pw_op.cu | 80 ++++++ .../module_basis/module_pw/kernels/pw_op.cpp | 23 ++ source/module_basis/module_pw/kernels/pw_op.h | 33 +++ .../module_pw/kernels/rocm/pw_op.hip.cu | 80 ++++++ source/module_basis/module_pw/pw_basis.h | 168 ++++++++++--- .../module_basis/module_pw/pw_transform.cpp | 231 +----------------- .../module_pw/pw_transform_gpu.cpp | 173 +++++++++++++ source/module_elecstate/elecstate_pw.cpp | 2 +- 9 files changed, 535 insertions(+), 256 deletions(-) create mode 100644 source/module_basis/module_pw/pw_transform_gpu.cpp diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt index e365e12b5e..50b7ec3430 100644 --- a/source/module_basis/module_pw/CMakeLists.txt +++ b/source/module_basis/module_pw/CMakeLists.txt @@ -30,6 +30,7 @@ list(APPEND objects pw_distributer.cpp pw_init.cpp pw_transform.cpp + pw_transform_gpu.cpp pw_transform_k.cpp module_fft/fft_bundle.cpp module_fft/fft_cpu.cpp diff --git a/source/module_basis/module_pw/kernels/cuda/pw_op.cu b/source/module_basis/module_pw/kernels/cuda/pw_op.cu index a9128db318..d9418ca486 100644 --- a/source/module_basis/module_pw/kernels/cuda/pw_op.cu +++ b/source/module_basis/module_pw/kernels/cuda/pw_op.cu @@ -41,6 +41,24 @@ __global__ void set_recip_to_real_output( } } +template<class FPTYPE> +__global__ void set_recip_to_real_output( + const int nrxx, + const bool add, + const FPTYPE factor, + const thrust::complex<FPTYPE>* in, + FPTYPE* out) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= nrxx) {return;} + if(add) { + out[idx] += factor * in[idx].real(); + } + else { + out[idx] = in[idx].real(); + } +} + template<class FPTYPE> __global__ void set_real_to_recip_output( const int npwk, @@ -61,6 +79,26 @@ __global__ void set_real_to_recip_output( } } +template<class FPTYPE> +__global__ void set_real_to_recip_output( + const int npwk, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const thrust::complex<FPTYPE>* in, + FPTYPE* out) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= npwk) {return;} + if(add) { + out[idx] += factor / nxyz * in[box_index[idx]].real(); + } + else { + out[idx] = in[box_index[idx]].real() / nxyz; + } +} + template <typename FPTYPE> void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, const int npwk, @@ -97,6 +135,25 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co cudaCheckOnDebug(); } +template <typename FPTYPE> +void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex<FPTYPE>* in, + FPTYPE* out) +{ + const int block = (nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + set_recip_to_real_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>( + nrxx, + add, + factor, + reinterpret_cast<const thrust::complex<FPTYPE>*>(in), + reinterpret_cast<FPTYPE*>(out)); + + cudaCheckOnDebug(); +} + template <typename FPTYPE> void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, const int npwk, @@ -120,6 +177,29 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co cudaCheckOnDebug(); } +template <typename FPTYPE> +void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int npwk, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const std::complex<FPTYPE>* in, + FPTYPE* out) +{ + const int block = (npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + set_real_to_recip_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>( + npwk, + nxyz, + add, + factor, + box_index, + reinterpret_cast<const thrust::complex<FPTYPE>*>(in), + reinterpret_cast<FPTYPE*>(out)); + + cudaCheckOnDebug(); +} + template struct set_3d_fft_box_op<float, base_device::DEVICE_GPU>; template struct set_recip_to_real_output_op<float, base_device::DEVICE_GPU>; template struct set_real_to_recip_output_op<float, base_device::DEVICE_GPU>; diff --git a/source/module_basis/module_pw/kernels/pw_op.cpp b/source/module_basis/module_pw/kernels/pw_op.cpp index b5fb453354..e39e8a6a7c 100644 --- a/source/module_basis/module_pw/kernels/pw_op.cpp +++ b/source/module_basis/module_pw/kernels/pw_op.cpp @@ -39,6 +39,29 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU> } } } + + void operator()(const base_device::DEVICE_CPU* /*dev*/, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex<FPTYPE>* in, + FPTYPE* out) + { + if (add) + { + for (int ir = 0; ir < nrxx; ++ir) + { + out[ir] += factor * in[ir].real(); + } + } + else + { + for (int ir = 0; ir < nrxx; ++ir) + { + out[ir] = in[ir].real(); + } + } + } }; template <typename FPTYPE> diff --git a/source/module_basis/module_pw/kernels/pw_op.h b/source/module_basis/module_pw/kernels/pw_op.h index 8415ad9677..b547190606 100644 --- a/source/module_basis/module_pw/kernels/pw_op.h +++ b/source/module_basis/module_pw/kernels/pw_op.h @@ -45,6 +45,14 @@ struct set_recip_to_real_output_op { const FPTYPE factor, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out); + + void operator() ( + const Device* dev, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex<FPTYPE>* in, + FPTYPE* out); }; template <typename FPTYPE, typename Device> @@ -70,6 +78,16 @@ struct set_real_to_recip_output_op { const int* box_index, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out); + + void operator() ( + const Device* dev, + const int npw_k, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const std::complex<FPTYPE>* in, + FPTYPE* out); }; #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM @@ -93,6 +111,13 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU> const FPTYPE factor, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out); + + void operator()(const base_device::DEVICE_GPU* dev, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex<FPTYPE>* in, + FPTYPE* out); }; template <typename FPTYPE> @@ -106,6 +131,14 @@ struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU> const int* box_index, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out); + void operator()(const base_device::DEVICE_GPU* dev, + const int npw_k, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const std::complex<FPTYPE>* in, + FPTYPE* out); }; #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM diff --git a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu index a3f5fe2c2b..44be7d0671 100644 --- a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu +++ b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu @@ -42,6 +42,24 @@ __global__ void set_recip_to_real_output( } } +template<class FPTYPE> +__global__ void set_recip_to_real_output( + const int nrxx, + const bool add, + const FPTYPE factor, + const thrust::complex<FPTYPE>* in, + FPTYPE* out) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= nrxx) {return;} + if(add) { + out[idx] += factor * in[idx].real(); + } + else { + out[idx] = in[idx].real(); + } +} + template<class FPTYPE> __global__ void set_real_to_recip_output( const int npwk, @@ -62,6 +80,26 @@ __global__ void set_real_to_recip_output( } } +template<class FPTYPE> +__global__ void set_real_to_recip_output( + const int npwk, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const thrust::complex<FPTYPE>* in, + FPTYPE* out) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if(idx >= npwk) {return;} + if(add) { + out[idx] += factor / nxyz * in[box_index[idx]].real(); + } + else { + out[idx] = in[box_index[idx]].real() / nxyz; + } +} + template <typename FPTYPE> void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, const int npwk, @@ -98,6 +136,25 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co hipCheckOnDebug(); } +template <typename FPTYPE> +void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int nrxx, + const bool add, + const FPTYPE factor, + const std::complex<FPTYPE>* in, + FPTYPE* out) +{ + const int block = (nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + hipLaunchKernelGGL(HIP_KERNEL_NAME(set_recip_to_real_output<FPTYPE>), dim3(block), dim3(THREADS_PER_BLOCK), 0, 0, + nrxx, + add, + factor, + reinterpret_cast<const thrust::complex<FPTYPE>*>(in), + reinterpret_cast<FPTYPE*>(out)); + + hipCheckOnDebug(); +} + template <typename FPTYPE> void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, const int npwk, @@ -121,6 +178,29 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co hipCheckOnDebug(); } +template <typename FPTYPE> +void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, + const int npwk, + const int nxyz, + const bool add, + const FPTYPE factor, + const int* box_index, + const std::complex<FPTYPE>* in, + FPTYPE* out) +{ + const int block = (npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + hipLaunchKernelGGL(HIP_KERNEL_NAME(set_real_to_recip_output<FPTYPE>), dim3(block), dim3(THREADS_PER_BLOCK), 0, 0, + npwk, + nxyz, + add, + factor, + box_index, + reinterpret_cast<const thrust::complex<FPTYPE>*>(in), + reinterpret_cast<FPTYPE*>(out)); + + hipCheckOnDebug(); +} + template struct set_3d_fft_box_op<float, base_device::DEVICE_GPU>; template struct set_recip_to_real_output_op<float, base_device::DEVICE_GPU>; template struct set_real_to_recip_output_op<float, base_device::DEVICE_GPU>; diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index 4f1ad9a1ee..2baf937af8 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -1,6 +1,7 @@ #ifndef PWBASIS_H #define PWBASIS_H +#include "module_base/macros.h" #include "module_base/module_device/memory_op.h" #include "module_base/matrix.h" #include "module_base/matrix3.h" @@ -245,7 +246,7 @@ class PW_Basis // FFT ft; FFT_Bundle fft_bundle; //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform). - + template <typename FPTYPE> void real2recip(const FPTYPE* in, std::complex<FPTYPE>* out, @@ -266,36 +267,148 @@ class PW_Basis std::complex<FPTYPE>* out, const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + + template <typename FPTYPE> + void real2recip_gpu(const FPTYPE* in, + std::complex<FPTYPE>* out, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) + template <typename FPTYPE> + void real2recip_gpu(const std::complex<FPTYPE>* in, + std::complex<FPTYPE>* out, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) + template <typename FPTYPE> + void recip2real_gpu(const std::complex<FPTYPE>* in, + FPTYPE* out, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + template <typename FPTYPE> + void recip2real_gpu(const std::complex<FPTYPE>* in, + std::complex<FPTYPE>* out, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) - template <typename FPTYPE, typename Device> - void real_to_recip(const std::complex<FPTYPE>* in, - std::complex<FPTYPE>* out, - const Device* ctx=get_default_device_ctx(), - const int ik=0, - const bool add = false, - const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) - template <typename FPTYPE, typename Device> - void recip_to_real(const std::complex<FPTYPE>* in, - std::complex<FPTYPE>* out, - const Device* ctx=get_default_device_ctx(), - const int ik=0, + /** + * @brief Converts data from reciprocal space to real space on Cpu + * + * This function handles the conversion of data from reciprocal space (Fourier space) to real space. + * It supports complex types as input. + * The output can be either the same fundamental type or the underlying real type of a complex type. + * + * @tparam FPTYPE The type of the input data, which can only be a compelx type (e.g., std::complex<float>, std::complex<double>) + * @tparam Device The device type, must be base_device::DEVICE_CPU. + * @tparam std::enable_if<!std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type + * SFINAE constraint to ensure that FPTYPE is a complex type. + * @tparam std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type + * SFINAE constraint to ensure that Device is base_device::DEVICE_CPU. + * + * @param in Pointer to the input data array in reciprocal space. + * @param out Pointer to the output data array in real space. If FPTYPE is a complex type, + * this should point to an array of the underlying real type. + * @param add Boolean flag indicating whether to add the result to the existing values in the output array. + * @param factor A scaling factor applied to the output values. + */ + template <typename TK, + typename TR, + typename Device, + typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value + && (std::is_same<TR, typename GetTypeReal<TK>::type>::value + || std::is_same<TR, TK>::value) + && std::is_same<Device, base_device::DEVICE_CPU>::value, + int>::type + = 0> + void recip_to_real(TK* in, TR* out, const bool add = false, const typename GetTypeReal<TK>::type factor = 1.0) const + { + this->recip2real(in, out, add, factor); + }; + + /** + * @brief Converts data from reciprocal space (Fourier space) to real space. + * + * This function handles the conversion of data from reciprocal space (typically after a Fourier transform) + * to real space, supporting scenarios where the input is of a complex type and the output is of the underlying + * fundamental type. + * + * @tparam FPTYPE The underlying fundamental type of the input data (e.g., float, double). + * Note that the actual type passed should be std::complex<FPTYPE>. + * @tparam Device The device type, which can be any supported device type (e.g., base_device::DEVICE_CPU). + * + * @param in Pointer to the input data array in reciprocal space, of type std::complex<FPTYPE>*. + * @param out Pointer to the output data array in real space, of type FPTYPE*. + * @param add Optional boolean flag, default value false, indicating whether to add the result to the existing + * values in the output array. + * @param factor Optional scaling factor, default value 1.0, applied to the output values. + */ + template <typename TK, + typename TR, + typename Device, + typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value + && (std::is_same<TR, typename GetTypeReal<TK>::type>::value + || std::is_same<TR, TK>::value) + && std::is_same<Device, base_device::DEVICE_GPU>::value, + int>::type + = 0> + void recip_to_real(TK* in, + TR* out, const bool add = false, - const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) - - template <typename FPTYPE, typename Device = base_device::DEVICE_CPU> - void real_to_recip(FPTYPE* in, - std::complex<FPTYPE>* out, - const Device* ctx=get_default_device_ctx(), - const int ik=0, + const typename GetTypeReal<TK>::type factor = 1.0) const + { + this->recip2real_gpu(in, out, add, factor); + }; + + // template <typename FPTYPE, + // typename Device, + // typename std::enable_if<!std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type = 0> + // void recip_to_real(FPTYPE* in, + // FPTYPE* out, + // const bool add = false, + // const typename GetTypeReal<FPTYPE>::type factor = 1.0) const; + /** + * @brief Converts data from real space to reciprocal space (Fourier space). + * + * This function handles the conversion of data from real space to reciprocal space (typically after performing a + * Fourier transform), supporting scenarios where the input is of a fundamental type (e.g., float, double) and the + * output is of a complex type. + * + * @tparam FPTYPE The underlying fundamental type of the input data (e.g., float, double). + * SFINAE constraint ensures that FPTYPE is a fundamental type, not a complex type. + * @tparam Device The device type, which must be base_device::DEVICE_CPU. + * @tparam std::enable_if<std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type + * SFINAE constraint to ensure that FPTYPE is a fundamental type. + * @tparam std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type + * SFINAE constraint to ensure that Device is base_device::DEVICE_CPU. + * + * @param in Pointer to the input data array in real space, of type FPTYPE*. + * @param out Pointer to the output data array in reciprocal space, of type std::complex<FPTYPE>*. + * @param ik Optional parameter, default value 0, representing some index or identifier. + * @param add Optional boolean flag, default value false, indicating whether to add the result to the existing + * values in the output array. + * @param factor Optional scaling factor, default value 1.0, applied to the output values. + */ + template <typename TK, + typename TR, + typename Device, + typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value + && (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value) + && std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0> + void real_to_recip(TR* in, + TK* out, const bool add = false, - const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) - template <typename FPTYPE, typename Device = base_device::DEVICE_CPU> - void recip_to_real(const std::complex<FPTYPE>* in, - FPTYPE* out, - const Device* ctx=get_default_device_ctx() , - const int ik=0, + const typename GetTypeReal<TK>::type factor = 1.0) const + { + this->real2recip(in, out, add, factor); + } + + template <typename TK,typename TR, typename Device, + typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value + && (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value) + && !std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0> + void real_to_recip(TR* in, + TK* out, const bool add = false, - const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + const typename GetTypeReal<TK>::type factor = 1.0) const; + protected: //gather planes and scatter sticks of all processors template <typename T> @@ -325,6 +438,5 @@ class PW_Basis } #endif // PWBASIS_H - #include "pw_basis_sup.h" #include "pw_basis_big.h" //temporary it will be removed diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp index 38d4bfc49d..4f34221775 100644 --- a/source/module_basis/module_pw/pw_transform.cpp +++ b/source/module_basis/module_pw/pw_transform.cpp @@ -10,10 +10,10 @@ namespace ModulePW { - const base_device::DEVICE_CPU* PW_Basis::get_default_device_ctx() { - static const base_device::DEVICE_CPU* default_device_cpu; - return default_device_cpu; -} +// const base_device::DEVICE_CPU* PW_Basis::get_default_device_ctx() { +// static const base_device::DEVICE_CPU* default_device_cpu; +// return default_device_cpu; +// } /** * @brief transform real space to reciprocal space * @details c(g)=\int dr*f(r)*exp(-ig*r) @@ -284,229 +284,6 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo } ModuleBase::timer::tick(this->classname, "recip2real"); } - -template <> -void PW_Basis::real_to_recip(const std::complex<float>* in, - std::complex<float>* out, - const base_device::DEVICE_CPU* ctx, - const int ik, - const bool add, - const float factor) const -{ - this->real2recip(in, out, add, factor); -} -template <> -void PW_Basis::real_to_recip(const std::complex<double>* in, - std::complex<double>* out, - const base_device::DEVICE_CPU* ctx, - const int ik, - const bool add, - const double factor) const -{ - this->real2recip(in, out, add, factor); -} - -template <> -void PW_Basis::recip_to_real(const std::complex<float>* in, - std::complex<float>* out, - const base_device::DEVICE_CPU* ctx, - const int ik, - const bool add, - const float factor) const -{ - this->recip2real(in, out, add, factor); -} -template <> -void PW_Basis::recip_to_real(const std::complex<double>* in, - std::complex<double>* out, - const base_device::DEVICE_CPU* ctx, - const int ik, - const bool add, - const double factor) const -{ - this->recip2real(in, out, add, factor); -} - -template <> -void PW_Basis::real_to_recip(float* in, - std::complex<float>* out, - const base_device::DEVICE_CPU* ctx, - const int ik, - const bool add, - const float factor) const -{ - this->real2recip(in, out, add, factor); -} -template <> -void PW_Basis::real_to_recip(double* in, - std::complex<double>* out, - const base_device::DEVICE_CPU* ctx, - const int ik, - const bool add, - const double factor) const -{ - this->real2recip(in, out, add, factor); -} - -template <> -void PW_Basis::recip_to_real(const std::complex<float>* in, - float* out, - const base_device::DEVICE_CPU* ctx, - const int ik, - const bool add, - const float factor) const -{ - this->recip2real(in, out, add, factor); -} - -template <> -void PW_Basis::recip_to_real(const std::complex<double>* in, - double* out, - const base_device::DEVICE_CPU* ctx, - const int ik, - const bool add, - const double factor) const -{ - this->recip2real(in, out, add, factor); -} - -#if (defined(__CUDA) || defined(__ROCM)) -template <> -void PW_Basis::real_to_recip(const std::complex<float>* in, - std::complex<float>* out, - const base_device::DEVICE_GPU* ctx, - const int ik, - const bool add, - const float factor) const -{ - ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); - assert(this->gamma_only == false); - assert(this->poolnproc == 1); - - base_device::memory::synchronize_memory_op<std::complex<float>, base_device::DEVICE_GPU, base_device::DEVICE_GPU>()( - this->fft_bundle.get_auxr_3d_data<float>(), - in, - this->nrxx); - - this->fft_bundle.fft3D_forward(ctx, - this->fft_bundle.get_auxr_3d_data<float>(), - this->fft_bundle.get_auxr_3d_data<float>()); - - set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(ctx, - npw, - this->nxyz, - add, - factor, - this->ig2isz, - this->fft_bundle.get_auxr_3d_data<float>(), - out); - ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); -} -template <> -void PW_Basis::real_to_recip(const std::complex<double>* in, - std::complex<double>* out, - const base_device::DEVICE_GPU* ctx, - const int ik, - const bool add, - const double factor) const -{ - ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); - assert(this->gamma_only == false); - assert(this->poolnproc == 1); - - base_device::memory::synchronize_memory_op<std::complex<double>, - base_device::DEVICE_GPU, - base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<double>(), - in, - this->nrxx); - - this->fft_bundle.fft3D_forward(ctx, - this->fft_bundle.get_auxr_3d_data<double>(), - this->fft_bundle.get_auxr_3d_data<double>()); - - set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(ctx, - npw, - this->nxyz, - add, - factor, - this->ig2isz, - this->fft_bundle.get_auxr_3d_data<double>(), - out); - ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); -} - -template <> -void PW_Basis::recip_to_real(const std::complex<float>* in, - std::complex<float>* out, - const base_device::DEVICE_GPU* ctx, - const int ik, - const bool add, - const float factor) const -{ - ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); - assert(this->gamma_only == false); - assert(this->poolnproc == 1); - // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<float>(), this->nxyz); - base_device::memory::set_memory_op<std::complex<float>, base_device::DEVICE_GPU>()( - this->fft_bundle.get_auxr_3d_data<float>(), - 0, - this->nxyz); - - set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(ctx, - npw, - this->ig2isz, - in, - this->fft_bundle.get_auxr_3d_data<float>()); - this->fft_bundle.fft3D_backward(ctx, - this->fft_bundle.get_auxr_3d_data<float>(), - this->fft_bundle.get_auxr_3d_data<float>()); - - set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(ctx, - this->nrxx, - add, - factor, - this->fft_bundle.get_auxr_3d_data<float>(), - out); - - ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); -} -template <> -void PW_Basis::recip_to_real(const std::complex<double>* in, - std::complex<double>* out, - const base_device::DEVICE_GPU* ctx, - const int ik, - const bool add, - const double factor) const -{ - ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); - assert(this->gamma_only == false); - assert(this->poolnproc == 1); - // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<double>(), this->nxyz); - base_device::memory::set_memory_op<std::complex<double>, base_device::DEVICE_GPU>()( - this->fft_bundle.get_auxr_3d_data<double>(), - 0, - this->nxyz); - - set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(ctx, - npw, - this->ig2isz, - in, - this->fft_bundle.get_auxr_3d_data<double>()); - this->fft_bundle.fft3D_backward(ctx, - this->fft_bundle.get_auxr_3d_data<double>(), - this->fft_bundle.get_auxr_3d_data<double>()); - - set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(ctx, - this->nrxx, - add, - factor, - this->fft_bundle.get_auxr_3d_data<double>(), - out); - - ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); -} -#endif - template void PW_Basis::real2recip<float>(const float* in, std::complex<float>* out, const bool add, diff --git a/source/module_basis/module_pw/pw_transform_gpu.cpp b/source/module_basis/module_pw/pw_transform_gpu.cpp new file mode 100644 index 0000000000..c17e72ca50 --- /dev/null +++ b/source/module_basis/module_pw/pw_transform_gpu.cpp @@ -0,0 +1,173 @@ +#include "pw_basis.h" +#include "module_base/timer.h" +#include "module_basis/module_pw/kernels/pw_op.h" +namespace ModulePW +{ +#if (defined(__CUDA) || defined(__ROCM)) +template <typename FPTYPE> +void PW_Basis::real2recip_gpu(const FPTYPE* in, + std::complex<FPTYPE>* out, + const bool add, + const FPTYPE factor) const +{ + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + base_device::DEVICE_GPU* ctx; + // base_device::memory::synchronize_memory_op<std::complex<FPTYPE>, + // base_device::DEVICE_GPU, + // base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + // in, + // this->nrxx); + + this->fft_bundle.fft3D_forward(ctx, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + this->fft_bundle.get_auxr_3d_data<FPTYPE>()); + + set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, + npw, + this->nxyz, + add, + factor, + this->ig2isz, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + out); + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); +} +template <typename FPTYPE> +void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in, + std::complex<FPTYPE>* out, + const bool add, + const FPTYPE factor) const +{ + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + base_device::DEVICE_GPU* ctx; + base_device::memory::synchronize_memory_op<std::complex<FPTYPE>, + base_device::DEVICE_GPU, + base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + in, + this->nrxx); + + this->fft_bundle.fft3D_forward(ctx, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + this->fft_bundle.get_auxr_3d_data<FPTYPE>()); + + set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, + npw, + this->nxyz, + add, + factor, + this->ig2isz, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + out); + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); +} + +template <typename FPTYPE> +void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in, + FPTYPE* out, + const bool add, + const FPTYPE factor) const +{ + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + base_device::DEVICE_GPU* ctx; + // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<FPTYPE>(), this->nxyz); + base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()( + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + 0, + this->nxyz); + + set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, + npw, + this->ig2isz, + in, + this->fft_bundle.get_auxr_3d_data<FPTYPE>()); + this->fft_bundle.fft3D_backward(ctx, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + this->fft_bundle.get_auxr_3d_data<FPTYPE>()); + + set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, + this->nrxx, + add, + factor, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + out); + + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); +} +template <typename FPTYPE> + void PW_Basis::recip2real_gpu(const std::complex<FPTYPE> *in, + std::complex<FPTYPE> *out, + const bool add, + const FPTYPE factor) const +{ + base_device::DEVICE_GPU* ctx; + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<double>(), this->nxyz); + base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()( + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + 0, + this->nxyz); + + set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, + npw, + this->ig2isz, + in, + this->fft_bundle.get_auxr_3d_data<FPTYPE>()); + this->fft_bundle.fft3D_backward(ctx, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + this->fft_bundle.get_auxr_3d_data<FPTYPE>()); + + set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, + this->nrxx, + add, + factor, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + out); + + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); +} +template void PW_Basis::real2recip_gpu<double>(const double* in, + std::complex<double>* out, + const bool add, + const double factor) const; +template void PW_Basis::real2recip_gpu<float>(const float* in, + std::complex<float>* out, + const bool add, + const float factor) const; + +template void PW_Basis::real2recip_gpu<double>(const std::complex<double>* in, + std::complex<double>* out, + const bool add, + const double factor) const; +template void PW_Basis::real2recip_gpu<float>(const std::complex<float>* in, + std::complex<float>* out, + const bool add, + const float factor) const; + +template void PW_Basis::recip2real_gpu<double>(const std::complex<double>* in, + double* out, + const bool add, + const double factor) const; +template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in, + float* out, + const bool add, + const float factor) const; + +template void PW_Basis::recip2real_gpu<double>(const std::complex<double>* in, + std::complex<double>* out, + const bool add, + const double factor) const; +template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in, + std::complex<float>* out, + const bool add, + const float factor) const; + +#endif +} // namespace ModulePW \ No newline at end of file diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index e09a1a3dfa..ef7a168752 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -448,7 +448,7 @@ void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi) { for (int is = 0; is < PARAM.inp.nspin; is++) { - this->charge->rhopw->recip_to_real(this->rhog[is], this->rho[is]); + this->charge->rhopw->recip2real(this->rhog[is], this->rho[is]); } } } From b96bb777cf6ee8712a73b6126e70809fe9fd05f9 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Fri, 21 Mar 2025 16:22:14 +0800 Subject: [PATCH 3/8] remove ctx in the file --- .../module_pw/kernels/cuda/pw_op.cu | 15 +++++-------- .../module_basis/module_pw/kernels/pw_op.cpp | 12 ++++------ source/module_basis/module_pw/kernels/pw_op.h | 20 +++++------------ .../module_pw/kernels/rocm/pw_op.hip.cu | 15 +++++-------- .../module_pw/pw_transform_gpu.cpp | 22 +++++++------------ .../module_basis/module_pw/pw_transform_k.cpp | 18 +++++---------- .../module_pw/pw_transform_k_dsp.cpp | 10 ++++----- 7 files changed, 37 insertions(+), 75 deletions(-) diff --git a/source/module_basis/module_pw/kernels/cuda/pw_op.cu b/source/module_basis/module_pw/kernels/cuda/pw_op.cu index d9418ca486..1e0fce284e 100644 --- a/source/module_basis/module_pw/kernels/cuda/pw_op.cu +++ b/source/module_basis/module_pw/kernels/cuda/pw_op.cu @@ -100,8 +100,7 @@ __global__ void set_real_to_recip_output( } template <typename FPTYPE> -void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int npwk, +void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk, const int* box_index, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) @@ -117,8 +116,7 @@ void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_d } template <typename FPTYPE> -void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int nrxx, +void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx, const bool add, const FPTYPE factor, const std::complex<FPTYPE>* in, @@ -136,8 +134,7 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co } template <typename FPTYPE> -void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int nrxx, +void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx, const bool add, const FPTYPE factor, const std::complex<FPTYPE>* in, @@ -155,8 +152,7 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co } template <typename FPTYPE> -void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int npwk, +void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk, const int nxyz, const bool add, const FPTYPE factor, @@ -178,8 +174,7 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co } template <typename FPTYPE> -void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int npwk, +void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk, const int nxyz, const bool add, const FPTYPE factor, diff --git a/source/module_basis/module_pw/kernels/pw_op.cpp b/source/module_basis/module_pw/kernels/pw_op.cpp index e39e8a6a7c..a1bafae69b 100644 --- a/source/module_basis/module_pw/kernels/pw_op.cpp +++ b/source/module_basis/module_pw/kernels/pw_op.cpp @@ -5,8 +5,7 @@ namespace ModulePW { template <typename FPTYPE> struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU> { - void operator()(const base_device::DEVICE_CPU* /*dev*/, - const int npwk, + void operator()(const int npwk, const int* box_index, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) @@ -21,8 +20,7 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU> template <typename FPTYPE> struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU> { - void operator()(const base_device::DEVICE_CPU* /*dev*/, - const int nrxx, + void operator()(const int nrxx, const bool add, const FPTYPE factor, const std::complex<FPTYPE>* in, @@ -40,8 +38,7 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU> } } - void operator()(const base_device::DEVICE_CPU* /*dev*/, - const int nrxx, + void operator()(const int nrxx, const bool add, const FPTYPE factor, const std::complex<FPTYPE>* in, @@ -67,8 +64,7 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU> template <typename FPTYPE> struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU> { - void operator()(const base_device::DEVICE_CPU* /*dev*/, - const int npw_k, + void operator()(const int npw_k, const int nxyz, const bool add, const FPTYPE factor, diff --git a/source/module_basis/module_pw/kernels/pw_op.h b/source/module_basis/module_pw/kernels/pw_op.h index b547190606..6221c8b68c 100644 --- a/source/module_basis/module_pw/kernels/pw_op.h +++ b/source/module_basis/module_pw/kernels/pw_op.h @@ -19,7 +19,6 @@ struct set_3d_fft_box_op { /// Output Parameters /// @param out - output psi within the 3D box(in recip space) void operator() ( - const Device* dev, const int npwk, const int* box_index, const std::complex<FPTYPE>* in, @@ -39,7 +38,6 @@ struct set_recip_to_real_output_op { /// Output Parameters /// @param out - output psi within the 3D box(in real space) void operator() ( - const Device* dev, const int nrxx, const bool add, const FPTYPE factor, @@ -47,7 +45,6 @@ struct set_recip_to_real_output_op { std::complex<FPTYPE>* out); void operator() ( - const Device* dev, const int nrxx, const bool add, const FPTYPE factor, @@ -70,7 +67,6 @@ struct set_real_to_recip_output_op { /// Output Parameters /// @param out - output psi within the 3D box(in recip space) void operator() ( - const Device* dev, const int npw_k, const int nxyz, const bool add, @@ -80,7 +76,6 @@ struct set_real_to_recip_output_op { std::complex<FPTYPE>* out); void operator() ( - const Device* dev, const int npw_k, const int nxyz, const bool add, @@ -95,8 +90,7 @@ struct set_real_to_recip_output_op { template <typename FPTYPE> struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU> { - void operator()(const base_device::DEVICE_GPU* dev, - const int npwk, + void operator()(const int npwk, const int* box_index, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out); @@ -105,15 +99,13 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU> template <typename FPTYPE> struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU> { - void operator()(const base_device::DEVICE_GPU* dev, - const int nrxx, + void operator()(const int nrxx, const bool add, const FPTYPE factor, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out); - void operator()(const base_device::DEVICE_GPU* dev, - const int nrxx, + void operator()(const int nrxx, const bool add, const FPTYPE factor, const std::complex<FPTYPE>* in, @@ -123,16 +115,14 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU> template <typename FPTYPE> struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU> { - void operator()(const base_device::DEVICE_GPU* dev, - const int npw_k, + void operator()(const int npw_k, const int nxyz, const bool add, const FPTYPE factor, const int* box_index, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out); - void operator()(const base_device::DEVICE_GPU* dev, - const int npw_k, + void operator()(const int npw_k, const int nxyz, const bool add, const FPTYPE factor, diff --git a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu index 44be7d0671..51b50c6b8b 100644 --- a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu +++ b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu @@ -101,8 +101,7 @@ __global__ void set_real_to_recip_output( } template <typename FPTYPE> -void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int npwk, +void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk, const int* box_index, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) @@ -118,8 +117,7 @@ void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_d } template <typename FPTYPE> -void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int nrxx, +void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx, const bool add, const FPTYPE factor, const std::complex<FPTYPE>* in, @@ -137,8 +135,7 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co } template <typename FPTYPE> -void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int nrxx, +void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx, const bool add, const FPTYPE factor, const std::complex<FPTYPE>* in, @@ -156,8 +153,7 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co } template <typename FPTYPE> -void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int npwk, +void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk, const int nxyz, const bool add, const FPTYPE factor, @@ -179,8 +175,7 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co } template <typename FPTYPE> -void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/, - const int npwk, +void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk, const int nxyz, const bool add, const FPTYPE factor, diff --git a/source/module_basis/module_pw/pw_transform_gpu.cpp b/source/module_basis/module_pw/pw_transform_gpu.cpp index c17e72ca50..db89d2fbaf 100644 --- a/source/module_basis/module_pw/pw_transform_gpu.cpp +++ b/source/module_basis/module_pw/pw_transform_gpu.cpp @@ -3,7 +3,7 @@ #include "module_basis/module_pw/kernels/pw_op.h" namespace ModulePW { -#if (defined(__CUDA) || defined(__ROCM)) +// #if (defined(__CUDA) || defined(__ROCM)) template <typename FPTYPE> void PW_Basis::real2recip_gpu(const FPTYPE* in, std::complex<FPTYPE>* out, @@ -24,8 +24,7 @@ void PW_Basis::real2recip_gpu(const FPTYPE* in, this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); - set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, - npw, + set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw, this->nxyz, add, factor, @@ -54,8 +53,7 @@ void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in, this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); - set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, - npw, + set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw, this->nxyz, add, factor, @@ -81,8 +79,7 @@ void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in, 0, this->nxyz); - set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, - npw, + set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw, this->ig2isz, in, this->fft_bundle.get_auxr_3d_data<FPTYPE>()); @@ -90,8 +87,7 @@ void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in, this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); - set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, - this->nrxx, + set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx, add, factor, this->fft_bundle.get_auxr_3d_data<FPTYPE>(), @@ -115,8 +111,7 @@ template <typename FPTYPE> 0, this->nxyz); - set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, - npw, + set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw, this->ig2isz, in, this->fft_bundle.get_auxr_3d_data<FPTYPE>()); @@ -124,8 +119,7 @@ template <typename FPTYPE> this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); - set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx, - this->nrxx, + set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx, add, factor, this->fft_bundle.get_auxr_3d_data<FPTYPE>(), @@ -169,5 +163,5 @@ template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in, const bool add, const float factor) const; -#endif +// #endif } // namespace ModulePW \ No newline at end of file diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 3d75f07f6f..4e3a9ecc9d 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -361,8 +361,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; - set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(ctx, - npw_k, + set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(npw_k, this->nxyz, add, factor, @@ -393,8 +392,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; - set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(ctx, - npw_k, + set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(npw_k, this->nxyz, add, factor, @@ -424,15 +422,13 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; - set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(ctx, - npw_k, + set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(npw_k, this->ig2ixyz_k + startig, in, this->fft_bundle.get_auxr_3d_data<float>()); this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>()); - set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(ctx, - this->nrxx, + set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(this->nrxx, add, factor, this->fft_bundle.get_auxr_3d_data<float>(), @@ -460,15 +456,13 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; - set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(ctx, - npw_k, + set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(npw_k, this->ig2ixyz_k + startig, in, this->fft_bundle.get_auxr_3d_data<double>()); this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>()); - set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(ctx, - this->nrxx, + set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(this->nrxx, add, factor, this->fft_bundle.get_auxr_3d_data<double>(), diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp index b292e25f0a..1632a20c5e 100644 --- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp +++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp @@ -32,8 +32,7 @@ void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in, auxr); this->fft_bundle.resource_handler(0); // copy the result from the auxr to the out ,while consider the add - set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>()(ctx, - npw_k, + set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>()(npw_k, this->nxyz, add, factor, @@ -58,7 +57,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex<FPTYPE>* in, const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; // copy the mapping form the type of stick to the 3dfft - set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr); + set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr); // use 3d fft backward this->fft_bundle.resource_handler(1); this->fft_bundle.fft3D_backward(gpux, auxr, auxr); @@ -107,7 +106,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, const int npw_k = this->npwk[ik]; // copy the mapping form the type of stick to the 3dfft - set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr); + set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr); // use 3d fft backward this->fft_bundle.fft3D_backward(gpux, auxr, auxr); @@ -120,8 +119,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, // 3d fft this->fft_bundle.fft3D_forward(gpux, auxr, auxr); // copy the result from the auxr to the out ,while consider the add - set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(ctx, - npw_k, + set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(npw_k, this->nxyz, add, factor, From a105a3cd92f40778404fc5928daf2121d778625d Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Fri, 21 Mar 2025 16:42:04 +0800 Subject: [PATCH 4/8] remove ctx in bundle --- .../module_pw/kernels/test/pw_op_test.cpp | 12 ++++++------ .../module_basis/module_pw/module_fft/fft_bundle.cpp | 12 ++++-------- .../module_basis/module_pw/module_fft/fft_bundle.h | 8 ++++---- source/module_basis/module_pw/pw_transform_gpu.cpp | 12 ++++-------- source/module_basis/module_pw/pw_transform_k.cpp | 8 ++++---- source/module_basis/module_pw/pw_transform_k_dsp.cpp | 9 ++++----- 6 files changed, 26 insertions(+), 35 deletions(-) diff --git a/source/module_basis/module_pw/kernels/test/pw_op_test.cpp b/source/module_basis/module_pw/kernels/test/pw_op_test.cpp index 6adac4613f..e9bc07bd9e 100644 --- a/source/module_basis/module_pw/kernels/test/pw_op_test.cpp +++ b/source/module_basis/module_pw/kernels/test/pw_op_test.cpp @@ -72,7 +72,7 @@ class TestModulePWPWMultiDevice : public ::testing::Test TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_cpu) { std::vector<std::complex<double>> res(out_1.size(), std::complex<double>{0, 0}); - set_3d_fft_box_cpu_op()(cpu_ctx, this->npwk, box_index.data(), in_1.data(), res.data()); + set_3d_fft_box_cpu_op()(this->npwk, box_index.data(), in_1.data(), res.data()); for (int ii = 0; ii < this->nxyz; ii++) { EXPECT_LT(std::abs(res[ii] - out_1[ii]), 1e-12); } @@ -81,7 +81,7 @@ TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_cpu) TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_cpu) { std::vector<std::complex<double>> res(out_2.size(), std::complex<double>{0, 0}); - set_recip_to_real_output_cpu_op()(cpu_ctx, this->nxyz, this->add, this->factor, in_2.data(), res.data()); + set_recip_to_real_output_cpu_op()(this->nxyz, this->add, this->factor, in_2.data(), res.data()); for (int ii = 0; ii < this->nxyz; ii++) { EXPECT_LT(std::abs(res[ii] - out_2[ii]), 1e-12); } @@ -90,7 +90,7 @@ TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_cpu) TEST_F(TestModulePWPWMultiDevice, set_real_to_recip_output_op_cpu) { std::vector<std::complex<double>> res = out_3_init; - set_real_to_recip_output_cpu_op()(cpu_ctx, this->npwk, this->nxyz, true, this->factor, box_index.data(), in_3.data(), res.data()); + set_real_to_recip_output_cpu_op()(this->npwk, this->nxyz, true, this->factor, box_index.data(), in_3.data(), res.data()); for (int ii = 0; ii < out_3.size(); ii++) { EXPECT_LT(std::abs(res[ii] - out_3[ii]), 5e-6); } @@ -109,7 +109,7 @@ TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_gpu) synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size()); synchronize_memory_complex_h2d_op()(d_in_1, in_1.data(), in_1.size()); - set_3d_fft_box_gpu_op()(gpu_ctx, this->npwk, d_box_index, d_in_1, d_res); + set_3d_fft_box_gpu_op()(this->npwk, d_box_index, d_in_1, d_res); synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size()); @@ -130,7 +130,7 @@ TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_gpu) synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size()); synchronize_memory_complex_h2d_op()(d_in_2, in_2.data(), in_2.size()); - set_recip_to_real_output_gpu_op()(gpu_ctx, this->nxyz, this->add, this->factor, d_in_2, d_res); + set_recip_to_real_output_gpu_op()(this->nxyz, this->add, this->factor, d_in_2, d_res); synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size()); @@ -153,7 +153,7 @@ TEST_F(TestModulePWPWMultiDevice, set_real_to_recip_output_op_gpu) synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size()); synchronize_memory_complex_h2d_op()(d_in_3, in_3.data(), in_3.size()); - set_real_to_recip_output_gpu_op()(gpu_ctx, this->npwk, this->nxyz, true, this->factor, d_box_index, d_in_3, d_res); + set_real_to_recip_output_gpu_op()(this->npwk, this->nxyz, true, this->factor, d_box_index, d_in_3, d_res); synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size()); diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp index 7289e8ab02..270f985498 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp +++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp @@ -227,30 +227,26 @@ void FFT_Bundle::fftxyc2r(std::complex<double>* in, double* out) const } template <> -void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, - std::complex<float>* in, +void FFT_Bundle::fft3D_forward(std::complex<float>* in, std::complex<float>* out) const { fft_float->fft3D_forward(in, out); } template <> -void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx, - std::complex<double>* in, +void FFT_Bundle::fft3D_forward(std::complex<double>* in, std::complex<double>* out) const { fft_double->fft3D_forward(in, out); } template <> -void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, - std::complex<float>* in, +void FFT_Bundle::fft3D_backward(std::complex<float>* in, std::complex<float>* out) const { fft_float->fft3D_backward(in, out); } template <> -void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx, - std::complex<double>* in, +void FFT_Bundle::fft3D_backward(std::complex<double>* in, std::complex<double>* out) const { fft_double->fft3D_backward(in, out); diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h index 1982a79a0c..5b8b527cc4 100644 --- a/source/module_basis/module_pw/module_fft/fft_bundle.h +++ b/source/module_basis/module_pw/module_fft/fft_bundle.h @@ -188,10 +188,10 @@ class FFT_Bundle template <typename FPTYPE> void fftxyc2r(std::complex<FPTYPE>* in, FPTYPE* out) const; - template <typename FPTYPE, typename Device> - void fft3D_forward(const Device* ctx, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const; - template <typename FPTYPE, typename Device> - void fft3D_backward(const Device* ctx, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const; + template <typename FPTYPE> + void fft3D_forward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const; + template <typename FPTYPE> + void fft3D_backward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const; private: int fft_mode = 0; diff --git a/source/module_basis/module_pw/pw_transform_gpu.cpp b/source/module_basis/module_pw/pw_transform_gpu.cpp index db89d2fbaf..b5970a48a6 100644 --- a/source/module_basis/module_pw/pw_transform_gpu.cpp +++ b/source/module_basis/module_pw/pw_transform_gpu.cpp @@ -20,8 +20,7 @@ void PW_Basis::real2recip_gpu(const FPTYPE* in, // in, // this->nrxx); - this->fft_bundle.fft3D_forward(ctx, - this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw, @@ -49,8 +48,7 @@ void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in, in, this->nrxx); - this->fft_bundle.fft3D_forward(ctx, - this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw, @@ -83,8 +81,7 @@ void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in, this->ig2isz, in, this->fft_bundle.get_auxr_3d_data<FPTYPE>()); - this->fft_bundle.fft3D_backward(ctx, - this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx, @@ -115,8 +112,7 @@ template <typename FPTYPE> this->ig2isz, in, this->fft_bundle.get_auxr_3d_data<FPTYPE>()); - this->fft_bundle.fft3D_backward(ctx, - this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx, diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 4e3a9ecc9d..776b493d35 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -357,7 +357,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, in, this->nrxx); - this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>()); + this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>()); const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; @@ -388,7 +388,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx, in, this->nrxx); - this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>()); + this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>()); const int startig = ik * this->npwk_max; const int npw_k = this->npwk[ik]; @@ -426,7 +426,7 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, this->ig2ixyz_k + startig, in, this->fft_bundle.get_auxr_3d_data<float>()); - this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>()); + this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>()); set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(this->nrxx, add, @@ -460,7 +460,7 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, this->ig2ixyz_k + startig, in, this->fft_bundle.get_auxr_3d_data<double>()); - this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>()); + this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>()); set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(this->nrxx, add, diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp index 1632a20c5e..56ba26eb2a 100644 --- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp +++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp @@ -27,8 +27,7 @@ void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in, // 3d fft this->fft_bundle.resource_handler(1); - this->fft_bundle.fft3D_forward(gpux, - auxr, + this->fft_bundle.fft3D_forward(auxr, auxr); this->fft_bundle.resource_handler(0); // copy the result from the auxr to the out ,while consider the add @@ -60,7 +59,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex<FPTYPE>* in, set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr); // use 3d fft backward this->fft_bundle.resource_handler(1); - this->fft_bundle.fft3D_backward(gpux, auxr, auxr); + this->fft_bundle.fft3D_backward(auxr, auxr); this->fft_bundle.resource_handler(0); if (add) { @@ -109,7 +108,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr); // use 3d fft backward - this->fft_bundle.fft3D_backward(gpux, auxr, auxr); + this->fft_bundle.fft3D_backward(auxr, auxr); for (int ir = 0; ir < size; ir++) { @@ -117,7 +116,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx, } // 3d fft - this->fft_bundle.fft3D_forward(gpux, auxr, auxr); + this->fft_bundle.fft3D_forward(auxr, auxr); // copy the result from the auxr to the out ,while consider the add set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(npw_k, this->nxyz, From 977b713f5fb103056e152d7995120d3882051ed1 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Fri, 21 Mar 2025 16:49:35 +0800 Subject: [PATCH 5/8] update compile bug --- source/module_basis/module_pw/pw_transform_gpu.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/module_basis/module_pw/pw_transform_gpu.cpp b/source/module_basis/module_pw/pw_transform_gpu.cpp index b5970a48a6..1898ff23eb 100644 --- a/source/module_basis/module_pw/pw_transform_gpu.cpp +++ b/source/module_basis/module_pw/pw_transform_gpu.cpp @@ -3,7 +3,7 @@ #include "module_basis/module_pw/kernels/pw_op.h" namespace ModulePW { -// #if (defined(__CUDA) || defined(__ROCM)) +#if (defined(__CUDA) || defined(__ROCM)) template <typename FPTYPE> void PW_Basis::real2recip_gpu(const FPTYPE* in, std::complex<FPTYPE>* out, @@ -159,5 +159,5 @@ template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in, const bool add, const float factor) const; -// #endif +#endif } // namespace ModulePW \ No newline at end of file From ee48127a32eaae8c139373f69ca8582c88c6e0d5 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Sun, 23 Mar 2025 19:53:19 +0800 Subject: [PATCH 6/8] add T,device --- source/module_basis/module_pw/pw_basis_k.h | 72 ++++++++++++++- .../module_basis/module_pw/pw_transform_k.cpp | 89 +++++++++++++++++++ .../module_pw/test/test-other.cpp | 8 +- source/module_elecstate/elecstate_pw.cpp | 8 +- .../module_elecstate/elecstate_pw_cal_tau.cpp | 4 +- .../module_xc/test/xc3_mock.h | 18 ++-- .../module_xc/xc_functional_gradcorr.cpp | 3 +- .../hamilt_ofdft/ml_data_descriptor.cpp | 4 +- .../hamilt_pwdft/operator_pw/meta_pw.cpp | 4 +- .../hamilt_pwdft/operator_pw/op_exx_pw.cpp | 10 +-- .../hamilt_pwdft/operator_pw/veff_pw.cpp | 12 +-- .../hamilt_stodft/sto_iter.cpp | 2 +- source/module_io/get_pchg_pw.h | 4 +- 13 files changed, 195 insertions(+), 43 deletions(-) diff --git a/source/module_basis/module_pw/pw_basis_k.h b/source/module_basis/module_pw/pw_basis_k.h index ae5076bba9..ef578852ae 100644 --- a/source/module_basis/module_pw/pw_basis_k.h +++ b/source/module_basis/module_pw/pw_basis_k.h @@ -161,7 +161,7 @@ class PW_Basis_K : public PW_Basis #endif - template <typename FPTYPE, typename Device> + template <typename FPTYPE, typename Device> void real_to_recip(const Device* ctx, const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, @@ -176,6 +176,76 @@ class PW_Basis_K : public PW_Basis const bool add = false, const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + + template <typename TK, + typename Device, + typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0> + void real_to_recip(const TK* in, + TK* out, + const int ik, + const bool add = false, + const typename GetTypeReal<TK>::type factor = 1.0) const + { + #if defined(__DSP) + this->recip2real_dsp(in, out, ik, add, factor); + #else + this->real2recip(in,out,ik,add,factor); + #endif + } + template <typename TK, + typename Device, + typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0> + void recip_to_real(const TK* in, + TK* out, + const int ik, + const bool add = false, + const typename GetTypeReal<TK>::type factor = 1.0) const + { + + #if defined(__DSP) + this->recip2real_dsp(in,out,ik,add,factor); + #else + this->recip2real(in,out,ik,add,factor); + #endif + } + template <typename FPTYPE> + void real2recip_gpu(const std::complex<FPTYPE>* in, + std::complex<FPTYPE>* out, + const int ik, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny) ; out(nz, ns) + + template <typename FPTYPE> + void recip2real_gpu(const std::complex<FPTYPE>* in, + std::complex<FPTYPE>* out, + const int ik, + const bool add = false, + const FPTYPE factor = 1.0) const; // in:(nz, ns) ; out(nplane,nx*ny) + + template <typename FPTYPE, + typename Device, + typename std::enable_if<!std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0> + void real_to_recip(const FPTYPE* in, + FPTYPE* out, + const int ik, + const bool add = false, + const typename GetTypeReal<FPTYPE>::type factor = 1.0) const + { + this->real2recip_gpu(in, out, ik, add, factor); + } + + template <typename TK, + typename Device, + typename std::enable_if<std::is_same<Device, base_device::DEVICE_GPU>::value, int>::type = 0> + void recip_to_real(const TK* in, + TK* out, + const int ik, + const bool add = false, + const typename GetTypeReal<TK>::type factor = 1.0) const + { + this->recip2real_gpu(in, out, ik, add, factor); + } + public: //operator: //get (G+K)^2: diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp index 776b493d35..a709b60429 100644 --- a/source/module_basis/module_pw/pw_transform_k.cpp +++ b/source/module_basis/module_pw/pw_transform_k.cpp @@ -470,6 +470,95 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx, ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); } + +template <typename FPTYPE> +void PW_Basis_K::real2recip_gpu(const std::complex<FPTYPE>* in, + std::complex<FPTYPE>* out, + const int ik, + const bool add, + const FPTYPE factor) const +{ + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + + base_device::memory::synchronize_memory_op<std::complex<FPTYPE>, + base_device::DEVICE_GPU, + base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + in, + this->nrxx); + + this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); + + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k, + this->nxyz, + add, + factor, + this->ig2ixyz_k + startig, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + out); + ModuleBase::timer::tick(this->classname, "real_to_recip gpu"); +} +template <typename FPTYPE> +void PW_Basis_K::recip2real_gpu(const std::complex<FPTYPE>* in, + std::complex<FPTYPE>* out, + const int ik, + const bool add, + const FPTYPE factor) const +{ + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); + assert(this->gamma_only == false); + assert(this->poolnproc == 1); + // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<FPTYPE>(), this->nxyz); + base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()( + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + 0, + this->nxyz); + + const int startig = ik * this->npwk_max; + const int npw_k = this->npwk[ik]; + + set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k, + this->ig2ixyz_k + startig, + in, + this->fft_bundle.get_auxr_3d_data<FPTYPE>()); + this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>()); + + set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx, + add, + factor, + this->fft_bundle.get_auxr_3d_data<FPTYPE>(), + out); + + ModuleBase::timer::tick(this->classname, "recip_to_real gpu"); +} + +template void PW_Basis_K::real2recip_gpu<float>(const std::complex<float>*, + std::complex<float>*, + const int, + const bool, + const float) const; + +template void PW_Basis_K::real2recip_gpu<double>(const std::complex<double>*, + std::complex<double>*, + const int, + const bool, + const double) const; + +template void PW_Basis_K::recip2real_gpu<float>(const std::complex<float>*, + std::complex<float>*, + const int, + const bool, + const float) const; + +template void PW_Basis_K::recip2real_gpu<double>(const std::complex<double>*, + std::complex<double>*, + const int, + const bool, + const double) const; + #endif template void PW_Basis_K::real2recip<float>(const float* in, diff --git a/source/module_basis/module_pw/test/test-other.cpp b/source/module_basis/module_pw/test/test-other.cpp index 308f4c6c68..c5eb0093b2 100644 --- a/source/module_basis/module_pw/test/test-other.cpp +++ b/source/module_basis/module_pw/test/test-other.cpp @@ -76,26 +76,26 @@ TEST_F(PWTEST,test_other) } #endif - pwktest.recip_to_real(ctx, rhog1, rhor1, ik); + pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhog1, rhor1, ik); pwktest.recip2real(rhog2, rhor2, ik); for(int ir = 0 ; ir < nrxx; ++ir) { EXPECT_NEAR(std::abs(rhor1[ir]),std::abs(rhor2[ir]),1e-8); } - pwktest.real_to_recip(ctx, rhor1, rhog1, ik); + pwktest.real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(rhor1, rhog1, ik); pwktest.real2recip(rhor2, rhog2, ik); for(int ig = 0 ; ig < npwk; ++ig) { EXPECT_NEAR(std::abs(rhog1[ig]),std::abs(rhog2[ig]),1e-8); } #ifdef __ENABLE_FLOAT_FFTW - pwktest.recip_to_real(ctx, rhofg1, rhofr1, ik); + pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhofg1, rhofr1, ik); pwktest.recip2real(rhofg2, rhofr2, ik); for(int ir = 0 ; ir < nrxx; ++ir) { EXPECT_NEAR(std::abs(rhofr1[ir]),std::abs(rhofr2[ir]),1e-6); } - pwktest.real_to_recip(ctx, rhofr1, rhofg1, ik); + pwktest.real_to_recip<std::complex<float>,base_device::DEVICE_CPU>(rhofr1, rhofg1, ik); pwktest.real2recip(rhofr2, rhofg2, ik); for(int ig = 0 ; ig < npwk; ++ig) { diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index ef7a168752..76a4ce6c42 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -202,9 +202,9 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi) /// be care of when smearing_sigma is large, wg would less than 0 /// - this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); + this->basis->recip_to_real<T,Device>( &psi(ibnd,0), this->wfcr, ik); - this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik); + this->basis->recip_to_real<T,Device>( &psi(ibnd,npwx), this->wfcr_another_spin, ik); const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega); @@ -230,7 +230,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi) /// only occupied band should be calculated. /// - this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); + this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik); const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega); @@ -258,7 +258,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi) &psi(ibnd, 0), this->wfcr); - this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik); + this->basis->recip_to_real<T,Device>( this->wfcr, this->wfcr, ik); elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr); } diff --git a/source/module_elecstate/elecstate_pw_cal_tau.cpp b/source/module_elecstate/elecstate_pw_cal_tau.cpp index 628dd25aef..1cf4f57b70 100644 --- a/source/module_elecstate/elecstate_pw_cal_tau.cpp +++ b/source/module_elecstate/elecstate_pw_cal_tau.cpp @@ -23,7 +23,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi) int nbands = psi.get_nbands(); for (int ibnd = 0; ibnd < nbands; ibnd++) { - this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); + this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik); const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega); @@ -43,7 +43,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi) &psi(ibnd, 0), this->wfcr); - this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik); + this->basis->recip_to_real<T,Device>(this->wfcr, this->wfcr, ik); elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr); } diff --git a/source/module_hamilt_general/module_xc/test/xc3_mock.h b/source/module_hamilt_general/module_xc/test/xc3_mock.h index ffca76c70c..ebbc139d75 100644 --- a/source/module_hamilt_general/module_xc/test/xc3_mock.h +++ b/source/module_hamilt_general/module_xc/test/xc3_mock.h @@ -78,8 +78,7 @@ namespace ModulePW template <typename FPTYPE, typename Device> - void PW_Basis_K::real_to_recip(const Device* ctx, - const std::complex<FPTYPE>* in, + void PW_Basis_K::real_to_recip(const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int ik, const bool add, @@ -91,8 +90,7 @@ namespace ModulePW } } template <typename FPTYPE, typename Device> - void PW_Basis_K::recip_to_real(const Device* ctx, - const std::complex<FPTYPE>* in, + void PW_Basis_K::recip_to_real(const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int ik, const bool add, @@ -104,28 +102,24 @@ namespace ModulePW } } - template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx, - const std::complex<double>* in, + template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const std::complex<double>* in, std::complex<double>* out, const int ik, const bool add, const double factor) const; - template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx, - const std::complex<double>* in, + template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const std::complex<double>* in, std::complex<double>* out, const int ik, const bool add, const double factor) const; #if __CUDA || __ROCM - template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx, - const std::complex<double>* in, + template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const std::complex<double>* in, std::complex<double>* out, const int ik, const bool add, const double factor) const; - template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx, - const std::complex<double>* in, + template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const std::complex<double>* in, std::complex<double>* out, const int ik, const bool add, diff --git a/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp b/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp index 4250146510..b3cad37c44 100644 --- a/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp +++ b/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp @@ -644,8 +644,7 @@ void XC_Functional::grad_wfc( rhog, porter.data<T>()); // Array of std::complex<double> // bring the gdr from G --> R - Device * ctx = nullptr; - wfc_basis->recip_to_real(ctx, porter.data<T>(), porter.data<T>(), ik); + wfc_basis->recip_to_real<T,Device>(porter.data<T>(), porter.data<T>(), ik); xc_functional_grad_wfc_solver( ipol, wfc_basis->nrxx, // Integers diff --git a/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp b/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp index 85135d789f..e9ed4e7268 100644 --- a/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp +++ b/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp @@ -197,7 +197,7 @@ void ML_data::getF_KS1( continue; } - pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik); + pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik); const double w1 = pelec->wg(ik, ibnd) / ucell.omega; // output one wf, to check KS equation @@ -308,7 +308,7 @@ void ML_data::getF_KS2( continue; } - pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik); + pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik); const double w1 = pelec->wg(ik, ibnd) / ucell.omega; if (pelec->ekb(ik,ibnd) > epsilonM) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp index 817dba61ce..538819f452 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp @@ -67,13 +67,13 @@ void Meta<OperatorPW<T, Device>>::act( for (int j = 0; j < 3; j++) { meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), tmpsi_in, this->porter); - wfcpw->recip_to_real(this->ctx, this->porter, this->porter, this->ik); + wfcpw->recip_to_real<T,Device>(this->porter, this->porter, this->ik); if(this->vk_col != 0) { vector_mul_vector_op()(this->vk_col, this->porter, this->porter, this->vk + current_spin * this->vk_col); } - wfcpw->real_to_recip(this->ctx, this->porter, this->porter, this->ik); + wfcpw->real_to_recip<T,Device>(this->porter, this->porter, this->ik); meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), this->porter, tmhpsi, true); } // x,y,z directions diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp index 4b6d36908a..202800026c 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp @@ -188,7 +188,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands, { const T *psi_nk = tmpsi_in + n_iband * nbasis; // retrieve \psi_nk in real space - wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, this->ik); + wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, this->ik); // for \psi_nk, get the pw of iq and band m auto q_points = get_q_points(this->ik); @@ -208,7 +208,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands, // if (has_real.find({iq, m_iband}) == has_real.end()) // { const T* psi_mq = get_pw(m_iband, iq); - wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); + wfcpw->recip_to_real<T,Device>( psi_mq, psi_mq_real, iq); // syncmem_complex_op()(this->ctx, this->ctx, psi_all_real + m_iband * wfcpw->nrxx, psi_mq_real, wfcpw->nrxx); // has_real[{iq, m_iband}] = true; // } @@ -271,7 +271,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands, } // end of iq auto h_psi_nk = tmhpsi + n_iband * nbasis; Real hybrid_alpha = GlobalC::exx_info.info_global.hybrid_alpha; - wfcpw->real_to_recip(ctx, h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha); + wfcpw->real_to_recip<T,Device>(h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha); setmem_complex_op()(h_psi_real, 0, rhopw->nrxx); } @@ -810,7 +810,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c psi.fix_kb(ik, n_iband); const T* psi_nk = psi.get_pointer(); // retrieve \psi_nk in real space - wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, ik); + wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, ik); // for \psi_nk, get the pw of iq and band m // q_points is a vector of integers, 0 to nks-1 @@ -839,7 +839,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c psi_.fix_kb(iq, m_iband); const T* psi_mq = psi_.get_pointer(); // const T* psi_mq = get_pw(m_iband, iq); - wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); + wfcpw->recip_to_real<T,Device>(psi_mq, psi_mq_real, iq); T omega_inv = 1.0 / ucell->omega; diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp index 54e1a052be..c931d61f2f 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp @@ -62,7 +62,7 @@ void Veff<OperatorPW<T, Device>>::act( if (npol == 1) { // wfcpw->recip2real(tmpsi_in, porter, this->ik); - wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik); + wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik); // NOTICE: when MPI threads are larger than number of Z grids // veff would contain nothing, and nothing should be done in real space // but the 3DFFT can not be skipped, it will cause hanging @@ -76,7 +76,7 @@ void Veff<OperatorPW<T, Device>>::act( // } } // wfcpw->real2recip(porter, tmhpsi, this->ik, true); - wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true); + wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true); // wfcpw->convolution(this->ctx, // this->ik, // this->veff_col, @@ -89,8 +89,8 @@ void Veff<OperatorPW<T, Device>>::act( { // T *porter1 = new T[wfcpw->nmaxgr]; // fft to real space and doing things. - wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik); - wfcpw->recip_to_real(this->ctx, tmpsi_in + max_npw, this->porter1, this->ik); + wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik); + wfcpw->recip_to_real<T,Device>(tmpsi_in + max_npw, this->porter1, this->ik); if(this->veff_col != 0) { /// denghui added at 20221109 @@ -114,8 +114,8 @@ void Veff<OperatorPW<T, Device>>::act( // } } // (3) fft back to G space. - wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true); - wfcpw->real_to_recip(this->ctx, this->porter1, tmhpsi + max_npw, this->ik, true); + wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true); + wfcpw->real_to_recip<T,Device>(this->porter1, tmhpsi + max_npw, this->ik, true); } tmhpsi += max_npw * npol; tmpsi_in += max_npw * npol; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index 1134c7777b..e8eddca95f 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -642,7 +642,7 @@ void Stochastic_Iter<T, Device>::cal_storho(const UnitCell& ucell, T* tmpout = stowf.shchi->get_pointer(); for (int ichi = 0; ichi < nchip_ik; ++ichi) { - wfc_basis->recip_to_real(this->ctx, tmpout, porter, ik); + wfc_basis->recip_to_real<T,Device>(tmpout, porter, ik); const auto w1 = static_cast<Real>(this->pkv->wk[ik]); elecstate::elecstate_pw_op<Real, Device>()(this->ctx, current_spin, nrxx, w1, pes->rho, porter); // for (int ir = 0; ir < nrxx; ++ir) diff --git a/source/module_io/get_pchg_pw.h b/source/module_io/get_pchg_pw.h index 58f82f4a9c..7fd1eff9ad 100644 --- a/source/module_io/get_pchg_pw.h +++ b/source/module_io/get_pchg_pw.h @@ -96,7 +96,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print, << ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl; psi->fix_k(ik); - pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik); + pw_wfc->recip_to_real<std::complex<double>,Device>(&psi[0](ib, 0), wfcr.data(), ik); // To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true) double wg_sum_k = 0; @@ -139,7 +139,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print, << ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl; psi->fix_k(ik); - pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik); + pw_wfc->recip_to_real<std::complex<double>,Device>( &psi[0](ib, 0), wfcr.data(), ik); double w1 = static_cast<double>(wk[ik] / ucell->omega); From 3301b52570dde9888b2bc119c875c8fbc5da8e49 Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Sun, 23 Mar 2025 20:06:22 +0800 Subject: [PATCH 7/8] moidfy back the func --- .../module_basis/module_pw/test/test-other.cpp | 8 ++++---- source/module_elecstate/elecstate_pw.cpp | 8 ++++---- .../module_elecstate/elecstate_pw_cal_tau.cpp | 4 ++-- .../module_xc/test/xc3_mock.h | 18 ++++++++++++------ .../module_xc/xc_functional_gradcorr.cpp | 3 ++- .../hamilt_ofdft/ml_data_descriptor.cpp | 4 ++-- .../hamilt_pwdft/operator_pw/meta_pw.cpp | 4 ++-- .../hamilt_pwdft/operator_pw/op_exx_pw.cpp | 10 +++++----- .../hamilt_stodft/sto_iter.cpp | 2 +- source/module_io/get_pchg_pw.h | 4 ++-- 10 files changed, 36 insertions(+), 29 deletions(-) diff --git a/source/module_basis/module_pw/test/test-other.cpp b/source/module_basis/module_pw/test/test-other.cpp index c5eb0093b2..308f4c6c68 100644 --- a/source/module_basis/module_pw/test/test-other.cpp +++ b/source/module_basis/module_pw/test/test-other.cpp @@ -76,26 +76,26 @@ TEST_F(PWTEST,test_other) } #endif - pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhog1, rhor1, ik); + pwktest.recip_to_real(ctx, rhog1, rhor1, ik); pwktest.recip2real(rhog2, rhor2, ik); for(int ir = 0 ; ir < nrxx; ++ir) { EXPECT_NEAR(std::abs(rhor1[ir]),std::abs(rhor2[ir]),1e-8); } - pwktest.real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(rhor1, rhog1, ik); + pwktest.real_to_recip(ctx, rhor1, rhog1, ik); pwktest.real2recip(rhor2, rhog2, ik); for(int ig = 0 ; ig < npwk; ++ig) { EXPECT_NEAR(std::abs(rhog1[ig]),std::abs(rhog2[ig]),1e-8); } #ifdef __ENABLE_FLOAT_FFTW - pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhofg1, rhofr1, ik); + pwktest.recip_to_real(ctx, rhofg1, rhofr1, ik); pwktest.recip2real(rhofg2, rhofr2, ik); for(int ir = 0 ; ir < nrxx; ++ir) { EXPECT_NEAR(std::abs(rhofr1[ir]),std::abs(rhofr2[ir]),1e-6); } - pwktest.real_to_recip<std::complex<float>,base_device::DEVICE_CPU>(rhofr1, rhofg1, ik); + pwktest.real_to_recip(ctx, rhofr1, rhofg1, ik); pwktest.real2recip(rhofr2, rhofg2, ik); for(int ig = 0 ; ig < npwk; ++ig) { diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp index 76a4ce6c42..ef7a168752 100644 --- a/source/module_elecstate/elecstate_pw.cpp +++ b/source/module_elecstate/elecstate_pw.cpp @@ -202,9 +202,9 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi) /// be care of when smearing_sigma is large, wg would less than 0 /// - this->basis->recip_to_real<T,Device>( &psi(ibnd,0), this->wfcr, ik); + this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); - this->basis->recip_to_real<T,Device>( &psi(ibnd,npwx), this->wfcr_another_spin, ik); + this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik); const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega); @@ -230,7 +230,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi) /// only occupied band should be calculated. /// - this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik); + this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega); @@ -258,7 +258,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi) &psi(ibnd, 0), this->wfcr); - this->basis->recip_to_real<T,Device>( this->wfcr, this->wfcr, ik); + this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik); elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr); } diff --git a/source/module_elecstate/elecstate_pw_cal_tau.cpp b/source/module_elecstate/elecstate_pw_cal_tau.cpp index 1cf4f57b70..628dd25aef 100644 --- a/source/module_elecstate/elecstate_pw_cal_tau.cpp +++ b/source/module_elecstate/elecstate_pw_cal_tau.cpp @@ -23,7 +23,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi) int nbands = psi.get_nbands(); for (int ibnd = 0; ibnd < nbands; ibnd++) { - this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik); + this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik); const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega); @@ -43,7 +43,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi) &psi(ibnd, 0), this->wfcr); - this->basis->recip_to_real<T,Device>(this->wfcr, this->wfcr, ik); + this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik); elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr); } diff --git a/source/module_hamilt_general/module_xc/test/xc3_mock.h b/source/module_hamilt_general/module_xc/test/xc3_mock.h index ebbc139d75..ffca76c70c 100644 --- a/source/module_hamilt_general/module_xc/test/xc3_mock.h +++ b/source/module_hamilt_general/module_xc/test/xc3_mock.h @@ -78,7 +78,8 @@ namespace ModulePW template <typename FPTYPE, typename Device> - void PW_Basis_K::real_to_recip(const std::complex<FPTYPE>* in, + void PW_Basis_K::real_to_recip(const Device* ctx, + const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int ik, const bool add, @@ -90,7 +91,8 @@ namespace ModulePW } } template <typename FPTYPE, typename Device> - void PW_Basis_K::recip_to_real(const std::complex<FPTYPE>* in, + void PW_Basis_K::recip_to_real(const Device* ctx, + const std::complex<FPTYPE>* in, std::complex<FPTYPE>* out, const int ik, const bool add, @@ -102,24 +104,28 @@ namespace ModulePW } } - template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const std::complex<double>* in, + template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx, + const std::complex<double>* in, std::complex<double>* out, const int ik, const bool add, const double factor) const; - template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const std::complex<double>* in, + template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx, + const std::complex<double>* in, std::complex<double>* out, const int ik, const bool add, const double factor) const; #if __CUDA || __ROCM - template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const std::complex<double>* in, + template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx, + const std::complex<double>* in, std::complex<double>* out, const int ik, const bool add, const double factor) const; - template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const std::complex<double>* in, + template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx, + const std::complex<double>* in, std::complex<double>* out, const int ik, const bool add, diff --git a/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp b/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp index b3cad37c44..4250146510 100644 --- a/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp +++ b/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp @@ -644,7 +644,8 @@ void XC_Functional::grad_wfc( rhog, porter.data<T>()); // Array of std::complex<double> // bring the gdr from G --> R - wfc_basis->recip_to_real<T,Device>(porter.data<T>(), porter.data<T>(), ik); + Device * ctx = nullptr; + wfc_basis->recip_to_real(ctx, porter.data<T>(), porter.data<T>(), ik); xc_functional_grad_wfc_solver( ipol, wfc_basis->nrxx, // Integers diff --git a/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp b/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp index e9ed4e7268..85135d789f 100644 --- a/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp +++ b/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp @@ -197,7 +197,7 @@ void ML_data::getF_KS1( continue; } - pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik); + pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik); const double w1 = pelec->wg(ik, ibnd) / ucell.omega; // output one wf, to check KS equation @@ -308,7 +308,7 @@ void ML_data::getF_KS2( continue; } - pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik); + pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik); const double w1 = pelec->wg(ik, ibnd) / ucell.omega; if (pelec->ekb(ik,ibnd) > epsilonM) diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp index 538819f452..817dba61ce 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp @@ -67,13 +67,13 @@ void Meta<OperatorPW<T, Device>>::act( for (int j = 0; j < 3; j++) { meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), tmpsi_in, this->porter); - wfcpw->recip_to_real<T,Device>(this->porter, this->porter, this->ik); + wfcpw->recip_to_real(this->ctx, this->porter, this->porter, this->ik); if(this->vk_col != 0) { vector_mul_vector_op()(this->vk_col, this->porter, this->porter, this->vk + current_spin * this->vk_col); } - wfcpw->real_to_recip<T,Device>(this->porter, this->porter, this->ik); + wfcpw->real_to_recip(this->ctx, this->porter, this->porter, this->ik); meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), this->porter, tmhpsi, true); } // x,y,z directions diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp index 202800026c..4b6d36908a 100644 --- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp +++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp @@ -188,7 +188,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands, { const T *psi_nk = tmpsi_in + n_iband * nbasis; // retrieve \psi_nk in real space - wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, this->ik); + wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, this->ik); // for \psi_nk, get the pw of iq and band m auto q_points = get_q_points(this->ik); @@ -208,7 +208,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands, // if (has_real.find({iq, m_iband}) == has_real.end()) // { const T* psi_mq = get_pw(m_iband, iq); - wfcpw->recip_to_real<T,Device>( psi_mq, psi_mq_real, iq); + wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); // syncmem_complex_op()(this->ctx, this->ctx, psi_all_real + m_iband * wfcpw->nrxx, psi_mq_real, wfcpw->nrxx); // has_real[{iq, m_iband}] = true; // } @@ -271,7 +271,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands, } // end of iq auto h_psi_nk = tmhpsi + n_iband * nbasis; Real hybrid_alpha = GlobalC::exx_info.info_global.hybrid_alpha; - wfcpw->real_to_recip<T,Device>(h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha); + wfcpw->real_to_recip(ctx, h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha); setmem_complex_op()(h_psi_real, 0, rhopw->nrxx); } @@ -810,7 +810,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c psi.fix_kb(ik, n_iband); const T* psi_nk = psi.get_pointer(); // retrieve \psi_nk in real space - wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, ik); + wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, ik); // for \psi_nk, get the pw of iq and band m // q_points is a vector of integers, 0 to nks-1 @@ -839,7 +839,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c psi_.fix_kb(iq, m_iband); const T* psi_mq = psi_.get_pointer(); // const T* psi_mq = get_pw(m_iband, iq); - wfcpw->recip_to_real<T,Device>(psi_mq, psi_mq_real, iq); + wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq); T omega_inv = 1.0 / ucell->omega; diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp index e8eddca95f..1134c7777b 100644 --- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp +++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp @@ -642,7 +642,7 @@ void Stochastic_Iter<T, Device>::cal_storho(const UnitCell& ucell, T* tmpout = stowf.shchi->get_pointer(); for (int ichi = 0; ichi < nchip_ik; ++ichi) { - wfc_basis->recip_to_real<T,Device>(tmpout, porter, ik); + wfc_basis->recip_to_real(this->ctx, tmpout, porter, ik); const auto w1 = static_cast<Real>(this->pkv->wk[ik]); elecstate::elecstate_pw_op<Real, Device>()(this->ctx, current_spin, nrxx, w1, pes->rho, porter); // for (int ir = 0; ir < nrxx; ++ir) diff --git a/source/module_io/get_pchg_pw.h b/source/module_io/get_pchg_pw.h index 7fd1eff9ad..58f82f4a9c 100644 --- a/source/module_io/get_pchg_pw.h +++ b/source/module_io/get_pchg_pw.h @@ -96,7 +96,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print, << ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl; psi->fix_k(ik); - pw_wfc->recip_to_real<std::complex<double>,Device>(&psi[0](ib, 0), wfcr.data(), ik); + pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik); // To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true) double wg_sum_k = 0; @@ -139,7 +139,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print, << ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl; psi->fix_k(ik); - pw_wfc->recip_to_real<std::complex<double>,Device>( &psi[0](ib, 0), wfcr.data(), ik); + pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik); double w1 = static_cast<double>(wk[ik] / ucell->omega); From 283c6745c162fc98c1f9af27131cf092350f0b4f Mon Sep 17 00:00:00 2001 From: ubuntu <3158793232@qq.com> Date: Mon, 24 Mar 2025 14:53:22 +0800 Subject: [PATCH 8/8] update func --- source/module_basis/module_pw/pw_basis.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h index 52a59d2fb3..8bc45518a2 100644 --- a/source/module_basis/module_pw/pw_basis.h +++ b/source/module_basis/module_pw/pw_basis.h @@ -352,10 +352,7 @@ class PW_Basis void recip_to_real(TK* in, TR* out, const bool add = false, - const typename GetTypeReal<TK>::type factor = 1.0) const - { - this->recip2real_gpu(in, out, add, factor); - }; + const typename GetTypeReal<TK>::type factor = 1.0) const; // template <typename FPTYPE, // typename Device,