From bdbb70b0b9ef15f0a4f38a074c078d0cd2e9c102 Mon Sep 17 00:00:00 2001
From: ubuntu <3158793232@qq.com>
Date: Wed, 19 Mar 2025 09:40:14 +0800
Subject: [PATCH 1/8] update recip_to_real in the rhpw

---
 source/module_basis/module_pw/pw_basis.h      |  36 +-
 .../module_basis/module_pw/pw_transform.cpp   | 361 ++++++++++++++----
 source/module_elecstate/elecstate_pw.cpp      |   3 +-
 3 files changed, 329 insertions(+), 71 deletions(-)

diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h
index 47827000eb..4f1ad9a1ee 100644
--- a/source/module_basis/module_pw/pw_basis.h
+++ b/source/module_basis/module_pw/pw_basis.h
@@ -267,6 +267,35 @@ class PW_Basis
                     const bool add = false,
                     const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
 
+    template <typename FPTYPE, typename Device>
+    void real_to_recip(const std::complex<FPTYPE>* in,
+                       std::complex<FPTYPE>* out,
+                       const Device* ctx=get_default_device_ctx(),
+                       const int ik=0,
+                       const bool add = false,
+                       const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
+    template <typename FPTYPE, typename Device>
+    void recip_to_real(const std::complex<FPTYPE>* in,
+                       std::complex<FPTYPE>* out,
+                       const Device* ctx=get_default_device_ctx(),
+                       const int ik=0,
+                       const bool add = false,
+                       const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
+
+    template <typename FPTYPE, typename Device = base_device::DEVICE_CPU>
+    void real_to_recip(FPTYPE* in,
+                       std::complex<FPTYPE>* out,
+                       const Device* ctx=get_default_device_ctx(),
+                       const int ik=0,
+                       const bool add = false,
+                       const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
+    template <typename FPTYPE, typename Device = base_device::DEVICE_CPU>
+    void recip_to_real(const std::complex<FPTYPE>* in,
+                       FPTYPE* out,
+                       const Device* ctx=get_default_device_ctx() ,
+                       const int ik=0,
+                       const bool add = false,
+                       const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
   protected:
     //gather planes and scatter sticks of all processors
     template <typename T>
@@ -282,15 +311,16 @@ class PW_Basis
 
     using resmem_int_op = base_device::memory::resize_memory_op<int, base_device::DEVICE_GPU>;
     using delmem_int_op = base_device::memory::delete_memory_op<int, base_device::DEVICE_GPU>;
-    using syncmem_int_h2d_op
-        = base_device::memory::synchronize_memory_op<int, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
-
+    using syncmem_int_h2d_op = base_device::memory::synchronize_memory_op<int, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
+    // using default_device_cpu = base_device::DEVICE_CPU;
+    
     void set_device(std::string device_);
     void set_precision(std::string precision_);
 
 protected:
     std::string device = "cpu";
     std::string precision = "double";
+    static const base_device::DEVICE_CPU* get_default_device_ctx(); 
 };
 
 }
diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp
index 8e458b2561..38d4bfc49d 100644
--- a/source/module_basis/module_pw/pw_transform.cpp
+++ b/source/module_basis/module_pw/pw_transform.cpp
@@ -1,12 +1,19 @@
-#include "module_fft/fft_bundle.h"
-#include <complex>
-#include "pw_basis.h"
-#include <cassert>
 #include "module_base/global_function.h"
 #include "module_base/timer.h"
+#include "module_basis/module_pw/kernels/pw_op.h"
+#include "module_fft/fft_bundle.h"
+#include "pw_basis.h"
 #include "pw_gatherscatter.h"
 
-namespace ModulePW {
+#include <cassert>
+#include <complex>
+
+namespace ModulePW
+{
+    const base_device::DEVICE_CPU* PW_Basis::get_default_device_ctx() {
+        static const base_device::DEVICE_CPU* default_device_cpu;
+    return default_device_cpu;
+}
 /**
  * @brief transform real space to reciprocal space
  * @details c(g)=\int dr*f(r)*exp(-ig*r)
@@ -24,25 +31,25 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in,
 
     assert(this->gamma_only == false);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int ir = 0 ; ir < this->nrxx ; ++ir)
+    for (int ir = 0; ir < this->nrxx; ++ir)
     {
         this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = in[ir];
     }
-    this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
+    this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
 
     this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
-    
-    this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(),fft_bundle.get_auxg_data<FPTYPE>());
 
-    if(add)
+    this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>());
+
+    if (add)
     {
         FPTYPE tmpfac = factor / FPTYPE(this->nxyz);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ig = 0 ; ig < this->npw ; ++ig)
+        for (int ig = 0; ig < this->npw; ++ig)
         {
             out[ig] += tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]];
         }
@@ -51,9 +58,9 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in,
     {
         FPTYPE tmpfac = 1.0 / FPTYPE(this->nxyz);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ig = 0 ; ig < this->npw ; ++ig)
+        for (int ig = 0; ig < this->npw; ++ig)
         {
             out[ig] = tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]];
         }
@@ -72,44 +79,44 @@ template <typename FPTYPE>
 void PW_Basis::real2recip(const FPTYPE* in, std::complex<FPTYPE>* out, const bool add, const FPTYPE factor) const
 {
     ModuleBase::timer::tick(this->classname, "real2recip");
-    if(this->gamma_only)
+    if (this->gamma_only)
     {
         const int npy = this->ny * this->nplane;
 #ifdef _OPENMP
-#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ix = 0 ; ix < this->nx ; ++ix)
+        for (int ix = 0; ix < this->nx; ++ix)
         {
-            for(int ipy = 0 ; ipy < npy ; ++ipy)
+            for (int ipy = 0; ipy < npy; ++ipy)
             {
-                this->fft_bundle.get_rspace_data<FPTYPE>()[ix*npy + ipy] = in[ix*npy + ipy];
+                this->fft_bundle.get_rspace_data<FPTYPE>()[ix * npy + ipy] = in[ix * npy + ipy];
             }
         }
 
-        this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
+        this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
     }
     else
     {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ir = 0 ; ir < this->nrxx ; ++ir)
+        for (int ir = 0; ir < this->nrxx; ++ir)
         {
-            this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = std::complex<FPTYPE>(in[ir],0);
+            this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = std::complex<FPTYPE>(in[ir], 0);
         }
-        this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
+        this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
     }
     this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
-    
-    this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(),fft_bundle.get_auxg_data<FPTYPE>());
 
-    if(add)
+    this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>());
+
+    if (add)
     {
         FPTYPE tmpfac = factor / FPTYPE(this->nxyz);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ig = 0 ; ig < this->npw ; ++ig)
+        for (int ig = 0; ig < this->npw; ++ig)
         {
             out[ig] += tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]];
         }
@@ -118,9 +125,9 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex<FPTYPE>* out, const boo
     {
         FPTYPE tmpfac = 1.0 / FPTYPE(this->nxyz);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ig = 0 ; ig < this->npw ; ++ig)
+        for (int ig = 0; ig < this->npw; ++ig)
         {
             out[ig] = tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]];
         }
@@ -144,32 +151,32 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in,
     ModuleBase::timer::tick(this->classname, "recip2real");
     assert(this->gamma_only == false);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int i = 0 ; i < this->nst * this->nz ; ++i)
+    for (int i = 0; i < this->nst * this->nz; ++i)
     {
         fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<FPTYPE>(0, 0);
     }
 
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int ig = 0 ; ig < this->npw ; ++ig)
+    for (int ig = 0; ig < this->npw; ++ig)
     {
         this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]] = in[ig];
     }
     this->fft_bundle.fftzbac(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>());
 
-    this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(),this->fft_bundle.get_auxr_data<FPTYPE>());
+    this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(), this->fft_bundle.get_auxr_data<FPTYPE>());
+
+    this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
 
-    this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
-    
-    if(add)
+    if (add)
     {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ir = 0 ; ir < this->nrxx ; ++ir)
+        for (int ir = 0; ir < this->nrxx; ++ir)
         {
             out[ir] += factor * this->fft_bundle.get_auxr_data<FPTYPE>()[ir];
         }
@@ -177,9 +184,9 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in,
     else
     {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ir = 0 ; ir < this->nrxx ; ++ir)
+        for (int ir = 0; ir < this->nrxx; ++ir)
         {
             out[ir] = this->fft_bundle.get_auxr_data<FPTYPE>()[ir];
         }
@@ -199,17 +206,17 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
 {
     ModuleBase::timer::tick(this->classname, "recip2real");
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int i = 0 ; i < this->nst * this->nz ; ++i)
+    for (int i = 0; i < this->nst * this->nz; ++i)
     {
         fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<double>(0, 0);
     }
 
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int ig = 0 ; ig < this->npw ; ++ig)
+    for (int ig = 0; ig < this->npw; ++ig)
     {
         this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]] = in[ig];
     }
@@ -217,49 +224,49 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
 
     this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(), this->fft_bundle.get_auxr_data<FPTYPE>());
 
-    if(this->gamma_only)
+    if (this->gamma_only)
     {
-        this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_rspace_data<FPTYPE>());
+        this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_rspace_data<FPTYPE>());
 
         // r2c in place
         const int npy = this->ny * this->nplane;
 
-        if(add)
+        if (add)
         {
 #ifdef _OPENMP
-#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-            for(int ix = 0 ; ix < this->nx ; ++ix)
+            for (int ix = 0; ix < this->nx; ++ix)
             {
-                for(int ipy = 0 ; ipy < npy ; ++ipy)
+                for (int ipy = 0; ipy < npy; ++ipy)
                 {
-                    out[ix*npy + ipy] += factor * this->fft_bundle.get_rspace_data<FPTYPE>()[ix*npy + ipy];
+                    out[ix * npy + ipy] += factor * this->fft_bundle.get_rspace_data<FPTYPE>()[ix * npy + ipy];
                 }
             }
         }
         else
         {
 #ifdef _OPENMP
-#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-            for(int ix = 0 ; ix < this->nx ; ++ix)
+            for (int ix = 0; ix < this->nx; ++ix)
             {
-                for(int ipy = 0 ; ipy < npy ; ++ipy)
+                for (int ipy = 0; ipy < npy; ++ipy)
                 {
-                    out[ix*npy + ipy] = this->fft_bundle.get_rspace_data<FPTYPE>()[ix*npy + ipy];
+                    out[ix * npy + ipy] = this->fft_bundle.get_rspace_data<FPTYPE>()[ix * npy + ipy];
                 }
             }
         }
     }
     else
     {
-        this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
-        if(add)
+        this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
+        if (add)
         {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-            for(int ir = 0 ; ir < this->nrxx ; ++ir)
+            for (int ir = 0; ir < this->nrxx; ++ir)
             {
                 out[ir] += factor * this->fft_bundle.get_auxr_data<FPTYPE>()[ir].real();
             }
@@ -267,9 +274,9 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
         else
         {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-            for(int ir = 0 ; ir < this->nrxx ; ++ir)
+            for (int ir = 0; ir < this->nrxx; ++ir)
             {
                 out[ir] = this->fft_bundle.get_auxr_data<FPTYPE>()[ir].real();
             }
@@ -278,6 +285,228 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
     ModuleBase::timer::tick(this->classname, "recip2real");
 }
 
+template <>
+void PW_Basis::real_to_recip(const std::complex<float>* in,
+                             std::complex<float>* out,
+                             const base_device::DEVICE_CPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const float factor) const
+{
+    this->real2recip(in, out, add, factor);
+}
+template <>
+void PW_Basis::real_to_recip(const std::complex<double>* in,
+                             std::complex<double>* out,
+                             const base_device::DEVICE_CPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const double factor) const
+{
+    this->real2recip(in, out, add, factor);
+}
+
+template <>
+void PW_Basis::recip_to_real(const std::complex<float>* in,
+                             std::complex<float>* out,
+                             const base_device::DEVICE_CPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const float factor) const
+{
+    this->recip2real(in, out, add, factor);
+}
+template <>
+void PW_Basis::recip_to_real(const std::complex<double>* in,
+                             std::complex<double>* out,
+                             const base_device::DEVICE_CPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const double factor) const
+{
+    this->recip2real(in, out, add, factor);
+}
+
+template <>
+void PW_Basis::real_to_recip(float* in,
+                             std::complex<float>* out,
+                             const base_device::DEVICE_CPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const float factor) const
+{
+    this->real2recip(in, out, add, factor);
+}
+template <>
+void PW_Basis::real_to_recip(double* in,
+                             std::complex<double>* out,
+                             const base_device::DEVICE_CPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const double factor) const
+{
+    this->real2recip(in, out, add, factor);
+}
+
+template <>
+void PW_Basis::recip_to_real(const std::complex<float>* in,
+                             float* out,
+                             const base_device::DEVICE_CPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const float factor) const
+{
+    this->recip2real(in, out, add, factor);
+}
+
+template <>
+void PW_Basis::recip_to_real(const std::complex<double>* in,
+                             double* out,
+                             const base_device::DEVICE_CPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const double factor) const
+{
+    this->recip2real(in, out, add, factor);
+}
+
+#if (defined(__CUDA) || defined(__ROCM))
+template <>
+void PW_Basis::real_to_recip(const std::complex<float>* in,
+                             std::complex<float>* out,
+                             const base_device::DEVICE_GPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const float factor) const
+{
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+
+    base_device::memory::synchronize_memory_op<std::complex<float>, base_device::DEVICE_GPU, base_device::DEVICE_GPU>()(
+        this->fft_bundle.get_auxr_3d_data<float>(),
+        in,
+        this->nrxx);
+
+    this->fft_bundle.fft3D_forward(ctx,
+                                   this->fft_bundle.get_auxr_3d_data<float>(),
+                                   this->fft_bundle.get_auxr_3d_data<float>());
+
+    set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(ctx,
+                                                                  npw,
+                                                                  this->nxyz,
+                                                                  add,
+                                                                  factor,
+                                                                  this->ig2isz,
+                                                                  this->fft_bundle.get_auxr_3d_data<float>(),
+                                                                  out);
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+}
+template <>
+void PW_Basis::real_to_recip(const std::complex<double>* in,
+                             std::complex<double>* out,
+                             const base_device::DEVICE_GPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const double factor) const
+{
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+
+    base_device::memory::synchronize_memory_op<std::complex<double>,
+                                               base_device::DEVICE_GPU,
+                                               base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<double>(),
+                                                                          in,
+                                                                          this->nrxx);
+
+    this->fft_bundle.fft3D_forward(ctx,
+                                   this->fft_bundle.get_auxr_3d_data<double>(),
+                                   this->fft_bundle.get_auxr_3d_data<double>());
+
+    set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(ctx,
+                                                                   npw,
+                                                                   this->nxyz,
+                                                                   add,
+                                                                   factor,
+                                                                   this->ig2isz,
+                                                                   this->fft_bundle.get_auxr_3d_data<double>(),
+                                                                   out);
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+}
+
+template <>
+void PW_Basis::recip_to_real(const std::complex<float>* in,
+                             std::complex<float>* out,
+                             const base_device::DEVICE_GPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const float factor) const
+{
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<float>(), this->nxyz);
+    base_device::memory::set_memory_op<std::complex<float>, base_device::DEVICE_GPU>()(
+        this->fft_bundle.get_auxr_3d_data<float>(),
+        0,
+        this->nxyz);
+
+    set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(ctx,
+                                                        npw,
+                                                        this->ig2isz,
+                                                        in,
+                                                        this->fft_bundle.get_auxr_3d_data<float>());
+    this->fft_bundle.fft3D_backward(ctx,
+                                    this->fft_bundle.get_auxr_3d_data<float>(),
+                                    this->fft_bundle.get_auxr_3d_data<float>());
+
+    set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(ctx,
+                                                                  this->nrxx,
+                                                                  add,
+                                                                  factor,
+                                                                  this->fft_bundle.get_auxr_3d_data<float>(),
+                                                                  out);
+
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+}
+template <>
+void PW_Basis::recip_to_real(const std::complex<double>* in,
+                             std::complex<double>* out,
+                             const base_device::DEVICE_GPU* ctx,
+                             const int ik,
+                             const bool add,
+                             const double factor) const
+{
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<double>(), this->nxyz);
+    base_device::memory::set_memory_op<std::complex<double>, base_device::DEVICE_GPU>()(
+        this->fft_bundle.get_auxr_3d_data<double>(),
+        0,
+        this->nxyz);
+
+    set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(ctx,
+                                                         npw,
+                                                         this->ig2isz,
+                                                         in,
+                                                         this->fft_bundle.get_auxr_3d_data<double>());
+    this->fft_bundle.fft3D_backward(ctx,
+                                    this->fft_bundle.get_auxr_3d_data<double>(),
+                                    this->fft_bundle.get_auxr_3d_data<double>());
+
+    set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(ctx,
+                                                                   this->nrxx,
+                                                                   add,
+                                                                   factor,
+                                                                   this->fft_bundle.get_auxr_3d_data<double>(),
+                                                                   out);
+
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+}
+#endif
+
 template void PW_Basis::real2recip<float>(const float* in,
                                           std::complex<float>* out,
                                           const bool add,
@@ -311,4 +540,4 @@ template void PW_Basis::recip2real<double>(const std::complex<double>* in,
                                            std::complex<double>* out,
                                            const bool add,
                                            const double factor) const;
-}
\ No newline at end of file
+} // namespace ModulePW
\ No newline at end of file
diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp
index 739e497515..e09a1a3dfa 100644
--- a/source/module_elecstate/elecstate_pw.cpp
+++ b/source/module_elecstate/elecstate_pw.cpp
@@ -443,13 +443,12 @@ void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
     {
         this->addusdens_g(becsum, rhog);
     }
-
     // transform back to real space using dense grids
     if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
     {
         for (int is = 0; is < PARAM.inp.nspin; is++)
         {
-            this->charge->rhopw->recip2real(this->rhog[is], this->rho[is]);
+            this->charge->rhopw->recip_to_real(this->rhog[is], this->rho[is]);
         }
     }
 }

From 926f608371e5bb6b9ae93a577e97300440261e15 Mon Sep 17 00:00:00 2001
From: ubuntu <3158793232@qq.com>
Date: Fri, 21 Mar 2025 16:07:36 +0800
Subject: [PATCH 2/8] add the operator in pw_basis

---
 source/module_basis/module_pw/CMakeLists.txt  |   1 +
 .../module_pw/kernels/cuda/pw_op.cu           |  80 ++++++
 .../module_basis/module_pw/kernels/pw_op.cpp  |  23 ++
 source/module_basis/module_pw/kernels/pw_op.h |  33 +++
 .../module_pw/kernels/rocm/pw_op.hip.cu       |  80 ++++++
 source/module_basis/module_pw/pw_basis.h      | 168 ++++++++++---
 .../module_basis/module_pw/pw_transform.cpp   | 231 +-----------------
 .../module_pw/pw_transform_gpu.cpp            | 173 +++++++++++++
 source/module_elecstate/elecstate_pw.cpp      |   2 +-
 9 files changed, 535 insertions(+), 256 deletions(-)
 create mode 100644 source/module_basis/module_pw/pw_transform_gpu.cpp

diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt
index e365e12b5e..50b7ec3430 100644
--- a/source/module_basis/module_pw/CMakeLists.txt
+++ b/source/module_basis/module_pw/CMakeLists.txt
@@ -30,6 +30,7 @@ list(APPEND objects
     pw_distributer.cpp
     pw_init.cpp
     pw_transform.cpp
+    pw_transform_gpu.cpp
     pw_transform_k.cpp
     module_fft/fft_bundle.cpp
     module_fft/fft_cpu.cpp
diff --git a/source/module_basis/module_pw/kernels/cuda/pw_op.cu b/source/module_basis/module_pw/kernels/cuda/pw_op.cu
index a9128db318..d9418ca486 100644
--- a/source/module_basis/module_pw/kernels/cuda/pw_op.cu
+++ b/source/module_basis/module_pw/kernels/cuda/pw_op.cu
@@ -41,6 +41,24 @@ __global__ void set_recip_to_real_output(
     }
 }
 
+template<class FPTYPE>
+__global__ void set_recip_to_real_output(
+    const int nrxx,
+    const bool add,
+    const FPTYPE factor,
+    const thrust::complex<FPTYPE>* in,
+    FPTYPE* out)
+{
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if(idx >= nrxx) {return;}
+    if(add) {
+        out[idx] += factor * in[idx].real();
+    }
+    else {
+        out[idx] = in[idx].real();
+    }
+}
+
 template<class FPTYPE>
 __global__ void set_real_to_recip_output(
     const int npwk,
@@ -61,6 +79,26 @@ __global__ void set_real_to_recip_output(
     }
 }
 
+template<class FPTYPE>
+__global__ void set_real_to_recip_output(
+    const int npwk,
+    const int nxyz,
+    const bool add,
+    const FPTYPE factor,
+    const int* box_index,
+    const thrust::complex<FPTYPE>* in,
+    FPTYPE* out)
+{
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if(idx >= npwk) {return;}
+    if(add) {
+        out[idx] += factor / nxyz * in[box_index[idx]].real();
+    }
+    else {
+        out[idx] = in[box_index[idx]].real() / nxyz;
+    }
+}
+
 template <typename FPTYPE>
 void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
                                                                     const int npwk,
@@ -97,6 +135,25 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
     cudaCheckOnDebug();
 }
 
+template <typename FPTYPE>
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
+                                                                              const int nrxx,
+                                                                              const bool add,
+                                                                              const FPTYPE factor,
+                                                                              const std::complex<FPTYPE>* in,
+                                                                              FPTYPE* out)
+{
+    const int block = (nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+    set_recip_to_real_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
+        nrxx,
+        add,
+        factor,
+        reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
+        reinterpret_cast<FPTYPE*>(out));
+
+    cudaCheckOnDebug();
+}
+
 template <typename FPTYPE>
 void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
                                                                               const int npwk,
@@ -120,6 +177,29 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
     cudaCheckOnDebug();
 }
 
+template <typename FPTYPE>
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
+                                                                              const int npwk,
+                                                                              const int nxyz,
+                                                                              const bool add,
+                                                                              const FPTYPE factor,
+                                                                              const int* box_index,
+                                                                              const std::complex<FPTYPE>* in,
+                                                                              FPTYPE* out)
+{
+    const int block = (npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+    set_real_to_recip_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
+        npwk,
+        nxyz,
+        add,
+        factor,
+        box_index,
+        reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
+        reinterpret_cast<FPTYPE*>(out));
+
+    cudaCheckOnDebug();
+}
+
 template struct set_3d_fft_box_op<float, base_device::DEVICE_GPU>;
 template struct set_recip_to_real_output_op<float, base_device::DEVICE_GPU>;
 template struct set_real_to_recip_output_op<float, base_device::DEVICE_GPU>;
diff --git a/source/module_basis/module_pw/kernels/pw_op.cpp b/source/module_basis/module_pw/kernels/pw_op.cpp
index b5fb453354..e39e8a6a7c 100644
--- a/source/module_basis/module_pw/kernels/pw_op.cpp
+++ b/source/module_basis/module_pw/kernels/pw_op.cpp
@@ -39,6 +39,29 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
             }
         }
     }
+
+    void operator()(const base_device::DEVICE_CPU* /*dev*/,
+                    const int nrxx,
+                    const bool add,
+                    const FPTYPE factor,
+                    const std::complex<FPTYPE>* in,
+                    FPTYPE* out)
+    {
+        if (add)
+        {
+            for (int ir = 0; ir < nrxx; ++ir)
+            {
+                out[ir] += factor * in[ir].real();
+            }
+        }
+        else
+        {
+            for (int ir = 0; ir < nrxx; ++ir)
+            {
+                out[ir] = in[ir].real();
+            }
+        }
+    }
 };
 
 template <typename FPTYPE>
diff --git a/source/module_basis/module_pw/kernels/pw_op.h b/source/module_basis/module_pw/kernels/pw_op.h
index 8415ad9677..b547190606 100644
--- a/source/module_basis/module_pw/kernels/pw_op.h
+++ b/source/module_basis/module_pw/kernels/pw_op.h
@@ -45,6 +45,14 @@ struct set_recip_to_real_output_op {
         const FPTYPE factor,
         const std::complex<FPTYPE>* in,
         std::complex<FPTYPE>* out);
+
+    void operator() (
+        const Device* dev,
+        const int nrxx,
+        const bool add,
+        const FPTYPE factor,
+        const std::complex<FPTYPE>* in,
+        FPTYPE* out);
 };
 
 template <typename FPTYPE, typename Device>
@@ -70,6 +78,16 @@ struct set_real_to_recip_output_op {
         const int* box_index,
         const std::complex<FPTYPE>* in,
         std::complex<FPTYPE>* out);
+
+    void operator() (
+        const Device* dev,
+        const int npw_k,
+        const int nxyz,
+        const bool add,
+        const FPTYPE factor,
+        const int* box_index,
+        const std::complex<FPTYPE>* in,
+        FPTYPE* out);
 };
 
 #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
