Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions fme/coupled/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ class StepLossABC(abc.ABC):
def effective_loss_scaling(self) -> TensorDict: ...

@abc.abstractmethod
def step_is_optimized(self, step: int) -> bool:
"""Returns True if the step is less than to the number of
steps contributing to the loss.
def step_is_optimized(self, step: int, n_total_steps: int | None = None) -> bool:
"""Returns True if the given step should contribute to the loss.

Args:
step: The step index to check.
n_total_steps: The total number of steps for this component. Required
when ``optimize_last_step_only`` is True.
"""
...

Expand All @@ -53,11 +57,16 @@ class LossContributionsConfig:
starting from the first.
weight: (optional) Weight applied to each step loss for the given realm.
Each step contributes equally to the total loss.
optimize_last_step_only: If True, only the last step within the training
horizon defined by ``n_steps`` is optimized (i.e. contributes to the
loss and has gradients enabled). The optimized step index is
``min(n_steps, n_total_steps) - 1``.

"""

n_steps: float = float("inf")
weight: float = 1.0
optimize_last_step_only: bool = False

def build(
self,
Expand All @@ -69,6 +78,7 @@ def build(
return LossContributions(
n_steps=self.n_steps,
weight=self.weight,
optimize_last_step_only=self.optimize_last_step_only,
loss_obj=loss_obj,
time_dim=time_dim,
)
Expand All @@ -90,7 +100,7 @@ def __init__(
def effective_loss_scaling(self) -> TensorDict:
return self._loss.effective_loss_scaling

def step_is_optimized(self, step: int) -> bool:
def step_is_optimized(self, step: int, n_total_steps: int | None = None) -> bool:
return False

def __call__(
Expand All @@ -104,30 +114,39 @@ def __init__(
self,
n_steps: float,
weight: float,
optimize_last_step_only: bool,
loss_obj: StepLoss,
time_dim: int,
):
self._loss = loss_obj
self._n_steps = n_steps
self._weight = weight
self._optimize_last_step_only = optimize_last_step_only
self._time_dim = time_dim

@property
def effective_loss_scaling(self) -> TensorDict:
return self._loss.effective_loss_scaling

def step_is_optimized(self, step: int) -> bool:
"""Returns True if the step is less than to the number of steps and
weight is != 0. The first step number is assumed to be 0.
def step_is_optimized(self, step: int, n_total_steps: int | None = None) -> bool:
"""Returns True if the step should contribute to the loss.

When ``optimize_last_step_only`` is False (default), returns True for
steps ``0`` through ``n_steps - 1``. When True, returns True only for
the step at index ``min(n_steps, n_total_steps) - 1``.
"""
return step < self._n_steps and self._weight != 0.0
if self._weight == 0.0:
return False
if self._optimize_last_step_only:
if n_total_steps is None:
raise ValueError(
"n_total_steps is required when optimize_last_step_only is True"
)
last_optimized_step = min(self._n_steps, n_total_steps) - 1
return step == last_optimized_step
return step < self._n_steps

def __call__(
self, prediction: StepPredictionABC, target_data: TensorMapping
) -> torch.Tensor:
if self.step_is_optimized(prediction.step):
return self._weight * self._loss(
prediction.data, target_data, prediction.step
)
return torch.tensor(0.0, device=get_device())
return self._weight * self._loss(prediction.data, target_data, prediction.step)
118 changes: 81 additions & 37 deletions fme/coupled/stepper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import dataclasses
import datetime
import logging
Expand Down Expand Up @@ -768,13 +769,22 @@ def effective_loss_scaling(self) -> CoupledTensorMapping:
atmosphere=self._loss_objs["atmosphere"].effective_loss_scaling,
)

def step_is_optimized(
self,
realm: Literal["ocean", "atmosphere"],
step: int,
n_total_steps: int,
) -> bool:
return self._loss_objs[realm].step_is_optimized(step, n_total_steps)

def __call__(
self,
prediction: ComponentStepPrediction,
target_data: TensorMapping,
n_total_steps: int | None = None,
) -> torch.Tensor | None:
loss_obj = self._loss_objs[prediction.realm]
if loss_obj.step_is_optimized(prediction.step):
if loss_obj.step_is_optimized(prediction.step, n_total_steps):
return loss_obj(prediction, target_data)
return None

Expand Down Expand Up @@ -1589,55 +1599,89 @@ def train_on_batch(

metrics = ComponentStepMetrics()
optimization.set_mode(self.modules)
n_outer_steps = data.ocean_data.n_timesteps - self.n_ic_timesteps
n_total_atmos_steps = n_outer_steps * self.n_inner_steps
with optimization.autocast():
output_generator = self._stepper.get_prediction_generator(
input_data,
data_ensemble,
optimization,
)
output_list = []
for gen_step in output_generator:
if gen_step.realm == "ocean":
# compute ocean step metrics
output_list: list[ComponentStepPrediction] = []
output_iterator = iter(output_generator)

for i_outer in range(n_outer_steps):
for i_inner in range(self.n_inner_steps):
global_atmos_step = i_outer * self.n_inner_steps + i_inner
optimize = self._loss.step_is_optimized(
"atmosphere",
global_atmos_step,
n_total_atmos_steps,
)
grad_context = (
contextlib.nullcontext() if optimize else torch.no_grad()
)
with grad_context:
gen_step = next(output_iterator)
target_step = {
k: v.select(self.atmosphere.TIME_DIM, gen_step.step)
for k, v in atmos_forward_data.data.items()
}
if n_ensemble > 1:
gen_step_data_unfolded = unfold_ensemble_dim(
gen_step.data, n_ensemble
)
gen_step_for_loss = ComponentStepPrediction(
realm=gen_step.realm,
data=gen_step_data_unfolded,
step=gen_step.step,
)
else:
gen_step_for_loss = gen_step
target_step_ensemble = add_ensemble_dim(target_step)
step_loss = self._loss(
gen_step_for_loss,
target_step_ensemble,
n_total_steps=n_total_atmos_steps,
)
if step_loss is not None:
label = f"loss/{gen_step.realm}_step_{gen_step.step}"
metrics.add_metric(
label, step_loss.detach(), gen_step.realm
)
if step_loss is not None:
optimization.accumulate_loss(step_loss)
if n_ensemble > 1:
gen_step_to_append = gen_step_for_loss.detach(optimization)
else:
gen_step_to_append = gen_step.detach(optimization)
output_list.append(gen_step_to_append)

optimize = self._loss.step_is_optimized(
"ocean",
i_outer,
n_outer_steps,
)
grad_context = contextlib.nullcontext() if optimize else torch.no_grad()
with grad_context:
gen_step = next(output_iterator)
target_step = {
k: v.select(self.ocean.TIME_DIM, gen_step.step)
for k, v in ocean_forward_data.data.items()
}
# Ocean predictions don't need ensemble handling
gen_step_for_loss = gen_step
else:
assert gen_step.realm == "atmosphere"
target_step = {
k: v.select(self.atmosphere.TIME_DIM, gen_step.step)
for k, v in atmos_forward_data.data.items()
}
# Unfold ensemble dimension for atmosphere loss computation
if n_ensemble > 1:
gen_step_data_unfolded = unfold_ensemble_dim(
gen_step.data, n_ensemble
)
gen_step_for_loss = ComponentStepPrediction(
realm=gen_step.realm,
data=gen_step_data_unfolded,
step=gen_step.step,
)
else:
gen_step_for_loss = gen_step
# Add ensemble dim to target (single member)
target_step_ensemble = add_ensemble_dim(target_step)
step_loss = self._loss(
gen_step_for_loss,
target_step_ensemble,
)
target_step_ensemble = add_ensemble_dim(target_step)
step_loss = self._loss(
gen_step_for_loss,
target_step_ensemble,
n_total_steps=n_outer_steps,
)
if step_loss is not None:
label = f"loss/{gen_step.realm}_step_{gen_step.step}"
metrics.add_metric(label, step_loss.detach(), gen_step.realm)
if step_loss is not None:
label = f"loss/{gen_step.realm}_step_{gen_step.step}"
metrics.add_metric(label, step_loss.detach(), gen_step.realm)
optimization.accumulate_loss(step_loss)
# For atmosphere with ensemble, append the unfolded step
if gen_step.realm == "atmosphere" and n_ensemble > 1:
gen_step_to_append = gen_step_for_loss.detach(optimization)
else:
gen_step_to_append = gen_step.detach(optimization)
gen_step_to_append = gen_step.detach(optimization)
output_list.append(gen_step_to_append)

loss = optimization.get_accumulated_loss().detach()
Expand Down
107 changes: 106 additions & 1 deletion fme/coupled/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
def effective_loss_scaling(self):
raise NotImplementedError()

def step_is_optimized(self, step: int) -> bool:
def step_is_optimized(self, step: int, n_total_steps: int | None = None) -> bool:
return step < 2

def __call__(
Expand Down Expand Up @@ -154,6 +154,111 @@ def mae_loss(gen, target, step: int):
assert_tensor_dicts_close(metrics, expected_metrics)


def test_loss_contributions_optimize_last_step_only(steps_thru_atmos_7):
def mae_loss(gen, target, step: int):
loss = torch.tensor(0.0)
for key in gen:
loss += (gen[key] - target[key]).abs().mean() / (step + 1)
return loss

n_total_atmos = 8
n_total_ocean = 4
atmos_loss_config = LossContributionsConfig(
n_steps=6,
weight=1 / 3,
optimize_last_step_only=True,
)
mock_step_loss = Mock(spec=StepLoss, side_effect=mae_loss)
atmosphere_loss = atmos_loss_config.build(
loss_obj=mock_step_loss,
time_dim=1,
)
ocean_loss_config = LossContributionsConfig(
n_steps=3,
optimize_last_step_only=True,
)
ocean_loss = ocean_loss_config.build(
loss_obj=Mock(spec=StepLoss, side_effect=mae_loss),
time_dim=1,
)
loss_obj = CoupledStepperTrainLoss(
ocean_loss=ocean_loss,
atmosphere_loss=atmosphere_loss,
)
metrics = {}
expected_metrics: dict[str, torch.Tensor | None] = {}
for prediction, target_data in steps_thru_atmos_7:
label = f"{prediction.realm}_{prediction.step}"
if prediction.realm == "atmosphere":
n_total = n_total_atmos
else:
n_total = n_total_ocean
metrics[label] = loss_obj(prediction, target_data, n_total_steps=n_total)
if prediction.realm == "atmosphere":
# n_steps=6, n_total=8 → last optimized step = min(6,8)-1 = 5
if prediction.step == 5:
expected_metrics[label] = (
mae_loss(prediction.data, target_data, step=prediction.step) / 3
)
else:
expected_metrics[label] = None
elif prediction.realm == "ocean":
# n_steps=3, n_total=4 → last optimized step = min(3,4)-1 = 2
if prediction.step == 2:
expected_metrics[label] = mae_loss(
prediction.data, target_data, step=prediction.step
)
else:
expected_metrics[label] = None
assert_tensor_dicts_close(metrics, expected_metrics)


@pytest.mark.parametrize(
"n_steps, n_total_steps, expected_optimized_step",
[
(6, 8, 5),
(10, 8, 7),
(float("inf"), 8, 7),
(1, 1, 0),
(3, 3, 2),
],
)
def test_step_is_optimized_last_step_only(
n_steps, n_total_steps, expected_optimized_step
):
config = LossContributionsConfig(n_steps=n_steps, optimize_last_step_only=True)
loss = config.build(
loss_obj=Mock(spec=StepLoss),
time_dim=1,
)
for step in range(n_total_steps):
result = loss.step_is_optimized(step, n_total_steps)
if step == expected_optimized_step:
assert result, f"step {step} should be optimized"
else:
assert not result, f"step {step} should not be optimized"


def test_step_is_optimized_last_step_only_requires_n_total_steps():
config = LossContributionsConfig(optimize_last_step_only=True)
loss = config.build(
loss_obj=Mock(spec=StepLoss),
time_dim=1,
)
with pytest.raises(ValueError, match="n_total_steps is required"):
loss.step_is_optimized(0)


def test_step_is_optimized_last_step_only_weight_zero():
config = LossContributionsConfig(optimize_last_step_only=True, weight=0.0)
loss = config.build(
loss_obj=Mock(spec=StepLoss),
time_dim=1,
)
# weight=0 → NullLossContributions, always returns False
assert not loss.step_is_optimized(0, n_total_steps=5)


@pytest.mark.parametrize("ocean_config_kwargs", [{"n_steps": 0}, {"weight": 0.0}])
def test_null_loss_contributions(steps_thru_atmos_7, ocean_config_kwargs):
# test LossContributionsConfig with n_steps = 0
Expand Down
Loading