diff --git a/fme/ace/aggregator/inference/test_zonal_mean.py b/fme/ace/aggregator/inference/test_zonal_mean.py index 6fc159cbb..9cc7f9771 100644 --- a/fme/ace/aggregator/inference/test_zonal_mean.py +++ b/fme/ace/aggregator/inference/test_zonal_mean.py @@ -176,6 +176,38 @@ def test_zonal_mean_time_coarsening(n_time): ) +def test_zonal_mean_time_coarsening_25_steps_per_window(): + """Regression: forward_steps_in_memory=25 with factor=2 causes 13 vs 12 + size mismatch. + + The bug occurs when there is no buffer and i_time_start is not aligned to the + coarsening factor. E.g. i_time_start=1, 25 steps: original code uses + time_slice length (1+25)//2 - 0 = 13 but _coarsen_tensor(25 steps) returns 12. + We must start at i_time_start=1 so the first batch has no buffer. + """ + n_sample, ny, nx = 3, 10, 20 + n_time = 100 + window_steps = 25 + # Factor 2: first batch at i_time_start=1 has no buffer -> 13 vs 12 mismatch + agg = ZonalMeanAggregator( + zonal_mean, + n_timesteps=n_time, + zonal_mean_max_size=50, # ceil(100/50)=2 + ) + assert agg.time_coarsening_factor == 2 + # Start at 1 so first batch triggers the bug (no buffer, misaligned i_time_start) + for i_time_start in range(1, n_time, window_steps): + steps = min(window_steps, n_time - i_time_start) + arr = torch.arange(ny, dtype=torch.float32, device=get_device()) + arr = arr[None, None, :, None].expand(n_sample, steps, ny, nx) + agg.record_batch( + {"a": arr}, {"a": arr}, {"a": arr}, {"a": arr}, i_time_start=i_time_start + ) + assert agg._target_data is not None + assert agg._gen_data is not None + assert agg._target_data["a"].shape[1] == n_time // agg.time_coarsening_factor + + @pytest.mark.parametrize("zonal_mean_max_size", [4, 2**14, 2**16]) def test_zonal_mean_time_coarsening_override(zonal_mean_max_size): n_time = 2**16 diff --git a/fme/ace/aggregator/inference/zonal_mean.py b/fme/ace/aggregator/inference/zonal_mean.py index 9a2058c01..8529ebcf7 100644 --- a/fme/ace/aggregator/inference/zonal_mean.py +++ b/fme/ace/aggregator/inference/zonal_mean.py @@ -173,21 +173,17 @@ def record_batch( # if we have a buffer that means we didnt record the last batch if self._buffer_gen: start_idx = self.last_step + buffer_size = ( + i_time_start + window_steps + ) - self.last_step * self.time_coarsening_factor else: start_idx = i_time_start // self.time_coarsening_factor - - time_slice = slice( - start_idx, - (i_time_start + window_steps) // self.time_coarsening_factor, - ) - - self.last_step = (i_time_start + window_steps) // self.time_coarsening_factor - - buffer_size = ( - i_time_start + window_steps - ) - self.last_step * self.time_coarsening_factor + buffer_size = (i_time_start + window_steps) - ( + (i_time_start + window_steps) // self.time_coarsening_factor + ) * (self.time_coarsening_factor) buffer = {} + time_slice = None for name, tensor in target_data.items(): if name in self._target_data: if self._buffer_target: @@ -198,11 +194,19 @@ def record_batch( ], dim=self._time_dim, ) - self._target_data[name][:, time_slice, :] += self._coarsen_tensor( - self._zonal_mean(tensor) - ) - if buffer_size > 0: - buffer[name] = tensor[:, -buffer_size:, :] + coarsened = self._coarsen_tensor(self._zonal_mean(tensor)) + if time_slice is None: + # Use actual coarsened size so slice matches when i_time_start is + # misaligned with coarsening factor + n_coarsened = coarsened.shape[self._time_dim] + time_slice = slice(start_idx, start_idx + n_coarsened) + self.last_step = start_idx + n_coarsened + new_buffer_size = ( + i_time_start + window_steps + ) - self.last_step * self.time_coarsening_factor + self._target_data[name][:, time_slice, :] += coarsened + if new_buffer_size > 0: + buffer[name] = tensor[:, -new_buffer_size:, :] self._buffer_target = buffer buffer = {} @@ -216,11 +220,10 @@ def record_batch( ], dim=self._time_dim, ) - self._gen_data[name][:, time_slice, :] += self._coarsen_tensor( - self._zonal_mean(tensor) - ) - if buffer_size > 0: - buffer[name] = tensor[:, -buffer_size:, :] + coarsened = self._coarsen_tensor(self._zonal_mean(tensor)) + self._gen_data[name][:, time_slice, :] += coarsened + if new_buffer_size > 0: + buffer[name] = tensor[:, -new_buffer_size:, :] self._buffer_gen = buffer self._n_batches[:, time_slice, :] += 1