diff --git a/source/module_basis/module_pw/CMakeLists.txt b/source/module_basis/module_pw/CMakeLists.txt
index e365e12b5e..50b7ec3430 100644
--- a/source/module_basis/module_pw/CMakeLists.txt
+++ b/source/module_basis/module_pw/CMakeLists.txt
@@ -30,6 +30,7 @@ list(APPEND objects
     pw_distributer.cpp
     pw_init.cpp
     pw_transform.cpp
+    pw_transform_gpu.cpp
     pw_transform_k.cpp
     module_fft/fft_bundle.cpp
     module_fft/fft_cpu.cpp
diff --git a/source/module_basis/module_pw/kernels/cuda/pw_op.cu b/source/module_basis/module_pw/kernels/cuda/pw_op.cu
index a9128db318..1e0fce284e 100644
--- a/source/module_basis/module_pw/kernels/cuda/pw_op.cu
+++ b/source/module_basis/module_pw/kernels/cuda/pw_op.cu
@@ -41,6 +41,24 @@ __global__ void set_recip_to_real_output(
     }
 }
 
+template<class FPTYPE>
+__global__ void set_recip_to_real_output(
+    const int nrxx,
+    const bool add,
+    const FPTYPE factor,
+    const thrust::complex<FPTYPE>* in,
+    FPTYPE* out)
+{
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if(idx >= nrxx) {return;}
+    if(add) {
+        out[idx] += factor * in[idx].real();
+    }
+    else {
+        out[idx] = in[idx].real();
+    }
+}
+
 template<class FPTYPE>
 __global__ void set_real_to_recip_output(
     const int npwk,
@@ -61,9 +79,28 @@ __global__ void set_real_to_recip_output(
     }
 }
 
+template<class FPTYPE>
+__global__ void set_real_to_recip_output(
+    const int npwk,
+    const int nxyz,
+    const bool add,
+    const FPTYPE factor,
+    const int* box_index,
+    const thrust::complex<FPTYPE>* in,
+    FPTYPE* out)
+{
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if(idx >= npwk) {return;}
+    if(add) {
+        out[idx] += factor / nxyz * in[box_index[idx]].real();
+    }
+    else {
+        out[idx] = in[box_index[idx]].real() / nxyz;
+    }
+}
+
 template <typename FPTYPE>
-void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                    const int npwk,
+void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                     const int* box_index,
                                                                     const std::complex<FPTYPE>* in,
                                                                     std::complex<FPTYPE>* out)
@@ -79,8 +116,7 @@ void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_d
 }
 
 template <typename FPTYPE>
-void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int nrxx,
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
                                                                               const bool add,
                                                                               const FPTYPE factor,
                                                                               const std::complex<FPTYPE>* in,
@@ -98,8 +134,25 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
 }
 
 template <typename FPTYPE>
-void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int npwk,
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
+                                                                              const bool add,
+                                                                              const FPTYPE factor,
+                                                                              const std::complex<FPTYPE>* in,
+                                                                              FPTYPE* out)
+{
+    const int block = (nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+    set_recip_to_real_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
+        nrxx,
+        add,
+        factor,
+        reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
+        reinterpret_cast<FPTYPE*>(out));
+
+    cudaCheckOnDebug();
+}
+
+template <typename FPTYPE>
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                               const int nxyz,
                                                                               const bool add,
                                                                               const FPTYPE factor,
@@ -120,6 +173,28 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
     cudaCheckOnDebug();
 }
 