@@ -93,6 +111,13 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>
                     const FPTYPE factor,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out);
+
+    void operator()(const base_device::DEVICE_GPU* dev,
+                    const int nrxx,
+                    const bool add,
+                    const FPTYPE factor,
+                    const std::complex<FPTYPE>* in,
+                    FPTYPE* out);
 };
 
 template <typename FPTYPE>
@@ -106,6 +131,14 @@ struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>
                     const int* box_index,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out);
+    void operator()(const base_device::DEVICE_GPU* dev,
+                    const int npw_k,
+                    const int nxyz,
+                    const bool add,
+                    const FPTYPE factor,
+                    const int* box_index,
+                    const std::complex<FPTYPE>* in,
+                    FPTYPE* out);
 };
 
 #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
diff --git a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu
index a3f5fe2c2b..44be7d0671 100644
--- a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu
+++ b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu
@@ -42,6 +42,24 @@ __global__ void set_recip_to_real_output(
     }
 }
 
+template<class FPTYPE>
+__global__ void set_recip_to_real_output(
+    const int nrxx,
+    const bool add,
+    const FPTYPE factor,
+    const thrust::complex<FPTYPE>* in,
+    FPTYPE* out)
+{
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if(idx >= nrxx) {return;}
+    if(add) {
+        out[idx] += factor * in[idx].real();
+    }
+    else {
+        out[idx] = in[idx].real();
+    }
+}
+
 template<class FPTYPE>
 __global__ void set_real_to_recip_output(
     const int npwk,
@@ -62,6 +80,26 @@ __global__ void set_real_to_recip_output(
     }
 }
 
+template<class FPTYPE>
+__global__ void set_real_to_recip_output(
+    const int npwk,
+    const int nxyz,
+    const bool add,
+    const FPTYPE factor,
+    const int* box_index,
+    const thrust::complex<FPTYPE>* in,
+    FPTYPE* out)
+{
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if(idx >= npwk) {return;}
+    if(add) {
+        out[idx] += factor / nxyz * in[box_index[idx]].real();
+    }
+    else {
+        out[idx] = in[box_index[idx]].real() / nxyz;
+    }
+}
+
 template <typename FPTYPE>
 void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
                                                                     const int npwk,
@@ -98,6 +136,25 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
     hipCheckOnDebug();
 }
 
