Skip to content

Commit

Permalink
few more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chillenb committed Oct 8, 2024
1 parent 4168f10 commit 7834c99
Show file tree
Hide file tree
Showing 9 changed files with 50 additions and 48 deletions.
16 changes: 15 additions & 1 deletion pyscf/lib/cc/ccsd_pack.c
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,10 @@ void CCload_eri(double *out, double *eri, int *orbs_slice, int nao)
#endif
int i, j, k, l, ij;
double *pout;

#ifndef PYSCF_USE_MKL
double *buf = pyscf_malloc(sizeof(double) * nao*nao);

#pragma omp for schedule (static)
for (ij = 0; ij < ni*nj; ij++) {
i = ij / nj;
Expand All @@ -193,7 +196,18 @@ void CCload_eri(double *out, double *eri, int *orbs_slice, int nao)
} }
}
pyscf_free(buf);
#ifdef PYSCF_USE_MKL

#else

#pragma omp for schedule(static)
for (ij = 0; ij < ni*nj; ij++) {
i = ij / nj;
j = ij % nj;
pout = out + (i*nn+j)*nao;
LAPACKE_mkl_dtpunpack(LAPACK_ROW_MAJOR, 'L', 'N', nao, eri+ij*nao_pair, 1, 1, nao, nao, pout, nn);
LAPACKE_mkl_dtpunpack(LAPACK_ROW_MAJOR, 'L', 'T', nao, eri+ij*nao_pair, 1, 1, nao, nao, pout, nn);
}

