Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Set up python
uses: actions/setup-python@v5
with:
python-version: 3.11
python-version: 3.13
- name: Runs pre-commit
run: |
pip install pre-commit
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,4 @@ cython_debug/

# netCDF files
*.nc
*.nc.*
70 changes: 70 additions & 0 deletions modelmimic/config/MVKO.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

[pass_b4b]
size = [1, 5, 2]
dims = ["time", "nCells", "nVertLevels"]
variables = [
"timeClimatology_avg_activeTracers_salinity",
"timeClimatology_avg_activeTracers_temperature",
"timeClimatology_avg_ssh",
"timeClimatology_avg_velocityMeridional",
"timeClimatology_avg_velocityZonal",
]
ntimes = 2
timestep = "yearly-year-month-day"
ninst = 30
hist_file_fmt = "mpaso_{inst:04d}.hist.am.timeSeriesStatsClimatology.{time}"
ensembles = ["baseline", "test"]

[pass_b4b.baseline]
ensemble = { seed = true }
name = "mvko_pass_b4b_base"

[pass_b4b.test]
ensemble = { seed = true }
name = "mvko_pass_b4b_test"

[pass_nb4b]
size = [1, 5, 2]
dims = ["time", "nCells", "nVertLevels"]
variables = [
"timeClimatology_avg_activeTracers_salinity",
"timeClimatology_avg_activeTracers_temperature",
"timeClimatology_avg_ssh",
"timeClimatology_avg_velocityMeridional",
"timeClimatology_avg_velocityZonal",
]

ntimes = 2
timestep = "yearly-year-month-day"
ninst = 30
hist_file_fmt = "mpaso_{inst:04d}.hist.am.timeSeriesStatsClimatology.{time}"
ensembles = ["baseline", "test"]
[pass_nb4b.baseline]
ensemble = { seed = false }
name = "mvko_pass_nb4b_base"
[pass_nb4b.test]
ensemble = { seed = false }
name = "mvko_pass_nb4b_test"

[fail]
size = [1, 5, 2]
dims = ["time", "nCells", "nVertLevels"]
variables = [
"timeClimatology_avg_activeTracers_salinity",
"timeClimatology_avg_activeTracers_temperature",
"timeClimatology_avg_ssh",
"timeClimatology_avg_velocityMeridional",
"timeClimatology_avg_velocityZonal",
]

ntimes = 2
timestep = "yearly-year-month-day"
ninst = 30
hist_file_fmt = "mpaso_{inst:04d}.hist.am.timeSeriesStatsClimatology.{time}"
ensembles = ["baseline", "test"]
[fail.baseline]
ensemble = { seed = false, popmean = 1.0 }
name = "mvko_fail_base"
[fail.test]
ensemble = { seed = false, popmean = 2.0 }
name = "mvko_fail_test"
57 changes: 57 additions & 0 deletions modelmimic/config/MVKxx.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

[pass_b4b]
size = [3, 5]
variables = [
"SW_flux_up_at_model_top",
"LW_flux_up_at_model_top",
"SW_flux_dn_at_model_top",
"LiqWaterPath",
]
ntimes = 12
ninst = 30
hist_file_fmt = "scream_{inst:04d}.h.{time}"
ensembles = ["baseline", "test"]
[pass_b4b.baseline]
ensemble = { seed = true }
name = "mvkxx_pass_b4b_base"
[pass_b4b.test]
ensemble = { seed = true }
name = "mvkxx_pass_b4b_test"

[pass_nb4b]
size = [3, 5]
variables = [
"SW_flux_up_at_model_top",
"LW_flux_up_at_model_top",
"SW_flux_dn_at_model_top",
"LiqWaterPath",
]
ntimes = 12
ninst = 30
hist_file_fmt = "scream_{inst:04d}.h.{time}"
ensembles = ["baseline", "test"]
[pass_nb4b.baseline]
ensemble = { seed = false }
name = "mvkxx_pass_nb4b_base"
[pass_nb4b.test]
ensemble = { seed = false }
name = "mvkxx_pass_nb4b_test"

