Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fme/core/dataset/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def __getitem__(self, idx: int) -> DatasetItem:
"""
time_slice = slice(idx, idx + self.sample_n_times)
time = xr.DataArray(self.all_times[time_slice].values, dims=["time"])
return (self._dummy_dict, time, self._labels, self._epoch)
data = {k: v[: len(time)] for k, v in self._dummy_dict.items()}
return (data, time, self._labels, self._epoch)

def set_epoch(self, epoch: int):
self._apply_sample_n_times(self._n_timesteps_schedule.get_value(epoch))
Expand Down
35 changes: 28 additions & 7 deletions fme/core/dataset/test_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,46 @@
from fme.core.dataset.schedule import IntSchedule


def test_dummy_dataset_has_expected_information():
def _make_dummy_dataset(
n_total_times: int = 10,
n_timesteps: int = 3,
) -> DummyDataset:
start_time = cftime.DatetimeGregorian(2000, 1, 1)
end_time = cftime.DatetimeGregorian(2000, 1, 10)
end_time = start_time + datetime.timedelta(days=n_total_times - 1)
timestep = datetime.timedelta(days=1)
n_timesteps = IntSchedule.from_constant(3)
schedule = IntSchedule.from_constant(n_timesteps)
horizontal_coordinates = LatLonCoordinates(
lat=torch.Tensor(np.arange(12)),
lon=torch.Tensor(np.arange(6)),
)
dataset = DummyDataset(
return DummyDataset(
start_time=start_time,
end_time=end_time,
timestep=timestep,
n_timesteps=n_timesteps,
n_timesteps=schedule,
horizontal_coordinates=horizontal_coordinates,
)


def test_dummy_dataset_has_expected_information():
dataset = _make_dummy_dataset(n_total_times=10, n_timesteps=3)
assert isinstance(dataset.all_times, xr.CFTimeIndex)
assert len(dataset.all_times) == 10
assert dataset.all_times[0] == start_time
assert dataset.all_times[-1] == end_time
assert dataset.all_times[0] == cftime.DatetimeGregorian(2000, 1, 1)
assert dataset.all_times[-1] == cftime.DatetimeGregorian(2000, 1, 10)
assert len(dataset) == 8
assert dataset[0][0]["__dummy__"].shape == (3, 12, 6)


def test_dummy_dataset_getitem_truncates_near_end():
dataset = _make_dummy_dataset(n_total_times=10, n_timesteps=7)
# 10 times total, sample_n_times=7, so valid full-window indices are 0..3
assert len(dataset) == 4
data, time, _, _ = dataset[0]
assert data["__dummy__"].shape[0] == 7
assert len(time) == 7

# Access beyond the last valid index — the last window is truncated
data, time, _, _ = dataset[5]
assert len(time) == 5
assert data["__dummy__"].shape[0] == 5
17 changes: 15 additions & 2 deletions fme/coupled/data_loading/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _get_batch_data(self, index) -> CoupledBatchData:
continue
i_window_start = i_start + self._start_indices[i_member]
samples.append(self._dataset[i_window_start])
return CoupledBatchData.collate_fn(
result = CoupledBatchData.collate_fn(
samples,
ocean_horizontal_dims=list(
self.properties.horizontal_coordinates.ocean.dims
Expand All @@ -182,6 +182,16 @@ def _get_batch_data(self, index) -> CoupledBatchData:
ocean_label_encoding=None,
atmosphere_label_encoding=None,
)
remaining_steps = self._total_coupled_steps - i_start
actual_coupled_steps = min(self._coupled_steps_in_memory, remaining_steps)
ocean_n_times = actual_coupled_steps + 1
atmos_n_times = actual_coupled_steps * self._properties.n_inner_steps + 1
return CoupledBatchData(
ocean_data=result.ocean_data.select_time_slice(slice(0, ocean_n_times)),
atmosphere_data=result.atmosphere_data.select_time_slice(
slice(0, atmos_n_times)
),
)

def __getitem__(self, index) -> CoupledBatchData:
dist = Distributed.get_instance()
Expand Down Expand Up @@ -247,9 +257,12 @@ def _make_dummy_ocean_forcing(
all_labels=set(),
)
ts = dataset_info.ocean.timestep
coupled_steps_in_memory = ocean_reqs.n_timesteps_schedule.get_value(0) - 1
n_windows = ceil(total_coupled_steps / coupled_steps_in_memory)
padded_coupled_steps = n_windows * coupled_steps_in_memory
ocean = DummyDataset(
start_time=initial_time.squeeze().values.flat[0],
end_time=initial_time.squeeze().values.flat[-1] + ts * total_coupled_steps,
end_time=initial_time.squeeze().values.flat[-1] + ts * padded_coupled_steps,
timestep=ts,
n_timesteps=ocean_reqs.n_timesteps_schedule,
horizontal_coordinates=dataset_info.ocean.horizontal_coordinates,
Expand Down
16 changes: 13 additions & 3 deletions fme/coupled/inference/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dataclasses
import os
import pathlib
from math import ceil

import numpy as np
import pytest
Expand Down Expand Up @@ -50,7 +51,9 @@ def _setup(
# create_coupled_data_on_disk already accounts for one initial condition
atmos_steps_per_ocean_step = 2
n_extra_initial_conditions = n_initial_conditions - 1
n_forward_times_ocean = n_coupled_steps + n_extra_initial_conditions
n_windows = ceil(n_coupled_steps / coupled_steps_in_memory)
padded_coupled_steps = n_windows * coupled_steps_in_memory
n_forward_times_ocean = padded_coupled_steps + n_extra_initial_conditions
n_forward_times_atmos = n_forward_times_ocean * atmos_steps_per_ocean_step
mock_data = create_coupled_data_on_disk(
data_dir,
Expand Down Expand Up @@ -220,15 +223,22 @@ def test_inference(
1,
],
)
@pytest.mark.parametrize(
("n_coupled_steps", "coupled_steps_in_memory"),
[
(2, 2),
(4, 3),
],
)
def test_inference_with_empty_ocean_forcing(
tmp_path: pathlib.Path,
atmosphere_times_offset: int,
n_coupled_steps: int,
coupled_steps_in_memory: int,
very_fast_only: bool,
):
if very_fast_only:
pytest.skip("Skipping non-fast tests")
n_coupled_steps = 2
coupled_steps_in_memory = 2
n_initial_conditions = 3
ocean_in_names = ["o_prog", "sst", "a_diag"]
ocean_out_names = ["o_prog", "sst", "o_diag"]
Expand Down