+template <typename FPTYPE>
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
+                                                                              const int nrxx,
+                                                                              const bool add,
+                                                                              const FPTYPE factor,
+                                                                              const std::complex<FPTYPE>* in,
+                                                                              FPTYPE* out)
+{
+    const int block = (nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+    hipLaunchKernelGGL(HIP_KERNEL_NAME(set_recip_to_real_output<FPTYPE>), dim3(block), dim3(THREADS_PER_BLOCK), 0, 0,
+        nrxx,
+        add,
+        factor,
+        reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
+        reinterpret_cast<FPTYPE*>(out));
+
+    hipCheckOnDebug();
+}
+
 template <typename FPTYPE>
 void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
                                                                               const int npwk,
@@ -121,6 +178,29 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
     hipCheckOnDebug();
 }
 
+template <typename FPTYPE>
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
+                                                                              const int npwk,
+                                                                              const int nxyz,
+                                                                              const bool add,
+                                                                              const FPTYPE factor,
+                                                                              const int* box_index,
+                                                                              const std::complex<FPTYPE>* in,
+                                                                              FPTYPE* out)
+{
+    const int block = (npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+    hipLaunchKernelGGL(HIP_KERNEL_NAME(set_real_to_recip_output<FPTYPE>), dim3(block), dim3(THREADS_PER_BLOCK), 0, 0,
+        npwk,
+        nxyz,
+        add,
+        factor,
+        box_index,
+        reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
+        reinterpret_cast<FPTYPE*>(out));
+
+    hipCheckOnDebug();
+}
+
 template struct set_3d_fft_box_op<float, base_device::DEVICE_GPU>;
 template struct set_recip_to_real_output_op<float, base_device::DEVICE_GPU>;
 template struct set_real_to_recip_output_op<float, base_device::DEVICE_GPU>;
diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h
index 4f1ad9a1ee..2baf937af8 100644
--- a/source/module_basis/module_pw/pw_basis.h
+++ b/source/module_basis/module_pw/pw_basis.h
@@ -1,6 +1,7 @@
 #ifndef PWBASIS_H
 #define PWBASIS_H
 
+#include "module_base/macros.h"
 #include "module_base/module_device/memory_op.h"
 #include "module_base/matrix.h"
 #include "module_base/matrix3.h"
@@ -245,7 +246,7 @@ class PW_Basis
     // FFT ft;
     FFT_Bundle fft_bundle;
     //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform).
-
+    
     template <typename FPTYPE>
     void real2recip(const FPTYPE* in,
                     std::complex<FPTYPE>* out,
@@ -266,36 +267,148 @@ class PW_Basis
                     std::complex<FPTYPE>* out,
                     const bool add = false,
                     const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
+    
+    template <typename FPTYPE>
+    void real2recip_gpu(const FPTYPE* in,
+                    std::complex<FPTYPE>* out,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
+    template <typename FPTYPE>
+    void real2recip_gpu(const std::complex<FPTYPE>* in,
+                    std::complex<FPTYPE>* out,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
+    template <typename FPTYPE>
+    void recip2real_gpu(const std::complex<FPTYPE>* in,
+                    FPTYPE* out,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
+    template <typename FPTYPE>
+    void recip2real_gpu(const std::complex<FPTYPE>* in,
+                    std::complex<FPTYPE>* out,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
 
-    template <typename FPTYPE, typename Device>
-    void real_to_recip(const std::complex<FPTYPE>* in,
-                       std::complex<FPTYPE>* out,
-                       const Device* ctx=get_default_device_ctx(),
-                       const int ik=0,
-                       const bool add = false,
-                       const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
-    template <typename FPTYPE, typename Device>
-    void recip_to_real(const std::complex<FPTYPE>* in,
-                       std::complex<FPTYPE>* out,
-                       const Device* ctx=get_default_device_ctx(),
-                       const int ik=0,
+    /**
+     * @brief Converts data from reciprocal space to real space on Cpu
+     *
+     * This function handles the conversion of data from reciprocal space (Fourier space) to real space.
+     * It supports complex types as input.
+     * The output can be either the same fundamental type or the underlying real type of a complex type.
+     *
+     * @tparam FPTYPE The type of the input data, which can only be a compelx type (e.g., std::complex<float>, std::complex<double>)
+     * @tparam Device The device type, must be base_device::DEVICE_CPU.
+     * @tparam std::enable_if<!std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type
+     *         SFINAE constraint to ensure that FPTYPE is a complex type.
+     * @tparam std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type
+     *         SFINAE constraint to ensure that Device is base_device::DEVICE_CPU.
+     *
+     * @param in Pointer to the input data array in reciprocal space.
+     * @param out Pointer to the output data array in real space. If FPTYPE is a complex type,
+     *            this should point to an array of the underlying real type.
+     * @param add Boolean flag indicating whether to add the result to the existing values in the output array.
+     * @param factor A scaling factor applied to the output values.
+     */
+    template <typename TK,
+              typename TR,
+              typename Device,
+              typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
+                                          && (std::is_same<TR, typename GetTypeReal<TK>::type>::value
+                                              || std::is_same<TR, TK>::value)
+                                          && std::is_same<Device, base_device::DEVICE_CPU>::value,
+                                      int>::type
+              = 0>
+    void recip_to_real(TK* in, TR* out, const bool add = false, const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+        this->recip2real(in, out, add, factor);
+    };
+
+    /**
+     * @brief Converts data from reciprocal space (Fourier space) to real space.
+     *
+     * This function handles the conversion of data from reciprocal space (typically after a Fourier transform)
+     * to real space, supporting scenarios where the input is of a complex type and the output is of the underlying
+     * fundamental type.
+     *
+     * @tparam FPTYPE The underlying fundamental type of the input data (e.g., float, double).
+     *         Note that the actual type passed should be std::complex<FPTYPE>.
+     * @tparam Device The device type, which can be any supported device type (e.g., base_device::DEVICE_CPU).
+     *
+     * @param in Pointer to the input data array in reciprocal space, of type std::complex<FPTYPE>*.
+     * @param out Pointer to the output data array in real space, of type FPTYPE*.
+     * @param add Optional boolean flag, default value false, indicating whether to add the result to the existing
+     * values in the output array.
+     * @param factor Optional scaling factor, default value 1.0, applied to the output values.
+     */
+    template <typename TK,
+              typename TR,
+              typename Device,
+              typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
+                                          && (std::is_same<TR, typename GetTypeReal<TK>::type>::value
+                                              || std::is_same<TR, TK>::value)
+                                          && std::is_same<Device, base_device::DEVICE_GPU>::value,
+                                      int>::type
+              = 0>
+    void recip_to_real(TK* in,
+                       TR* out,
                        const bool add = false,
-                       const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
-
-    template <typename FPTYPE, typename Device = base_device::DEVICE_CPU>
-    void real_to_recip(FPTYPE* in,
-                       std::complex<FPTYPE>* out,
-                       const Device* ctx=get_default_device_ctx(),
-                       const int ik=0,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+        this->recip2real_gpu(in, out, add, factor);
+    };
+
+    // template <typename FPTYPE,
+    //         typename Device,
+    //         typename std::enable_if<!std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type = 0>
+    // void recip_to_real(FPTYPE* in,
+    //                    FPTYPE* out,
+    //                    const bool add = false,
+    //                    const typename GetTypeReal<FPTYPE>::type factor = 1.0) const;
+    /**
+     * @brief Converts data from real space to reciprocal space (Fourier space).
+     *
+     * This function handles the conversion of data from real space to reciprocal space (typically after performing a
+     * Fourier transform), supporting scenarios where the input is of a fundamental type (e.g., float, double) and the
+     * output is of a complex type.
+     *
+     * @tparam FPTYPE The underlying fundamental type of the input data (e.g., float, double).
+     *         SFINAE constraint ensures that FPTYPE is a fundamental type, not a complex type.
+     * @tparam Device The device type, which must be base_device::DEVICE_CPU.
+     * @tparam std::enable_if<std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type
+     *         SFINAE constraint to ensure that FPTYPE is a fundamental type.
+     * @tparam std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type
+     *         SFINAE constraint to ensure that Device is base_device::DEVICE_CPU.
+     *
+     * @param in Pointer to the input data array in real space, of type FPTYPE*.
+     * @param out Pointer to the output data array in reciprocal space, of type std::complex<FPTYPE>*.
+     * @param ik Optional parameter, default value 0, representing some index or identifier.
+     * @param add Optional boolean flag, default value false, indicating whether to add the result to the existing
+     * values in the output array.
+     * @param factor Optional scaling factor, default value 1.0, applied to the output values.
+     */
+    template <typename TK,
+            typename TR,
+            typename Device,
+            typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
+                    && (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
+                    && std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0>
+    void real_to_recip(TR* in,
+                       TK* out,
                        const bool add = false,
-                       const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
-    template <typename FPTYPE, typename Device = base_device::DEVICE_CPU>
-    void recip_to_real(const std::complex<FPTYPE>* in,
-                       FPTYPE* out,
-                       const Device* ctx=get_default_device_ctx() ,
-                       const int ik=0,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+        this->real2recip(in, out, add, factor);
+    }
+
+    template <typename TK,typename TR, typename Device,
+            typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
+                    && (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
+                    && !std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0>
+    void real_to_recip(TR* in,
+                       TK* out,
                        const bool add = false,
-                       const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
+                       const typename GetTypeReal<TK>::type factor = 1.0) const;
+
   protected:
     //gather planes and scatter sticks of all processors
     template <typename T>
@@ -325,6 +438,5 @@ class PW_Basis
 
 }
 #endif // PWBASIS_H
-
 #include "pw_basis_sup.h"
 #include "pw_basis_big.h" //temporary it will be removed
diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp
index 38d4bfc49d..4f34221775 100644
--- a/source/module_basis/module_pw/pw_transform.cpp
+++ b/source/module_basis/module_pw/pw_transform.cpp
@@ -10,10 +10,10 @@
 
 namespace ModulePW
 {
-    const base_device::DEVICE_CPU* PW_Basis::get_default_device_ctx() {
-        static const base_device::DEVICE_CPU* default_device_cpu;
-    return default_device_cpu;
-}
+//     const base_device::DEVICE_CPU* PW_Basis::get_default_device_ctx() {
+//         static const base_device::DEVICE_CPU* default_device_cpu;
+//     return default_device_cpu;
+// }
 /**
  * @brief transform real space to reciprocal space
  * @details c(g)=\int dr*f(r)*exp(-ig*r)
@@ -284,229 +284,6 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
     }
     ModuleBase::timer::tick(this->classname, "recip2real");
 }
-
-template <>
-void PW_Basis::real_to_recip(const std::complex<float>* in,
-                             std::complex<float>* out,
-                             const base_device::DEVICE_CPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const float factor) const
-{
-    this->real2recip(in, out, add, factor);
-}
-template <>
-void PW_Basis::real_to_recip(const std::complex<double>* in,
-                             std::complex<double>* out,
-                             const base_device::DEVICE_CPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const double factor) const
-{
-    this->real2recip(in, out, add, factor);
-}
-
-template <>
-void PW_Basis::recip_to_real(const std::complex<float>* in,
-                             std::complex<float>* out,
-                             const base_device::DEVICE_CPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const float factor) const
-{
-    this->recip2real(in, out, add, factor);
-}
-template <>
-void PW_Basis::recip_to_real(const std::complex<double>* in,
-                             std::complex<double>* out,
-                             const base_device::DEVICE_CPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const double factor) const
-{
-    this->recip2real(in, out, add, factor);
-}
-
-template <>
-void PW_Basis::real_to_recip(float* in,
-                             std::complex<float>* out,
-                             const base_device::DEVICE_CPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const float factor) const
-{
-    this->real2recip(in, out, add, factor);
-}
-template <>
-void PW_Basis::real_to_recip(double* in,
-                             std::complex<double>* out,
-                             const base_device::DEVICE_CPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const double factor) const
-{
-    this->real2recip(in, out, add, factor);
-}
-
-template <>
-void PW_Basis::recip_to_real(const std::complex<float>* in,
-                             float* out,
-                             const base_device::DEVICE_CPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const float factor) const
-{
-    this->recip2real(in, out, add, factor);
-}
-
-template <>
-void PW_Basis::recip_to_real(const std::complex<double>* in,
-                             double* out,
-                             const base_device::DEVICE_CPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const double factor) const
-{
-    this->recip2real(in, out, add, factor);
-}
-
-#if (defined(__CUDA) || defined(__ROCM))
-template <>
-void PW_Basis::real_to_recip(const std::complex<float>* in,
-                             std::complex<float>* out,
-                             const base_device::DEVICE_GPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const float factor) const
-{
-    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
-    assert(this->gamma_only == false);
-    assert(this->poolnproc == 1);
-
-    base_device::memory::synchronize_memory_op<std::complex<float>, base_device::DEVICE_GPU, base_device::DEVICE_GPU>()(
-        this->fft_bundle.get_auxr_3d_data<float>(),
-        in,
-        this->nrxx);
-
-    this->fft_bundle.fft3D_forward(ctx,
-                                   this->fft_bundle.get_auxr_3d_data<float>(),
-                                   this->fft_bundle.get_auxr_3d_data<float>());
-
-    set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(ctx,
-                                                                  npw,
-                                                                  this->nxyz,
-                                                                  add,
-                                                                  factor,
-                                                                  this->ig2isz,
-                                                                  this->fft_bundle.get_auxr_3d_data<float>(),
-                                                                  out);
-    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
-}
-template <>
-void PW_Basis::real_to_recip(const std::complex<double>* in,
-                             std::complex<double>* out,
-                             const base_device::DEVICE_GPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const double factor) const
-{
-    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
-    assert(this->gamma_only == false);
-    assert(this->poolnproc == 1);
-
-    base_device::memory::synchronize_memory_op<std::complex<double>,
-                                               base_device::DEVICE_GPU,
-                                               base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<double>(),
-                                                                          in,
-                                                                          this->nrxx);
-
-    this->fft_bundle.fft3D_forward(ctx,
-                                   this->fft_bundle.get_auxr_3d_data<double>(),
-                                   this->fft_bundle.get_auxr_3d_data<double>());
-
-    set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(ctx,
-                                                                   npw,
-                                                                   this->nxyz,
-                                                                   add,
-                                                                   factor,
-                                                                   this->ig2isz,
-                                                                   this->fft_bundle.get_auxr_3d_data<double>(),
-                                                                   out);
-    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
-}
-
-template <>
-void PW_Basis::recip_to_real(const std::complex<float>* in,
-                             std::complex<float>* out,
-                             const base_device::DEVICE_GPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const float factor) const
-{
-    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
-    assert(this->gamma_only == false);
-    assert(this->poolnproc == 1);
-    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<float>(), this->nxyz);
-    base_device::memory::set_memory_op<std::complex<float>, base_device::DEVICE_GPU>()(
-        this->fft_bundle.get_auxr_3d_data<float>(),
-        0,
-        this->nxyz);
-
-    set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(ctx,
-                                                        npw,
-                                                        this->ig2isz,
-                                                        in,
-                                                        this->fft_bundle.get_auxr_3d_data<float>());
-    this->fft_bundle.fft3D_backward(ctx,
-                                    this->fft_bundle.get_auxr_3d_data<float>(),
-                                    this->fft_bundle.get_auxr_3d_data<float>());
-
-    set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(ctx,
-                                                                  this->nrxx,
-                                                                  add,
-                                                                  factor,
-                                                                  this->fft_bundle.get_auxr_3d_data<float>(),
-                                                                  out);
-
-    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
-}
-template <>
-void PW_Basis::recip_to_real(const std::complex<double>* in,
-                             std::complex<double>* out,
-                             const base_device::DEVICE_GPU* ctx,
-                             const int ik,
-                             const bool add,
-                             const double factor) const
-{
-    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
-    assert(this->gamma_only == false);
-    assert(this->poolnproc == 1);
-    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<double>(), this->nxyz);
-    base_device::memory::set_memory_op<std::complex<double>, base_device::DEVICE_GPU>()(
-        this->fft_bundle.get_auxr_3d_data<double>(),
-        0,
-        this->nxyz);
-
-    set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(ctx,
-                                                         npw,
-                                                         this->ig2isz,
-                                                         in,
-                                                         this->fft_bundle.get_auxr_3d_data<double>());
-    this->fft_bundle.fft3D_backward(ctx,
-                                    this->fft_bundle.get_auxr_3d_data<double>(),
-                                    this->fft_bundle.get_auxr_3d_data<double>());
-
-    set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(ctx,
-                                                                   this->nrxx,
-                                                                   add,
-                                                                   factor,
-                                                                   this->fft_bundle.get_auxr_3d_data<double>(),
-                                                                   out);
-
-    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
-}
-#endif
-
 template void PW_Basis::real2recip<float>(const float* in,
                                           std::complex<float>* out,
                                           const bool add,
diff --git a/source/module_basis/module_pw/pw_transform_gpu.cpp b/source/module_basis/module_pw/pw_transform_gpu.cpp
new file mode 100644
index 0000000000..c17e72ca50
--- /dev/null
+++ b/source/module_basis/module_pw/pw_transform_gpu.cpp
@@ -0,0 +1,173 @@
+#include "pw_basis.h"
+#include "module_base/timer.h"
+#include "module_basis/module_pw/kernels/pw_op.h"
+namespace ModulePW
+{
+#if (defined(__CUDA) || defined(__ROCM))
+template <typename FPTYPE>
+void PW_Basis::real2recip_gpu(const FPTYPE* in,
+                             std::complex<FPTYPE>* out,
+                             const bool add,
+                             const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    base_device::DEVICE_GPU* ctx;
+    // base_device::memory::synchronize_memory_op<std::complex<FPTYPE>,
+    //                                            base_device::DEVICE_GPU,
+    //                                            base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+    //                                                                       in,
+    //                                                                       this->nrxx);
+
+    this->fft_bundle.fft3D_forward(ctx,
+                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
+                                                                  npw,
+                                                                  this->nxyz,
+                                                                  add,
+                                                                  factor,
+                                                                  this->ig2isz,
+                                                                  this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                  out);
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+}
+template <typename FPTYPE>
+void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in,
+                             std::complex<FPTYPE>* out,
+                             const bool add,
+                             const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    base_device::DEVICE_GPU* ctx;
+    base_device::memory::synchronize_memory_op<std::complex<FPTYPE>,
+                                               base_device::DEVICE_GPU,
+                                               base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                          in,
+                                                                          this->nrxx);
+
+    this->fft_bundle.fft3D_forward(ctx,
+                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
+                                                                   npw,
+                                                                   this->nxyz,
+                                                                   add,
+                                                                   factor,
+                                                                   this->ig2isz,
+                                                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                   out);
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+}
+
+template <typename FPTYPE>
+void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in,
+                             FPTYPE* out,
+                             const bool add,
+                             const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    base_device::DEVICE_GPU* ctx;
+    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<FPTYPE>(), this->nxyz);
+    base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
+        this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+        0,
+        this->nxyz);
+
+    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
+                                                        npw,
+                                                        this->ig2isz,
+                                                        in,
+                                                        this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+    this->fft_bundle.fft3D_backward(ctx,
+                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
+                                                                  this->nrxx,
+                                                                  add,
+                                                                  factor,
+                                                                  this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                  out);
+
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+}
+template <typename FPTYPE>
+    void PW_Basis::recip2real_gpu(const std::complex<FPTYPE> *in,
+                                 std::complex<FPTYPE> *out,
+                                 const bool add,
+                                 const FPTYPE factor) const
+{
+    base_device::DEVICE_GPU* ctx;
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<double>(), this->nxyz);
+    base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
+        this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+        0,
+        this->nxyz);
+
+    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
+                                                         npw,
+                                                         this->ig2isz,
+                                                         in,
+                                                         this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+    this->fft_bundle.fft3D_backward(ctx,
+                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
+                                                                   this->nrxx,
+                                                                   add,
+                                                                   factor,
+                                                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                   out);
+
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+}
+template void PW_Basis::real2recip_gpu<double>(const double* in,
+                                              std::complex<double>* out,
+                                              const bool add,
+                                              const double factor) const;
+template void PW_Basis::real2recip_gpu<float>(const float* in,
+                                              std::complex<float>* out,
+                                              const bool add,
+                                              const float factor) const;
+
+template void PW_Basis::real2recip_gpu<double>(const std::complex<double>* in,
+                                              std::complex<double>* out,
+                                              const bool add,
+                                              const double factor) const;
+template void PW_Basis::real2recip_gpu<float>(const std::complex<float>* in,
+                                              std::complex<float>* out,
+                                              const bool add,
+                                              const float factor) const;
+
+template void PW_Basis::recip2real_gpu<double>(const std::complex<double>* in,
+                                              double* out,
+                                              const bool add,
+                                              const double factor) const;
+template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in,
+                                              float* out,
+                                              const bool add,
+                                              const float factor) const;
+
+template void PW_Basis::recip2real_gpu<double>(const std::complex<double>* in,
+                                              std::complex<double>* out,
+                                              const bool add,
+                                              const double factor) const;
+template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in,
+                                              std::complex<float>* out,
+                                              const bool add,
+                                              const float factor) const;
+
+#endif
+} // namespace ModulePW
\ No newline at end of file
diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp
index e09a1a3dfa..ef7a168752 100644
--- a/source/module_elecstate/elecstate_pw.cpp
+++ b/source/module_elecstate/elecstate_pw.cpp
@@ -448,7 +448,7 @@ void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
     {
         for (int is = 0; is < PARAM.inp.nspin; is++)
         {
-            this->charge->rhopw->recip_to_real(this->rhog[is], this->rho[is]);
+            this->charge->rhopw->recip2real(this->rhog[is], this->rho[is]);
         }
     }
 }

From b96bb777cf6ee8712a73b6126e70809fe9fd05f9 Mon Sep 17 00:00:00 2001
From: ubuntu <3158793232@qq.com>
Date: Fri, 21 Mar 2025 16:22:14 +0800
Subject: [PATCH 3/8] remove ctx in the file

---
 .../module_pw/kernels/cuda/pw_op.cu           | 15 +++++--------
 .../module_basis/module_pw/kernels/pw_op.cpp  | 12 ++++------
 source/module_basis/module_pw/kernels/pw_op.h | 20 +++++------------
 .../module_pw/kernels/rocm/pw_op.hip.cu       | 15 +++++--------
 .../module_pw/pw_transform_gpu.cpp            | 22 +++++++------------
 .../module_basis/module_pw/pw_transform_k.cpp | 18 +++++----------
 .../module_pw/pw_transform_k_dsp.cpp          | 10 ++++-----
 7 files changed, 37 insertions(+), 75 deletions(-)

diff --git a/source/module_basis/module_pw/kernels/cuda/pw_op.cu b/source/module_basis/module_pw/kernels/cuda/pw_op.cu
index d9418ca486..1e0fce284e 100644
--- a/source/module_basis/module_pw/kernels/cuda/pw_op.cu
+++ b/source/module_basis/module_pw/kernels/cuda/pw_op.cu
@@ -100,8 +100,7 @@ __global__ void set_real_to_recip_output(
 }
 
 template <typename FPTYPE>
-void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                    const int npwk,
+void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                     const int* box_index,
                                                                     const std::complex<FPTYPE>* in,
                                                                     std::complex<FPTYPE>* out)
@@ -117,8 +116,7 @@ void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_d
 }
 
 template <typename FPTYPE>
-void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int nrxx,
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
                                                                               const bool add,
                                                                               const FPTYPE factor,
                                                                               const std::complex<FPTYPE>* in,
@@ -136,8 +134,7 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
 }
 
 template <typename FPTYPE>
-void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int nrxx,
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
                                                                               const bool add,
                                                                               const FPTYPE factor,
                                                                               const std::complex<FPTYPE>* in,
@@ -155,8 +152,7 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
 }
 
 template <typename FPTYPE>
-void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int npwk,
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                               const int nxyz,
                                                                               const bool add,
                                                                               const FPTYPE factor,
@@ -178,8 +174,7 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
 }
 
 template <typename FPTYPE>
-void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int npwk,
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                               const int nxyz,
                                                                               const bool add,
                                                                               const FPTYPE factor,
diff --git a/source/module_basis/module_pw/kernels/pw_op.cpp b/source/module_basis/module_pw/kernels/pw_op.cpp
index e39e8a6a7c..a1bafae69b 100644
--- a/source/module_basis/module_pw/kernels/pw_op.cpp
+++ b/source/module_basis/module_pw/kernels/pw_op.cpp
@@ -5,8 +5,7 @@ namespace ModulePW {
 template <typename FPTYPE>
 struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU>
 {
-    void operator()(const base_device::DEVICE_CPU* /*dev*/,
-                    const int npwk,
+    void operator()(const int npwk,
                     const int* box_index,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out)
@@ -21,8 +20,7 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU>
 template <typename FPTYPE>
 struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
 {
-    void operator()(const base_device::DEVICE_CPU* /*dev*/,
-                    const int nrxx,
+    void operator()(const int nrxx,
                     const bool add,
                     const FPTYPE factor,
                     const std::complex<FPTYPE>* in,
@@ -40,8 +38,7 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
         }
     }
 
-    void operator()(const base_device::DEVICE_CPU* /*dev*/,
-                    const int nrxx,
+    void operator()(const int nrxx,
                     const bool add,
                     const FPTYPE factor,
                     const std::complex<FPTYPE>* in,
@@ -67,8 +64,7 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
 template <typename FPTYPE>
 struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>
 {
-    void operator()(const base_device::DEVICE_CPU* /*dev*/,
-                    const int npw_k,
+    void operator()(const int npw_k,
                     const int nxyz,
                     const bool add,
                     const FPTYPE factor,
diff --git a/source/module_basis/module_pw/kernels/pw_op.h b/source/module_basis/module_pw/kernels/pw_op.h
index b547190606..6221c8b68c 100644
--- a/source/module_basis/module_pw/kernels/pw_op.h
+++ b/source/module_basis/module_pw/kernels/pw_op.h
@@ -19,7 +19,6 @@ struct set_3d_fft_box_op {
     /// Output Parameters
     /// @param out - output psi within the 3D box(in recip space)
     void operator() (
-        const Device* dev,
         const int npwk,
         const int* box_index,
         const std::complex<FPTYPE>* in,
@@ -39,7 +38,6 @@ struct set_recip_to_real_output_op {
     /// Output Parameters
     /// @param out - output psi within the 3D box(in real space)
     void operator() (
-        const Device* dev,
         const int nrxx,
         const bool add,
         const FPTYPE factor,
@@ -47,7 +45,6 @@ struct set_recip_to_real_output_op {
         std::complex<FPTYPE>* out);
 
     void operator() (
-        const Device* dev,
         const int nrxx,
         const bool add,
         const FPTYPE factor,
@@ -70,7 +67,6 @@ struct set_real_to_recip_output_op {
     /// Output Parameters
     /// @param out - output psi within the 3D box(in recip space)
     void operator() (
-        const Device* dev,
         const int npw_k,
         const int nxyz,
         const bool add,
@@ -80,7 +76,6 @@ struct set_real_to_recip_output_op {
         std::complex<FPTYPE>* out);
 
     void operator() (
-        const Device* dev,
         const int npw_k,
         const int nxyz,
         const bool add,
@@ -95,8 +90,7 @@ struct set_real_to_recip_output_op {
 template <typename FPTYPE>
 struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>
 {
-    void operator()(const base_device::DEVICE_GPU* dev,
-                    const int npwk,
+    void operator()(const int npwk,
                     const int* box_index,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out);
@@ -105,15 +99,13 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>
 template <typename FPTYPE>
 struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>
 {
-    void operator()(const base_device::DEVICE_GPU* dev,
-                    const int nrxx,
+    void operator()(const int nrxx,
                     const bool add,
                     const FPTYPE factor,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out);
 
-    void operator()(const base_device::DEVICE_GPU* dev,
-                    const int nrxx,
+    void operator()(const int nrxx,
                     const bool add,
                     const FPTYPE factor,
                     const std::complex<FPTYPE>* in,
@@ -123,16 +115,14 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>
 template <typename FPTYPE>
 struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>
 {
-    void operator()(const base_device::DEVICE_GPU* dev,
-                    const int npw_k,
+    void operator()(const int npw_k,
                     const int nxyz,
                     const bool add,
                     const FPTYPE factor,
                     const int* box_index,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out);
-    void operator()(const base_device::DEVICE_GPU* dev,
-                    const int npw_k,
+    void operator()(const int npw_k,
                     const int nxyz,
                     const bool add,
                     const FPTYPE factor,
diff --git a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu
index 44be7d0671..51b50c6b8b 100644
--- a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu
+++ b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu
@@ -101,8 +101,7 @@ __global__ void set_real_to_recip_output(
 }
 
 template <typename FPTYPE>
-void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                    const int npwk,
+void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                     const int* box_index,
                                                                     const std::complex<FPTYPE>* in,
                                                                     std::complex<FPTYPE>* out)
@@ -118,8 +117,7 @@ void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_d
 }
 
 template <typename FPTYPE>
-void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int nrxx,
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
                                                                               const bool add,
                                                                               const FPTYPE factor,
                                                                               const std::complex<FPTYPE>* in,
@@ -137,8 +135,7 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
 }
 
 template <typename FPTYPE>
-void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int nrxx,
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
                                                                               const bool add,
                                                                               const FPTYPE factor,
                                                                               const std::complex<FPTYPE>* in,
@@ -156,8 +153,7 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
 }
 
 template <typename FPTYPE>
-void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int npwk,
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                               const int nxyz,
                                                                               const bool add,
                                                                               const FPTYPE factor,
@@ -179,8 +175,7 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
 }
 
 template <typename FPTYPE>
-void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int npwk,
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                               const int nxyz,
                                                                               const bool add,
                                                                               const FPTYPE factor,
diff --git a/source/module_basis/module_pw/pw_transform_gpu.cpp b/source/module_basis/module_pw/pw_transform_gpu.cpp
index c17e72ca50..db89d2fbaf 100644
--- a/source/module_basis/module_pw/pw_transform_gpu.cpp
+++ b/source/module_basis/module_pw/pw_transform_gpu.cpp
@@ -3,7 +3,7 @@
 #include "module_basis/module_pw/kernels/pw_op.h"
 namespace ModulePW
 {
-#if (defined(__CUDA) || defined(__ROCM))
+// #if (defined(__CUDA) || defined(__ROCM))
 template <typename FPTYPE>
 void PW_Basis::real2recip_gpu(const FPTYPE* in,
                              std::complex<FPTYPE>* out,
@@ -24,8 +24,7 @@ void PW_Basis::real2recip_gpu(const FPTYPE* in,
                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>());
 
-    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
-                                                                  npw,
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
                                                                   this->nxyz,
                                                                   add,
                                                                   factor,
@@ -54,8 +53,7 @@ void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in,
                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>());
 
-    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
-                                                                   npw,
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
                                                                    this->nxyz,
                                                                    add,
                                                                    factor,
@@ -81,8 +79,7 @@ void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in,
         0,
         this->nxyz);
 
-    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
-                                                        npw,
+    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
                                                         this->ig2isz,
                                                         in,
                                                         this->fft_bundle.get_auxr_3d_data<FPTYPE>());
@@ -90,8 +87,7 @@ void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in,
                                     this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
                                     this->fft_bundle.get_auxr_3d_data<FPTYPE>());
 
-    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
-                                                                  this->nrxx,
+    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
                                                                   add,
                                                                   factor,
                                                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
@@ -115,8 +111,7 @@ template <typename FPTYPE>
         0,
         this->nxyz);
 
-    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
-                                                         npw,
+    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
                                                          this->ig2isz,
                                                          in,
                                                          this->fft_bundle.get_auxr_3d_data<FPTYPE>());
@@ -124,8 +119,7 @@ template <typename FPTYPE>
                                     this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
                                     this->fft_bundle.get_auxr_3d_data<FPTYPE>());
 
-    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(ctx,
-                                                                   this->nrxx,
+    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
                                                                    add,
                                                                    factor,
                                                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
@@ -169,5 +163,5 @@ template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in,
                                               const bool add,
                                               const float factor) const;
 
-#endif
+// #endif
 } // namespace ModulePW
\ No newline at end of file
diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp
index 3d75f07f6f..4e3a9ecc9d 100644
--- a/source/module_basis/module_pw/pw_transform_k.cpp
+++ b/source/module_basis/module_pw/pw_transform_k.cpp
@@ -361,8 +361,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
 
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
-    set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(ctx,
-                                                                  npw_k,
+    set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(npw_k,
                                                                   this->nxyz,
                                                                   add,
                                                                   factor,
@@ -393,8 +392,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
 
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
-    set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(ctx,
-                                                                   npw_k,
+    set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(npw_k,
                                                                    this->nxyz,
                                                                    add,
                                                                    factor,
@@ -424,15 +422,13 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
 
-    set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(ctx,
-                                                        npw_k,
+    set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(npw_k,
                                                         this->ig2ixyz_k + startig,
                                                         in,
                                                         this->fft_bundle.get_auxr_3d_data<float>());
     this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
 
-    set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(ctx,
-                                                                  this->nrxx,
+    set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(this->nrxx,
                                                                   add,
                                                                   factor,
                                                                   this->fft_bundle.get_auxr_3d_data<float>(),
@@ -460,15 +456,13 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
 
-    set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(ctx,
-                                                         npw_k,
+    set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(npw_k,
                                                          this->ig2ixyz_k + startig,
                                                          in,
                                                          this->fft_bundle.get_auxr_3d_data<double>());
     this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
 
-    set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(ctx,
-                                                                   this->nrxx,
+    set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(this->nrxx,
                                                                    add,
                                                                    factor,
                                                                    this->fft_bundle.get_auxr_3d_data<double>(),
diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp
index b292e25f0a..1632a20c5e 100644
--- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp
+++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp
@@ -32,8 +32,7 @@ void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in,
                                    auxr);
     this->fft_bundle.resource_handler(0);
     // copy the result from the auxr to the out ,while consider the add
-    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>()(ctx,
-                                                                   npw_k,
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>()(npw_k,
                                                                    this->nxyz,
                                                                    add,
                                                                    factor,
@@ -58,7 +57,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex<FPTYPE>* in,
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
     // copy the mapping form the type of stick to the 3dfft
-    set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr);
+    set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr);
     // use 3d fft backward
     this->fft_bundle.resource_handler(1);
     this->fft_bundle.fft3D_backward(gpux, auxr, auxr);
@@ -107,7 +106,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
     const int npw_k = this->npwk[ik];
 
     // copy the mapping form the type of stick to the 3dfft
-    set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr);
+    set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr);
 
     // use 3d fft backward
     this->fft_bundle.fft3D_backward(gpux, auxr, auxr);
@@ -120,8 +119,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
     // 3d fft
     this->fft_bundle.fft3D_forward(gpux, auxr, auxr);
     // copy the result from the auxr to the out ,while consider the add
-    set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(ctx,
-                                                                   npw_k,
+    set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(npw_k,
                                                                    this->nxyz,
                                                                    add,
                                                                    factor,

From a105a3cd92f40778404fc5928daf2121d778625d Mon Sep 17 00:00:00 2001
From: ubuntu <3158793232@qq.com>
Date: Fri, 21 Mar 2025 16:42:04 +0800
Subject: [PATCH 4/8] remove ctx in bundle

---
 .../module_pw/kernels/test/pw_op_test.cpp            | 12 ++++++------
 .../module_basis/module_pw/module_fft/fft_bundle.cpp | 12 ++++--------
 .../module_basis/module_pw/module_fft/fft_bundle.h   |  8 ++++----
 source/module_basis/module_pw/pw_transform_gpu.cpp   | 12 ++++--------
 source/module_basis/module_pw/pw_transform_k.cpp     |  8 ++++----
 source/module_basis/module_pw/pw_transform_k_dsp.cpp |  9 ++++-----
 6 files changed, 26 insertions(+), 35 deletions(-)

diff --git a/source/module_basis/module_pw/kernels/test/pw_op_test.cpp b/source/module_basis/module_pw/kernels/test/pw_op_test.cpp
index 6adac4613f..e9bc07bd9e 100644
--- a/source/module_basis/module_pw/kernels/test/pw_op_test.cpp
+++ b/source/module_basis/module_pw/kernels/test/pw_op_test.cpp
@@ -72,7 +72,7 @@ class TestModulePWPWMultiDevice : public ::testing::Test
 TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_cpu)
 {
     std::vector<std::complex<double>> res(out_1.size(), std::complex<double>{0, 0});
-    set_3d_fft_box_cpu_op()(cpu_ctx, this->npwk, box_index.data(), in_1.data(), res.data());
+    set_3d_fft_box_cpu_op()(this->npwk, box_index.data(), in_1.data(), res.data());
     for (int ii = 0; ii < this->nxyz; ii++) {
         EXPECT_LT(std::abs(res[ii] - out_1[ii]), 1e-12);
     }
@@ -81,7 +81,7 @@ TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_cpu)
 TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_cpu)
 {
     std::vector<std::complex<double>> res(out_2.size(), std::complex<double>{0, 0});
-    set_recip_to_real_output_cpu_op()(cpu_ctx, this->nxyz, this->add, this->factor, in_2.data(), res.data());
+    set_recip_to_real_output_cpu_op()(this->nxyz, this->add, this->factor, in_2.data(), res.data());
     for (int ii = 0; ii < this->nxyz; ii++) {
         EXPECT_LT(std::abs(res[ii] - out_2[ii]), 1e-12);
     }
@@ -90,7 +90,7 @@ TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_cpu)
 TEST_F(TestModulePWPWMultiDevice, set_real_to_recip_output_op_cpu)
 {
     std::vector<std::complex<double>> res = out_3_init;
-    set_real_to_recip_output_cpu_op()(cpu_ctx, this->npwk, this->nxyz, true, this->factor, box_index.data(), in_3.data(), res.data());
+    set_real_to_recip_output_cpu_op()(this->npwk, this->nxyz, true, this->factor, box_index.data(), in_3.data(), res.data());
     for (int ii = 0; ii < out_3.size(); ii++) {
         EXPECT_LT(std::abs(res[ii] - out_3[ii]), 5e-6);
     }
@@ -109,7 +109,7 @@ TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_gpu)
     synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size());
     synchronize_memory_complex_h2d_op()(d_in_1, in_1.data(), in_1.size());
 
-    set_3d_fft_box_gpu_op()(gpu_ctx, this->npwk, d_box_index, d_in_1, d_res);
+    set_3d_fft_box_gpu_op()(this->npwk, d_box_index, d_in_1, d_res);
 
     synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size());
 
@@ -130,7 +130,7 @@ TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_gpu)
     synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size());
     synchronize_memory_complex_h2d_op()(d_in_2, in_2.data(), in_2.size());
 
-    set_recip_to_real_output_gpu_op()(gpu_ctx, this->nxyz, this->add, this->factor, d_in_2, d_res);
+    set_recip_to_real_output_gpu_op()(this->nxyz, this->add, this->factor, d_in_2, d_res);
 
     synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size());
 
@@ -153,7 +153,7 @@ TEST_F(TestModulePWPWMultiDevice, set_real_to_recip_output_op_gpu)
     synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size());
     synchronize_memory_complex_h2d_op()(d_in_3, in_3.data(), in_3.size());
 
-    set_real_to_recip_output_gpu_op()(gpu_ctx, this->npwk, this->nxyz, true, this->factor, d_box_index, d_in_3, d_res);
+    set_real_to_recip_output_gpu_op()(this->npwk, this->nxyz, true, this->factor, d_box_index, d_in_3, d_res);
 
     synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size());
 
diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp
index 7289e8ab02..270f985498 100644
--- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp
+++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp
@@ -227,30 +227,26 @@ void FFT_Bundle::fftxyc2r(std::complex<double>* in, double* out) const
 }
 
 template <>
-void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx,
-                               std::complex<float>* in,
+void FFT_Bundle::fft3D_forward(std::complex<float>* in,
                                std::complex<float>* out) const
 {
     fft_float->fft3D_forward(in, out);
 }
 template <>
-void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx,
-                               std::complex<double>* in,
+void FFT_Bundle::fft3D_forward(std::complex<double>* in,
                                std::complex<double>* out) const
 {
     fft_double->fft3D_forward(in, out);
 }
 
 template <>
-void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx,
-                                std::complex<float>* in,
+void FFT_Bundle::fft3D_backward(std::complex<float>* in,
                                 std::complex<float>* out) const
 {
     fft_float->fft3D_backward(in, out);
 }
 template <>
-void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx,
-                                std::complex<double>* in,
+void FFT_Bundle::fft3D_backward(std::complex<double>* in,
                                 std::complex<double>* out) const
 {
     fft_double->fft3D_backward(in, out);
diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h
index 1982a79a0c..5b8b527cc4 100644
--- a/source/module_basis/module_pw/module_fft/fft_bundle.h
+++ b/source/module_basis/module_pw/module_fft/fft_bundle.h
@@ -188,10 +188,10 @@ class FFT_Bundle
     template <typename FPTYPE>
     void fftxyc2r(std::complex<FPTYPE>* in, FPTYPE* out) const;
 
-    template <typename FPTYPE, typename Device>
-    void fft3D_forward(const Device* ctx, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
-    template <typename FPTYPE, typename Device>
-    void fft3D_backward(const Device* ctx, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
+    template <typename FPTYPE>
+    void fft3D_forward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
+    template <typename FPTYPE>
+    void fft3D_backward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
 
   private:
     int fft_mode = 0;
diff --git a/source/module_basis/module_pw/pw_transform_gpu.cpp b/source/module_basis/module_pw/pw_transform_gpu.cpp
index db89d2fbaf..b5970a48a6 100644
--- a/source/module_basis/module_pw/pw_transform_gpu.cpp
+++ b/source/module_basis/module_pw/pw_transform_gpu.cpp
@@ -20,8 +20,7 @@ void PW_Basis::real2recip_gpu(const FPTYPE* in,
     //                                                                       in,
     //                                                                       this->nrxx);
 
-    this->fft_bundle.fft3D_forward(ctx,
-                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>());
 
     set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
@@ -49,8 +48,7 @@ void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in,
                                                                           in,
                                                                           this->nrxx);
 
-    this->fft_bundle.fft3D_forward(ctx,
-                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>());
 
     set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
@@ -83,8 +81,7 @@ void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in,
                                                         this->ig2isz,
                                                         in,
                                                         this->fft_bundle.get_auxr_3d_data<FPTYPE>());
-    this->fft_bundle.fft3D_backward(ctx,
-                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
                                     this->fft_bundle.get_auxr_3d_data<FPTYPE>());
 
     set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
@@ -115,8 +112,7 @@ template <typename FPTYPE>
                                                          this->ig2isz,
                                                          in,
                                                          this->fft_bundle.get_auxr_3d_data<FPTYPE>());
-    this->fft_bundle.fft3D_backward(ctx,
-                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
                                     this->fft_bundle.get_auxr_3d_data<FPTYPE>());
 
     set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp
index 4e3a9ecc9d..776b493d35 100644
--- a/source/module_basis/module_pw/pw_transform_k.cpp
+++ b/source/module_basis/module_pw/pw_transform_k.cpp
@@ -357,7 +357,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
         in,
         this->nrxx);
 
-    this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
 
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
@@ -388,7 +388,7 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
                                                                           in,
                                                                           this->nrxx);
 
-    this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
 
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
@@ -426,7 +426,7 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
                                                         this->ig2ixyz_k + startig,
                                                         in,
                                                         this->fft_bundle.get_auxr_3d_data<float>());
-    this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
 
     set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(this->nrxx,
                                                                   add,
@@ -460,7 +460,7 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
                                                          this->ig2ixyz_k + startig,
                                                          in,
                                                          this->fft_bundle.get_auxr_3d_data<double>());
-    this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
 
     set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(this->nrxx,
                                                                    add,
diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp
index 1632a20c5e..56ba26eb2a 100644
--- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp
+++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp
@@ -27,8 +27,7 @@ void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in,
 
     // 3d fft
     this->fft_bundle.resource_handler(1);
-    this->fft_bundle.fft3D_forward(gpux, 
-                                   auxr, 
+    this->fft_bundle.fft3D_forward(auxr, 
                                    auxr);
     this->fft_bundle.resource_handler(0);
     // copy the result from the auxr to the out ,while consider the add
@@ -60,7 +59,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex<FPTYPE>* in,
     set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr);
     // use 3d fft backward
     this->fft_bundle.resource_handler(1);
-    this->fft_bundle.fft3D_backward(gpux, auxr, auxr);
+    this->fft_bundle.fft3D_backward(auxr, auxr);
     this->fft_bundle.resource_handler(0);
     if (add)
     {
@@ -109,7 +108,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
     set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr);
 
     // use 3d fft backward
-    this->fft_bundle.fft3D_backward(gpux, auxr, auxr);
+    this->fft_bundle.fft3D_backward(auxr, auxr);
 
     for (int ir = 0; ir < size; ir++)
     {
@@ -117,7 +116,7 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
     }
 
     // 3d fft
-    this->fft_bundle.fft3D_forward(gpux, auxr, auxr);
+    this->fft_bundle.fft3D_forward(auxr, auxr);
     // copy the result from the auxr to the out ,while consider the add
     set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(npw_k,
                                                                    this->nxyz,

From 977b713f5fb103056e152d7995120d3882051ed1 Mon Sep 17 00:00:00 2001
From: ubuntu <3158793232@qq.com>
Date: Fri, 21 Mar 2025 16:49:35 +0800
Subject: [PATCH 5/8] update compile bug

---
 source/module_basis/module_pw/pw_transform_gpu.cpp | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/source/module_basis/module_pw/pw_transform_gpu.cpp b/source/module_basis/module_pw/pw_transform_gpu.cpp
index b5970a48a6..1898ff23eb 100644
--- a/source/module_basis/module_pw/pw_transform_gpu.cpp
+++ b/source/module_basis/module_pw/pw_transform_gpu.cpp
@@ -3,7 +3,7 @@
 #include "module_basis/module_pw/kernels/pw_op.h"
 namespace ModulePW
 {
-// #if (defined(__CUDA) || defined(__ROCM))
+#if (defined(__CUDA) || defined(__ROCM))
 template <typename FPTYPE>
 void PW_Basis::real2recip_gpu(const FPTYPE* in,
                              std::complex<FPTYPE>* out,
@@ -159,5 +159,5 @@ template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in,
                                               const bool add,
                                               const float factor) const;
 
-// #endif
+#endif
 } // namespace ModulePW
\ No newline at end of file

From ee48127a32eaae8c139373f69ca8582c88c6e0d5 Mon Sep 17 00:00:00 2001
From: ubuntu <3158793232@qq.com>
Date: Sun, 23 Mar 2025 19:53:19 +0800
Subject: [PATCH 6/8] add T,device

---
 source/module_basis/module_pw/pw_basis_k.h    | 72 ++++++++++++++-
 .../module_basis/module_pw/pw_transform_k.cpp | 89 +++++++++++++++++++
 .../module_pw/test/test-other.cpp             |  8 +-
 source/module_elecstate/elecstate_pw.cpp      |  8 +-
 .../module_elecstate/elecstate_pw_cal_tau.cpp |  4 +-
 .../module_xc/test/xc3_mock.h                 | 18 ++--
 .../module_xc/xc_functional_gradcorr.cpp      |  3 +-
 .../hamilt_ofdft/ml_data_descriptor.cpp       |  4 +-
 .../hamilt_pwdft/operator_pw/meta_pw.cpp      |  4 +-
 .../hamilt_pwdft/operator_pw/op_exx_pw.cpp    | 10 +--
 .../hamilt_pwdft/operator_pw/veff_pw.cpp      | 12 +--
 .../hamilt_stodft/sto_iter.cpp                |  2 +-
 source/module_io/get_pchg_pw.h                |  4 +-
 13 files changed, 195 insertions(+), 43 deletions(-)

diff --git a/source/module_basis/module_pw/pw_basis_k.h b/source/module_basis/module_pw/pw_basis_k.h
index ae5076bba9..ef578852ae 100644
--- a/source/module_basis/module_pw/pw_basis_k.h
+++ b/source/module_basis/module_pw/pw_basis_k.h
@@ -161,7 +161,7 @@ class PW_Basis_K : public PW_Basis
     
     #endif
 
-    template <typename FPTYPE, typename Device>
+     template <typename FPTYPE, typename Device>
     void real_to_recip(const Device* ctx,
                        const std::complex<FPTYPE>* in,
                        std::complex<FPTYPE>* out,
@@ -176,6 +176,76 @@ class PW_Basis_K : public PW_Basis
                        const bool add = false,
                        const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
 
+
+    template <typename TK,
+              typename Device,
+              typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
+    void real_to_recip(const TK* in,
+                       TK* out,
+                       const int ik,
+                       const bool add = false,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+      #if defined(__DSP)
+        this->recip2real_dsp(in, out, ik, add, factor);
+      #else
+        this->real2recip(in,out,ik,add,factor);
+      #endif
+    }
+    template <typename TK,
+              typename Device,
+              typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
+    void recip_to_real(const TK* in,
+                       TK* out,
+                       const int ik,
+                       const bool add = false,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+      
+      #if defined(__DSP)
+        this->recip2real_dsp(in,out,ik,add,factor);
+      #else
+        this->recip2real(in,out,ik,add,factor);
+      #endif
+    }
+    template <typename FPTYPE>
+    void real2recip_gpu(const std::complex<FPTYPE>* in,
+                    std::complex<FPTYPE>* out,
+                    const int ik,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
+                    
+    template <typename FPTYPE>
+    void recip2real_gpu(const std::complex<FPTYPE>* in,
+                    std::complex<FPTYPE>* out,
+                    const int ik,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
+
+    template <typename FPTYPE,
+              typename Device,
+              typename std::enable_if<!std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
+    void real_to_recip(const FPTYPE* in,
+                       FPTYPE* out,
+                       const int ik,
+                       const bool add = false,
+                       const typename GetTypeReal<FPTYPE>::type factor = 1.0) const
+    {
+        this->real2recip_gpu(in, out, ik, add, factor);
+    }
+
+    template <typename TK,
+              typename Device,
+              typename std::enable_if<std::is_same<Device, base_device::DEVICE_GPU>::value, int>::type = 0>
+    void recip_to_real(const TK* in,
+                       TK* out,
+                       const int ik,
+                       const bool add = false,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+        this->recip2real_gpu(in, out, ik, add, factor);
+    }
+
   public:
     //operator:
     //get (G+K)^2:
diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp
index 776b493d35..a709b60429 100644
--- a/source/module_basis/module_pw/pw_transform_k.cpp
+++ b/source/module_basis/module_pw/pw_transform_k.cpp
@@ -470,6 +470,95 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
 
     ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
 }
+
+template <typename FPTYPE>
+void PW_Basis_K::real2recip_gpu(const std::complex<FPTYPE>* in,
+                               std::complex<FPTYPE>* out,
+                               const int ik,
+                               const bool add,
+                               const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+
+    base_device::memory::synchronize_memory_op<std::complex<FPTYPE>,
+                                               base_device::DEVICE_GPU,
+                                               base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                          in,
+                                                                          this->nrxx);
+
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    const int startig = ik * this->npwk_max;
+    const int npw_k = this->npwk[ik];
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
+                                                                   this->nxyz,
+                                                                   add,
+                                                                   factor,
+                                                                   this->ig2ixyz_k + startig,
+                                                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                   out);
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+}
+template <typename FPTYPE>
+void PW_Basis_K::recip2real_gpu(const std::complex<FPTYPE>* in,
+                               std::complex<FPTYPE>* out,
+                               const int ik,
+                               const bool add,
+                               const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<FPTYPE>(), this->nxyz);
+    base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
+        this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+        0,
+        this->nxyz);
+
+    const int startig = ik * this->npwk_max;
+    const int npw_k = this->npwk[ik];
+
+    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
+                                                         this->ig2ixyz_k + startig,
+                                                         in,
+                                                         this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
+                                                                   add,
+                                                                   factor,
+                                                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                   out);
+
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+}
+
+template void PW_Basis_K::real2recip_gpu<float>(const std::complex<float>*,
+                                                std::complex<float>*,
+                                                const int,
+                                                const bool,
+                                                const float) const;
+
+template void PW_Basis_K::real2recip_gpu<double>(const std::complex<double>*,
+                                                 std::complex<double>*,
+                                                 const int,
+                                                 const bool,
+                                                 const double) const;
+
+template void PW_Basis_K::recip2real_gpu<float>(const std::complex<float>*,
+                                                std::complex<float>*,
+                                                const int,
+                                                const bool,
+                                                const float) const;
+
+template void PW_Basis_K::recip2real_gpu<double>(const std::complex<double>*,
+                                                 std::complex<double>*,
+                                                 const int,
+                                                 const bool,
+                                                 const double) const;
+
 #endif
 
 template void PW_Basis_K::real2recip<float>(const float* in,
diff --git a/source/module_basis/module_pw/test/test-other.cpp b/source/module_basis/module_pw/test/test-other.cpp
index 308f4c6c68..c5eb0093b2 100644
--- a/source/module_basis/module_pw/test/test-other.cpp
+++ b/source/module_basis/module_pw/test/test-other.cpp
@@ -76,26 +76,26 @@ TEST_F(PWTEST,test_other)
         }  
 #endif
 
-        pwktest.recip_to_real(ctx, rhog1, rhor1, ik);
+        pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhog1, rhor1, ik);
         pwktest.recip2real(rhog2, rhor2, ik);
         for(int ir = 0 ; ir < nrxx; ++ir)
         {
             EXPECT_NEAR(std::abs(rhor1[ir]),std::abs(rhor2[ir]),1e-8);
         }
-        pwktest.real_to_recip(ctx, rhor1, rhog1, ik);
+        pwktest.real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(rhor1, rhog1, ik);
         pwktest.real2recip(rhor2, rhog2, ik);
         for(int ig = 0 ; ig < npwk; ++ig)
         {
             EXPECT_NEAR(std::abs(rhog1[ig]),std::abs(rhog2[ig]),1e-8);
         }
 #ifdef __ENABLE_FLOAT_FFTW
-        pwktest.recip_to_real(ctx, rhofg1, rhofr1, ik);
+        pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhofg1, rhofr1, ik);
         pwktest.recip2real(rhofg2, rhofr2, ik);
         for(int ir = 0 ; ir < nrxx; ++ir)
         {
             EXPECT_NEAR(std::abs(rhofr1[ir]),std::abs(rhofr2[ir]),1e-6);
         }
-        pwktest.real_to_recip(ctx, rhofr1, rhofg1, ik);
+        pwktest.real_to_recip<std::complex<float>,base_device::DEVICE_CPU>(rhofr1, rhofg1, ik);
         pwktest.real2recip(rhofr2, rhofg2, ik);
         for(int ig = 0 ; ig < npwk; ++ig)
         {
diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp
index ef7a168752..76a4ce6c42 100644
--- a/source/module_elecstate/elecstate_pw.cpp
+++ b/source/module_elecstate/elecstate_pw.cpp
@@ -202,9 +202,9 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
             /// be care of when smearing_sigma is large, wg would less than 0
             ///
 
-            this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);
+            this->basis->recip_to_real<T,Device>( &psi(ibnd,0), this->wfcr, ik);
 
-            this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik);
+            this->basis->recip_to_real<T,Device>( &psi(ibnd,npwx), this->wfcr_another_spin, ik);
 
             const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega);
 
@@ -230,7 +230,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
             /// only occupied band should be calculated.
             ///
 
-            this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);
+            this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik);
 
             const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega);
 
@@ -258,7 +258,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
                               &psi(ibnd, 0),
                               this->wfcr);
 
-                    this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik);
+                    this->basis->recip_to_real<T,Device>( this->wfcr, this->wfcr, ik);
 
                     elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr);
                 }
diff --git a/source/module_elecstate/elecstate_pw_cal_tau.cpp b/source/module_elecstate/elecstate_pw_cal_tau.cpp
index 628dd25aef..1cf4f57b70 100644
--- a/source/module_elecstate/elecstate_pw_cal_tau.cpp
+++ b/source/module_elecstate/elecstate_pw_cal_tau.cpp
@@ -23,7 +23,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
         int nbands = psi.get_nbands();
         for (int ibnd = 0; ibnd < nbands; ibnd++)
         {
-            this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);
+            this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik);
 
             const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega);
 
@@ -43,7 +43,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
                             &psi(ibnd, 0),
                             this->wfcr);
 
-                this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik);
+                this->basis->recip_to_real<T,Device>(this->wfcr, this->wfcr, ik);
 
                 elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr);
             }
diff --git a/source/module_hamilt_general/module_xc/test/xc3_mock.h b/source/module_hamilt_general/module_xc/test/xc3_mock.h
index ffca76c70c..ebbc139d75 100644
--- a/source/module_hamilt_general/module_xc/test/xc3_mock.h
+++ b/source/module_hamilt_general/module_xc/test/xc3_mock.h
@@ -78,8 +78,7 @@ namespace ModulePW
 
 
     template <typename FPTYPE, typename Device>
-    void PW_Basis_K::real_to_recip(const Device* ctx,
-                       const std::complex<FPTYPE>* in,
+    void PW_Basis_K::real_to_recip(const std::complex<FPTYPE>* in,
                        std::complex<FPTYPE>* out,
                        const int ik,
                        const bool add,
@@ -91,8 +90,7 @@ namespace ModulePW
         }
     }
     template <typename FPTYPE, typename Device>
-    void PW_Basis_K::recip_to_real(const Device* ctx,
-                       const std::complex<FPTYPE>* in,
+    void PW_Basis_K::recip_to_real(const std::complex<FPTYPE>* in,
                        std::complex<FPTYPE>* out,
                        const int ik,
                        const bool add,
@@ -104,28 +102,24 @@ namespace ModulePW
         }
     }
 
-    template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx,
-                                                                             const std::complex<double>* in,
+    template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const std::complex<double>* in,
                                                                              std::complex<double>* out,
                                                                              const int ik,
                                                                              const bool add,
                                                                              const double factor) const;
-    template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx,
-                                                                             const std::complex<double>* in,
+    template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const std::complex<double>* in,
                                                                              std::complex<double>* out,
                                                                              const int ik,
                                                                              const bool add,
                                                                              const double factor) const;
 #if __CUDA || __ROCM
-    template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx,
-                                                                             const std::complex<double>* in,
+    template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const std::complex<double>* in,
                                                                              std::complex<double>* out,
                                                                              const int ik,
                                                                              const bool add,
                                                                              const double factor) const;
 
-    template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx,
-                                                                             const std::complex<double>* in,
+    template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const std::complex<double>* in,
                                                                              std::complex<double>* out,
                                                                              const int ik,
                                                                              const bool add,
diff --git a/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp b/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp
index 4250146510..b3cad37c44 100644
--- a/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp
+++ b/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp
@@ -644,8 +644,7 @@ void XC_Functional::grad_wfc(
 			rhog, porter.data<T>());    // Array of std::complex<double>
 
 		// bring the gdr from G --> R
-		Device * ctx = nullptr;
-		wfc_basis->recip_to_real(ctx, porter.data<T>(), porter.data<T>(), ik);
+		wfc_basis->recip_to_real<T,Device>(porter.data<T>(), porter.data<T>(), ik);
 
 		xc_functional_grad_wfc_solver(
             ipol, wfc_basis->nrxx,	// Integers
diff --git a/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp b/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp
index 85135d789f..e9ed4e7268 100644
--- a/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp
+++ b/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp
@@ -197,7 +197,7 @@ void ML_data::getF_KS1(
                 continue;
             }
 
-            pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik);
+            pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik);
             const double w1 = pelec->wg(ik, ibnd) / ucell.omega;
             
             // output one wf, to check KS equation
@@ -308,7 +308,7 @@ void ML_data::getF_KS2(
                 continue;
             }
 
-            pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik);
+            pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik);
             const double w1 = pelec->wg(ik, ibnd) / ucell.omega;
 
             if (pelec->ekb(ik,ibnd) > epsilonM)
diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp
index 817dba61ce..538819f452 100644
--- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp
+++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp
@@ -67,13 +67,13 @@ void Meta<OperatorPW<T, Device>>::act(
         for (int j = 0; j < 3; j++)
         {
             meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), tmpsi_in, this->porter);
-            wfcpw->recip_to_real(this->ctx, this->porter, this->porter, this->ik);
+            wfcpw->recip_to_real<T,Device>(this->porter, this->porter, this->ik);
 
             if(this->vk_col != 0) {
                 vector_mul_vector_op()(this->vk_col, this->porter, this->porter, this->vk + current_spin * this->vk_col);
             }
 
-            wfcpw->real_to_recip(this->ctx, this->porter, this->porter, this->ik);
+            wfcpw->real_to_recip<T,Device>(this->porter, this->porter, this->ik);
             meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), this->porter, tmhpsi, true);
 
         } // x,y,z directions
diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp
index 4b6d36908a..202800026c 100644
--- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp
+++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp
@@ -188,7 +188,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
     {
         const T *psi_nk = tmpsi_in + n_iband * nbasis;
         // retrieve \psi_nk in real space
-        wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, this->ik);
+        wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, this->ik);
 
         // for \psi_nk, get the pw of iq and band m
         auto q_points = get_q_points(this->ik);