[fail]
size = [3, 5]
variables = [
"SW_flux_up_at_model_top",
"LW_flux_up_at_model_top",
"SW_flux_dn_at_model_top",
"LiqWaterPath",
]
ntimes = 12
ninst = 30
hist_file_fmt = "scream_{inst:04d}.h.{time}"
ensembles = ["baseline", "test"]
[fail.baseline]
ensemble = { seed = false, popmean = 1.0 }
name = "mvkxx_fail_base"
[fail.test]
ensemble = { seed = false, popmean = 2.0 }
name = "mvkxx_fail_test"
82 changes: 63 additions & 19 deletions modelmimic/mimic.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ def gen_hybrid_pres(
in the TSC (time step convergence) test

"""
nlev, ncol = size
if len(size) == 2:
nlev, ncol = size
elif len(size) == 3:
_, nlev, ncol = size
p_0 = 100000

if seed is not None:
Expand Down Expand Up @@ -262,6 +265,10 @@ def make_ensemble(
_area, _ = gen_field(self.size[-1:])
_area = norm(_area)

# Last dimension of self.size should be number of columns
lat = np.linspace(-np.pi / 2, np.pi / 2, self.size[-1])
lon = np.linspace(0, 2 * np.pi, self.size[-1])

# 2/3 ocean, 1/3 land
_landfrac = np.zeros(self.size[-1])
_landfrac[: self.size[-1] // 3] = 1.0
Expand All @@ -288,6 +295,8 @@ def make_ensemble(
ens_data[iinst]["hybi"] = hybi
ens_data[iinst]["LANDFRAC"] = _landfrac
ens_data[iinst]["area"] = _area
ens_data[iinst]["lon"] = lon
ens_data[iinst]["lat"] = lat

self.ens_data = ens_data

Expand Down Expand Up @@ -319,6 +328,16 @@ def get_file_times(
f"{sim_start}-{istep:05d}"
for istep in range(0, step_mult * self.ntimes, step_mult)
]
elif timestep.lower() == "yearly-year-month-day":
file_times = pd.date_range(
start=sim_start, periods=self.ntimes, freq="YS", unit="s"
)
file_times = [_time.strftime("%04Y-%m-%d") for _time in file_times]
elif timestep.lower() == "year":
file_times = pd.date_range(
start=sim_start, periods=self.ntimes, freq="YS", unit="s"
)
file_times = [_time.strftime("%04Y") for _time in file_times]
else:
raise NotImplementedError(f"FREQ: {timestep} NOT YET IMPLEMENTED")
return file_times
Expand All @@ -330,7 +349,9 @@ def write_to_nc(
timestep: str = "month",
step_mult: int = 1,
hist_file_pattern: str = "eam_{inst:04d}.h0.{time}",
casename: bool = True,
file_suffix: str = None,
**kwargs,
):
"""Write generated data to a netCDF file."""
# Make an xarray.Dataset for each instance so it can be written to a file.
Expand Down Expand Up @@ -361,33 +382,46 @@ def write_to_nc(
if file_suffix is not None:
extn = f"{extn}.{file_suffix}"

_time_suffix = kwargs.get("time_suffix", None)
if _time_suffix is None:
_time_suffix = ""

# TODO: Parallelize this
for iinst in self.ens_data:
for itime in range(self.ntimes):
_outfile_name = hist_file_pattern.format(
inst=(iinst + 1), time=file_times[itime]
inst=(iinst + 1), time=file_times[itime], time_suffix=_time_suffix
)
data_vars = {}
for _var in self.vars:
data_vars[_var] = (self.dims, self.ens_data[iinst][_var][itime])
data_vars["PS"] = (self.dims[-1:], self.ens_data[iinst]["PS"])
data_vars["area"] = (self.dims[-1:], self.ens_data[iinst]["area"])
data_vars["LANDFRAC"] = (
("time", self.dims[-1]),
np.expand_dims(self.ens_data[iinst]["LANDFRAC"], 0),
)

data_vars["hyai"] = (self.dims[:1], self.ens_data[iinst]["hyai"])
data_vars["hybi"] = (self.dims[:1], self.ens_data[iinst]["hybi"])
data_vars["P0"] = self.ens_data[iinst]["P0"]

_dset = xr.Dataset(
data_vars=data_vars,
coords=coords,
attrs={**ds_attrs, "inst": iinst},
)
# 1-D variables only x/y
for _var in ["PS", "area", "lon", "lat"]:
data_vars[_var] = (self.dims[-1:], self.ens_data[iinst][_var])

if "mvko" not in self.name:
data_vars["LANDFRAC"] = (
("time", self.dims[-1]),
np.expand_dims(self.ens_data[iinst]["LANDFRAC"], 0),
)
data_vars["hyai"] = (self.dims[:1], self.ens_data[iinst]["hyai"])
data_vars["hybi"] = (self.dims[:1], self.ens_data[iinst]["hybi"])
data_vars["P0"] = self.ens_data[iinst]["P0"]
try:
_dset = xr.Dataset(
data_vars=data_vars,
coords=coords,
attrs={**ds_attrs, "inst": iinst},
)
except ValueError:
breakpoint()
ens_xarray[iinst] = _dset
_out_file = Path(out_path, f"{self.name}.{_outfile_name}.{extn}")
if casename:
_filename = f"{self.name}.{_outfile_name}.{extn}"
else:
_filename = f"{_outfile_name}.{extn}"
_out_file = Path(out_path, _filename)
_dset.to_netcdf(
_out_file,
unlimited_dims="time",
Expand Down Expand Up @@ -441,27 +475,37 @@ def main(args):
variables=_test["variables"],
ntimes=_test["ntimes"],
ninst=_test["ninst"],
dims=_test.get("dims", ("nlev", "ncol")),
)
mimic_case.make_ensemble(**_test[case]["ensemble"])

# Defaults for file output
out_path = Path("./data", _testname, mimic_case.name)
file_suffix = None
step_mult = 1
timestep = "month"
# Will be overridden if set in "to_nc" below
timestep = _test.get("timestep", "month")
casename = True

if "to_nc" in _test[case]:
# If the to_nc has variables in config file, use that, otherwise use default
out_path = Path(_test[case]["to_nc"].get("out_path", out_path))
file_suffix = _test[case]["to_nc"].get("file_suffix", file_suffix)
timestep = _test[case]["to_nc"].get("timestep", timestep)
step_mult = _test[case]["to_nc"].get("step_mult", step_mult)
casename = _test[case]["to_nc"].get("casename", casename)
time_suffix = _test[case]["to_nc"].get("time_suffix", None)
else:
time_suffix = ""

mimic_case.write_to_nc(
out_path=out_path,
hist_file_pattern=_test["hist_file_fmt"],
casename=casename,
file_suffix=file_suffix,
timestep=timestep,
step_mult=step_mult,
time_suffix=time_suffix,
)
out_dirs[_testname][case] = out_path

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ classifiers = [
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
]
dependencies = ["numpy", "xarray", "netCDF4", "pandas", "toml"]
dynamic = ["version"]
Expand Down
32 changes: 32 additions & 0 deletions tests/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,38 @@ def test_file_times():
]
assert _times == step_times

_times = gen.get_file_times("0001-01-01", timestep="year")
assert _times == [
"0001",
"0002",
"0003",
"0004",
"0005",
"0006",
"0007",
"0008",
"0009",
"0010",
"0011",
"0012",
]

_times = gen.get_file_times("0001-01-01", timestep="yearly-year-month-day")
assert _times == [
"0001-01-01",
"0002-01-01",
"0003-01-01",
"0004-01-01",
"0005-01-01",
"0006-01-01",
"0007-01-01",
"0008-01-01",
"0009-01-01",
"0010-01-01",
"0011-01-01",
"0012-01-01",
]


def test_to_netcdf():
ntimes = 12
Expand Down