mkl_set_num_threads_local(save);
#endif
}
Expand Down
1 change: 1 addition & 0 deletions pyscf/lib/dft/grid_basis.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ void VXCgen_grid(double *out, double *coords, double *atm_coords,
int i, j;
double dx, dy, dz;
double *atom_dist = pyscf_malloc(sizeof(double) * natm*natm);

for (i = 0; i < natm; i++) {
for (j = 0; j < i; j++) {
dx = atm_coords[i*3+0] - atm_coords[j*3+0];
Expand Down
7 changes: 0 additions & 7 deletions pyscf/lib/dft/grid_common.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ void get_cart2sph_coeff(double** contr_coeff, double** gto_norm,
ncart, nsph, nprim, nctr,\
ptr_exp, ptr_coeff)
{
#ifdef PYSCF_USE_MKL
int save = mkl_set_num_threads_local(1);
#endif

#pragma omp for schedule(dynamic)
for (ish = ish0; ish < ish1; ish++) {
l = bas[ANG_OF+ish*BAS_SLOTS];
Expand Down Expand Up @@ -137,9 +133,6 @@ void get_cart2sph_coeff(double** contr_coeff, double** gto_norm,
pyscf_free(buf);
}

#ifdef PYSCF_USE_MKL
mkl_set_num_threads_local(save);
#endif
}

for (l = 0; l <= lmax; l++) {
Expand Down
8 changes: 8 additions & 0 deletions pyscf/lib/dft/grid_integrate.c
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ static void fill_tril(double* mat, int comp, int* ish_ao_loc, int* jsh_ao_loc,
double *pmat_up = mat + i0*((size_t)naoj) + j0;
double *pmat_low = mat + j0*((size_t)naoj) + i0;
int ic, i, j;
#ifndef PYSCF_USE_MKL
for (ic = 0; ic < comp; ic++) {
for (i = 0; i < ni; i++) {
for (j = 0; j < nj; j++) {
Expand All @@ -95,6 +96,13 @@ static void fill_tril(double* mat, int comp, int* ish_ao_loc, int* jsh_ao_loc,
pmat_up += nao2;
pmat_low += nao2;
}
#else
mkl_domatcopy_batch_strided(
'R', 'T', ni, nj,
1.0, pmat_up, naoj, nao2,
pmat_low, naoj, nao2, comp
);
#endif
}


Expand Down
30 changes: 0 additions & 30 deletions pyscf/lib/dft/nr_numint.c
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,6 @@ void VXC_dscale_ao(double *aow, double *ao, double *wv,
{
#pragma omp parallel
{
#ifdef PYSCF_USE_MKL
int save = mkl_set_num_threads_local(1);
#endif

size_t Ngrids = ngrids;
size_t ao_size = nao * Ngrids;
int i, j, ic;
Expand All @@ -240,10 +236,6 @@ void VXC_dscale_ao(double *aow, double *ao, double *wv,
aow[i*Ngrids+j] += pao[ic*ao_size+j] * wv[ic*Ngrids+j];
} }
}

#ifdef PYSCF_USE_MKL
mkl_set_num_threads_local(save);
#endif
}
}

Expand All @@ -253,10 +245,6 @@ void VXC_dcontract_rho(double *rho, double *bra, double *ket,
{
#pragma omp parallel
{
#ifdef PYSCF_USE_MKL
int save = mkl_set_num_threads_local(1);
#endif

size_t Ngrids = ngrids;
int nthread = omp_get_num_threads();
int blksize = MAX((Ngrids+nthread-1) / nthread, 1);
Expand All @@ -273,10 +261,6 @@ void VXC_dcontract_rho(double *rho, double *bra, double *ket,
rho[j] += bra[i*Ngrids+j] * ket[i*Ngrids+j];
} }
}

#ifdef PYSCF_USE_MKL
mkl_set_num_threads_local(save);
#endif
}
}

Expand All @@ -287,10 +271,6 @@ void VXC_vv10nlc(double *Fvec, double *Uvec, double *Wvec,
{
#pragma omp parallel
{
#ifdef PYSCF_USE_MKL
int save = mkl_set_num_threads_local(1);
#endif

double DX, DY, DZ, R2;
double gp, g, gt, T, F, U, W;
int i, j;
Expand Down Expand Up @@ -318,9 +298,6 @@ void VXC_vv10nlc(double *Fvec, double *Uvec, double *Wvec,
Wvec[i] = W;
}

#ifdef PYSCF_USE_MKL
mkl_set_num_threads_local(save);
#endif
}
}

Expand All @@ -330,10 +307,6 @@ void VXC_vv10nlc_grad(double *Fvec, double *vvcoords, double *coords,
{
#pragma omp parallel
{
#ifdef PYSCF_USE_MKL
int save = mkl_set_num_threads_local(1);
#endif

double DX, DY, DZ, R2;
double gp, g, gt, T, Q, FX, FY, FZ;
int i, j;
Expand Down Expand Up @@ -361,8 +334,5 @@ void VXC_vv10nlc_grad(double *Fvec, double *vvcoords, double *coords,
Fvec[i*3+2] = FZ * -3;
}

#ifdef PYSCF_USE_MKL
mkl_set_num_threads_local(save);
#endif
}
}
6 changes: 0 additions & 6 deletions pyscf/lib/dft/nr_numint_sparse.c
Original file line number Diff line number Diff line change
Expand Up @@ -1152,9 +1152,6 @@ void VXCdscale_ao_sparse(double *aow, double *ao, double *wv,
{
#pragma omp parallel
{
#ifdef PYSCF_USE_MKL
int save = mkl_set_num_threads_local(1);
#endif
size_t Ngrids = ngrids;
size_t ao_size = nao * Ngrids;
int ish, i, j, ic, i0, i1, ig0, ig1, row;
Expand All @@ -1180,8 +1177,5 @@ for (i = i0; i < i1; i++) {
}
}
}
#ifdef PYSCF_USE_MKL
mkl_set_num_threads_local(save);
#endif
}
}
18 changes: 14 additions & 4 deletions pyscf/lib/np_helper/np_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@
for (I = 0, j1 = MIN(j0+BLOCK_DIM, n); I < j1; I++) \
for (J = MAX(I,j0); J < j1; J++)

#ifdef PYSCF_USE_MKL
#define pyscf_malloc mkl_malloc
#define pyscf_free mkl_free
#define pyscf_calloc mkl_calloc
#define pyscf_realloc mkl_realloc

#else

#define pyscf_malloc malloc
#define pyscf_free free
#define pyscf_calloc calloc
#define pyscf_realloc realloc
#endif

void NPdsymm_triu(int n, double *mat, int hermi);
void NPzhermi_triu(int n, double complex *mat, int hermi);
void NPdunpack_tril(int n, double *tril, double *mat, int hermi);
Expand Down Expand Up @@ -78,8 +92,4 @@ void NPdgemm(const char trans_a, const char trans_b,
double *a, double *b, double *c,
const double alpha, const double beta);

void *pyscf_malloc(size_t alloc_size);
void *pyscf_calloc(size_t n, size_t size);
void *pyscf_realloc(void *ptr, size_t size);
void pyscf_free(void *ptr);
int pyscf_has_mkl(void);
4 changes: 4 additions & 0 deletions pyscf/lib/pbc/fft.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

#include <fftw3.h>

#ifdef PYSCF_USE_MKL
#include "fftw3_mkl.h"
#endif

#define FFT_PLAN fftw_plan

FFT_PLAN fft_create_r2c_plan(double* in, complex double* out, int rank, int* mesh);
Expand Down
8 changes: 8 additions & 0 deletions pyscf/lib/pbc/hf_grad.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ void contract_vhf_dm(double* out, double* vhf, double* dm,

#pragma omp parallel
{
#ifdef PYSCF_USE_MKL
int save = mkl_set_num_threads_local(1);
#endif

size_t ij, ish, jsh, p0, q0;
int ni, nj, i, ic, iatm, nimgs=1;
NeighborList *nl0=NULL;
Expand Down Expand Up @@ -91,5 +95,9 @@ void contract_vhf_dm(double* out, double* vhf, double* dm,
if (thread_id != 0) {
pyscf_free(buf);
}

#ifdef PYSCF_USE_MKL
mkl_set_num_threads_local(save);
#endif
}
}

0 comments on commit 7834c99

Please sign in to comment.