+template <typename FPTYPE>
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
+                                                                              const int nxyz,
+                                                                              const bool add,
+                                                                              const FPTYPE factor,
+                                                                              const int* box_index,
+                                                                              const std::complex<FPTYPE>* in,
+                                                                              FPTYPE* out)
+{
+    const int block = (npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+    set_real_to_recip_output<FPTYPE><<<block, THREADS_PER_BLOCK>>>(
+        npwk,
+        nxyz,
+        add,
+        factor,
+        box_index,
+        reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
+        reinterpret_cast<FPTYPE*>(out));
+
+    cudaCheckOnDebug();
+}
+
 template struct set_3d_fft_box_op<float, base_device::DEVICE_GPU>;
 template struct set_recip_to_real_output_op<float, base_device::DEVICE_GPU>;
 template struct set_real_to_recip_output_op<float, base_device::DEVICE_GPU>;
diff --git a/source/module_basis/module_pw/kernels/pw_op.cpp b/source/module_basis/module_pw/kernels/pw_op.cpp
index b5fb453354..a1bafae69b 100644
--- a/source/module_basis/module_pw/kernels/pw_op.cpp
+++ b/source/module_basis/module_pw/kernels/pw_op.cpp
@@ -5,8 +5,7 @@ namespace ModulePW {
 template <typename FPTYPE>
 struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU>
 {
-    void operator()(const base_device::DEVICE_CPU* /*dev*/,
-                    const int npwk,
+    void operator()(const int npwk,
                     const int* box_index,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out)
@@ -21,8 +20,7 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_CPU>
 template <typename FPTYPE>
 struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
 {
-    void operator()(const base_device::DEVICE_CPU* /*dev*/,
-                    const int nrxx,
+    void operator()(const int nrxx,
                     const bool add,
                     const FPTYPE factor,
                     const std::complex<FPTYPE>* in,
@@ -39,13 +37,34 @@ struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_CPU>
             }
         }
     }
+
+    void operator()(const int nrxx,
+                    const bool add,
+                    const FPTYPE factor,
+                    const std::complex<FPTYPE>* in,
+                    FPTYPE* out)
+    {
+        if (add)
+        {
+            for (int ir = 0; ir < nrxx; ++ir)
+            {
+                out[ir] += factor * in[ir].real();
+            }
+        }
+        else
+        {
+            for (int ir = 0; ir < nrxx; ++ir)
+            {
+                out[ir] = in[ir].real();
+            }
+        }
+    }
 };
 
 template <typename FPTYPE>
 struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>
 {
-    void operator()(const base_device::DEVICE_CPU* /*dev*/,
-                    const int npw_k,
+    void operator()(const int npw_k,
                     const int nxyz,
                     const bool add,
                     const FPTYPE factor,
diff --git a/source/module_basis/module_pw/kernels/pw_op.h b/source/module_basis/module_pw/kernels/pw_op.h
index 8415ad9677..6221c8b68c 100644
--- a/source/module_basis/module_pw/kernels/pw_op.h
+++ b/source/module_basis/module_pw/kernels/pw_op.h
@@ -19,7 +19,6 @@ struct set_3d_fft_box_op {
     /// Output Parameters
     /// @param out - output psi within the 3D box(in recip space)
     void operator() (
-        const Device* dev,
         const int npwk,
         const int* box_index,
         const std::complex<FPTYPE>* in,
@@ -39,12 +38,18 @@ struct set_recip_to_real_output_op {
     /// Output Parameters
     /// @param out - output psi within the 3D box(in real space)
     void operator() (
-        const Device* dev,
         const int nrxx,
         const bool add,
         const FPTYPE factor,
         const std::complex<FPTYPE>* in,
         std::complex<FPTYPE>* out);
+
+    void operator() (
+        const int nrxx,
+        const bool add,
+        const FPTYPE factor,
+        const std::complex<FPTYPE>* in,
+        FPTYPE* out);
 };
 
 template <typename FPTYPE, typename Device>
@@ -62,7 +67,6 @@ struct set_real_to_recip_output_op {
     /// Output Parameters
     /// @param out - output psi within the 3D box(in recip space)
     void operator() (
-        const Device* dev,
         const int npw_k,
         const int nxyz,
         const bool add,
@@ -70,6 +74,15 @@ struct set_real_to_recip_output_op {
         const int* box_index,
         const std::complex<FPTYPE>* in,
         std::complex<FPTYPE>* out);
+
+    void operator() (
+        const int npw_k,
+        const int nxyz,
+        const bool add,
+        const FPTYPE factor,
+        const int* box_index,
+        const std::complex<FPTYPE>* in,
+        FPTYPE* out);
 };
 
 #if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
@@ -77,8 +90,7 @@ struct set_real_to_recip_output_op {
 template <typename FPTYPE>
 struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>
 {
-    void operator()(const base_device::DEVICE_GPU* dev,
-                    const int npwk,
+    void operator()(const int npwk,
                     const int* box_index,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out);
@@ -87,25 +99,36 @@ struct set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>
 template <typename FPTYPE>
 struct set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>
 {
-    void operator()(const base_device::DEVICE_GPU* dev,
-                    const int nrxx,
+    void operator()(const int nrxx,
                     const bool add,
                     const FPTYPE factor,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out);
+
+    void operator()(const int nrxx,
+                    const bool add,
+                    const FPTYPE factor,
+                    const std::complex<FPTYPE>* in,
+                    FPTYPE* out);
 };
 
 template <typename FPTYPE>
 struct set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>
 {
-    void operator()(const base_device::DEVICE_GPU* dev,
-                    const int npw_k,
+    void operator()(const int npw_k,
                     const int nxyz,
                     const bool add,
                     const FPTYPE factor,
                     const int* box_index,
                     const std::complex<FPTYPE>* in,
                     std::complex<FPTYPE>* out);
+    void operator()(const int npw_k,
+                    const int nxyz,
+                    const bool add,
+                    const FPTYPE factor,
+                    const int* box_index,
+                    const std::complex<FPTYPE>* in,
+                    FPTYPE* out);
 };
 
 #endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
diff --git a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu
index a3f5fe2c2b..51b50c6b8b 100644
--- a/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu
+++ b/source/module_basis/module_pw/kernels/rocm/pw_op.hip.cu
@@ -42,6 +42,24 @@ __global__ void set_recip_to_real_output(
     }
 }
 
+template<class FPTYPE>
+__global__ void set_recip_to_real_output(
+    const int nrxx,
+    const bool add,
+    const FPTYPE factor,
+    const thrust::complex<FPTYPE>* in,
+    FPTYPE* out)
+{
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if(idx >= nrxx) {return;}
+    if(add) {
+        out[idx] += factor * in[idx].real();
+    }
+    else {
+        out[idx] = in[idx].real();
+    }
+}
+
 template<class FPTYPE>
 __global__ void set_real_to_recip_output(
     const int npwk,
@@ -62,9 +80,28 @@ __global__ void set_real_to_recip_output(
     }
 }
 
+template<class FPTYPE>
+__global__ void set_real_to_recip_output(
+    const int npwk,
+    const int nxyz,
+    const bool add,
+    const FPTYPE factor,
+    const int* box_index,
+    const thrust::complex<FPTYPE>* in,
+    FPTYPE* out)
+{
+    int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    if(idx >= npwk) {return;}
+    if(add) {
+        out[idx] += factor / nxyz * in[box_index[idx]].real();
+    }
+    else {
+        out[idx] = in[box_index[idx]].real() / nxyz;
+    }
+}
+
 template <typename FPTYPE>
-void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                    const int npwk,
+void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                     const int* box_index,
                                                                     const std::complex<FPTYPE>* in,
                                                                     std::complex<FPTYPE>* out)
@@ -80,8 +117,7 @@ void set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_d
 }
 
 template <typename FPTYPE>
-void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int nrxx,
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
                                                                               const bool add,
                                                                               const FPTYPE factor,
                                                                               const std::complex<FPTYPE>* in,
@@ -99,8 +135,25 @@ void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
 }
 
 template <typename FPTYPE>
-void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const base_device::DEVICE_GPU* /*dev*/,
-                                                                              const int npwk,
+void set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int nrxx,
+                                                                              const bool add,
+                                                                              const FPTYPE factor,
+                                                                              const std::complex<FPTYPE>* in,
+                                                                              FPTYPE* out)
+{
+    const int block = (nrxx + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+    hipLaunchKernelGGL(HIP_KERNEL_NAME(set_recip_to_real_output<FPTYPE>), dim3(block), dim3(THREADS_PER_BLOCK), 0, 0,
+        nrxx,
+        add,
+        factor,
+        reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
+        reinterpret_cast<FPTYPE*>(out));
+
+    hipCheckOnDebug();
+}
+
+template <typename FPTYPE>
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
                                                                               const int nxyz,
                                                                               const bool add,
                                                                               const FPTYPE factor,
@@ -121,6 +174,28 @@ void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(co
     hipCheckOnDebug();
 }
 
+template <typename FPTYPE>
+void set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>::operator()(const int npwk,
+                                                                              const int nxyz,
+                                                                              const bool add,
+                                                                              const FPTYPE factor,
+                                                                              const int* box_index,
+                                                                              const std::complex<FPTYPE>* in,
+                                                                              FPTYPE* out)
+{
+    const int block = (npwk + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK;
+    hipLaunchKernelGGL(HIP_KERNEL_NAME(set_real_to_recip_output<FPTYPE>), dim3(block), dim3(THREADS_PER_BLOCK), 0, 0,
+        npwk,
+        nxyz,
+        add,
+        factor,
+        box_index,
+        reinterpret_cast<const thrust::complex<FPTYPE>*>(in),
+        reinterpret_cast<FPTYPE*>(out));
+
+    hipCheckOnDebug();
+}
+
 template struct set_3d_fft_box_op<float, base_device::DEVICE_GPU>;
 template struct set_recip_to_real_output_op<float, base_device::DEVICE_GPU>;
 template struct set_real_to_recip_output_op<float, base_device::DEVICE_GPU>;
diff --git a/source/module_basis/module_pw/kernels/test/pw_op_test.cpp b/source/module_basis/module_pw/kernels/test/pw_op_test.cpp
index 6adac4613f..e9bc07bd9e 100644
--- a/source/module_basis/module_pw/kernels/test/pw_op_test.cpp
+++ b/source/module_basis/module_pw/kernels/test/pw_op_test.cpp
@@ -72,7 +72,7 @@ class TestModulePWPWMultiDevice : public ::testing::Test
 TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_cpu)
 {
     std::vector<std::complex<double>> res(out_1.size(), std::complex<double>{0, 0});
-    set_3d_fft_box_cpu_op()(cpu_ctx, this->npwk, box_index.data(), in_1.data(), res.data());
+    set_3d_fft_box_cpu_op()(this->npwk, box_index.data(), in_1.data(), res.data());
     for (int ii = 0; ii < this->nxyz; ii++) {
         EXPECT_LT(std::abs(res[ii] - out_1[ii]), 1e-12);
     }
@@ -81,7 +81,7 @@ TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_cpu)
 TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_cpu)
 {
     std::vector<std::complex<double>> res(out_2.size(), std::complex<double>{0, 0});
-    set_recip_to_real_output_cpu_op()(cpu_ctx, this->nxyz, this->add, this->factor, in_2.data(), res.data());
+    set_recip_to_real_output_cpu_op()(this->nxyz, this->add, this->factor, in_2.data(), res.data());
     for (int ii = 0; ii < this->nxyz; ii++) {
         EXPECT_LT(std::abs(res[ii] - out_2[ii]), 1e-12);
     }
@@ -90,7 +90,7 @@ TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_cpu)
 TEST_F(TestModulePWPWMultiDevice, set_real_to_recip_output_op_cpu)
 {
     std::vector<std::complex<double>> res = out_3_init;
-    set_real_to_recip_output_cpu_op()(cpu_ctx, this->npwk, this->nxyz, true, this->factor, box_index.data(), in_3.data(), res.data());
+    set_real_to_recip_output_cpu_op()(this->npwk, this->nxyz, true, this->factor, box_index.data(), in_3.data(), res.data());
     for (int ii = 0; ii < out_3.size(); ii++) {
         EXPECT_LT(std::abs(res[ii] - out_3[ii]), 5e-6);
     }
@@ -109,7 +109,7 @@ TEST_F(TestModulePWPWMultiDevice, set_3d_fft_box_op_gpu)
     synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size());
     synchronize_memory_complex_h2d_op()(d_in_1, in_1.data(), in_1.size());
 
-    set_3d_fft_box_gpu_op()(gpu_ctx, this->npwk, d_box_index, d_in_1, d_res);
+    set_3d_fft_box_gpu_op()(this->npwk, d_box_index, d_in_1, d_res);
 
     synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size());
 
@@ -130,7 +130,7 @@ TEST_F(TestModulePWPWMultiDevice, set_recip_to_real_output_op_gpu)
     synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size());
     synchronize_memory_complex_h2d_op()(d_in_2, in_2.data(), in_2.size());
 
-    set_recip_to_real_output_gpu_op()(gpu_ctx, this->nxyz, this->add, this->factor, d_in_2, d_res);
+    set_recip_to_real_output_gpu_op()(this->nxyz, this->add, this->factor, d_in_2, d_res);
 
     synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size());
 
@@ -153,7 +153,7 @@ TEST_F(TestModulePWPWMultiDevice, set_real_to_recip_output_op_gpu)
     synchronize_memory_complex_h2d_op()(d_res, res.data(), res.size());
     synchronize_memory_complex_h2d_op()(d_in_3, in_3.data(), in_3.size());
 
-    set_real_to_recip_output_gpu_op()(gpu_ctx, this->npwk, this->nxyz, true, this->factor, d_box_index, d_in_3, d_res);
+    set_real_to_recip_output_gpu_op()(this->npwk, this->nxyz, true, this->factor, d_box_index, d_in_3, d_res);
 
     synchronize_memory_complex_d2h_op()(res.data(), d_res, res.size());
 
diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.cpp b/source/module_basis/module_pw/module_fft/fft_bundle.cpp
index 98889c69db..150475dd2e 100644
--- a/source/module_basis/module_pw/module_fft/fft_bundle.cpp
+++ b/source/module_basis/module_pw/module_fft/fft_bundle.cpp
@@ -220,30 +220,26 @@ void FFT_Bundle::fftxyc2r(std::complex<double>* in, double* out) const
 }
 
 template <>
-void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx,
-                               std::complex<float>* in,
+void FFT_Bundle::fft3D_forward(std::complex<float>* in,
                                std::complex<float>* out) const
 {
     fft_float->fft3D_forward(in, out);
 }
 template <>
-void FFT_Bundle::fft3D_forward(const base_device::DEVICE_GPU* ctx,
-                               std::complex<double>* in,
+void FFT_Bundle::fft3D_forward(std::complex<double>* in,
                                std::complex<double>* out) const
 {
     fft_double->fft3D_forward(in, out);
 }
 
 template <>
-void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx,
-                                std::complex<float>* in,
+void FFT_Bundle::fft3D_backward(std::complex<float>* in,
                                 std::complex<float>* out) const
 {
     fft_float->fft3D_backward(in, out);
 }
 template <>
-void FFT_Bundle::fft3D_backward(const base_device::DEVICE_GPU* ctx,
-                                std::complex<double>* in,
+void FFT_Bundle::fft3D_backward(std::complex<double>* in,
                                 std::complex<double>* out) const
 {
     fft_double->fft3D_backward(in, out);
diff --git a/source/module_basis/module_pw/module_fft/fft_bundle.h b/source/module_basis/module_pw/module_fft/fft_bundle.h
index 1cef050884..1c5d388292 100644
--- a/source/module_basis/module_pw/module_fft/fft_bundle.h
+++ b/source/module_basis/module_pw/module_fft/fft_bundle.h
@@ -188,10 +188,10 @@ class FFT_Bundle
     template <typename FPTYPE>
     void fftxyc2r(std::complex<FPTYPE>* in, FPTYPE* out) const;
 
-    template <typename FPTYPE, typename Device>
-    void fft3D_forward(const Device* ctx, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
-    template <typename FPTYPE, typename Device>
-    void fft3D_backward(const Device* ctx, std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
+    template <typename FPTYPE>
+    void fft3D_forward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
+    template <typename FPTYPE>
+    void fft3D_backward(std::complex<FPTYPE>* in, std::complex<FPTYPE>* out) const;
 
   private:
     int fft_mode = 0;
diff --git a/source/module_basis/module_pw/pw_basis.h b/source/module_basis/module_pw/pw_basis.h
index d988305d31..8bc45518a2 100644
--- a/source/module_basis/module_pw/pw_basis.h
+++ b/source/module_basis/module_pw/pw_basis.h
@@ -1,6 +1,7 @@
 #ifndef PWBASIS_H
 #define PWBASIS_H
 
+#include "module_base/macros.h"
 #include "module_base/module_device/memory_op.h"
 #include "module_base/matrix.h"
 #include "module_base/matrix3.h"
@@ -245,7 +246,7 @@ class PW_Basis
     // FFT ft;
     FFT_Bundle fft_bundle;
     //The position of pointer in and out can be equal(in-place transform) or different(out-of-place transform).
-
+    
     template <typename FPTYPE>
     void real2recip(const FPTYPE* in,
                     std::complex<FPTYPE>* out,
@@ -266,6 +267,144 @@ class PW_Basis
                     std::complex<FPTYPE>* out,
                     const bool add = false,
                     const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
+    
+    template <typename FPTYPE>
+    void real2recip_gpu(const FPTYPE* in,
+                    std::complex<FPTYPE>* out,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
+    template <typename FPTYPE>
+    void real2recip_gpu(const std::complex<FPTYPE>* in,
+                    std::complex<FPTYPE>* out,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
+    template <typename FPTYPE>
+    void recip2real_gpu(const std::complex<FPTYPE>* in,
+                    FPTYPE* out,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
+    template <typename FPTYPE>
+    void recip2real_gpu(const std::complex<FPTYPE>* in,
+                    std::complex<FPTYPE>* out,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
+
+    /**
+     * @brief Converts data from reciprocal space to real space on Cpu
+     *
+     * This function handles the conversion of data from reciprocal space (Fourier space) to real space.
+     * It supports complex types as input.
+     * The output can be either the same fundamental type or the underlying real type of a complex type.
+     *
+     * @tparam FPTYPE The type of the input data, which can only be a compelx type (e.g., std::complex<float>, std::complex<double>)
+     * @tparam Device The device type, must be base_device::DEVICE_CPU.
+     * @tparam std::enable_if<!std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type
+     *         SFINAE constraint to ensure that FPTYPE is a complex type.
+     * @tparam std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type
+     *         SFINAE constraint to ensure that Device is base_device::DEVICE_CPU.
+     *
+     * @param in Pointer to the input data array in reciprocal space.
+     * @param out Pointer to the output data array in real space. If FPTYPE is a complex type,
+     *            this should point to an array of the underlying real type.
+     * @param add Boolean flag indicating whether to add the result to the existing values in the output array.
+     * @param factor A scaling factor applied to the output values.
+     */
+    template <typename TK,
+              typename TR,
+              typename Device,
+              typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
+                                          && (std::is_same<TR, typename GetTypeReal<TK>::type>::value
+                                              || std::is_same<TR, TK>::value)
+                                          && std::is_same<Device, base_device::DEVICE_CPU>::value,
+                                      int>::type
+              = 0>
+    void recip_to_real(TK* in, TR* out, const bool add = false, const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+        this->recip2real(in, out, add, factor);
+    };
+
+    /**
+     * @brief Converts data from reciprocal space (Fourier space) to real space.
+     *
+     * This function handles the conversion of data from reciprocal space (typically after a Fourier transform)
+     * to real space, supporting scenarios where the input is of a complex type and the output is of the underlying
+     * fundamental type.
+     *
+     * @tparam FPTYPE The underlying fundamental type of the input data (e.g., float, double).
+     *         Note that the actual type passed should be std::complex<FPTYPE>.
+     * @tparam Device The device type, which can be any supported device type (e.g., base_device::DEVICE_CPU).
+     *
+     * @param in Pointer to the input data array in reciprocal space, of type std::complex<FPTYPE>*.
+     * @param out Pointer to the output data array in real space, of type FPTYPE*.
+     * @param add Optional boolean flag, default value false, indicating whether to add the result to the existing
+     * values in the output array.
+     * @param factor Optional scaling factor, default value 1.0, applied to the output values.
+     */
+    template <typename TK,
+              typename TR,
+              typename Device,
+              typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
+                                          && (std::is_same<TR, typename GetTypeReal<TK>::type>::value
+                                              || std::is_same<TR, TK>::value)
+                                          && std::is_same<Device, base_device::DEVICE_GPU>::value,
+                                      int>::type
+              = 0>
+    void recip_to_real(TK* in,
+                       TR* out,
+                       const bool add = false,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const;
+
+    // template <typename FPTYPE,
+    //         typename Device,
+    //         typename std::enable_if<!std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type = 0>
+    // void recip_to_real(FPTYPE* in,
+    //                    FPTYPE* out,
+    //                    const bool add = false,
+    //                    const typename GetTypeReal<FPTYPE>::type factor = 1.0) const;
+    /**
+     * @brief Converts data from real space to reciprocal space (Fourier space).
+     *
+     * This function handles the conversion of data from real space to reciprocal space (typically after performing a
+     * Fourier transform), supporting scenarios where the input is of a fundamental type (e.g., float, double) and the
+     * output is of a complex type.
+     *
+     * @tparam FPTYPE The underlying fundamental type of the input data (e.g., float, double).
+     *         SFINAE constraint ensures that FPTYPE is a fundamental type, not a complex type.
+     * @tparam Device The device type, which must be base_device::DEVICE_CPU.
+     * @tparam std::enable_if<std::is_same<FPTYPE, typename GetTypeReal<FPTYPE>::type>::value, int>::type
+     *         SFINAE constraint to ensure that FPTYPE is a fundamental type.
+     * @tparam std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type
+     *         SFINAE constraint to ensure that Device is base_device::DEVICE_CPU.
+     *
+     * @param in Pointer to the input data array in real space, of type FPTYPE*.
+     * @param out Pointer to the output data array in reciprocal space, of type std::complex<FPTYPE>*.
+     * @param ik Optional parameter, default value 0, representing some index or identifier.
+     * @param add Optional boolean flag, default value false, indicating whether to add the result to the existing
+     * values in the output array.
+     * @param factor Optional scaling factor, default value 1.0, applied to the output values.
+     */
+    template <typename TK,
+            typename TR,
+            typename Device,
+            typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
+                    && (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
+                    && std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0>
+    void real_to_recip(TR* in,
+                       TK* out,
+                       const bool add = false,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+        this->real2recip(in, out, add, factor);
+    }
+
+    template <typename TK,typename TR, typename Device,
+            typename std::enable_if<!std::is_same<TK, typename GetTypeReal<TK>::type>::value
+                    && (std::is_same<TR, typename GetTypeReal<TK>::type>::value || std::is_same<TR, TK>::value)
+                    && !std::is_same<Device, base_device::DEVICE_CPU>::value ,int>::type = 0>
+    void real_to_recip(TR* in,
+                       TK* out,
+                       const bool add = false,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const;
 
   protected:
     //gather planes and scatter sticks of all processors
@@ -282,13 +421,14 @@ class PW_Basis
 
     using resmem_int_op = base_device::memory::resize_memory_op<int, base_device::DEVICE_GPU>;
     using delmem_int_op = base_device::memory::delete_memory_op<int, base_device::DEVICE_GPU>;
-    using syncmem_int_h2d_op
-        = base_device::memory::synchronize_memory_op<int, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
-
+    using syncmem_int_h2d_op = base_device::memory::synchronize_memory_op<int, base_device::DEVICE_GPU, base_device::DEVICE_CPU>;
+    // using default_device_cpu = base_device::DEVICE_CPU;
+    
     void set_device(std::string device_);
     void set_precision(std::string precision_);
 
 protected:
+
   std::string device = "cpu";       ///< cpu or gpu
   std::string precision = "double"; ///< single, double, mixing
   bool double_data_ = true;         ///<  if has double data
@@ -296,6 +436,5 @@ class PW_Basis
 };
 }
 #endif // PWBASIS_H
-
 #include "pw_basis_sup.h"
 #include "pw_basis_big.h" //temporary it will be removed
diff --git a/source/module_basis/module_pw/pw_basis_k.h b/source/module_basis/module_pw/pw_basis_k.h
index defa0cb5cc..379aba6cee 100644
--- a/source/module_basis/module_pw/pw_basis_k.h
+++ b/source/module_basis/module_pw/pw_basis_k.h
@@ -161,7 +161,7 @@ class PW_Basis_K : public PW_Basis
     
     #endif
 
-    template <typename FPTYPE, typename Device>
+     template <typename FPTYPE, typename Device>
     void real_to_recip(const Device* ctx,
                        const std::complex<FPTYPE>* in,
                        std::complex<FPTYPE>* out,
@@ -176,6 +176,76 @@ class PW_Basis_K : public PW_Basis
                        const bool add = false,
                        const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
 
+
+    template <typename TK,
+              typename Device,
+              typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
+    void real_to_recip(const TK* in,
+                       TK* out,
+                       const int ik,
+                       const bool add = false,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+      #if defined(__DSP)
+        this->recip2real_dsp(in, out, ik, add, factor);
+      #else
+        this->real2recip(in,out,ik,add,factor);
+      #endif
+    }
+    template <typename TK,
+              typename Device,
+              typename std::enable_if<std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
+    void recip_to_real(const TK* in,
+                       TK* out,
+                       const int ik,
+                       const bool add = false,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+      
+      #if defined(__DSP)
+        this->recip2real_dsp(in,out,ik,add,factor);
+      #else
+        this->recip2real(in,out,ik,add,factor);
+      #endif
+    }
+    template <typename FPTYPE>
+    void real2recip_gpu(const std::complex<FPTYPE>* in,
+                    std::complex<FPTYPE>* out,
+                    const int ik,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nplane,nx*ny)  ; out(nz, ns)
+                    
+    template <typename FPTYPE>
+    void recip2real_gpu(const std::complex<FPTYPE>* in,
+                    std::complex<FPTYPE>* out,
+                    const int ik,
+                    const bool add = false,
+                    const FPTYPE factor = 1.0) const; // in:(nz, ns)  ; out(nplane,nx*ny)
+
+    template <typename FPTYPE,
+              typename Device,
+              typename std::enable_if<!std::is_same<Device, base_device::DEVICE_CPU>::value, int>::type = 0>
+    void real_to_recip(const FPTYPE* in,
+                       FPTYPE* out,
+                       const int ik,
+                       const bool add = false,
+                       const typename GetTypeReal<FPTYPE>::type factor = 1.0) const
+    {
+        this->real2recip_gpu(in, out, ik, add, factor);
+    }
+
+    template <typename TK,
+              typename Device,
+              typename std::enable_if<std::is_same<Device, base_device::DEVICE_GPU>::value, int>::type = 0>
+    void recip_to_real(const TK* in,
+                       TK* out,
+                       const int ik,
+                       const bool add = false,
+                       const typename GetTypeReal<TK>::type factor = 1.0) const
+    {
+        this->recip2real_gpu(in, out, ik, add, factor);
+    }
+
   public:
     //operator:
     //get (G+K)^2:
diff --git a/source/module_basis/module_pw/pw_transform.cpp b/source/module_basis/module_pw/pw_transform.cpp
index 8e458b2561..4f34221775 100644
--- a/source/module_basis/module_pw/pw_transform.cpp
+++ b/source/module_basis/module_pw/pw_transform.cpp
@@ -1,12 +1,19 @@
-#include "module_fft/fft_bundle.h"
-#include <complex>
-#include "pw_basis.h"
-#include <cassert>
 #include "module_base/global_function.h"
 #include "module_base/timer.h"
+#include "module_basis/module_pw/kernels/pw_op.h"
+#include "module_fft/fft_bundle.h"
+#include "pw_basis.h"
 #include "pw_gatherscatter.h"
 
-namespace ModulePW {
+#include <cassert>
+#include <complex>
+
+namespace ModulePW
+{
+//     const base_device::DEVICE_CPU* PW_Basis::get_default_device_ctx() {
+//         static const base_device::DEVICE_CPU* default_device_cpu;
+//     return default_device_cpu;
+// }
 /**
  * @brief transform real space to reciprocal space
  * @details c(g)=\int dr*f(r)*exp(-ig*r)
@@ -24,25 +31,25 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in,
 
     assert(this->gamma_only == false);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int ir = 0 ; ir < this->nrxx ; ++ir)
+    for (int ir = 0; ir < this->nrxx; ++ir)
     {
         this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = in[ir];
     }
-    this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
+    this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
 
     this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
-    
-    this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(),fft_bundle.get_auxg_data<FPTYPE>());
 
-    if(add)
+    this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>());
+
+    if (add)
     {
         FPTYPE tmpfac = factor / FPTYPE(this->nxyz);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ig = 0 ; ig < this->npw ; ++ig)
+        for (int ig = 0; ig < this->npw; ++ig)
         {
             out[ig] += tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]];
         }
@@ -51,9 +58,9 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in,
     {
         FPTYPE tmpfac = 1.0 / FPTYPE(this->nxyz);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ig = 0 ; ig < this->npw ; ++ig)
+        for (int ig = 0; ig < this->npw; ++ig)
         {
             out[ig] = tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]];
         }
@@ -72,44 +79,44 @@ template <typename FPTYPE>
 void PW_Basis::real2recip(const FPTYPE* in, std::complex<FPTYPE>* out, const bool add, const FPTYPE factor) const
 {
     ModuleBase::timer::tick(this->classname, "real2recip");
-    if(this->gamma_only)
+    if (this->gamma_only)
     {
         const int npy = this->ny * this->nplane;
 #ifdef _OPENMP
-#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ix = 0 ; ix < this->nx ; ++ix)
+        for (int ix = 0; ix < this->nx; ++ix)
         {
-            for(int ipy = 0 ; ipy < npy ; ++ipy)
+            for (int ipy = 0; ipy < npy; ++ipy)
             {
-                this->fft_bundle.get_rspace_data<FPTYPE>()[ix*npy + ipy] = in[ix*npy + ipy];
+                this->fft_bundle.get_rspace_data<FPTYPE>()[ix * npy + ipy] = in[ix * npy + ipy];
             }
         }
 
-        this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
+        this->fft_bundle.fftxyr2c(fft_bundle.get_rspace_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
     }
     else
     {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ir = 0 ; ir < this->nrxx ; ++ir)
+        for (int ir = 0; ir < this->nrxx; ++ir)
         {
-            this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = std::complex<FPTYPE>(in[ir],0);
+            this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = std::complex<FPTYPE>(in[ir], 0);
         }
-        this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
+        this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
     }
     this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
-    
-    this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(),fft_bundle.get_auxg_data<FPTYPE>());
 
-    if(add)
+    this->fft_bundle.fftzfor(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>());
+
+    if (add)
     {
         FPTYPE tmpfac = factor / FPTYPE(this->nxyz);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ig = 0 ; ig < this->npw ; ++ig)
+        for (int ig = 0; ig < this->npw; ++ig)
         {
             out[ig] += tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]];
         }
@@ -118,9 +125,9 @@ void PW_Basis::real2recip(const FPTYPE* in, std::complex<FPTYPE>* out, const boo
     {
         FPTYPE tmpfac = 1.0 / FPTYPE(this->nxyz);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ig = 0 ; ig < this->npw ; ++ig)
+        for (int ig = 0; ig < this->npw; ++ig)
         {
             out[ig] = tmpfac * this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]];
         }
@@ -144,32 +151,32 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in,
     ModuleBase::timer::tick(this->classname, "recip2real");
     assert(this->gamma_only == false);
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int i = 0 ; i < this->nst * this->nz ; ++i)
+    for (int i = 0; i < this->nst * this->nz; ++i)
     {
         fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<FPTYPE>(0, 0);
     }
 
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int ig = 0 ; ig < this->npw ; ++ig)
+    for (int ig = 0; ig < this->npw; ++ig)
     {
         this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]] = in[ig];
     }
     this->fft_bundle.fftzbac(fft_bundle.get_auxg_data<FPTYPE>(), fft_bundle.get_auxg_data<FPTYPE>());
 
-    this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(),this->fft_bundle.get_auxr_data<FPTYPE>());
+    this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(), this->fft_bundle.get_auxr_data<FPTYPE>());
+
+    this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
 
-    this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
-    
-    if(add)
+    if (add)
     {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ir = 0 ; ir < this->nrxx ; ++ir)
+        for (int ir = 0; ir < this->nrxx; ++ir)
         {
             out[ir] += factor * this->fft_bundle.get_auxr_data<FPTYPE>()[ir];
         }
@@ -177,9 +184,9 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in,
     else
     {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-        for(int ir = 0 ; ir < this->nrxx ; ++ir)
+        for (int ir = 0; ir < this->nrxx; ++ir)
         {
             out[ir] = this->fft_bundle.get_auxr_data<FPTYPE>()[ir];
         }
@@ -199,17 +206,17 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
 {
     ModuleBase::timer::tick(this->classname, "recip2real");
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int i = 0 ; i < this->nst * this->nz ; ++i)
+    for (int i = 0; i < this->nst * this->nz; ++i)
     {
         fft_bundle.get_auxg_data<FPTYPE>()[i] = std::complex<double>(0, 0);
     }
 
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-    for(int ig = 0 ; ig < this->npw ; ++ig)
+    for (int ig = 0; ig < this->npw; ++ig)
     {
         this->fft_bundle.get_auxg_data<FPTYPE>()[this->ig2isz[ig]] = in[ig];
     }
@@ -217,49 +224,49 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
 
     this->gathers_scatterp(this->fft_bundle.get_auxg_data<FPTYPE>(), this->fft_bundle.get_auxr_data<FPTYPE>());
 
-    if(this->gamma_only)
+    if (this->gamma_only)
     {
-        this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_rspace_data<FPTYPE>());
+        this->fft_bundle.fftxyc2r(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_rspace_data<FPTYPE>());
 
         // r2c in place
         const int npy = this->ny * this->nplane;
 
-        if(add)
+        if (add)
         {
 #ifdef _OPENMP
-#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-            for(int ix = 0 ; ix < this->nx ; ++ix)
+            for (int ix = 0; ix < this->nx; ++ix)
             {
-                for(int ipy = 0 ; ipy < npy ; ++ipy)
+                for (int ipy = 0; ipy < npy; ++ipy)
                 {
-                    out[ix*npy + ipy] += factor * this->fft_bundle.get_rspace_data<FPTYPE>()[ix*npy + ipy];
+                    out[ix * npy + ipy] += factor * this->fft_bundle.get_rspace_data<FPTYPE>()[ix * npy + ipy];
                 }
             }
         }
         else
         {
 #ifdef _OPENMP
-#pragma omp parallel for collapse(2) schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for collapse(2) schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-            for(int ix = 0 ; ix < this->nx ; ++ix)
+            for (int ix = 0; ix < this->nx; ++ix)
             {
-                for(int ipy = 0 ; ipy < npy ; ++ipy)
+                for (int ipy = 0; ipy < npy; ++ipy)
                 {
-                    out[ix*npy + ipy] = this->fft_bundle.get_rspace_data<FPTYPE>()[ix*npy + ipy];
+                    out[ix * npy + ipy] = this->fft_bundle.get_rspace_data<FPTYPE>()[ix * npy + ipy];
                 }
             }
         }
     }
     else
     {
-        this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(),fft_bundle.get_auxr_data<FPTYPE>());
-        if(add)
+        this->fft_bundle.fftxybac(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());
+        if (add)
         {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-            for(int ir = 0 ; ir < this->nrxx ; ++ir)
+            for (int ir = 0; ir < this->nrxx; ++ir)
             {
                 out[ir] += factor * this->fft_bundle.get_auxr_data<FPTYPE>()[ir].real();
             }
@@ -267,9 +274,9 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
         else
         {
 #ifdef _OPENMP
-#pragma omp parallel for schedule(static, 4096/sizeof(FPTYPE))
+#pragma omp parallel for schedule(static, 4096 / sizeof(FPTYPE))
 #endif
-            for(int ir = 0 ; ir < this->nrxx ; ++ir)
+            for (int ir = 0; ir < this->nrxx; ++ir)
             {
                 out[ir] = this->fft_bundle.get_auxr_data<FPTYPE>()[ir].real();
             }
@@ -277,7 +284,6 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in, FPTYPE* out, const boo
     }
     ModuleBase::timer::tick(this->classname, "recip2real");
 }
-
 template void PW_Basis::real2recip<float>(const float* in,
                                           std::complex<float>* out,
                                           const bool add,
@@ -311,4 +317,4 @@ template void PW_Basis::recip2real<double>(const std::complex<double>* in,
                                            std::complex<double>* out,
                                            const bool add,
                                            const double factor) const;
-}
\ No newline at end of file
+} // namespace ModulePW
\ No newline at end of file
diff --git a/source/module_basis/module_pw/pw_transform_gpu.cpp b/source/module_basis/module_pw/pw_transform_gpu.cpp
new file mode 100644
index 0000000000..1898ff23eb
--- /dev/null
+++ b/source/module_basis/module_pw/pw_transform_gpu.cpp
@@ -0,0 +1,163 @@
+#include "pw_basis.h"
+#include "module_base/timer.h"
+#include "module_basis/module_pw/kernels/pw_op.h"
+namespace ModulePW
+{
+#if (defined(__CUDA) || defined(__ROCM))
+template <typename FPTYPE>
+void PW_Basis::real2recip_gpu(const FPTYPE* in,
+                             std::complex<FPTYPE>* out,
+                             const bool add,
+                             const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    base_device::DEVICE_GPU* ctx;
+    // base_device::memory::synchronize_memory_op<std::complex<FPTYPE>,
+    //                                            base_device::DEVICE_GPU,
+    //                                            base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+    //                                                                       in,
+    //                                                                       this->nrxx);
+
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
+                                                                  this->nxyz,
+                                                                  add,
+                                                                  factor,
+                                                                  this->ig2isz,
+                                                                  this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                  out);
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+}
+template <typename FPTYPE>
+void PW_Basis::real2recip_gpu(const std::complex<FPTYPE>* in,
+                             std::complex<FPTYPE>* out,
+                             const bool add,
+                             const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    base_device::DEVICE_GPU* ctx;
+    base_device::memory::synchronize_memory_op<std::complex<FPTYPE>,
+                                               base_device::DEVICE_GPU,
+                                               base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                          in,
+                                                                          this->nrxx);
+
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
+                                                                   this->nxyz,
+                                                                   add,
+                                                                   factor,
+                                                                   this->ig2isz,
+                                                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                   out);
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+}
+
+template <typename FPTYPE>
+void PW_Basis::recip2real_gpu(const std::complex<FPTYPE>* in,
+                             FPTYPE* out,
+                             const bool add,
+                             const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    base_device::DEVICE_GPU* ctx;
+    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<FPTYPE>(), this->nxyz);
+    base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
+        this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+        0,
+        this->nxyz);
+
+    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
+                                                        this->ig2isz,
+                                                        in,
+                                                        this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
+                                                                  add,
+                                                                  factor,
+                                                                  this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                  out);
+
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+}
+template <typename FPTYPE>
+    void PW_Basis::recip2real_gpu(const std::complex<FPTYPE> *in,
+                                 std::complex<FPTYPE> *out,
+                                 const bool add,
+                                 const FPTYPE factor) const
+{
+    base_device::DEVICE_GPU* ctx;
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<double>(), this->nxyz);
+    base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
+        this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+        0,
+        this->nxyz);
+
+    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw,
+                                                         this->ig2isz,
+                                                         in,
+                                                         this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                    this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
+                                                                   add,
+                                                                   factor,
+                                                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                   out);
+
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+}
+template void PW_Basis::real2recip_gpu<double>(const double* in,
+                                              std::complex<double>* out,
+                                              const bool add,
+                                              const double factor) const;
+template void PW_Basis::real2recip_gpu<float>(const float* in,
+                                              std::complex<float>* out,
+                                              const bool add,
+                                              const float factor) const;
+
+template void PW_Basis::real2recip_gpu<double>(const std::complex<double>* in,
+                                              std::complex<double>* out,
+                                              const bool add,
+                                              const double factor) const;
+template void PW_Basis::real2recip_gpu<float>(const std::complex<float>* in,
+                                              std::complex<float>* out,
+                                              const bool add,
+                                              const float factor) const;
+
+template void PW_Basis::recip2real_gpu<double>(const std::complex<double>* in,
+                                              double* out,
+                                              const bool add,
+                                              const double factor) const;
+template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in,
+                                              float* out,
+                                              const bool add,
+                                              const float factor) const;
+
+template void PW_Basis::recip2real_gpu<double>(const std::complex<double>* in,
+                                              std::complex<double>* out,
+                                              const bool add,
+                                              const double factor) const;
+template void PW_Basis::recip2real_gpu<float>(const std::complex<float>* in,
+                                              std::complex<float>* out,
+                                              const bool add,
+                                              const float factor) const;
+
+#endif
+} // namespace ModulePW
\ No newline at end of file
diff --git a/source/module_basis/module_pw/pw_transform_k.cpp b/source/module_basis/module_pw/pw_transform_k.cpp
index 3d75f07f6f..a709b60429 100644
--- a/source/module_basis/module_pw/pw_transform_k.cpp
+++ b/source/module_basis/module_pw/pw_transform_k.cpp
@@ -357,12 +357,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
         in,
         this->nrxx);
 
-    this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
 
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
-    set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(ctx,
-                                                                  npw_k,
+    set_real_to_recip_output_op<float, base_device::DEVICE_GPU>()(npw_k,
                                                                   this->nxyz,
                                                                   add,
                                                                   factor,
@@ -389,12 +388,11 @@ void PW_Basis_K::real_to_recip(const base_device::DEVICE_GPU* ctx,
                                                                           in,
                                                                           this->nrxx);
 
-    this->fft_bundle.fft3D_forward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
 
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
-    set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(ctx,
-                                                                   npw_k,
+    set_real_to_recip_output_op<double, base_device::DEVICE_GPU>()(npw_k,
                                                                    this->nxyz,
                                                                    add,
                                                                    factor,
@@ -424,15 +422,13 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
 
-    set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(ctx,
-                                                        npw_k,
+    set_3d_fft_box_op<float, base_device::DEVICE_GPU>()(npw_k,
                                                         this->ig2ixyz_k + startig,
                                                         in,
                                                         this->fft_bundle.get_auxr_3d_data<float>());
-    this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<float>(), this->fft_bundle.get_auxr_3d_data<float>());
 
-    set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(ctx,
-                                                                  this->nrxx,
+    set_recip_to_real_output_op<float, base_device::DEVICE_GPU>()(this->nrxx,
                                                                   add,
                                                                   factor,
                                                                   this->fft_bundle.get_auxr_3d_data<float>(),
@@ -460,15 +456,13 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
 
-    set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(ctx,
-                                                         npw_k,
+    set_3d_fft_box_op<double, base_device::DEVICE_GPU>()(npw_k,
                                                          this->ig2ixyz_k + startig,
                                                          in,
                                                          this->fft_bundle.get_auxr_3d_data<double>());
-    this->fft_bundle.fft3D_backward(ctx, this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<double>(), this->fft_bundle.get_auxr_3d_data<double>());
 
-    set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(ctx,
-                                                                   this->nrxx,
+    set_recip_to_real_output_op<double, base_device::DEVICE_GPU>()(this->nrxx,
                                                                    add,
                                                                    factor,
                                                                    this->fft_bundle.get_auxr_3d_data<double>(),
@@ -476,6 +470,95 @@ void PW_Basis_K::recip_to_real(const base_device::DEVICE_GPU* ctx,
 
     ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
 }
+
+template <typename FPTYPE>
+void PW_Basis_K::real2recip_gpu(const std::complex<FPTYPE>* in,
+                               std::complex<FPTYPE>* out,
+                               const int ik,
+                               const bool add,
+                               const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+
+    base_device::memory::synchronize_memory_op<std::complex<FPTYPE>,
+                                               base_device::DEVICE_GPU,
+                                               base_device::DEVICE_GPU>()(this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                          in,
+                                                                          this->nrxx);
+
+    this->fft_bundle.fft3D_forward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    const int startig = ik * this->npwk_max;
+    const int npw_k = this->npwk[ik];
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
+                                                                   this->nxyz,
+                                                                   add,
+                                                                   factor,
+                                                                   this->ig2ixyz_k + startig,
+                                                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                   out);
+    ModuleBase::timer::tick(this->classname, "real_to_recip gpu");
+}
+template <typename FPTYPE>
+void PW_Basis_K::recip2real_gpu(const std::complex<FPTYPE>* in,
+                               std::complex<FPTYPE>* out,
+                               const int ik,
+                               const bool add,
+                               const FPTYPE factor) const
+{
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+    assert(this->gamma_only == false);
+    assert(this->poolnproc == 1);
+    // ModuleBase::GlobalFunc::ZEROS(fft_bundle.get_auxr_3d_data<FPTYPE>(), this->nxyz);
+    base_device::memory::set_memory_op<std::complex<FPTYPE>, base_device::DEVICE_GPU>()(
+        this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+        0,
+        this->nxyz);
+
+    const int startig = ik * this->npwk_max;
+    const int npw_k = this->npwk[ik];
+
+    set_3d_fft_box_op<FPTYPE, base_device::DEVICE_GPU>()(npw_k,
+                                                         this->ig2ixyz_k + startig,
+                                                         in,
+                                                         this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+    this->fft_bundle.fft3D_backward(this->fft_bundle.get_auxr_3d_data<FPTYPE>(), this->fft_bundle.get_auxr_3d_data<FPTYPE>());
+
+    set_recip_to_real_output_op<FPTYPE, base_device::DEVICE_GPU>()(this->nrxx,
+                                                                   add,
+                                                                   factor,
+                                                                   this->fft_bundle.get_auxr_3d_data<FPTYPE>(),
+                                                                   out);
+
+    ModuleBase::timer::tick(this->classname, "recip_to_real gpu");
+}
+
+template void PW_Basis_K::real2recip_gpu<float>(const std::complex<float>*,
+                                                std::complex<float>*,
+                                                const int,
+                                                const bool,
+                                                const float) const;
+
+template void PW_Basis_K::real2recip_gpu<double>(const std::complex<double>*,
+                                                 std::complex<double>*,
+                                                 const int,
+                                                 const bool,
+                                                 const double) const;
+
+template void PW_Basis_K::recip2real_gpu<float>(const std::complex<float>*,
+                                                std::complex<float>*,
+                                                const int,
+                                                const bool,
+                                                const float) const;
+
+template void PW_Basis_K::recip2real_gpu<double>(const std::complex<double>*,
+                                                 std::complex<double>*,
+                                                 const int,
+                                                 const bool,
+                                                 const double) const;
+
 #endif
 
 template void PW_Basis_K::real2recip<float>(const float* in,
diff --git a/source/module_basis/module_pw/pw_transform_k_dsp.cpp b/source/module_basis/module_pw/pw_transform_k_dsp.cpp
index b292e25f0a..56ba26eb2a 100644
--- a/source/module_basis/module_pw/pw_transform_k_dsp.cpp
+++ b/source/module_basis/module_pw/pw_transform_k_dsp.cpp
@@ -27,13 +27,11 @@ void PW_Basis_K::real2recip_dsp(const std::complex<FPTYPE>* in,
 
     // 3d fft
     this->fft_bundle.resource_handler(1);
-    this->fft_bundle.fft3D_forward(gpux, 
-                                   auxr, 
+    this->fft_bundle.fft3D_forward(auxr, 
                                    auxr);
     this->fft_bundle.resource_handler(0);
     // copy the result from the auxr to the out ,while consider the add
-    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>()(ctx,
-                                                                   npw_k,
+    set_real_to_recip_output_op<FPTYPE, base_device::DEVICE_CPU>()(npw_k,
                                                                    this->nxyz,
                                                                    add,
                                                                    factor,
@@ -58,10 +56,10 @@ void PW_Basis_K::recip2real_dsp(const std::complex<FPTYPE>* in,
     const int startig = ik * this->npwk_max;
     const int npw_k = this->npwk[ik];
     // copy the mapping form the type of stick to the 3dfft
-    set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr);
+    set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, in, auxr);
     // use 3d fft backward
     this->fft_bundle.resource_handler(1);
-    this->fft_bundle.fft3D_backward(gpux, auxr, auxr);
+    this->fft_bundle.fft3D_backward(auxr, auxr);
     this->fft_bundle.resource_handler(0);
     if (add)
     {
@@ -107,10 +105,10 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
     const int npw_k = this->npwk[ik];
 
     // copy the mapping form the type of stick to the 3dfft
-    set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(ctx, npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr);
+    set_3d_fft_box_op<double, base_device::DEVICE_CPU>()(npw_k, this->ig2ixyz_k_cpu.data() + startig, input, auxr);
 
     // use 3d fft backward
-    this->fft_bundle.fft3D_backward(gpux, auxr, auxr);
+    this->fft_bundle.fft3D_backward(auxr, auxr);
 
     for (int ir = 0; ir < size; ir++)
     {
@@ -118,10 +116,9 @@ void PW_Basis_K::convolution(const base_device::DEVICE_CPU* ctx,
     }
 
     // 3d fft
-    this->fft_bundle.fft3D_forward(gpux, auxr, auxr);
+    this->fft_bundle.fft3D_forward(auxr, auxr);
     // copy the result from the auxr to the out ,while consider the add
-    set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(ctx,
-                                                                   npw_k,
+    set_real_to_recip_output_op<double, base_device::DEVICE_CPU>()(npw_k,
                                                                    this->nxyz,
                                                                    add,
                                                                    factor,
diff --git a/source/module_elecstate/elecstate_pw.cpp b/source/module_elecstate/elecstate_pw.cpp
index 739e497515..ef7a168752 100644
--- a/source/module_elecstate/elecstate_pw.cpp
+++ b/source/module_elecstate/elecstate_pw.cpp
@@ -443,7 +443,6 @@ void ElecStatePW<T, Device>::add_usrho(const psi::Psi<T, Device>& psi)
     {
         this->addusdens_g(becsum, rhog);
     }
-
     // transform back to real space using dense grids
     if (PARAM.globalv.double_grid || PARAM.globalv.use_uspp)
     {
diff --git a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp
index 54e1a052be..c931d61f2f 100644
--- a/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp
+++ b/source/module_hamilt_pw/hamilt_pwdft/operator_pw/veff_pw.cpp
@@ -62,7 +62,7 @@ void Veff<OperatorPW<T, Device>>::act(
         if (npol == 1)
         {
             // wfcpw->recip2real(tmpsi_in, porter, this->ik);
-            wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik);
+            wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
             // NOTICE: when MPI threads are larger than number of Z grids
             // veff would contain nothing, and nothing should be done in real space
             // but the 3DFFT can not be skipped, it will cause hanging
@@ -76,7 +76,7 @@ void Veff<OperatorPW<T, Device>>::act(
                 // }
             }
             // wfcpw->real2recip(porter, tmhpsi, this->ik, true);
-            wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
+            wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
             // wfcpw->convolution(this->ctx,
             // this->ik,
             // this->veff_col,
@@ -89,8 +89,8 @@ void Veff<OperatorPW<T, Device>>::act(
         {
             // T *porter1 = new T[wfcpw->nmaxgr];
             // fft to real space and doing things.
-            wfcpw->recip_to_real(this->ctx, tmpsi_in, this->porter, this->ik);
-            wfcpw->recip_to_real(this->ctx, tmpsi_in + max_npw, this->porter1, this->ik);
+            wfcpw->recip_to_real<T,Device>(tmpsi_in, this->porter, this->ik);
+            wfcpw->recip_to_real<T,Device>(tmpsi_in + max_npw, this->porter1, this->ik);
             if(this->veff_col != 0)
             {
                 /// denghui added at 20221109
@@ -114,8 +114,8 @@ void Veff<OperatorPW<T, Device>>::act(
                 // }
             }
             // (3) fft back to G space.
-            wfcpw->real_to_recip(this->ctx, this->porter, tmhpsi, this->ik, true);
-            wfcpw->real_to_recip(this->ctx, this->porter1, tmhpsi + max_npw, this->ik, true);
+            wfcpw->real_to_recip<T,Device>(this->porter, tmhpsi, this->ik, true);
+            wfcpw->real_to_recip<T,Device>(this->porter1, tmhpsi + max_npw, this->ik, true);
         }
         tmhpsi += max_npw * npol;
         tmpsi_in += max_npw * npol;