Skip to content

Feature: Optimize dngvd_op with CPU/GPU Branching Based on nstart #5919

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion source/module_esolver/esolver_ks_pw.cpp
Original file line number Diff line number Diff line change
@@ -556,7 +556,8 @@ void ESolver_KS_PW<T, Device>::hamilt2density_single(UnitCell& ucell,
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
hsolver::DiagoIterAssist<T, Device>::need_subspace);
hsolver::DiagoIterAssist<T, Device>::need_subspace,
PARAM.inp.use_k_continuity);

hsolver_pw_obj.solve(this->p_hamilt,
this->kspw_psi[0],
3 changes: 2 additions & 1 deletion source/module_esolver/esolver_sdft_pw.cpp
Original file line number Diff line number Diff line change
@@ -175,7 +175,8 @@ void ESolver_SDFT_PW<T, Device>::hamilt2density_single(UnitCell& ucell, int iste
hsolver::DiagoIterAssist<T, Device>::SCF_ITER,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<T, Device>::PW_DIAG_THR,
hsolver::DiagoIterAssist<T, Device>::need_subspace);
hsolver::DiagoIterAssist<T, Device>::need_subspace,
PARAM.inp.use_k_continuity);

hsolver_pw_sdft_obj.solve(ucell,
this->p_hamilt,
Original file line number Diff line number Diff line change
@@ -428,7 +428,8 @@ void spinconstrain::SpinConstrain<std::complex<double>>::update_psi_charge(const
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::PW_DIAG_THR,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace);
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_CPU>::need_subspace,
PARAM.inp.use_k_continuity);

hsolver_pw_obj.solve(hamilt_t,
psi_t[0],
@@ -503,7 +504,8 @@ void spinconstrain::SpinConstrain<std::complex<double>>::update_psi_charge(const
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_GPU>::SCF_ITER,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_GPU>::PW_DIAG_NMAX,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_GPU>::PW_DIAG_THR,
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_GPU>::need_subspace);
hsolver::DiagoIterAssist<std::complex<double>, base_device::DEVICE_GPU>::need_subspace,
PARAM.inp.use_k_continuity);

hsolver_pw_obj.solve(hamilt_t,
psi_t[0],
119 changes: 117 additions & 2 deletions source/module_hsolver/hsolver_pw.cpp
Original file line number Diff line number Diff line change
@@ -280,9 +280,18 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
std::vector<Real> eigenvalues(this->wfc_basis->nks * psi.get_nbands(), 0.0);
ethr_band.resize(psi.get_nbands(), this->diag_thr);

// using k-point continuity
if (use_k_continuity) {
build_k_neighbors();
}

static int count = 0;

/// Loop over k points for solve Hamiltonian to charge density
for (int ik = 0; ik < this->wfc_basis->nks; ++ik)
for (int i = 0; i < this->wfc_basis->nks; ++i)
{
const int ik = use_k_continuity ? k_order[i] : i;
// ModuleBase::timer::tick("HsolverPW", "k_point: " + std::to_string(ik));
/// update H(k) for each k point
pHamilt->updateHk(ik);

@@ -293,6 +302,11 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
/// update psi pointer for each k point
psi.fix_k(ik);

// If using k-point continuity and not first k-point, propagate from parent
if (use_k_continuity && ik > 0 && psi.is_first_iter) {
propagate_psi(psi, k_parent[ik], ik);
}

// template add precondition calculating here
update_precondition(precondition, ik, this->wfc_basis->npwk[ik], Real(pes->pot->get_vl_of_0()));

@@ -319,8 +333,10 @@ void HSolverPW<T, Device>::solve(hamilt::Hamilt<T, Device>* pHamilt,
<< " ; where current threshold is: " << this->diag_thr << " . " << std::endl;
DiagoIterAssist<T, Device>::avg_iter = 0.0;
}
// ModuleBase::timer::tick("HsolverPW", "k_point: " + std::to_string(ik));
/// calculate the contribution of Psi for charge density rho
}
psi.is_first_iter = false;
// END Loop over k points

// copy eigenvalues to ekb in ElecState
@@ -666,11 +682,110 @@ void HSolverPW<T, Device>::output_iterInfo()
}
}