@@ -208,7 +208,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
                 // if (has_real.find({iq, m_iband}) == has_real.end())
                 // {
                     const T* psi_mq = get_pw(m_iband, iq);
-                    wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq);
+                    wfcpw->recip_to_real<T,Device>( psi_mq, psi_mq_real, iq);
                 //     syncmem_complex_op()(this->ctx, this->ctx, psi_all_real + m_iband * wfcpw->nrxx, psi_mq_real, wfcpw->nrxx);
                 //     has_real[{iq, m_iband}] = true;
                 // }
@@ -271,7 +271,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
         } // end of iq
         auto h_psi_nk = tmhpsi + n_iband * nbasis;
         Real hybrid_alpha = GlobalC::exx_info.info_global.hybrid_alpha;
-        wfcpw->real_to_recip(ctx, h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha);
+        wfcpw->real_to_recip<T,Device>(h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha);
         setmem_complex_op()(h_psi_real, 0, rhopw->nrxx);
         
     }
@@ -810,7 +810,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c
             psi.fix_kb(ik, n_iband);
             const T* psi_nk = psi.get_pointer();
             // retrieve \psi_nk in real space
-            wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, ik);
+            wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, ik);
 
             // for \psi_nk, get the pw of iq and band m
             // q_points is a vector of integers, 0 to nks-1
