diff --git a/fme/coupled/loss.py b/fme/coupled/loss.py index de6ac841e..acd616dd8 100644 --- a/fme/coupled/loss.py +++ b/fme/coupled/loss.py @@ -29,9 +29,13 @@ class StepLossABC(abc.ABC): def effective_loss_scaling(self) -> TensorDict: ... @abc.abstractmethod - def step_is_optimized(self, step: int) -> bool: - """Returns True if the step is less than to the number of - steps contributing to the loss. + def step_is_optimized(self, step: int, n_total_steps: int | None = None) -> bool: + """Returns True if the given step should contribute to the loss. + + Args: + step: The step index to check. + n_total_steps: The total number of steps for this component. Required + when ``optimize_last_step_only`` is True. """ ... @@ -53,11 +57,16 @@ class LossContributionsConfig: starting from the first. weight: (optional) Weight applied to each step loss for the given realm. Each step contributes equally to the total loss. + optimize_last_step_only: If True, only the last step within the training + horizon defined by ``n_steps`` is optimized (i.e. contributes to the + loss and has gradients enabled). The optimized step index is + ``min(n_steps, n_total_steps) - 1``. """ n_steps: float = float("inf") weight: float = 1.0 + optimize_last_step_only: bool = False def build( self, @@ -69,6 +78,7 @@ def build( return LossContributions( n_steps=self.n_steps, weight=self.weight, + optimize_last_step_only=self.optimize_last_step_only, loss_obj=loss_obj, time_dim=time_dim, ) @@ -90,7 +100,7 @@ def __init__( def effective_loss_scaling(self) -> TensorDict: return self._loss.effective_loss_scaling - def step_is_optimized(self, step: int) -> bool: + def step_is_optimized(self, step: int, n_total_steps: int | None = None) -> bool: return False def __call__( @@ -104,30 +114,39 @@ def __init__( self, n_steps: float, weight: float, + optimize_last_step_only: bool, loss_obj: StepLoss, time_dim: int, ): self._loss = loss_obj self._n_steps = n_steps self._weight = weight + self._optimize_last_step_only = optimize_last_step_only self._time_dim = time_dim @property def effective_loss_scaling(self) -> TensorDict: return self._loss.effective_loss_scaling - def step_is_optimized(self, step: int) -> bool: - """Returns True if the step is less than to the number of steps and - weight is != 0. The first step number is assumed to be 0. + def step_is_optimized(self, step: int, n_total_steps: int | None = None) -> bool: + """Returns True if the step should contribute to the loss. + When ``optimize_last_step_only`` is False (default), returns True for + steps ``0`` through ``n_steps - 1``. When True, returns True only for + the step at index ``min(n_steps, n_total_steps) - 1``. """ - return step < self._n_steps and self._weight != 0.0 + if self._weight == 0.0: + return False + if self._optimize_last_step_only: + if n_total_steps is None: + raise ValueError( + "n_total_steps is required when optimize_last_step_only is True" + ) + last_optimized_step = min(self._n_steps, n_total_steps) - 1 + return step == last_optimized_step + return step < self._n_steps def __call__( self, prediction: StepPredictionABC, target_data: TensorMapping ) -> torch.Tensor: - if self.step_is_optimized(prediction.step): - return self._weight * self._loss( - prediction.data, target_data, prediction.step - ) - return torch.tensor(0.0, device=get_device()) + return self._weight * self._loss(prediction.data, target_data, prediction.step) diff --git a/fme/coupled/stepper.py b/fme/coupled/stepper.py index c4ab0ad70..50e9a315f 100644 --- a/fme/coupled/stepper.py +++ b/fme/coupled/stepper.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses import datetime import logging @@ -768,13 +769,22 @@ def effective_loss_scaling(self) -> CoupledTensorMapping: atmosphere=self._loss_objs["atmosphere"].effective_loss_scaling, ) + def step_is_optimized( + self, + realm: Literal["ocean", "atmosphere"], + step: int, + n_total_steps: int, + ) -> bool: + return self._loss_objs[realm].step_is_optimized(step, n_total_steps) + def __call__( self, prediction: ComponentStepPrediction, target_data: TensorMapping, + n_total_steps: int | None = None, ) -> torch.Tensor | None: loss_obj = self._loss_objs[prediction.realm] - if loss_obj.step_is_optimized(prediction.step): + if loss_obj.step_is_optimized(prediction.step, n_total_steps): return loss_obj(prediction, target_data) return None @@ -1589,55 +1599,89 @@ def train_on_batch( metrics = ComponentStepMetrics() optimization.set_mode(self.modules) + n_outer_steps = data.ocean_data.n_timesteps - self.n_ic_timesteps + n_total_atmos_steps = n_outer_steps * self.n_inner_steps with optimization.autocast(): output_generator = self._stepper.get_prediction_generator( input_data, data_ensemble, optimization, ) - output_list = [] - for gen_step in output_generator: - if gen_step.realm == "ocean": - # compute ocean step metrics + output_list: list[ComponentStepPrediction] = [] + output_iterator = iter(output_generator) + + for i_outer in range(n_outer_steps): + for i_inner in range(self.n_inner_steps): + global_atmos_step = i_outer * self.n_inner_steps + i_inner + optimize = self._loss.step_is_optimized( + "atmosphere", + global_atmos_step, + n_total_atmos_steps, + ) + grad_context = ( + contextlib.nullcontext() if optimize else torch.no_grad() + ) + with grad_context: + gen_step = next(output_iterator) + target_step = { + k: v.select(self.atmosphere.TIME_DIM, gen_step.step) + for k, v in atmos_forward_data.data.items() + } + if n_ensemble > 1: + gen_step_data_unfolded = unfold_ensemble_dim( + gen_step.data, n_ensemble + ) + gen_step_for_loss = ComponentStepPrediction( + realm=gen_step.realm, + data=gen_step_data_unfolded, + step=gen_step.step, + ) + else: + gen_step_for_loss = gen_step + target_step_ensemble = add_ensemble_dim(target_step) + step_loss = self._loss( + gen_step_for_loss, + target_step_ensemble, + n_total_steps=n_total_atmos_steps, + ) + if step_loss is not None: + label = f"loss/{gen_step.realm}_step_{gen_step.step}" + metrics.add_metric( + label, step_loss.detach(), gen_step.realm + ) + if step_loss is not None: + optimization.accumulate_loss(step_loss) + if n_ensemble > 1: + gen_step_to_append = gen_step_for_loss.detach(optimization) + else: + gen_step_to_append = gen_step.detach(optimization) + output_list.append(gen_step_to_append) + + optimize = self._loss.step_is_optimized( + "ocean", + i_outer, + n_outer_steps, + ) + grad_context = contextlib.nullcontext() if optimize else torch.no_grad() + with grad_context: + gen_step = next(output_iterator) target_step = { k: v.select(self.ocean.TIME_DIM, gen_step.step) for k, v in ocean_forward_data.data.items() } - # Ocean predictions don't need ensemble handling gen_step_for_loss = gen_step - else: - assert gen_step.realm == "atmosphere" - target_step = { - k: v.select(self.atmosphere.TIME_DIM, gen_step.step) - for k, v in atmos_forward_data.data.items() - } - # Unfold ensemble dimension for atmosphere loss computation - if n_ensemble > 1: - gen_step_data_unfolded = unfold_ensemble_dim( - gen_step.data, n_ensemble - ) - gen_step_for_loss = ComponentStepPrediction( - realm=gen_step.realm, - data=gen_step_data_unfolded, - step=gen_step.step, - ) - else: - gen_step_for_loss = gen_step - # Add ensemble dim to target (single member) - target_step_ensemble = add_ensemble_dim(target_step) - step_loss = self._loss( - gen_step_for_loss, - target_step_ensemble, - ) + target_step_ensemble = add_ensemble_dim(target_step) + step_loss = self._loss( + gen_step_for_loss, + target_step_ensemble, + n_total_steps=n_outer_steps, + ) + if step_loss is not None: + label = f"loss/{gen_step.realm}_step_{gen_step.step}" + metrics.add_metric(label, step_loss.detach(), gen_step.realm) if step_loss is not None: - label = f"loss/{gen_step.realm}_step_{gen_step.step}" - metrics.add_metric(label, step_loss.detach(), gen_step.realm) optimization.accumulate_loss(step_loss) - # For atmosphere with ensemble, append the unfolded step - if gen_step.realm == "atmosphere" and n_ensemble > 1: - gen_step_to_append = gen_step_for_loss.detach(optimization) - else: - gen_step_to_append = gen_step.detach(optimization) + gen_step_to_append = gen_step.detach(optimization) output_list.append(gen_step_to_append) loss = optimization.get_accumulated_loss().detach() diff --git a/fme/coupled/test_loss.py b/fme/coupled/test_loss.py index d72abbbec..0ace5ea5e 100644 --- a/fme/coupled/test_loss.py +++ b/fme/coupled/test_loss.py @@ -65,7 +65,7 @@ def __init__( def effective_loss_scaling(self): raise NotImplementedError() - def step_is_optimized(self, step: int) -> bool: + def step_is_optimized(self, step: int, n_total_steps: int | None = None) -> bool: return step < 2 def __call__( @@ -154,6 +154,111 @@ def mae_loss(gen, target, step: int): assert_tensor_dicts_close(metrics, expected_metrics) +def test_loss_contributions_optimize_last_step_only(steps_thru_atmos_7): + def mae_loss(gen, target, step: int): + loss = torch.tensor(0.0) + for key in gen: + loss += (gen[key] - target[key]).abs().mean() / (step + 1) + return loss + + n_total_atmos = 8 + n_total_ocean = 4 + atmos_loss_config = LossContributionsConfig( + n_steps=6, + weight=1 / 3, + optimize_last_step_only=True, + ) + mock_step_loss = Mock(spec=StepLoss, side_effect=mae_loss) + atmosphere_loss = atmos_loss_config.build( + loss_obj=mock_step_loss, + time_dim=1, + ) + ocean_loss_config = LossContributionsConfig( + n_steps=3, + optimize_last_step_only=True, + ) + ocean_loss = ocean_loss_config.build( + loss_obj=Mock(spec=StepLoss, side_effect=mae_loss), + time_dim=1, + ) + loss_obj = CoupledStepperTrainLoss( + ocean_loss=ocean_loss, + atmosphere_loss=atmosphere_loss, + ) + metrics = {} + expected_metrics: dict[str, torch.Tensor | None] = {} + for prediction, target_data in steps_thru_atmos_7: + label = f"{prediction.realm}_{prediction.step}" + if prediction.realm == "atmosphere": + n_total = n_total_atmos + else: + n_total = n_total_ocean + metrics[label] = loss_obj(prediction, target_data, n_total_steps=n_total) + if prediction.realm == "atmosphere": + # n_steps=6, n_total=8 → last optimized step = min(6,8)-1 = 5 + if prediction.step == 5: + expected_metrics[label] = ( + mae_loss(prediction.data, target_data, step=prediction.step) / 3 + ) + else: + expected_metrics[label] = None + elif prediction.realm == "ocean": + # n_steps=3, n_total=4 → last optimized step = min(3,4)-1 = 2 + if prediction.step == 2: + expected_metrics[label] = mae_loss( + prediction.data, target_data, step=prediction.step + ) + else: + expected_metrics[label] = None + assert_tensor_dicts_close(metrics, expected_metrics) + + +@pytest.mark.parametrize( + "n_steps, n_total_steps, expected_optimized_step", + [ + (6, 8, 5), + (10, 8, 7), + (float("inf"), 8, 7), + (1, 1, 0), + (3, 3, 2), + ], +) +def test_step_is_optimized_last_step_only( + n_steps, n_total_steps, expected_optimized_step +): + config = LossContributionsConfig(n_steps=n_steps, optimize_last_step_only=True) + loss = config.build( + loss_obj=Mock(spec=StepLoss), + time_dim=1, + ) + for step in range(n_total_steps): + result = loss.step_is_optimized(step, n_total_steps) + if step == expected_optimized_step: + assert result, f"step {step} should be optimized" + else: + assert not result, f"step {step} should not be optimized" + + +def test_step_is_optimized_last_step_only_requires_n_total_steps(): + config = LossContributionsConfig(optimize_last_step_only=True) + loss = config.build( + loss_obj=Mock(spec=StepLoss), + time_dim=1, + ) + with pytest.raises(ValueError, match="n_total_steps is required"): + loss.step_is_optimized(0) + + +def test_step_is_optimized_last_step_only_weight_zero(): + config = LossContributionsConfig(optimize_last_step_only=True, weight=0.0) + loss = config.build( + loss_obj=Mock(spec=StepLoss), + time_dim=1, + ) + # weight=0 → NullLossContributions, always returns False + assert not loss.step_is_optimized(0, n_total_steps=5) + + @pytest.mark.parametrize("ocean_config_kwargs", [{"n_steps": 0}, {"weight": 0.0}]) def test_null_loss_contributions(steps_thru_atmos_7, ocean_config_kwargs): # test LossContributionsConfig with n_steps = 0 diff --git a/fme/coupled/test_stepper.py b/fme/coupled/test_stepper.py index 517ac747c..90fed0481 100644 --- a/fme/coupled/test_stepper.py +++ b/fme/coupled/test_stepper.py @@ -41,6 +41,7 @@ CoupledHorizontalCoordinates, CoupledVerticalCoordinate, ) +from .loss import LossContributionsConfig from .stepper import ( ComponentConfig, ComponentTrainingConfig, @@ -1625,3 +1626,101 @@ def test_set_train_eval(): stepper.set_train() for module in stepper.modules: assert module.training + + +@pytest.mark.parametrize("optimize_last_step_only", [True, False]) +def test_train_on_batch_optimize_last_step_only(optimize_last_step_only: bool): + torch.manual_seed(0) + n_forward_times_ocean = 2 + n_forward_times_atmosphere = 4 + + train_stepper_config = CoupledTrainStepperConfig( + ocean=ComponentTrainingConfig( + loss=StepLossConfig(type="MSE"), + loss_contributions=LossContributionsConfig( + optimize_last_step_only=optimize_last_step_only, + ), + ), + atmosphere=ComponentTrainingConfig( + loss=StepLossConfig(type="MSE"), + loss_contributions=LossContributionsConfig( + optimize_last_step_only=optimize_last_step_only, + ), + ), + ) + train_stepper, coupled_data, _, _ = get_train_stepper_and_batch( + train_stepper_config=train_stepper_config, + ocean_in_names=["sst", "mask_0"], + ocean_out_names=["sst"], + atmosphere_in_names=["surface_temperature", "ocean_fraction"], + atmosphere_out_names=["surface_temperature"], + n_forward_times_ocean=n_forward_times_ocean, + n_forward_times_atmosphere=n_forward_times_atmosphere, + n_samples=3, + ) + optimization = Mock(wraps=NullOptimization()) + train_stepper.train_on_batch( + data=coupled_data.data, + optimization=optimization, + ) + n_total_atmos = n_forward_times_atmosphere + n_total_ocean = n_forward_times_ocean + if optimize_last_step_only: + # only the last atmosphere step and last ocean step are optimized + assert len(optimization.accumulate_loss.call_args_list) == 2 + else: + # all atmosphere and ocean steps are optimized + expected_calls = n_total_atmos + n_total_ocean + assert len(optimization.accumulate_loss.call_args_list) == expected_calls + + +@pytest.mark.parametrize("optimize_last_step_only", [True, False]) +def test_train_on_batch_optimize_last_step_only_with_n_steps( + optimize_last_step_only: bool, +): + torch.manual_seed(0) + n_forward_times_ocean = 2 + n_forward_times_atmosphere = 4 + atmos_n_steps = 3 + ocean_n_steps = 1 + + train_stepper_config = CoupledTrainStepperConfig( + ocean=ComponentTrainingConfig( + loss=StepLossConfig(type="MSE"), + loss_contributions=LossContributionsConfig( + n_steps=ocean_n_steps, + optimize_last_step_only=optimize_last_step_only, + ), + ), + atmosphere=ComponentTrainingConfig( + loss=StepLossConfig(type="MSE"), + loss_contributions=LossContributionsConfig( + n_steps=atmos_n_steps, + optimize_last_step_only=optimize_last_step_only, + ), + ), + ) + train_stepper, coupled_data, _, _ = get_train_stepper_and_batch( + train_stepper_config=train_stepper_config, + ocean_in_names=["sst", "mask_0"], + ocean_out_names=["sst"], + atmosphere_in_names=["surface_temperature", "ocean_fraction"], + atmosphere_out_names=["surface_temperature"], + n_forward_times_ocean=n_forward_times_ocean, + n_forward_times_atmosphere=n_forward_times_atmosphere, + n_samples=3, + ) + optimization = Mock(wraps=NullOptimization()) + train_stepper.train_on_batch( + data=coupled_data.data, + optimization=optimization, + ) + if optimize_last_step_only: + # atmos: only step min(3,4)-1=2 is optimized + # ocean: only step min(1,2)-1=0 is optimized + assert len(optimization.accumulate_loss.call_args_list) == 2 + else: + # atmos: steps 0,1,2 are optimized (n_steps=3) + # ocean: step 0 is optimized (n_steps=1) + expected_calls = atmos_n_steps + ocean_n_steps + assert len(optimization.accumulate_loss.call_args_list) == expected_calls