Skip to content

Commit 81a1e70

Browse files
committed
Fixes for pool-allocated view sizes
1 parent 33f9559 commit 81a1e70

File tree

1 file changed

+54
-40
lines changed

1 file changed

+54
-40
lines changed

components/eamxx/src/physics/rrtmgp/scream_rrtmgp_interface.hpp

+54-40
Original file line numberDiff line numberDiff line change
@@ -409,45 +409,51 @@ static void rrtmgp_main(
409409
const int int_size2 = sw_nband;
410410
const int int_size3 = 2*lw_nband;
411411
const int int_size4 = lw_nband;
412+
const int int_size5 = sw_ngpt;
413+
const int int_size6 = lw_ngpt;
412414

413415
const int real_size1 = ncol*nlay*sw_nband;
414416
const int real_size2 = ncol*nlay*lw_nband;
417+
const int real_size3 = ncol*nlay*sw_ngpt;
418+
const int real_size4 = ncol*nlay*lw_ngpt;
415419

416-
const int total_int_size = 3 * (int_size1 + int_size2 + int_size3 + int_size4);
417-
const int total_real_size = 3 * (3 * real_size1 + real_size2);
420+
const int total_int_size = 2 * (int_size1 + int_size2 + int_size3 + int_size4) + (int_size1 + int_size5 + int_size3 + int_size6);
421+
const int total_real_size = 2 * (3 * real_size1 + real_size2) + (3*real_size3 + real_size4);
418422
auto int_data = pool_t::template alloc_and_init<int>(total_int_size); int *dcurr_int = int_data.data();
419423

420424
view_t<int**> sw_band2gpt_mem(dcurr_int, 2, sw_nband); dcurr_int += int_size1;
421-
view_t<int*> sw_gpt2band_mem(dcurr_int, sw_nband); dcurr_int += int_size2;
425+
view_t<int*> sw_gpt2band_mem(dcurr_int, sw_nband); dcurr_int += int_size2;
422426
view_t<int**> lw_band2gpt_mem(dcurr_int, 2, lw_nband); dcurr_int += int_size3;
423-
view_t<int*> lw_gpt2band_mem(dcurr_int, lw_nband); dcurr_int += int_size4;
427+
view_t<int*> lw_gpt2band_mem(dcurr_int, lw_nband); dcurr_int += int_size4;
424428

425429
view_t<int**> sw_cloud_band2gpt_mem(dcurr_int, 2, sw_nband); dcurr_int += int_size1;
426-
view_t<int*> sw_cloud_gpt2band_mem(dcurr_int, sw_nband); dcurr_int += int_size2;
430+
view_t<int*> sw_cloud_gpt2band_mem(dcurr_int, sw_nband); dcurr_int += int_size2;
427431
view_t<int**> lw_cloud_band2gpt_mem(dcurr_int, 2, lw_nband); dcurr_int += int_size3;
428-
view_t<int*> lw_cloud_gpt2band_mem(dcurr_int, lw_nband); dcurr_int += int_size4;
432+
view_t<int*> lw_cloud_gpt2band_mem(dcurr_int, lw_nband); dcurr_int += int_size4;
429433

430434
view_t<int**> sw_subcloud_band2gpt_mem(dcurr_int, 2, sw_nband); dcurr_int += int_size1;
431-
view_t<int*> sw_subcloud_gpt2band_mem(dcurr_int, sw_ngpt); dcurr_int += int_size2;
435+
view_t<int*> sw_subcloud_gpt2band_mem(dcurr_int, sw_ngpt); dcurr_int += int_size5;
432436
view_t<int**> lw_subcloud_band2gpt_mem(dcurr_int, 2, lw_nband); dcurr_int += int_size3;
433-
view_t<int*> lw_subcloud_gpt2band_mem(dcurr_int, lw_ngpt); dcurr_int += int_size4;
437+
view_t<int*> lw_subcloud_gpt2band_mem(dcurr_int, lw_ngpt); dcurr_int += int_size6;
438+
assert(dcurr_int - int_data.data() == total_int_size);
434439

435-
auto data = pool_t::template alloc<RealT>(total_real_size); RealT *dcurr = data.data();
440+
auto data = pool_t::template alloc_and_init<RealT>(total_real_size); RealT *dcurr = data.data();
436441

437442
view_t<RealT***> sw_tau_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
438443
view_t<RealT***> sw_ssa_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
439-
view_t<RealT***> sw_g_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
444+
view_t<RealT***> sw_g_mem (dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
440445
view_t<RealT***> lw_tau_mem(dcurr, ncol, nlay, lw_nband); dcurr += real_size2;
441446

442447
view_t<RealT***> sw_cloud_tau_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
443448
view_t<RealT***> sw_cloud_ssa_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
444-
view_t<RealT***> sw_cloud_g_mem(dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
449+
view_t<RealT***> sw_cloud_g_mem (dcurr, ncol, nlay, sw_nband); dcurr += real_size1;
445450
view_t<RealT***> lw_cloud_tau_mem(dcurr, ncol, nlay, lw_nband); dcurr += real_size2;
446451

447-
view_t<RealT***> sw_subcloud_tau_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size1;
448-
view_t<RealT***> sw_subcloud_ssa_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size1;
449-
view_t<RealT***> sw_subcloud_g_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size1;
450-
view_t<RealT***> lw_subcloud_tau_mem(dcurr, ncol, nlay, lw_ngpt); dcurr += real_size2;
452+
view_t<RealT***> sw_subcloud_tau_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size3;
453+
view_t<RealT***> sw_subcloud_ssa_mem(dcurr, ncol, nlay, sw_ngpt); dcurr += real_size3;
454+
view_t<RealT***> sw_subcloud_g_mem (dcurr, ncol, nlay, sw_ngpt); dcurr += real_size3;
455+
view_t<RealT***> lw_subcloud_tau_mem(dcurr, ncol, nlay, lw_ngpt); dcurr += real_size4;
456+
assert(dcurr - data.data() == total_real_size);
451457

452458
// Setup pointers to RRTMGP SW fluxes
453459
fluxes_t fluxes_sw;
@@ -532,6 +538,7 @@ static void rrtmgp_main(
532538
// subcolumn (cloud state) to each gpoint.
533539
auto nswgpts = k_dist_sw_k.get_ngpt();
534540
auto clouds_sw_gpt = get_subsampled_clouds(ncol, nlay, nswbands, nswgpts, clouds_sw, k_dist_sw_k, cldfrac, p_lay, sw_subcloud_band2gpt_mem, sw_subcloud_gpt2band_mem, sw_subcloud_tau_mem, sw_subcloud_ssa_mem, sw_subcloud_g_mem);
541+
535542
// Longwave
536543
auto nlwgpts = k_dist_lw_k.get_ngpt();
537544
auto clouds_lw_gpt = get_subsampled_clouds(ncol, nlay, nlwbands, nlwgpts, clouds_lw, k_dist_lw_k, cldfrac, p_lay, lw_subcloud_band2gpt_mem, lw_subcloud_gpt2band_mem, lw_subcloud_tau_mem);
@@ -729,21 +736,23 @@ static void rrtmgp_sw(
729736
auto sw_noaero_tau_mem = view_t<RealT***>(dcurr, nday, nlay, ngpt); dcurr += size11;
730737
auto sw_noaero_ssa_mem = view_t<RealT***>(dcurr, nday, nlay, ngpt); dcurr += size11;
731738
auto sw_noaero_g_mem = view_t<RealT***>(dcurr, nday, nlay, ngpt); dcurr += size11;
739+
assert(dcurr - data.data() == total_size);
732740

733741
const int int_size1 = 2*nbnd;
734742
const int int_size2 = nbnd;
735743
const int int_size3 = ngpt;
736744
const int total_int_size = 3 * (int_size1 + int_size3) + (int_size1 + int_size2);
737745
auto int_data = pool_t::template alloc_and_init<int>(total_int_size); int *dcurr_int = int_data.data();
738746

739-
auto sw_aero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
740-
auto sw_aero_gpt2band_mem = view_t<int*>(dcurr_int, nbnd); dcurr_int += int_size2;
741-
auto sw_cloud_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
742-
auto sw_cloud_gpt2band_mem = view_t<int*>(dcurr_int, ngpt); dcurr_int += int_size3;
747+
auto sw_aero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
748+
auto sw_aero_gpt2band_mem = view_t<int*> (dcurr_int, nbnd); dcurr_int += int_size2;
749+
auto sw_cloud_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
750+
auto sw_cloud_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size3;
743751
auto sw_optics_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
744-
auto sw_optics_gpt2band_mem = view_t<int*>(dcurr_int, ngpt); dcurr_int += int_size3;
752+
auto sw_optics_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size3;
745753
auto sw_noaero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
746-
auto sw_noaero_gpt2band_mem = view_t<int*>(dcurr_int, ngpt); dcurr_int += int_size3;
754+
auto sw_noaero_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size3;
755+
assert(dcurr_int - int_data.data() == total_int_size);
747756

748757
// Subset mu0
749758
TIMED_KERNEL(Kokkos::parallel_for(nday, KOKKOS_LAMBDA(int iday) {
@@ -942,31 +951,33 @@ static void rrtmgp_lw(
942951
const int total_size = size1 + size2 + size3*2 + size4 + size5 + size6 + size7*5 + size8;
943952
auto data = pool_t::template alloc_and_init<RealT>(total_size); RealT *dcurr = data.data();
944953

945-
view_t<RealT*> t_sfc (dcurr, ncol); dcurr += size1;
946-
view_t<RealT**> emis_sfc (dcurr, nbnd,ncol); dcurr += size2;
947-
view_t<RealT**> gauss_Ds (dcurr, max_gauss_pts,max_gauss_pts); dcurr += size3;
948-
view_t<RealT**> gauss_wts (dcurr, max_gauss_pts,max_gauss_pts); dcurr += size3;
949-
view_t<RealT**> t_lay_limited(dcurr, ncol, nlay); dcurr += size4;
950-
view_t<RealT**> t_lev_limited(dcurr, ncol, nlay+1); dcurr += size5;
951-
view_t<RealT***> col_gas (dcurr, ncol, nlay, k_dist.get_ngas()+1); dcurr += size6;
952-
view_t<RealT***> lw_optics_tau_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
953-
view_t<RealT***> lw_noaero_tau_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
954-
view_t<RealT***> lay_source_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
955-
view_t<RealT***> lev_source_inc_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
956-
view_t<RealT***> lev_source_dec_mem(dcurr, ncol, nlay, ngpt); dcurr += size7;
957-
view_t<RealT**> sfc_source_mem(dcurr, ncol, ngpt); dcurr += size8;
954+
view_t<RealT*> t_sfc (dcurr, ncol); dcurr += size1;
955+
view_t<RealT**> emis_sfc (dcurr, nbnd, ncol); dcurr += size2;
956+
view_t<RealT**> gauss_Ds (dcurr, max_gauss_pts, max_gauss_pts); dcurr += size3;
957+
view_t<RealT**> gauss_wts (dcurr, max_gauss_pts, max_gauss_pts); dcurr += size3;
958+
view_t<RealT**> t_lay_limited (dcurr, ncol, nlay); dcurr += size4;
959+
view_t<RealT**> t_lev_limited (dcurr, ncol, nlay+1); dcurr += size5;
960+
view_t<RealT***> col_gas (dcurr, ncol, nlay, k_dist.get_ngas()+1); dcurr += size6;
961+
view_t<RealT***> lw_optics_tau_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
962+
view_t<RealT***> lw_noaero_tau_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
963+
view_t<RealT***> lay_source_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
964+
view_t<RealT***> lev_source_inc_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
965+
view_t<RealT***> lev_source_dec_mem (dcurr, ncol, nlay, ngpt); dcurr += size7;
966+
view_t<RealT**> sfc_source_mem (dcurr, ncol, ngpt); dcurr += size8;
967+
assert(dcurr - data.data() == total_size);
958968

959969
const int int_size1 = 2*nbnd;
960970
const int int_size2 = ngpt;
961971
const int total_int_size = 3 * (int_size1 + int_size2);
962972
auto int_data = pool_t::template alloc_and_init<int>(total_int_size); int *dcurr_int = int_data.data();
963973

964-
auto lw_optics_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd);
965-
auto lw_optics_gpt2band_mem = view_t<int*>(dcurr_int, ngpt);
966-
auto lw_noaero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd);
967-
auto lw_noaero_gpt2band_mem = view_t<int*>(dcurr_int, ngpt);
968-
auto lw_source_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd);
969-
auto lw_source_gpt2band_mem = view_t<int*>(dcurr_int, ngpt);
974+
auto lw_optics_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
975+
auto lw_optics_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size2;
976+
auto lw_noaero_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
977+
auto lw_noaero_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size2;
978+
auto lw_source_band2gpt_mem = view_t<int**>(dcurr_int, 2, nbnd); dcurr_int += int_size1;
979+
auto lw_source_gpt2band_mem = view_t<int*> (dcurr_int, ngpt); dcurr_int += int_size2;
980+
assert(dcurr_int - int_data.data() == total_int_size);
970981

971982
// Associate local pointers for fluxes
972983
auto &flux_up = fluxes.flux_up;
@@ -1002,6 +1013,7 @@ static void rrtmgp_lw(
10021013
// Allocate space for optical properties
10031014
optical_props1_t optics;
10041015
optics.alloc_1scl_no_alloc(ncol, nlay, k_dist, lw_optics_band2gpt_mem, lw_optics_gpt2band_mem, lw_optics_tau_mem);
1016+
10051017
optical_props1_t optics_no_aerosols;
10061018
if (extra_clnsky_diag) {
10071019
// Allocate space for optical properties (no aerosols)
@@ -1490,6 +1502,7 @@ static optical_props2_t get_subsampled_clouds(
14901502
cldfrac_rad(icol,ilay) = cld(icol,ilay);
14911503
}
14921504
}));
1505+
14931506
// Get subcolumn cloud mask; note that get_subcolumn_mask exposes overlap assumption as an option,
14941507
// but the only currently supported options are 0 (trivial all-or-nothing cloud) or 1 (max-rand),
14951508
// so overlap has not been exposed as an option beyond this subcolumn. In the future, we should
@@ -1503,6 +1516,7 @@ static optical_props2_t get_subsampled_clouds(
15031516
seeds(icol) = 1e9 * (p_lay(icol,nlay-1) - int(p_lay(icol,nlay-1)));
15041517
}));
15051518
get_subcolumn_mask(ncol, nlay, ngpt, cldfrac_rad, overlap, seeds, cldmask);
1519+
15061520
// Assign optical properties to subcolumns (note this implements MCICA)
15071521
auto gpoint_bands = kdist.get_gpoint_bands();
15081522
TIMED_KERNEL(Kokkos::parallel_for(MDRP::template get<3>({ncol,nlay,ngpt}), KOKKOS_LAMBDA(int icol, int ilay, int igpt) {

0 commit comments

Comments
 (0)