@@ -839,7 +839,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c
                     psi_.fix_kb(iq, m_iband);
                     const T* psi_mq = psi_.get_pointer();
                     // const T* psi_mq = get_pw(m_iband, iq);
-                    wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq);
+                    wfcpw->recip_to_real<T,Device>(psi_mq, psi_mq_real, iq);
 
                     T omega_inv = 1.0 / ucell->omega;
 
diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp
index 54e1a052be..c931d61f2f 100644
--- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp
+++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp
@@ -62,7 +62,7 @@ void Veff<OperatorPW<T, Device>>::act(
         if (npol == 1)
         {
             // wfcpw->recip2real(tmpsi_in, porter, this->ik);
-            wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik);
+            wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
             // NOTICE: when MPI threads are larger than number of Z grids
             // veff would contain nothing, and nothing should be done in real space
             // but the 3DFFT can not be skipped, it will cause hanging
@@ -76,7 +76,7 @@ void Veff<OperatorPW<T, Device>>::act(
                 // }
             }
             // wfcpw->real2recip(porter, tmhpsi, this->ik, true);
-            wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
+            wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
             // wfcpw->convolution(this->ctx,
             // this->ik,
             // this->veff_col,
@@ -89,8 +89,8 @@ void Veff<OperatorPW<T, Device>>::act(
         {
             // T *porter1 = new T[wfcpw->nmaxgr];
             // fft to real space and doing things.
-            wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik);
-            wfcpw->recip_to_real(this->ctx, tmpsi_in + max_npw, this->porter1, this->ik);
+            wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
+            wfcpw->recip_to_real<T,Device>(tmpsi_in + max_npw, this->porter1, this->ik);
             if(this->veff_col != 0)
             {
                 /// denghui added at 20221109
@@ -114,8 +114,8 @@ void Veff<OperatorPW<T, Device>>::act(
                 // }
             }
             // (3) fft back to G space.
-            wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
-            wfcpw->real_to_recip(this->ctx, this->porter1, tmhpsi + max_npw, this->ik, true);
+            wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
+            wfcpw->real_to_recip<T,Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
         }
         tmhpsi += max_npw * npol;
         tmpsi_in += max_npw * npol;
diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp
index 1134c7777b..e8eddca95f 100644
--- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp
+++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp
@@ -642,7 +642,7 @@ void Stochastic_Iter<T, Device>::cal_storho(const UnitCell& ucell,
         T* tmpout = stowf.shchi->get_pointer();
         for (int ichi = 0; ichi < nchip_ik; ++ichi)
         {
-            wfc_basis->recip_to_real(this->ctx, tmpout, porter, ik);
+            wfc_basis->recip_to_real<T,Device>(tmpout, porter, ik);
             const auto w1 = static_cast<Real>(this->pkv->wk[ik]);
             elecstate::elecstate_pw_op<Real, Device>()(this->ctx, current_spin, nrxx, w1, pes->rho, porter);
             // for (int ir = 0; ir < nrxx; ++ir)
diff --git a/source/module_io/get_pchg_pw.h b/source/module_io/get_pchg_pw.h
index 58f82f4a9c..7fd1eff9ad 100644
--- a/source/module_io/get_pchg_pw.h
+++ b/source/module_io/get_pchg_pw.h
@@ -96,7 +96,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print,
                           << ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl;
 
                 psi->fix_k(ik);
-                pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
+                pw_wfc->recip_to_real<std::complex<double>,Device>(&psi[0](ib, 0), wfcr.data(), ik);
 
                 // To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true)
                 double wg_sum_k = 0;
@@ -139,7 +139,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print,
                           << ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl;
 
                 psi->fix_k(ik);
-                pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
+                pw_wfc->recip_to_real<std::complex<double>,Device>( &psi[0](ib, 0), wfcr.data(), ik);
 
                 double w1 = static_cast<double>(wk[ik] / ucell->omega);
 

From 3301b52570dde9888b2bc119c875c8fbc5da8e49 Mon Sep 17 00:00:00 2001
From: ubuntu <3158793232@qq.com>
Date: Sun, 23 Mar 2025 20:06:22 +0800
Subject: [PATCH 7/8] moidfy back the func

---
 .../module_basis/module_pw/test/test-other.cpp |  8 ++++----
 source/module_elecstate/elecstate_pw.cpp       |  8 ++++----
 .../module_elecstate/elecstate_pw_cal_tau.cpp  |  4 ++--
 .../module_xc/test/xc3_mock.h                  | 18 ++++++++++++------
 .../module_xc/xc_functional_gradcorr.cpp       |  3 ++-
 .../hamilt_ofdft/ml_data_descriptor.cpp        |  4 ++--
 .../hamilt_pwdft/operator_pw/meta_pw.cpp       |  4 ++--
 .../hamilt_pwdft/operator_pw/op_exx_pw.cpp     | 10 +++++-----
 .../hamilt_stodft/sto_iter.cpp                 |  2 +-
 source/module_io/get_pchg_pw.h                 |  4 ++--
 10 files changed, 36 insertions(+), 29 deletions(-)

diff --git a/source/module_basis/module_pw/test/test-other.cpp b/source/module_basis/module_pw/test/test-other.cpp
index c5eb0093b2..308f4c6c68 100644
--- a/source/module_basis/module_pw/test/test-other.cpp
+++ b/source/module_basis/module_pw/test/test-other.cpp
@@ -76,26 +76,26 @@ TEST_F(PWTEST,test_other)
         }  
 #endif
 
-        pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhog1, rhor1, ik);
+        pwktest.recip_to_real(ctx, rhog1, rhor1, ik);
         pwktest.recip2real(rhog2, rhor2, ik);
         for(int ir = 0 ; ir < nrxx; ++ir)
         {
             EXPECT_NEAR(std::abs(rhor1[ir]),std::abs(rhor2[ir]),1e-8);
         }
-        pwktest.real_to_recip<std::complex<double>,base_device::DEVICE_CPU>(rhor1, rhog1, ik);
+        pwktest.real_to_recip(ctx, rhor1, rhog1, ik);
         pwktest.real2recip(rhor2, rhog2, ik);
         for(int ig = 0 ; ig < npwk; ++ig)
         {
             EXPECT_NEAR(std::abs(rhog1[ig]),std::abs(rhog2[ig]),1e-8);
         }
 #ifdef __ENABLE_FLOAT_FFTW
-        pwktest.recip_to_real<std::complex<double>,base_device::DEVICE_CPU>(rhofg1, rhofr1, ik);
+        pwktest.recip_to_real(ctx, rhofg1, rhofr1, ik);
         pwktest.recip2real(rhofg2, rhofr2, ik);
         for(int ir = 0 ; ir < nrxx; ++ir)
         {
             EXPECT_NEAR(std::abs(rhofr1[ir]),std::abs(rhofr2[ir]),1e-6);
         }
-        pwktest.real_to_recip<std::complex<float>,base_device::DEVICE_CPU>(rhofr1, rhofg1, ik);
+        pwktest.real_to_recip(ctx, rhofr1, rhofg1, ik);
         pwktest.real2recip(rhofr2, rhofg2, ik);
         for(int ig = 0 ; ig < npwk; ++ig)
         {
diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp
index 76a4ce6c42..ef7a168752 100644
--- a/source/module_elecstate/elecstate_pw.cpp
+++ b/source/module_elecstate/elecstate_pw.cpp
@@ -202,9 +202,9 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
             /// be care of when smearing_sigma is large, wg would less than 0
             ///
 
-            this->basis->recip_to_real<T,Device>( &psi(ibnd,0), this->wfcr, ik);
+            this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);
 
-            this->basis->recip_to_real<T,Device>( &psi(ibnd,npwx), this->wfcr_another_spin, ik);
+            this->basis->recip_to_real(this->ctx, &psi(ibnd,npwx), this->wfcr_another_spin, ik);
 
             const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega);
 
@@ -230,7 +230,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
             /// only occupied band should be calculated.
             ///
 
-            this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik);
+            this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);
 
             const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega);
 
