Skip to content

Commit 3aaf335

Browse files
committed
EAMxx: add horizontal average diagnostic field
1 parent b13a08f commit 3aaf335

File tree

6 files changed

+311
-0
lines changed

6 files changed

+311
-0
lines changed

components/eamxx/src/diagnostics/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ set(DIAGNOSTIC_SRCS
88
field_at_height.cpp
99
field_at_level.cpp
1010
field_at_pressure_level.cpp
11+
horiz_avg.cpp
1112
longwave_cloud_forcing.cpp
1213
number_path.cpp
1314
potential_temperature.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include "diagnostics/horiz_avg.hpp"
2+
3+
#include "share/field/field_utils.hpp"
4+
5+
namespace scream {
6+
7+
HorizAvgDiag::HorizAvgDiag(const ekat::Comm &comm,
8+
const ekat::ParameterList &params)
9+
: AtmosphereDiagnostic(comm, params) {
10+
const auto &fname = m_params.get<std::string>("field_name");
11+
m_diag_name = fname + "_horiz_avg";
12+
}
13+
14+
void HorizAvgDiag::set_grids(
15+
const std::shared_ptr<const GridsManager> grids_manager) {
16+
const auto &fn = m_params.get<std::string>("field_name");
17+
const auto &gn = m_params.get<std::string>("grid_name");
18+
const auto g = grids_manager->get_grid("Physics");
19+
20+
add_field<Required>(fn, gn);
21+
22+
// first clone the area unscaled, we will scale it later in initialize_impl
23+
m_scaled_area = g->get_geometry_data("area").clone();
24+
}
25+
26+
void HorizAvgDiag::initialize_impl(const RunType /*run_type*/) {
27+
using namespace ShortFieldTagsNames;
28+
const auto &f = get_fields_in().front();
29+
const auto &fid = f.get_header().get_identifier();
30+
const auto &layout = fid.get_layout();
31+
32+
EKAT_REQUIRE_MSG(layout.rank() >= 1 && layout.rank() <= 3,
33+
"Error! Field rank not supported by HorizAvgDiag.\n"
34+
" - field name: " +
35+
fid.name() +
36+
"\n"
37+
" - field layout: " +
38+
layout.to_string() + "\n");
39+
EKAT_REQUIRE_MSG(layout.tags()[0] == COL,
40+
"Error! HorizAvgDiag diagnostic expects a layout starting "
41+
"with the 'COL' tag.\n"
42+
" - field name : " +
43+
fid.name() +
44+
"\n"
45+
" - field layout: " +
46+
layout.to_string() + "\n");
47+
48+
FieldIdentifier d_fid(m_diag_name, layout.clone().strip_dim(COL),
49+
fid.get_units(), fid.get_grid_name());
50+
m_diagnostic_output = Field(d_fid);
51+
m_diagnostic_output.allocate_view();
52+
53+
// scale the area field
54+
auto total_area = field_sum<Real>(m_scaled_area, &m_comm);
55+
m_scaled_area.scale(sp(1.0) / total_area);
56+
}
57+
58+
void HorizAvgDiag::compute_diagnostic_impl() {
59+
const auto &f = get_fields_in().front();
60+
const auto &d = m_diagnostic_output;
61+
// Call the horiz_contraction impl that will take care of everything
62+
horiz_contraction<Real>(d, f, m_scaled_area, &m_comm);
63+
}
64+
65+
} // namespace scream
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#ifndef EAMXX_HORIZ_AVERAGE_HPP
2+
#define EAMXX_HORIZ_AVERAGE_HPP
3+
4+
#include "share/atm_process/atmosphere_diagnostic.hpp"
5+
6+
namespace scream {
7+
8+
/*
9+
* This diagnostic will calculate the area-weighted average of a field
10+
* across the COL tag dimension, producing an N-1 dimensional field
11+
* that is area-weighted average of the input field.
12+
*/
13+
14+
class HorizAvgDiag : public AtmosphereDiagnostic {
15+
public:
16+
// Constructors
17+
HorizAvgDiag(const ekat::Comm &comm, const ekat::ParameterList &params);
18+
19+
// The name of the diagnostic
20+
std::string name() const { return m_diag_name; }
21+
22+
// Set the grid
23+
void set_grids(const std::shared_ptr<const GridsManager> grids_manager);
24+
25+
protected:
26+
#ifdef KOKKOS_ENABLE_CUDA
27+
public:
28+
#endif
29+
void compute_diagnostic_impl();
30+
31+
protected:
32+
void initialize_impl(const RunType /*run_type*/);
33+
34+
// Name of each field (because the diagnostic impl is generic)
35+
std::string m_diag_name;
36+
37+
// Need area field, let's store it scaled by its norm
38+
Field m_scaled_area;
39+
};
40+
41+
} // namespace scream
42+
43+
#endif // EAMXX_HORIZ_AVERAGE_HPP

components/eamxx/src/diagnostics/register_diagnostics.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "diagnostics/number_path.hpp"
2525
#include "diagnostics/aerocom_cld.hpp"
2626
#include "diagnostics/atm_backtend.hpp"
27+
#include "diagnostics/horiz_avg.hpp"
2728

2829
namespace scream {
2930

@@ -51,6 +52,7 @@ inline void register_diagnostics () {
5152
diag_factory.register_product("NumberPath",&create_atmosphere_diagnostic<NumberPathDiagnostic>);
5253
diag_factory.register_product("AeroComCld",&create_atmosphere_diagnostic<AeroComCld>);
5354
diag_factory.register_product("AtmBackTendDiag",&create_atmosphere_diagnostic<AtmBackTendDiag>);
55+
diag_factory.register_product("HorizAvgDiag",&create_atmosphere_diagnostic<HorizAvgDiag>);
5456
}
5557

5658
} // namespace scream

components/eamxx/src/diagnostics/tests/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,6 @@ CreateDiagTest(aerocom_cld "aerocom_cld_test.cpp")
7171

7272
# Test atm_tend
7373
CreateDiagTest(atm_backtend "atm_backtend_test.cpp")
74+
75+
# Test horizontal averaging
76+
CreateDiagTest(horiz_avg "horiz_avg_test.cpp")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
#include "catch2/catch.hpp"
2+
#include "diagnostics/register_diagnostics.hpp"
3+
#include "share/field/field_utils.hpp"
4+
#include "share/grid/mesh_free_grids_manager.hpp"
5+
#include "share/util/scream_setup_random_test.hpp"
6+
#include "share/util/scream_universal_constants.hpp"
7+
8+
namespace scream {
9+
10+
std::shared_ptr<GridsManager> create_gm(const ekat::Comm &comm, const int ncols,
11+
const int nlevs) {
12+
const int num_global_cols = ncols * comm.size();
13+
14+
using vos_t = std::vector<std::string>;
15+
ekat::ParameterList gm_params;
16+
gm_params.set("grids_names", vos_t{"Point Grid"});
17+
auto &pl = gm_params.sublist("Point Grid");
18+
pl.set<std::string>("type", "point_grid");
19+
pl.set("aliases", vos_t{"Physics"});
20+
pl.set<int>("number_of_global_columns", num_global_cols);
21+
pl.set<int>("number_of_vertical_levels", nlevs);
22+
23+
auto gm = create_mesh_free_grids_manager(comm, gm_params);
24+
gm->build_grids();
25+
26+
return gm;
27+
}
28+
29+
TEST_CASE("horiz_avg") {
30+
using namespace ShortFieldTagsNames;
31+
using namespace ekat::units;
32+
using TeamPolicy = Kokkos::TeamPolicy<Field::device_t::execution_space>;
33+
using TeamMember = typename TeamPolicy::member_type;
34+
using KT = ekat::KokkosTypes<DefaultDevice>;
35+
using ESU = ekat::ExeSpaceUtils<typename KT::ExeSpace>;
36+
37+
// A numerical tolerance
38+
auto tol = std::numeric_limits<Real>::epsilon() * 100;
39+
40+
// A world comm
41+
ekat::Comm comm(MPI_COMM_WORLD);
42+
43+
// A time stamp
44+
util::TimeStamp t0({2024, 1, 1}, {0, 0, 0});
45+
46+
// Create a grids manager - single column for these tests
47+
constexpr int nlevs = 3;
48+
constexpr int dim3 = 4;
49+
const int ngcols = 6 * comm.size();
50+
51+
auto gm = create_gm(comm, ngcols, nlevs);
52+
auto grid = gm->get_grid("Physics");
53+
54+
// Input (randomized) qc
55+
FieldLayout scalar1d_layout{{COL}, {ngcols}};
56+
FieldLayout scalar2d_layout{{COL, LEV}, {ngcols, nlevs}};
57+
FieldLayout scalar3d_layout{{COL, CMP, LEV}, {ngcols, dim3, nlevs}};
58+
59+
FieldIdentifier qc1_fid("qc", scalar1d_layout, kg / kg, grid->name());
60+
FieldIdentifier qc2_fid("qc", scalar2d_layout, kg / kg, grid->name());
61+
FieldIdentifier qc3_fid("qc", scalar3d_layout, kg / kg, grid->name());
62+
63+
Field qc1(qc1_fid);
64+
Field qc2(qc2_fid);
65+
Field qc3(qc3_fid);
66+
67+
qc1.allocate_view();
68+
qc2.allocate_view();
69+
qc3.allocate_view();
70+
71+
// Construct random number generator stuff
72+
using RPDF = std::uniform_real_distribution<Real>;
73+
RPDF pdf(sp(0.0), sp(200.0));
74+
75+
auto engine = scream::setup_random_test();
76+
77+
// Construct the Diagnostics
78+
std::map<std::string, std::shared_ptr<AtmosphereDiagnostic>> diags;
79+
auto &diag_factory = AtmosphereDiagnosticFactory::instance();
80+
register_diagnostics();
81+
82+
ekat::ParameterList params;
83+
REQUIRE_THROWS(diag_factory.create("HorizAvgDiag", comm,
84+
params)); // No 'field_name' parameter
85+
86+
// Set time for qc and randomize its values
87+
qc1.get_header().get_tracking().update_time_stamp(t0);
88+
qc2.get_header().get_tracking().update_time_stamp(t0);
89+
qc3.get_header().get_tracking().update_time_stamp(t0);
90+
randomize(qc1, engine, pdf);
91+
randomize(qc2, engine, pdf);
92+
randomize(qc3, engine, pdf);
93+
94+
// Create and set up the diagnostic
95+
params.set("grid_name", grid->name());
96+
params.set<std::string>("field_name", "qc");
97+
auto diag1 = diag_factory.create("HorizAvgDiag", comm, params);
98+
auto diag2 = diag_factory.create("HorizAvgDiag", comm, params);
99+
auto diag3 = diag_factory.create("HorizAvgDiag", comm, params);
100+
diag1->set_grids(gm);
101+
diag2->set_grids(gm);
102+
diag3->set_grids(gm);
103+
104+
auto area = grid->get_geometry_data("area");
105+
106+
diag1->set_required_field(qc1);
107+
diag1->initialize(t0, RunType::Initial);
108+
109+
diag1->compute_diagnostic();
110+
auto diag1_f = diag1->get_diagnostic();
111+
112+
FieldIdentifier diag0_fid("qc_horiz_avg_manual",
113+
scalar1d_layout.clone().strip_dim(COL), kg / kg,
114+
grid->name());
115+
Field diag0(diag0_fid);
116+
diag0.allocate_view();
117+
auto diag0_v = diag0.get_view<Real>();
118+
119+
auto qc1_v = qc1.get_view<Real *>();
120+
auto area_v = area.get_view<const Real *>();
121+
122+
// calculate total area
123+
Real atot = field_sum<Real>(area, &comm);
124+
// calculate weighted avg
125+
Real wavg = sp(0.0);
126+
Kokkos::parallel_reduce(
127+
"HorizAvgDiag::compute_diagnostic_impl::weighted_sum", ngcols,
128+
KOKKOS_LAMBDA(const int icol, Real &local_wavg) {
129+
local_wavg += (area_v[icol] / atot) * qc1_v[icol];
130+
},
131+
wavg);
132+
Kokkos::deep_copy(diag0_v, wavg);
133+
134+
diag1_f.sync_to_host();
135+
auto diag1_v_h = diag1_f.get_view<Real, Host>();
136+
REQUIRE(diag1_v_h() == wavg);
137+
138+
// Try known cases
139+
// Set qc1_v to 1.0 to get weighted average of 1.0
140+
wavg = sp(1.0);
141+
Kokkos::deep_copy(qc1_v, wavg);
142+
diag1->compute_diagnostic();
143+
auto diag1_v2_host = diag1_f.get_view<Real, Host>();
144+
REQUIRE_THAT(diag1_v2_host(),
145+
Catch::Matchers::WithinRel(
146+
wavg, tol)); // Catch2's floating point comparison
147+
148+
// other diags
149+
// Set qc2_v to 5.0 to get weighted average of 5.0
150+
wavg = sp(5.0);
151+
auto qc2_v = qc2.get_view<Real **>();
152+
Kokkos::deep_copy(qc2_v, wavg);
153+
154+
diag2->set_required_field(qc2);
155+
diag2->initialize(t0, RunType::Initial);
156+
diag2->compute_diagnostic();
157+
auto diag2_f = diag2->get_diagnostic();
158+
159+
auto diag2_v_host = diag2_f.get_view<Real *, Host>();
160+
161+
for(int i = 0; i < nlevs; ++i) {
162+
REQUIRE_THAT(diag2_v_host(i), Catch::Matchers::WithinRel(wavg, tol));
163+
}
164+
165+
auto qc3_v = qc3.get_view<Real ***>();
166+
FieldIdentifier diag3_manual_fid("qc_horiz_avg_manual",
167+
scalar3d_layout.clone().strip_dim(COL),
168+
kg / kg, grid->name());
169+
Field diag3_manual(diag3_manual_fid);
170+
diag3_manual.allocate_view();
171+
auto diag3_manual_v = diag3_manual.get_view<Real **>();
172+
// calculate diag3_manual by hand
173+
auto p = ESU::get_default_team_policy(dim3 * nlevs, ngcols);
174+
Kokkos::parallel_for(
175+
"HorizAvgDiag::compute_diagnostic_impl::manual_diag3", p,
176+
KOKKOS_LAMBDA(const TeamMember &m) {
177+
const int idx = m.league_rank();
178+
const int j = idx / nlevs;
179+
const int k = idx % nlevs;
180+
Real sum = sp(0.0);
181+
Kokkos::parallel_reduce(
182+
Kokkos::TeamThreadRange(m, ngcols),
183+
[&](const int icol, Real &accum) {
184+
accum += (area_v(icol) / atot) * qc3_v(icol, j, k);
185+
},
186+
sum);
187+
Kokkos::single(Kokkos::PerTeam(m),
188+
[&]() { diag3_manual_v(j, k) = sum; });
189+
});
190+
diag3->set_required_field(qc3);
191+
diag3->initialize(t0, RunType::Initial);
192+
diag3->compute_diagnostic();
193+
auto diag3_f = diag3->get_diagnostic();
194+
REQUIRE(views_are_equal(diag3_f, diag3_manual));
195+
}
196+
197+
} // namespace scream

0 commit comments

Comments
 (0)