Skip to content
32 changes: 32 additions & 0 deletions fme/ace/aggregator/inference/test_zonal_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,38 @@ def test_zonal_mean_time_coarsening(n_time):
)


def test_zonal_mean_time_coarsening_25_steps_per_window():
"""Regression: forward_steps_in_memory=25 with factor=2 causes 13 vs 12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Regression: forward_steps_in_memory=25 with factor=2 causes 13 vs 12
"""
Regression: forward_steps_in_memory=25 with factor=2 causes 13 vs 12

size mismatch.

The bug occurs when there is no buffer and i_time_start is not aligned to the
coarsening factor. E.g. i_time_start=1, 25 steps: original code uses
time_slice length (1+25)//2 - 0 = 13 but _coarsen_tensor(25 steps) returns 12.
We must start at i_time_start=1 so the first batch has no buffer.
"""
n_sample, ny, nx = 3, 10, 20
n_time = 100
window_steps = 25
# Factor 2: first batch at i_time_start=1 has no buffer -> 13 vs 12 mismatch
agg = ZonalMeanAggregator(
zonal_mean,
n_timesteps=n_time,
zonal_mean_max_size=50, # ceil(100/50)=2
)
assert agg.time_coarsening_factor == 2
# Start at 1 so first batch triggers the bug (no buffer, misaligned i_time_start)
for i_time_start in range(1, n_time, window_steps):
steps = min(window_steps, n_time - i_time_start)
arr = torch.arange(ny, dtype=torch.float32, device=get_device())
arr = arr[None, None, :, None].expand(n_sample, steps, ny, nx)
agg.record_batch(
{"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=i_time_start
)
assert agg._target_data is not None
assert agg._gen_data is not None
assert agg._target_data["a"].shape[1] == n_time // agg.time_coarsening_factor


@pytest.mark.parametrize("zonal_mean_max_size", [4, 2**14, 2**16])
def test_zonal_mean_time_coarsening_override(zonal_mean_max_size):
n_time = 2**16
Expand Down
45 changes: 24 additions & 21 deletions fme/ace/aggregator/inference/zonal_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,17 @@ def record_batch(
# if we have a buffer that means we didnt record the last batch
if self._buffer_gen:
start_idx = self.last_step
buffer_size = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't confidently review this file, the code is hard for me to understand. The for-if statements are nested up to 4 levels, and there's communication I don't understand happening between the first outer loop (for target_data) and the second outer loop (for gen_data). I also don't really know what "buffer" means in this context.

We chatted on Slack about potentially refactoring this to a helper function and cleaning it up before merging this, since this isn't a bug that usually affects us.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the buffer is some residual that gets patched on to the start of the next coarsen? If so, a model like

def coarsen(data: T1, prefix_data: T2 | None) -> tuple[T1, T2 | None]:

for some suitable types T1 and T2 would be a good pattern to use

i_time_start + window_steps
) - self.last_step * self.time_coarsening_factor
else:
start_idx = i_time_start // self.time_coarsening_factor

time_slice = slice(
start_idx,
(i_time_start + window_steps) // self.time_coarsening_factor,
)

self.last_step = (i_time_start + window_steps) // self.time_coarsening_factor

buffer_size = (
i_time_start + window_steps
) - self.last_step * self.time_coarsening_factor
buffer_size = (i_time_start + window_steps) - (
(i_time_start + window_steps) // self.time_coarsening_factor
) * (self.time_coarsening_factor)

buffer = {}
time_slice = None
for name, tensor in target_data.items():
if name in self._target_data:
if self._buffer_target:
Expand All @@ -198,11 +194,19 @@ def record_batch(
],
dim=self._time_dim,
)
self._target_data[name][:, time_slice, :] += self._coarsen_tensor(
self._zonal_mean(tensor)
)
if buffer_size > 0:
buffer[name] = tensor[:, -buffer_size:, :]
coarsened = self._coarsen_tensor(self._zonal_mean(tensor))
if time_slice is None:
# Use actual coarsened size so slice matches when i_time_start is
# misaligned with coarsening factor
n_coarsened = coarsened.shape[self._time_dim]
time_slice = slice(start_idx, start_idx + n_coarsened)
self.last_step = start_idx + n_coarsened
new_buffer_size = (
i_time_start + window_steps
) - self.last_step * self.time_coarsening_factor
self._target_data[name][:, time_slice, :] += coarsened
if new_buffer_size > 0:
buffer[name] = tensor[:, -new_buffer_size:, :]
self._buffer_target = buffer

buffer = {}
Expand All @@ -216,11 +220,10 @@ def record_batch(
],
dim=self._time_dim,
)
self._gen_data[name][:, time_slice, :] += self._coarsen_tensor(
self._zonal_mean(tensor)
)
if buffer_size > 0:
buffer[name] = tensor[:, -buffer_size:, :]
coarsened = self._coarsen_tensor(self._zonal_mean(tensor))
self._gen_data[name][:, time_slice, :] += coarsened
if new_buffer_size > 0:
buffer[name] = tensor[:, -new_buffer_size:, :]
self._buffer_gen = buffer

self._n_batches[:, time_slice, :] += 1
Expand Down