@@ -258,7 +258,7 @@ void ElecStatePW<T, Device>::rhoBandK(const psi::Psi<T, Device>& psi)
                               &psi(ibnd, 0),
                               this->wfcr);
 
-                    this->basis->recip_to_real<T,Device>( this->wfcr, this->wfcr, ik);
+                    this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik);
 
                     elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr);
                 }
diff --git a/source/module_elecstate/elecstate_pw_cal_tau.cpp b/source/module_elecstate/elecstate_pw_cal_tau.cpp
index 1cf4f57b70..628dd25aef 100644
--- a/source/module_elecstate/elecstate_pw_cal_tau.cpp
+++ b/source/module_elecstate/elecstate_pw_cal_tau.cpp
@@ -23,7 +23,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
         int nbands = psi.get_nbands();
         for (int ibnd = 0; ibnd < nbands; ibnd++)
         {
-            this->basis->recip_to_real<T,Device>(&psi(ibnd,0), this->wfcr, ik);
+            this->basis->recip_to_real(this->ctx, &psi(ibnd,0), this->wfcr, ik);
 
             const auto w1 = static_cast<Real>(this->wg(ik, ibnd) / ucell->omega);
 
@@ -43,7 +43,7 @@ void ElecStatePW<T, Device>::cal_tau(const psi::Psi<T, Device>& psi)
                             &psi(ibnd, 0),
                             this->wfcr);
 