template <typename T, typename Device>
void HSolverPW<T, Device>::build_k_neighbors() {
const int nk = this->wfc_basis->nks;
kvecs_c.resize(nk);
k_order.clear();
k_order.reserve(nk);

/**
* @brief Structure representing a K-point with its vector, index, and norm.
*/
struct KPoint {
ModuleBase::Vector3<double> kvec; ///< K-vector of the K-point.
int index; ///< Index of the K-point.
double norm; ///< Norm of the K-vector.

KPoint(const ModuleBase::Vector3<double>& v, int i) :
kvec(v), index(i), norm(v.norm()) {}
};

// Build k-point list
std::vector<KPoint> klist;
for (int ik = 0; ik < nk; ++ik) {
kvecs_c[ik] = this->wfc_basis->kvec_c[ik];
klist.push_back(KPoint(kvecs_c[ik], ik));
}

// Sort k-points by distance from origin
std::sort(klist.begin(), klist.end(),
[](const KPoint& a, const KPoint& b) {
return a.norm < b.norm;
});

// Build parent-child relationships
k_order.push_back(klist[0].index);

// Find nearest processed k-point as parent for each k-point
for (int i = 1; i < nk; ++i) {
int current_k = klist[i].index;
double min_dist = 1e10;
int parent = -1;

// find the nearest k-point as parent
for (int j = 0; j < k_order.size(); ++j) {
int processed_k = k_order[j];
double dist = (kvecs_c[current_k] - kvecs_c[processed_k]).norm2();
if (dist < min_dist) {
min_dist = dist;
parent = processed_k;
}
}

k_parent[current_k] = parent;
k_order.push_back(current_k);
}
}

template <typename T, typename Device>
void HSolverPW<T, Device>::propagate_psi(psi::Psi<T, Device>& psi, const int from_ik, const int to_ik) {
const int nbands = psi.get_nbands();
const int npwk = this->wfc_basis->npwk[to_ik];

// Get k-point difference
ModuleBase::Vector3<double> dk = kvecs_c[to_ik] - kvecs_c[from_ik];

/**
* @brief Allocates memory for the porter used in FFT operations.
*/
T* porter = nullptr;
resmem_complex_op()(this->ctx, porter, this->wfc_basis->nmaxgr, "HSolverPW::porter");

// Process each band
for (int ib = 0; ib < nbands; ib++)
{
// Fix current k-point and band
// psi.fix_k(from_ik);

// FFT to real space
// this->wfc_basis->recip_to_real(this->ctx, psi.get_pointer(ib), porter, from_ik);
this->wfc_basis->recip_to_real(this->ctx, &psi(from_ik, ib, 0), porter, from_ik);

// Apply phase factor
// // TODO: Check how to get the r vector
// ModuleBase::Vector3<double> r = this->wfc_basis->get_ir2r(ir);
// double phase = this->wfc_basis->tpiba * (dk.x * r.x + dk.y * r.y + dk.z * r.z);
// psi_real[ir] *= std::exp(std::complex<double>(0.0, phase));
// }

// Fix k-point for target
// psi.fix_k(to_ik);

// FFT back to reciprocal space
// this->wfc_basis->real_to_recip(this->ctx, porter, psi.get_pointer(ib), to_ik, true);
this->wfc_basis->real_to_recip(this->ctx, porter, &psi(to_ik, ib, 0), to_ik);
}

// Clean up porter
delmem_complex_op()(this->ctx, porter);
}

template class HSolverPW<std::complex<float>, base_device::DEVICE_CPU>;
template class HSolverPW<std::complex<double>, base_device::DEVICE_CPU>;
#if ((defined __CUDA) || (defined __ROCM))
template class HSolverPW<std::complex<float>, base_device::DEVICE_GPU>;
template class HSolverPW<std::complex<double>, base_device::DEVICE_GPU>;
#endif

} // namespace hsolver
} // namespace hsolver
44 changes: 42 additions & 2 deletions source/module_hsolver/hsolver_pw.h
Original file line number Diff line number Diff line change
@@ -6,6 +6,8 @@
#include "module_base/macros.h"
#include "module_basis/module_pw/pw_basis_k.h"
#include "module_psi/wavefunc.h"
#include <unordered_map>
#include "module_base/memory.h"

