From 1046362c661db17ad6e3471dada80a655cc89017 Mon Sep 17 00:00:00 2001 From: Kevin Boyd Date: Fri, 5 Dec 2025 13:56:03 -0500 Subject: [PATCH] Rework MMFF loop to warp iteration --- src/forcefields/mmff_kernels_device.cuh | 451 +++++++++++++++--------- 1 file changed, 290 insertions(+), 161 deletions(-) diff --git a/src/forcefields/mmff_kernels_device.cuh b/src/forcefields/mmff_kernels_device.cuh index 8752d58..94590b4 100644 --- a/src/forcefields/mmff_kernels_device.cuh +++ b/src/forcefields/mmff_kernels_device.cuh @@ -16,6 +16,8 @@ #ifndef NVMOLKIT_MMFF_KERNELS_DEVICE_CUH #define NVMOLKIT_MMFF_KERNELS_DEVICE_CUH +#include + #include "kernel_utils.cuh" using namespace nvMolKit::FFKernelUtils; @@ -672,92 +674,155 @@ static __device__ __inline__ double molEnergy(const EnergyForceContribsDevicePtr double energy = 0.0; - const auto& [idx1s, idx2s, r0s, kbs] = terms.bondTerms; - const int bondStart = systemIndices.bondTermStarts[molIdx]; - const int bondEnd = systemIndices.bondTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = bondStart + tid; i < bondEnd; i += stride) { - const int localIdx1 = idx1s[i] - atomStart; - const int localIdx2 = idx2s[i] - atomStart; - energy += bondStretchEnergy(molCoords, localIdx1, localIdx2, r0s[i], kbs[i]); - } - + namespace cg = cooperative_groups; + constexpr int WARP_SIZE = 32; + auto tile32 = cg::tiled_partition(cg::this_thread_block()); + const int laneId = tile32.thread_rank(); + const int warpId = tile32.meta_group_rank(); + const int numWarps = tile32.meta_group_size(); + + // Get term ranges + const int bondStart = systemIndices.bondTermStarts[molIdx]; + const int bondEnd = systemIndices.bondTermStarts[molIdx + 1]; + const int angleStart = systemIndices.angleTermStarts[molIdx]; + const int angleEnd = systemIndices.angleTermStarts[molIdx + 1]; + const int bendStart = systemIndices.bendTermStarts[molIdx]; + const int bendEnd = systemIndices.bendTermStarts[molIdx + 1]; + const int oopStart = systemIndices.oopTermStarts[molIdx]; + const int oopEnd = systemIndices.oopTermStarts[molIdx + 1]; + const int torsionStart = systemIndices.torsionTermStarts[molIdx]; + const int torsionEnd = systemIndices.torsionTermStarts[molIdx + 1]; + const int vdwStart = systemIndices.vdwTermStarts[molIdx]; + const int vdwEnd = systemIndices.vdwTermStarts[molIdx + 1]; + const int eleStart = systemIndices.eleTermStarts[molIdx]; + const int eleEnd = systemIndices.eleTermStarts[molIdx + 1]; + + const int numBond = bondEnd - bondStart; + const int numAngle = angleEnd - angleStart; + const int numBend = bendEnd - bendStart; + const int numOop = oopEnd - oopStart; + const int numTorsion = torsionEnd - torsionStart; + const int numVdw = vdwEnd - vdwStart; + const int numEle = eleEnd - eleStart; + + // Get term data + const auto& [idx1s, idx2s, r0s, kbs] = terms.bondTerms; const auto& [a_idx1s, a_idx2s, a_idx3s, theta0s, kas, isLinears] = terms.angleTerms; - const int angleStart = systemIndices.angleTermStarts[molIdx]; - const int angleEnd = systemIndices.angleTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = angleStart + tid; i < angleEnd; i += stride) { - const int localIdx1 = a_idx1s[i] - atomStart; - const int localIdx2 = a_idx2s[i] - atomStart; - const int localIdx3 = a_idx3s[i] - atomStart; - const bool isLinear = static_cast(isLinears[i]); - energy += angleBendEnergy(molCoords, localIdx1, localIdx2, localIdx3, theta0s[i], kas[i], isLinear); - } - const auto& [bs_idx1s, bs_idx2s, bs_idx3s, bs_theta0s, restLen1s, restLen2s, forceConst1s, forceConst2s] = terms.bendTerms; - const int bendStart = systemIndices.bendTermStarts[molIdx]; - const int bendEnd = systemIndices.bendTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = bendStart + tid; i < bendEnd; i += stride) { - const int localIdx1 = bs_idx1s[i] - atomStart; - const int localIdx2 = bs_idx2s[i] - atomStart; - const int localIdx3 = bs_idx3s[i] - atomStart; - energy += bendStretchEnergy(molCoords, + const auto& [o_idx1s, o_idx2s, o_idx3s, o_idx4s, koops] = terms.oopTerms; + const auto& [t_idx1s, t_idx2s, t_idx3s, t_idx4s, V1s, V2s, V3s] = terms.torsionTerms; + const auto& [v_idx1s, v_idx2s, R_ij_stars, wellDepths] = terms.vdwTerms; + const auto& [e_idx1s, e_idx2s, chargeTerms, dielModels, is1_4s] = terms.eleTerms; + + // Calculate number of warps needed for each term type + const int warpsForBond = (numBond + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForAngle = (numAngle + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForBend = (numBend + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForOop = (numOop + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForTorsion = (numTorsion + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForVdw = (numVdw + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForEle = (numEle + WARP_SIZE - 1) / WARP_SIZE; + const int totalWarpsNeeded = + warpsForBond + warpsForAngle + warpsForBend + warpsForOop + warpsForTorsion + warpsForVdw + warpsForEle; + + // Each warp processes chunks in round-robin fashion + for (int chunkIdx = warpId; chunkIdx < totalWarpsNeeded; chunkIdx += numWarps) { + if (chunkIdx < warpsForBond) { + // Bond terms + const int baseIdx = chunkIdx * WARP_SIZE; + const int termIdx = bondStart + baseIdx + laneId; + if (baseIdx + laneId < numBond) { + const int localIdx1 = idx1s[termIdx] - atomStart; + const int localIdx2 = idx2s[termIdx] - atomStart; + energy += bondStretchEnergy(molCoords, localIdx1, localIdx2, r0s[termIdx], kbs[termIdx]); + } + } else if (chunkIdx < warpsForBond + warpsForAngle) { + // Angle terms + const int warpOffset = chunkIdx - warpsForBond; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = angleStart + baseIdx + laneId; + if (baseIdx + laneId < numAngle) { + const int localIdx1 = a_idx1s[termIdx] - atomStart; + const int localIdx2 = a_idx2s[termIdx] - atomStart; + const int localIdx3 = a_idx3s[termIdx] - atomStart; + const bool isLinear = static_cast(isLinears[termIdx]); + energy += angleBendEnergy(molCoords, localIdx1, localIdx2, localIdx3, theta0s[termIdx], kas[termIdx], isLinear); + } + } else if (chunkIdx < warpsForBond + warpsForAngle + warpsForBend) { + // Bend (stretch-bend) terms + const int warpOffset = chunkIdx - warpsForBond - warpsForAngle; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = bendStart + baseIdx + laneId; + if (baseIdx + laneId < numBend) { + const int localIdx1 = bs_idx1s[termIdx] - atomStart; + const int localIdx2 = bs_idx2s[termIdx] - atomStart; + const int localIdx3 = bs_idx3s[termIdx] - atomStart; + energy += bendStretchEnergy(molCoords, + localIdx1, + localIdx2, + localIdx3, + bs_theta0s[termIdx], + restLen1s[termIdx], + restLen2s[termIdx], + forceConst1s[termIdx], + forceConst2s[termIdx]); + } + } else if (chunkIdx < warpsForBond + warpsForAngle + warpsForBend + warpsForOop) { + // OOP terms + const int warpOffset = chunkIdx - warpsForBond - warpsForAngle - warpsForBend; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = oopStart + baseIdx + laneId; + if (baseIdx + laneId < numOop) { + const int localIdx1 = o_idx1s[termIdx] - atomStart; + const int localIdx2 = o_idx2s[termIdx] - atomStart; + const int localIdx3 = o_idx3s[termIdx] - atomStart; + const int localIdx4 = o_idx4s[termIdx] - atomStart; + energy += oopBendEnergy(molCoords, localIdx1, localIdx2, localIdx3, localIdx4, koops[termIdx]); + } + } else if (chunkIdx < warpsForBond + warpsForAngle + warpsForBend + warpsForOop + warpsForTorsion) { + // Torsion terms + const int warpOffset = chunkIdx - warpsForBond - warpsForAngle - warpsForBend - warpsForOop; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = torsionStart + baseIdx + laneId; + if (baseIdx + laneId < numTorsion) { + const int localIdx1 = t_idx1s[termIdx] - atomStart; + const int localIdx2 = t_idx2s[termIdx] - atomStart; + const int localIdx3 = t_idx3s[termIdx] - atomStart; + const int localIdx4 = t_idx4s[termIdx] - atomStart; + energy += torsionEnergy(molCoords, localIdx1, localIdx2, localIdx3, - bs_theta0s[i], - restLen1s[i], - restLen2s[i], - forceConst1s[i], - forceConst2s[i]); - } - - const auto& [o_idx1s, o_idx2s, o_idx3s, o_idx4s, koops] = terms.oopTerms; - const int oopStart = systemIndices.oopTermStarts[molIdx]; - const int oopEnd = systemIndices.oopTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = oopStart + tid; i < oopEnd; i += stride) { - const int localIdx1 = o_idx1s[i] - atomStart; - const int localIdx2 = o_idx2s[i] - atomStart; - const int localIdx3 = o_idx3s[i] - atomStart; - const int localIdx4 = o_idx4s[i] - atomStart; - energy += oopBendEnergy(molCoords, localIdx1, localIdx2, localIdx3, localIdx4, koops[i]); - } - - const auto& [t_idx1s, t_idx2s, t_idx3s, t_idx4s, V1s, V2s, V3s] = terms.torsionTerms; - const int torsionStart = systemIndices.torsionTermStarts[molIdx]; - const int torsionEnd = systemIndices.torsionTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = torsionStart + tid; i < torsionEnd; i += stride) { - const int localIdx1 = t_idx1s[i] - atomStart; - const int localIdx2 = t_idx2s[i] - atomStart; - const int localIdx3 = t_idx3s[i] - atomStart; - const int localIdx4 = t_idx4s[i] - atomStart; - energy += torsionEnergy(molCoords, localIdx1, localIdx2, localIdx3, localIdx4, V1s[i], V2s[i], V3s[i]); - } - - const auto& [v_idx1s, v_idx2s, R_ij_stars, wellDepths] = terms.vdwTerms; - const int vdwStart = systemIndices.vdwTermStarts[molIdx]; - const int vdwEnd = systemIndices.vdwTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = vdwStart + tid; i < vdwEnd; i += stride) { - const int localIdx1 = v_idx1s[i] - atomStart; - const int localIdx2 = v_idx2s[i] - atomStart; - energy += vdwEnergy(molCoords, localIdx1, localIdx2, R_ij_stars[i], wellDepths[i]); - } - - const auto& [e_idx1s, e_idx2s, chargeTerms, dielModels, is1_4s] = terms.eleTerms; - const int eleStart = systemIndices.eleTermStarts[molIdx]; - const int eleEnd = systemIndices.eleTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = eleStart + tid; i < eleEnd; i += stride) { - const int localIdx1 = e_idx1s[i] - atomStart; - const int localIdx2 = e_idx2s[i] - atomStart; - const int dielModel = static_cast(dielModels[i]); - const bool is14 = is1_4s[i] > 0; - energy += eleEnergy(molCoords, localIdx1, localIdx2, chargeTerms[i], dielModel, is14); + localIdx4, + V1s[termIdx], + V2s[termIdx], + V3s[termIdx]); + } + } else if (chunkIdx < warpsForBond + warpsForAngle + warpsForBend + warpsForOop + warpsForTorsion + warpsForVdw) { + // VDW terms + const int warpOffset = chunkIdx - warpsForBond - warpsForAngle - warpsForBend - warpsForOop - warpsForTorsion; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = vdwStart + baseIdx + laneId; + if (baseIdx + laneId < numVdw) { + const int localIdx1 = v_idx1s[termIdx] - atomStart; + const int localIdx2 = v_idx2s[termIdx] - atomStart; + energy += vdwEnergy(molCoords, localIdx1, localIdx2, R_ij_stars[termIdx], wellDepths[termIdx]); + } + } else { + // ELE terms + const int warpOffset = + chunkIdx - warpsForBond - warpsForAngle - warpsForBend - warpsForOop - warpsForTorsion - warpsForVdw; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = eleStart + baseIdx + laneId; + if (baseIdx + laneId < numEle) { + const int localIdx1 = e_idx1s[termIdx] - atomStart; + const int localIdx2 = e_idx2s[termIdx] - atomStart; + const int dielModel = static_cast(dielModels[termIdx]); + const bool is14 = is1_4s[termIdx] > 0; + energy += eleEnergy(molCoords, localIdx1, localIdx2, chargeTerms[termIdx], dielModel, is14); + } + } } return energy; @@ -773,92 +838,156 @@ static __device__ __inline__ void molGrad(const EnergyForceContribsDevicePtr& te const int atomStart = systemIndices.atomStarts[molIdx]; const double* molCoords = coords + atomStart * 3; - const auto& [idx1s, idx2s, r0s, kbs] = terms.bondTerms; - const int bondStart = systemIndices.bondTermStarts[molIdx]; - const int bondEnd = systemIndices.bondTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = bondStart + tid; i < bondEnd; i += stride) { - const int localIdx1 = idx1s[i] - atomStart; - const int localIdx2 = idx2s[i] - atomStart; - bondStretchGrad(molCoords, localIdx1, localIdx2, r0s[i], kbs[i], grad); - } - + namespace cg = cooperative_groups; + constexpr int WARP_SIZE = 32; + auto tile32 = cg::tiled_partition(cg::this_thread_block()); + const int laneId = tile32.thread_rank(); + const int warpId = tile32.meta_group_rank(); + const int numWarps = tile32.meta_group_size(); + + // Get term ranges + const int bondStart = systemIndices.bondTermStarts[molIdx]; + const int bondEnd = systemIndices.bondTermStarts[molIdx + 1]; + const int angleStart = systemIndices.angleTermStarts[molIdx]; + const int angleEnd = systemIndices.angleTermStarts[molIdx + 1]; + const int bendStart = systemIndices.bendTermStarts[molIdx]; + const int bendEnd = systemIndices.bendTermStarts[molIdx + 1]; + const int oopStart = systemIndices.oopTermStarts[molIdx]; + const int oopEnd = systemIndices.oopTermStarts[molIdx + 1]; + const int torsionStart = systemIndices.torsionTermStarts[molIdx]; + const int torsionEnd = systemIndices.torsionTermStarts[molIdx + 1]; + const int vdwStart = systemIndices.vdwTermStarts[molIdx]; + const int vdwEnd = systemIndices.vdwTermStarts[molIdx + 1]; + const int eleStart = systemIndices.eleTermStarts[molIdx]; + const int eleEnd = systemIndices.eleTermStarts[molIdx + 1]; + + const int numBond = bondEnd - bondStart; + const int numAngle = angleEnd - angleStart; + const int numBend = bendEnd - bendStart; + const int numOop = oopEnd - oopStart; + const int numTorsion = torsionEnd - torsionStart; + const int numVdw = vdwEnd - vdwStart; + const int numEle = eleEnd - eleStart; + + // Get term data + const auto& [idx1s, idx2s, r0s, kbs] = terms.bondTerms; const auto& [a_idx1s, a_idx2s, a_idx3s, theta0s, kas, isLinears] = terms.angleTerms; - const int angleStart = systemIndices.angleTermStarts[molIdx]; - const int angleEnd = systemIndices.angleTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = angleStart + tid; i < angleEnd; i += stride) { - const int localIdx1 = a_idx1s[i] - atomStart; - const int localIdx2 = a_idx2s[i] - atomStart; - const int localIdx3 = a_idx3s[i] - atomStart; - const bool isLinear = static_cast(isLinears[i]); - angleBendGrad(localIdx1, localIdx2, localIdx3, theta0s[i], kas[i], isLinear, molCoords, grad); - } - const auto& [bs_idx1s, bs_idx2s, bs_idx3s, bs_theta0s, restLen1s, restLen2s, forceConst1s, forceConst2s] = terms.bendTerms; - const int bendStart = systemIndices.bendTermStarts[molIdx]; - const int bendEnd = systemIndices.bendTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = bendStart + tid; i < bendEnd; i += stride) { - const int localIdx1 = bs_idx1s[i] - atomStart; - const int localIdx2 = bs_idx2s[i] - atomStart; - const int localIdx3 = bs_idx3s[i] - atomStart; - bendStretchGrad(molCoords, - localIdx1, - localIdx2, - localIdx3, - bs_theta0s[i], - restLen1s[i], - restLen2s[i], - forceConst1s[i], - forceConst2s[i], - grad); - } - - const auto& [o_idx1s, o_idx2s, o_idx3s, o_idx4s, koops] = terms.oopTerms; - const int oopStart = systemIndices.oopTermStarts[molIdx]; - const int oopEnd = systemIndices.oopTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = oopStart + tid; i < oopEnd; i += stride) { - const int localIdx1 = o_idx1s[i] - atomStart; - const int localIdx2 = o_idx2s[i] - atomStart; - const int localIdx3 = o_idx3s[i] - atomStart; - const int localIdx4 = o_idx4s[i] - atomStart; - rdkit_ports::oopGrad(molCoords, localIdx1, localIdx2, localIdx3, localIdx4, koops[i], grad); - } - + const auto& [o_idx1s, o_idx2s, o_idx3s, o_idx4s, koops] = terms.oopTerms; const auto& [t_idx1s, t_idx2s, t_idx3s, t_idx4s, V1s, V2s, V3s] = terms.torsionTerms; - const int torsionStart = systemIndices.torsionTermStarts[molIdx]; - const int torsionEnd = systemIndices.torsionTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = torsionStart + tid; i < torsionEnd; i += stride) { - const int localIdx1 = t_idx1s[i] - atomStart; - const int localIdx2 = t_idx2s[i] - atomStart; - const int localIdx3 = t_idx3s[i] - atomStart; - const int localIdx4 = t_idx4s[i] - atomStart; - rdkit_ports::torsionGrad(molCoords, localIdx1, localIdx2, localIdx3, localIdx4, V1s[i], V2s[i], V3s[i], grad); - } - - const auto& [v_idx1s, v_idx2s, R_ij_stars, wellDepths] = terms.vdwTerms; - const int vdwStart = systemIndices.vdwTermStarts[molIdx]; - const int vdwEnd = systemIndices.vdwTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = vdwStart + tid; i < vdwEnd; i += stride) { - const int localIdx1 = v_idx1s[i] - atomStart; - const int localIdx2 = v_idx2s[i] - atomStart; - rdkit_ports::vDWGrad(molCoords, localIdx1, localIdx2, R_ij_stars[i], wellDepths[i], grad); - } - + const auto& [v_idx1s, v_idx2s, R_ij_stars, wellDepths] = terms.vdwTerms; const auto& [e_idx1s, e_idx2s, chargeTerms, dielModels, is1_4s] = terms.eleTerms; - const int eleStart = systemIndices.eleTermStarts[molIdx]; - const int eleEnd = systemIndices.eleTermStarts[molIdx + 1]; -#pragma unroll 1 - for (int i = eleStart + tid; i < eleEnd; i += stride) { - const int localIdx1 = e_idx1s[i] - atomStart; - const int localIdx2 = e_idx2s[i] - atomStart; - const bool is14 = is1_4s[i] > 0; - eleGrad(molCoords, localIdx1, localIdx2, chargeTerms[i], dielModels[i], is14, grad); + + // Calculate number of warps needed for each term type + const int warpsForBond = (numBond + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForAngle = (numAngle + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForBend = (numBend + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForOop = (numOop + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForTorsion = (numTorsion + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForVdw = (numVdw + WARP_SIZE - 1) / WARP_SIZE; + const int warpsForEle = (numEle + WARP_SIZE - 1) / WARP_SIZE; + const int totalWarpsNeeded = + warpsForBond + warpsForAngle + warpsForBend + warpsForOop + warpsForTorsion + warpsForVdw + warpsForEle; + + // Each warp processes chunks in round-robin fashion + for (int chunkIdx = warpId; chunkIdx < totalWarpsNeeded; chunkIdx += numWarps) { + if (chunkIdx < warpsForBond) { + // Bond terms + const int baseIdx = chunkIdx * WARP_SIZE; + const int termIdx = bondStart + baseIdx + laneId; + if (baseIdx + laneId < numBond) { + const int localIdx1 = idx1s[termIdx] - atomStart; + const int localIdx2 = idx2s[termIdx] - atomStart; + bondStretchGrad(molCoords, localIdx1, localIdx2, r0s[termIdx], kbs[termIdx], grad); + } + } else if (chunkIdx < warpsForBond + warpsForAngle) { + // Angle terms + const int warpOffset = chunkIdx - warpsForBond; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = angleStart + baseIdx + laneId; + if (baseIdx + laneId < numAngle) { + const int localIdx1 = a_idx1s[termIdx] - atomStart; + const int localIdx2 = a_idx2s[termIdx] - atomStart; + const int localIdx3 = a_idx3s[termIdx] - atomStart; + const bool isLinear = static_cast(isLinears[termIdx]); + angleBendGrad(localIdx1, localIdx2, localIdx3, theta0s[termIdx], kas[termIdx], isLinear, molCoords, grad); + } + } else if (chunkIdx < warpsForBond + warpsForAngle + warpsForBend) { + // Bend (stretch-bend) terms + const int warpOffset = chunkIdx - warpsForBond - warpsForAngle; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = bendStart + baseIdx + laneId; + if (baseIdx + laneId < numBend) { + const int localIdx1 = bs_idx1s[termIdx] - atomStart; + const int localIdx2 = bs_idx2s[termIdx] - atomStart; + const int localIdx3 = bs_idx3s[termIdx] - atomStart; + bendStretchGrad(molCoords, + localIdx1, + localIdx2, + localIdx3, + bs_theta0s[termIdx], + restLen1s[termIdx], + restLen2s[termIdx], + forceConst1s[termIdx], + forceConst2s[termIdx], + grad); + } + } else if (chunkIdx < warpsForBond + warpsForAngle + warpsForBend + warpsForOop) { + // OOP terms + const int warpOffset = chunkIdx - warpsForBond - warpsForAngle - warpsForBend; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = oopStart + baseIdx + laneId; + if (baseIdx + laneId < numOop) { + const int localIdx1 = o_idx1s[termIdx] - atomStart; + const int localIdx2 = o_idx2s[termIdx] - atomStart; + const int localIdx3 = o_idx3s[termIdx] - atomStart; + const int localIdx4 = o_idx4s[termIdx] - atomStart; + rdkit_ports::oopGrad(molCoords, localIdx1, localIdx2, localIdx3, localIdx4, koops[termIdx], grad); + } + } else if (chunkIdx < warpsForBond + warpsForAngle + warpsForBend + warpsForOop + warpsForTorsion) { + // Torsion terms + const int warpOffset = chunkIdx - warpsForBond - warpsForAngle - warpsForBend - warpsForOop; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = torsionStart + baseIdx + laneId; + if (baseIdx + laneId < numTorsion) { + const int localIdx1 = t_idx1s[termIdx] - atomStart; + const int localIdx2 = t_idx2s[termIdx] - atomStart; + const int localIdx3 = t_idx3s[termIdx] - atomStart; + const int localIdx4 = t_idx4s[termIdx] - atomStart; + rdkit_ports::torsionGrad(molCoords, + localIdx1, + localIdx2, + localIdx3, + localIdx4, + V1s[termIdx], + V2s[termIdx], + V3s[termIdx], + grad); + } + } else if (chunkIdx < warpsForBond + warpsForAngle + warpsForBend + warpsForOop + warpsForTorsion + warpsForVdw) { + // VDW terms + const int warpOffset = chunkIdx - warpsForBond - warpsForAngle - warpsForBend - warpsForOop - warpsForTorsion; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = vdwStart + baseIdx + laneId; + if (baseIdx + laneId < numVdw) { + const int localIdx1 = v_idx1s[termIdx] - atomStart; + const int localIdx2 = v_idx2s[termIdx] - atomStart; + rdkit_ports::vDWGrad(molCoords, localIdx1, localIdx2, R_ij_stars[termIdx], wellDepths[termIdx], grad); + } + } else { + // ELE terms + const int warpOffset = + chunkIdx - warpsForBond - warpsForAngle - warpsForBend - warpsForOop - warpsForTorsion - warpsForVdw; + const int baseIdx = warpOffset * WARP_SIZE; + const int termIdx = eleStart + baseIdx + laneId; + if (baseIdx + laneId < numEle) { + const int localIdx1 = e_idx1s[termIdx] - atomStart; + const int localIdx2 = e_idx2s[termIdx] - atomStart; + const bool is14 = is1_4s[termIdx] > 0; + eleGrad(molCoords, localIdx1, localIdx2, chargeTerms[termIdx], dielModels[termIdx], is14, grad); + } + } } }