-                this->basis->recip_to_real<T,Device>(this->wfcr, this->wfcr, ik);
+                this->basis->recip_to_real(this->ctx, this->wfcr, this->wfcr, ik);
 
                 elecstate_pw_op()(this->ctx, current_spin, this->charge->nrxx, w1, this->kin_r, this->wfcr);
             }
diff --git a/source/module_hamilt_general/module_xc/test/xc3_mock.h b/source/module_hamilt_general/module_xc/test/xc3_mock.h
index ebbc139d75..ffca76c70c 100644
--- a/source/module_hamilt_general/module_xc/test/xc3_mock.h
+++ b/source/module_hamilt_general/module_xc/test/xc3_mock.h
@@ -78,7 +78,8 @@ namespace ModulePW
 
 
     template <typename FPTYPE, typename Device>
-    void PW_Basis_K::real_to_recip(const std::complex<FPTYPE>* in,
+    void PW_Basis_K::real_to_recip(const Device* ctx,
+                       const std::complex<FPTYPE>* in,
                        std::complex<FPTYPE>* out,
                        const int ik,
                        const bool add,
@@ -90,7 +91,8 @@ namespace ModulePW
         }
     }
     template <typename FPTYPE, typename Device>
-    void PW_Basis_K::recip_to_real(const std::complex<FPTYPE>* in,
+    void PW_Basis_K::recip_to_real(const Device* ctx,
+                       const std::complex<FPTYPE>* in,
                        std::complex<FPTYPE>* out,
                        const int ik,
                        const bool add,
@@ -102,24 +104,28 @@ namespace ModulePW
         }
     }
 
-    template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const std::complex<double>* in,
+    template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx,
+                                                                             const std::complex<double>* in,
                                                                              std::complex<double>* out,
                                                                              const int ik,
                                                                              const bool add,
                                                                              const double factor) const;
-    template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const std::complex<double>* in,
+    template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_CPU>(const base_device::DEVICE_CPU* ctx,
+                                                                             const std::complex<double>* in,
                                                                              std::complex<double>* out,
                                                                              const int ik,
                                                                              const bool add,
                                                                              const double factor) const;
 #if __CUDA || __ROCM
-    template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const std::complex<double>* in,
+    template void PW_Basis_K::real_to_recip<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx,
+                                                                             const std::complex<double>* in,
                                                                              std::complex<double>* out,
                                                                              const int ik,
                                                                              const bool add,
                                                                              const double factor) const;
 
-    template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const std::complex<double>* in,
+    template void PW_Basis_K::recip_to_real<double, base_device::DEVICE_GPU>(const base_device::DEVICE_GPU* ctx,
+                                                                             const std::complex<double>* in,
                                                                              std::complex<double>* out,
                                                                              const int ik,
                                                                              const bool add,
diff --git a/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp b/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp
index b3cad37c44..4250146510 100644
--- a/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp
+++ b/source/module_hamilt_general/module_xc/xc_functional_gradcorr.cpp
@@ -644,7 +644,8 @@ void XC_Functional::grad_wfc(
 			rhog, porter.data<T>());    // Array of std::complex<double>
 
 		// bring the gdr from G --> R
-		wfc_basis->recip_to_real<T,Device>(porter.data<T>(), porter.data<T>(), ik);
+		Device * ctx = nullptr;
+		wfc_basis->recip_to_real(ctx, porter.data<T>(), porter.data<T>(), ik);
 
 		xc_functional_grad_wfc_solver(
             ipol, wfc_basis->nrxx,	// Integers
diff --git a/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp b/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp
index e9ed4e7268..85135d789f 100644
--- a/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp
+++ b/source/module_hamilt_pw/hamilt_ofdft/ml_data_descriptor.cpp
@@ -197,7 +197,7 @@ void ML_data::getF_KS1(
                 continue;
             }
 
-            pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik);
+            pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik);
             const double w1 = pelec->wg(ik, ibnd) / ucell.omega;
             
             // output one wf, to check KS equation
@@ -308,7 +308,7 @@ void ML_data::getF_KS2(
                 continue;
             }
 
-            pw_psi->recip_to_real<T,Device>( &psi->operator()(ibnd,0), wfcr, ik);
+            pw_psi->recip_to_real(ctx, &psi->operator()(ibnd,0), wfcr, ik);
             const double w1 = pelec->wg(ik, ibnd) / ucell.omega;
 
             if (pelec->ekb(ik,ibnd) > epsilonM)
diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp
index 538819f452..817dba61ce 100644
--- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp
+++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/meta_pw.cpp
@@ -67,13 +67,13 @@ void Meta<OperatorPW<T, Device>>::act(
         for (int j = 0; j < 3; j++)
         {
             meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), tmpsi_in, this->porter);
-            wfcpw->recip_to_real<T,Device>(this->porter, this->porter, this->ik);
+            wfcpw->recip_to_real(this->ctx, this->porter, this->porter, this->ik);
 
             if(this->vk_col != 0) {
                 vector_mul_vector_op()(this->vk_col, this->porter, this->porter, this->vk + current_spin * this->vk_col);
             }
 
-            wfcpw->real_to_recip<T,Device>(this->porter, this->porter, this->ik);
+            wfcpw->real_to_recip(this->ctx, this->porter, this->porter, this->ik);
             meta_op()(this->ctx, this->ik, j, ngk_ik, this->wfcpw->npwk_max, this->tpiba, wfcpw->get_gcar_data<Real>(), wfcpw->get_kvec_c_data<Real>(), this->porter, tmhpsi, true);
 
         } // x,y,z directions
diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp
index 202800026c..4b6d36908a 100644
--- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp
+++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/op_exx_pw.cpp
@@ -188,7 +188,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
     {
         const T *psi_nk = tmpsi_in + n_iband * nbasis;
         // retrieve \psi_nk in real space
-        wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, this->ik);
+        wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, this->ik);
 
         // for \psi_nk, get the pw of iq and band m
         auto q_points = get_q_points(this->ik);
@@ -208,7 +208,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
                 // if (has_real.find({iq, m_iband}) == has_real.end())
                 // {
                     const T* psi_mq = get_pw(m_iband, iq);
-                    wfcpw->recip_to_real<T,Device>( psi_mq, psi_mq_real, iq);
+                    wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq);
                 //     syncmem_complex_op()(this->ctx, this->ctx, psi_all_real + m_iband * wfcpw->nrxx, psi_mq_real, wfcpw->nrxx);
                 //     has_real[{iq, m_iband}] = true;
                 // }
@@ -271,7 +271,7 @@ void OperatorEXXPW<T, Device>::act_op(const int nbands,
         } // end of iq
         auto h_psi_nk = tmhpsi + n_iband * nbasis;
         Real hybrid_alpha = GlobalC::exx_info.info_global.hybrid_alpha;
-        wfcpw->real_to_recip<T,Device>(h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha);
+        wfcpw->real_to_recip(ctx, h_psi_real, h_psi_nk, this->ik, true, hybrid_alpha);
         setmem_complex_op()(h_psi_real, 0, rhopw->nrxx);
         
     }
@@ -810,7 +810,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c
             psi.fix_kb(ik, n_iband);
             const T* psi_nk = psi.get_pointer();
             // retrieve \psi_nk in real space
-            wfcpw->recip_to_real<T,Device>( psi_nk, psi_nk_real, ik);
+            wfcpw->recip_to_real(ctx, psi_nk, psi_nk_real, ik);
 
             // for \psi_nk, get the pw of iq and band m
             // q_points is a vector of integers, 0 to nks-1
@@ -839,7 +839,7 @@ double OperatorEXXPW<T, Device>::cal_exx_energy_op(psi::Psi<T, Device> *ppsi_) c
                     psi_.fix_kb(iq, m_iband);
                     const T* psi_mq = psi_.get_pointer();
                     // const T* psi_mq = get_pw(m_iband, iq);
-                    wfcpw->recip_to_real<T,Device>(psi_mq, psi_mq_real, iq);
+                    wfcpw->recip_to_real(ctx, psi_mq, psi_mq_real, iq);
 
                     T omega_inv = 1.0 / ucell->omega;
 
diff --git a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp
index e8eddca95f..1134c7777b 100644
--- a/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp
+++ b/source/module_hamilt_pw/hamilt_stodft/sto_iter.cpp
@@ -642,7 +642,7 @@ void Stochastic_Iter<T, Device>::cal_storho(const UnitCell& ucell,
         T* tmpout = stowf.shchi->get_pointer();
         for (int ichi = 0; ichi < nchip_ik; ++ichi)
         {
-            wfc_basis->recip_to_real<T,Device>(tmpout, porter, ik);
+            wfc_basis->recip_to_real(this->ctx, tmpout, porter, ik);
             const auto w1 = static_cast<Real>(this->pkv->wk[ik]);
             elecstate::elecstate_pw_op<Real, Device>()(this->ctx, current_spin, nrxx, w1, pes->rho, porter);
             // for (int ir = 0; ir < nrxx; ++ir)
diff --git a/source/module_io/get_pchg_pw.h b/source/module_io/get_pchg_pw.h
index 7fd1eff9ad..58f82f4a9c 100644
--- a/source/module_io/get_pchg_pw.h
+++ b/source/module_io/get_pchg_pw.h
@@ -96,7 +96,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print,
                           << ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl;
 
                 psi->fix_k(ik);
-                pw_wfc->recip_to_real<std::complex<double>,Device>(&psi[0](ib, 0), wfcr.data(), ik);
+                pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
 
                 // To ensure the normalization of charge density in multi-k calculation (if if_separate_k is true)
                 double wg_sum_k = 0;
@@ -139,7 +139,7 @@ void get_pchg_pw(const std::vector<int>& bands_to_print,
                           << ik % (nks / nspin) + 1 << ", spin " << spin_index + 1 << std::endl;
 
                 psi->fix_k(ik);
-                pw_wfc->recip_to_real<std::complex<double>,Device>( &psi[0](ib, 0), wfcr.data(), ik);
+                pw_wfc->recip_to_real(ctx, &psi[0](ib, 0), wfcr.data(), ik);
 
                 double w1 = static_cast<double>(wk[ik] / ucell->omega);
 

From 283c6745c162fc98c1f9af27131cf092350f0b4f Mon Sep 17 00:00:00 2001
From: ubuntu <3158793232@qq.com>
Date: Mon, 24 Mar 2025 14:53:22 +0800
Subject: [PATCH 8/8] update func

---
 source/module_basis/module_pw/pw_basis.h | 5 +----
 1 file changed, 1 insertion(+), 4 deletions(-)

diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h
index 52a59d2fb3..8bc45518a2 100644
--- a/source/module_basis/module_pw/pw_basis.h
+++ b/source/module_basis/module_pw/pw_basis.h
@@ -352,10 +352,7 @@ class PW_Basis
     void recip_to_real(TK* in,
                        TR* out,
                        const bool add = false,
-                       const typename GetTypeReal<TK>::type factor = 1.0) const
-    {
-        this->recip2real_gpu(in, out, add, factor);
-    };
+                       const typename GetTypeReal<TK>::type factor = 1.0) const;
 
     // template <typename FPTYPE,
     //         typename Device,