namespace hsolver
{
@@ -18,6 +20,9 @@ class HSolverPW
// return T if T is real type(float, double),
// otherwise return the real type of T(complex<float>, complex<double>)
using Real = typename GetTypeReal<T>::type;
using resmem_complex_op = base_device::memory::resize_memory_op<T, Device>;
using delmem_complex_op = base_device::memory::delete_memory_op<T, Device>;
using setmem_complex_op = base_device::memory::set_memory_op<T, Device>;

public:
HSolverPW(ModulePW::PW_Basis_K* wfc_basis_in,
@@ -30,10 +35,12 @@ class HSolverPW
const int scf_iter_in,
const int diag_iter_max_in,
const double diag_thr_in,
const bool need_subspace_in)
const bool need_subspace_in,
const bool use_k_continuity_in = true)
: wfc_basis(wfc_basis_in), calculation_type(calculation_type_in), basis_type(basis_type_in), method(method_in),
use_paw(use_paw_in), use_uspp(use_uspp_in), nspin(nspin_in), scf_iter(scf_iter_in),
diag_iter_max(diag_iter_max_in), diag_thr(diag_thr_in), need_subspace(need_subspace_in){};
diag_iter_max(diag_iter_max_in), diag_thr(diag_thr_in), need_subspace(need_subspace_in),
use_k_continuity(use_k_continuity_in){};

/// @brief solve function for pw
/// @param pHamilt interface to hamilt
@@ -51,6 +58,7 @@ class HSolverPW
const double tpiba,
const int nat);


protected:
// diago caller
void hamiltSolvePsiK(hamilt::Hamilt<T, Device>* hm,
@@ -79,6 +87,8 @@ class HSolverPW

const bool need_subspace; // for cg or dav_subspace

const bool use_k_continuity;

protected:
Device* ctx = {};

@@ -99,6 +109,36 @@ class HSolverPW

void paw_func_after_kloop(psi::Psi<T, Device>& psi, elecstate::ElecState* pes,const double tpiba,const int nat);
#endif

// K-point continuity related members
/**
* @brief Order of K-points.
*/
std::vector<int> k_order;

/**
* @brief Parent-child relationships for K-points.
*/
std::unordered_map<int, int> k_parent;

/**
* @brief K-vectors for K-points in continuous space.
*/
std::vector<ModuleBase::Vector3<double>> kvecs_c;

/**
* @brief Builds the K-neighbors for the K-points.
*/
void build_k_neighbors();

/**
* @brief Propagates the wave function from one K-point to another.
*
* @param psi The wave function object.
* @param from_ik The index of the starting K-point.
* @param to_ik The index of the target K-point.
*/
void propagate_psi(psi::Psi<T, Device>& psi, const int from_ik, const int to_ik);
};

} // namespace hsolver
6 changes: 4 additions & 2 deletions source/module_hsolver/hsolver_pw_sdft.h
Original file line number Diff line number Diff line change
@@ -26,7 +26,8 @@ class HSolverPW_SDFT : public HSolverPW<T, Device>
const int scf_iter_in,
const int diag_iter_max_in,
const double diag_thr_in,
const bool need_subspace_in)
const bool need_subspace_in,
const bool use_k_continuity_in = false)
: HSolverPW<T, Device>(wfc_basis_in,
calculation_type_in,
basis_type_in,
@@ -37,7 +38,8 @@ class HSolverPW_SDFT : public HSolverPW<T, Device>
scf_iter_in,
diag_iter_max_in,
diag_thr_in,
need_subspace_in)
need_subspace_in,
use_k_continuity_in)
{
stoiter.init(pkv, wfc_basis_in, stowf, stoche, p_hamilt_sto);
}
Loading
Loading