diff --git a/fme/ace/__init__.py b/fme/ace/__init__.py index 4e7a5038f..adeedd782 100644 --- a/fme/ace/__init__.py +++ b/fme/ace/__init__.py @@ -72,6 +72,7 @@ from fme.core.dataset.time import RepeatedInterval, TimeSlice from fme.core.dataset.utils import FillNaNsConfig from fme.core.dataset.xarray import OverwriteConfig, XarrayDataConfig +from fme.core.generics.lr_tuning import LRTuningConfig from fme.core.gridded_ops import GriddedOperations from fme.core.loss import StepLossConfig from fme.core.masking import StaticMaskingConfig diff --git a/fme/ace/train/train_config.py b/fme/ace/train/train_config.py index c236ffc95..ff9e50f63 100644 --- a/fme/ace/train/train_config.py +++ b/fme/ace/train/train_config.py @@ -33,6 +33,7 @@ from fme.core.dataset_info import DatasetInfo from fme.core.distributed import Distributed from fme.core.ema import EMAConfig, EMATracker +from fme.core.generics.lr_tuning import LRTuningConfig from fme.core.generics.trainer import EndOfBatchCallback, EndOfEpochCallback from fme.core.logging_utils import LoggingConfig from fme.core.optimization import Optimization, OptimizationConfig @@ -267,6 +268,7 @@ class TrainConfig: ) evaluate_before_training: bool = False save_best_inference_epoch_checkpoints: bool = False + lr_tuning: LRTuningConfig | None = None resume_results: ResumeResultsConfig | None = None def __post_init__(self): diff --git a/fme/core/ema.py b/fme/core/ema.py index e3b7e09ff..f965aeede 100644 --- a/fme/core/ema.py +++ b/fme/core/ema.py @@ -199,10 +199,10 @@ def get_state(self): The state of the EMA tracker. """ return { - "decay": self.decay, - "num_updates": self.num_updates, + "decay": self.decay.clone(), + "num_updates": self.num_updates.clone(), "faster_decay_at_start": self._faster_decay_at_start, - "module_name_to_ema_name": self._module_name_to_ema_name, + "module_name_to_ema_name": dict(self._module_name_to_ema_name), "ema_params": { name: param.clone().detach() for name, param in self._ema_params.items() }, diff --git a/fme/core/generics/lr_tuning.py b/fme/core/generics/lr_tuning.py new file mode 100644 index 000000000..5f727155f --- /dev/null +++ b/fme/core/generics/lr_tuning.py @@ -0,0 +1,156 @@ +import copy +import dataclasses +import logging +from collections.abc import Callable + +import torch + +from fme.core.ema import EMATracker +from fme.core.generics.aggregator import AggregatorABC +from fme.core.generics.data import GriddedDataABC +from fme.core.generics.optimization import OptimizationABC +from fme.core.generics.train_stepper import TrainStepperABC +from fme.core.generics.validation import run_validation + + +@dataclasses.dataclass +class LRTuningConfig: + """ + Configuration for periodic learning rate tuning trials. + + At the start of every ``epoch_frequency`` epochs, the trainer forks the + current model into a baseline and a candidate copy. Both are trained for + ``num_batches`` on the first batches of the epoch; the candidate uses a + learning rate of ``current_lr * lr_factor``. Both are then validated. If + both improve the validation loss and the candidate improves it by more than + ``improvement_threshold`` fractionally more than the baseline, the trainer + adopts the candidate's learning rate. + + Parameters: + epoch_frequency: Run a trial every N epochs. + lr_factor: Multiply the current LR by this to get the candidate LR. + num_batches: Number of training batches for each fork in the trial. + improvement_threshold: The candidate must improve validation loss by + at least this fraction more than the baseline (e.g. 0.1 means + the candidate must improve at least 10% more). + """ + + epoch_frequency: int + lr_factor: float + num_batches: int + improvement_threshold: float + + +def run_lr_tuning_trial( + train_data: GriddedDataABC, + valid_data: GriddedDataABC, + optimization: OptimizationABC, + copy_stepper: Callable[[], TrainStepperABC], + build_optimization: Callable[[torch.nn.ModuleList], OptimizationABC], + copy_ema: Callable[[torch.nn.ModuleList], EMATracker], + config: LRTuningConfig, + current_lr: float, + pre_trial_val_loss: float, + get_validation_aggregator: Callable[[], AggregatorABC], + validate_using_ema: bool, +) -> float | None: + """ + Run an isolated LR tuning trial comparing the current LR against a candidate. + + Creates two stepper forks, trains both, validates both, and compares + validation loss improvements. Does not mutate the original stepper or + optimization. Does not log to wandb. + + Args: + train_data: Training data; ``subset_loader`` is used for the first N + batches. The caller must have already called ``set_epoch``. + valid_data: Validation data. + optimization: The current optimization (used to copy momentum state + into the forks). + copy_stepper: Factory that returns a new stepper initialized from the + current stepper's state. Called twice (baseline and candidate). + The caller is responsible for ensuring proper deep copy semantics + (e.g. using get_state/load_state rather than copy.deepcopy). + build_optimization: Factory to build a fresh optimization for a + given ModuleList. + copy_ema: Factory that returns a new EMA tracker initialized from the + current EMA state but tracking the given modules. Called twice. + config: The LR tuning configuration. + current_lr: The current learning rate. + pre_trial_val_loss: Validation loss from the end of the previous epoch. + get_validation_aggregator: Factory for validation aggregators. + validate_using_ema: Whether to use EMA parameters during validation. + + Returns: + The candidate learning rate if the candidate wins, otherwise None. + """ + candidate_lr = current_lr * config.lr_factor + optimization_state = copy.deepcopy(optimization.get_state()) + + baseline_stepper = copy_stepper() + candidate_stepper = copy_stepper() + + baseline_opt = build_optimization(baseline_stepper.modules) + baseline_opt.load_state(optimization_state) + baseline_opt.set_learning_rate(current_lr) + + candidate_opt = build_optimization(candidate_stepper.modules) + candidate_opt.load_state(optimization_state) + candidate_opt.set_learning_rate(candidate_lr) + + baseline_ema = copy_ema(baseline_stepper.modules) + candidate_ema = copy_ema(candidate_stepper.modules) + + # Train both forks + baseline_stepper.set_train() + candidate_stepper.set_train() + for batch in train_data.subset_loader(stop_batch=config.num_batches): + baseline_stepper.train_on_batch(batch, baseline_opt) + baseline_ema(baseline_stepper.modules) + + candidate_stepper.train_on_batch(batch, candidate_opt) + candidate_ema(candidate_stepper.modules) + + # Validate both forks + baseline_val_logs = run_validation( + stepper=baseline_stepper, + valid_data=valid_data, + aggregator=get_validation_aggregator(), + ema=baseline_ema, + validate_using_ema=validate_using_ema, + ) + candidate_val_logs = run_validation( + stepper=candidate_stepper, + valid_data=valid_data, + aggregator=get_validation_aggregator(), + ema=candidate_ema, + validate_using_ema=validate_using_ema, + ) + + baseline_val_loss = baseline_val_logs["val/mean/loss"] + candidate_val_loss = candidate_val_logs["val/mean/loss"] + + baseline_improvement = pre_trial_val_loss - baseline_val_loss + candidate_improvement = pre_trial_val_loss - candidate_val_loss + + logging.info( + f"LR tuning trial: baseline LR={current_lr}, candidate LR={candidate_lr}, " + f"pre-trial val loss={pre_trial_val_loss:.6f}, " + f"baseline val loss={baseline_val_loss:.6f} " + f"(improvement={baseline_improvement:.6f}), " + f"candidate val loss={candidate_val_loss:.6f} " + f"(improvement={candidate_improvement:.6f})" + ) + + if baseline_improvement > 0 and candidate_improvement > 0: + threshold = baseline_improvement * (1 + config.improvement_threshold) + if candidate_improvement > threshold: + logging.info( + f"LR tuning trial: candidate wins " + f"(improvement {candidate_improvement:.6f} > " + f"threshold {threshold:.6f})" + ) + return candidate_lr + + logging.info("LR tuning trial: baseline wins, keeping current LR") + return None diff --git a/fme/core/generics/optimization.py b/fme/core/generics/optimization.py index b74e6ea89..8702ef573 100644 --- a/fme/core/generics/optimization.py +++ b/fme/core/generics/optimization.py @@ -94,6 +94,13 @@ def get_state(self): """ ... + @abc.abstractmethod + def set_learning_rate(self, lr: float): + """ + Set the learning rate for all parameter groups. + """ + ... + @abc.abstractmethod def load_state(self, state): """ diff --git a/fme/core/generics/test_lr_tuning.py b/fme/core/generics/test_lr_tuning.py new file mode 100644 index 000000000..a5800b0d7 --- /dev/null +++ b/fme/core/generics/test_lr_tuning.py @@ -0,0 +1,551 @@ +"""Tests for run_lr_tuning_trial and LRTuningConfig.""" + +import copy +import itertools +import unittest.mock +from typing import Any + +import torch + +from fme.core.device import get_device +from fme.core.ema import EMATracker +from fme.core.generics.aggregator import AggregatorABC +from fme.core.generics.data import DataLoader, GriddedDataABC +from fme.core.generics.lr_tuning import LRTuningConfig, run_lr_tuning_trial +from fme.core.generics.optimization import OptimizationABC +from fme.core.generics.train_stepper import TrainOutputABC, TrainStepperABC +from fme.core.optimization import Optimization +from fme.core.scheduler import SchedulerConfig +from fme.core.training_history import TrainingJob +from fme.core.typing_ import TensorDict + + +class _TrainOutput(TrainOutputABC): + def get_metrics(self) -> TensorDict: + return {} + + +class _BatchData: + def __init__(self, i: int): + self.i = i + + +class _TrainData(GriddedDataABC["_BatchData"]): + """Minimal GriddedDataABC for testing.""" + + def __init__(self, n_batches: int): + self._n_batches = n_batches + + @property + def loader(self) -> DataLoader["_BatchData"]: + return [_BatchData(i) for i in range(self._n_batches)] + + @property + def n_samples(self) -> int: + return self._n_batches + + @property + def n_batches(self) -> int: + return self._n_batches + + @property + def batch_size(self) -> int: + return 1 + + def set_epoch(self, epoch: int): + pass + + def alternate_shuffle(self): + pass + + def subset_loader( + self, start_batch: int | None = None, stop_batch: int | None = None + ) -> DataLoader["_BatchData"]: + return [_BatchData(i) for i in range(self._n_batches)][ + slice(start_batch, stop_batch) + ] + + def log_info(self, name: str): + pass + + +class _Stepper(TrainStepperABC["None", "_BatchData", "None", "None", "_TrainOutput"]): + def __init__(self): + self._modules = torch.nn.ModuleList([torch.nn.Linear(1, 1, bias=False)]).to( + get_device() + ) + + def train_on_batch( + self, + data: "_BatchData", + optimization: OptimizationABC, + compute_derived_variables: bool = False, + ) -> _TrainOutput: + # Produce a loss with a grad_fn so Optimization.step_weights can backward + x = torch.ones(1, 1, device=get_device()) + loss = self._modules[0](x).sum() + optimization.accumulate_loss(loss) + optimization.step_weights() + return _TrainOutput() + + def predict_paired( + self, initial_condition, forcing, compute_derived_variables=False + ): + return None, None + + @property + def modules(self) -> torch.nn.ModuleList: + return self._modules + + def get_state(self) -> dict[str, Any]: + return {"modules": self._modules.state_dict()} + + def load_state(self, state: dict[str, Any]) -> None: + self._modules.load_state_dict(state["modules"]) + + def set_eval(self) -> None: + pass + + def set_train(self) -> None: + pass + + def update_training_history(self, training_job: TrainingJob) -> None: + pass + + +class _ValidationAggregator(AggregatorABC["_TrainOutput"]): + def __init__(self, loss: float): + self._loss = loss + + def record_batch(self, batch: "_TrainOutput") -> None: + pass + + def get_logs(self, label: str) -> dict[str, float]: + return {f"{label}/mean/loss": self._loss} + + def flush_diagnostics(self, subdir: str | None) -> None: + pass + + +def _build_optimization(modules: torch.nn.ModuleList) -> Optimization: + return Optimization( + parameters=itertools.chain(*[m.parameters() for m in modules]), + optimizer_type="Adam", + lr=0.01, + max_epochs=10, + scheduler=SchedulerConfig(), + enable_automatic_mixed_precision=False, + kwargs={}, + ) + + +def _make_copy_ema(ema: EMATracker): + """Return a callable that creates a copy of the EMA via state APIs, + matching the real Trainer._copy_ema pattern.""" + + def copy_ema(modules: torch.nn.ModuleList) -> EMATracker: + return EMATracker.from_state(ema.get_state(), modules) + + return copy_ema + + +def _make_copy_stepper(stepper: _Stepper): + """Return a callable that creates a copy of the stepper via state APIs, + matching the real Trainer._copy_stepper pattern (deepcopy + load_state).""" + + def copy_stepper() -> _Stepper: + new = copy.deepcopy(stepper) + new.load_state(copy.deepcopy(stepper.get_state())) + return new + + return copy_stepper + + +def _make_aggregator_factory(*losses: float): + """Return a get_validation_aggregator callable that yields the given losses.""" + it = iter(losses) + + def factory(): + return _ValidationAggregator(next(it)) + + return factory + + +def test_candidate_wins(): + """Candidate wins when its improvement exceeds the threshold.""" + stepper = _Stepper() + train_data = _TrainData(n_batches=5) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.1, + ) + + # pre_trial=1.0, baseline->0.8 (improvement=0.2), candidate->0.5 (improvement=0.5) + # candidate_improvement (0.5) > baseline_improvement * 1.1 (0.22) → candidate wins + result = run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(EMATracker(stepper.modules, decay=0.9999)), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(0.8, 0.5), + validate_using_ema=False, + ) + + assert result == 0.01 * 0.5 + + +def test_candidate_below_threshold(): + """Candidate improves but not enough to exceed the threshold.""" + stepper = _Stepper() + train_data = _TrainData(n_batches=5) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.5, + ) + + # pre_trial=1.0, baseline->0.8 (improvement=0.2), candidate->0.75 (improvement=0.25) + # candidate_improvement (0.25) < baseline_improvement * 1.5 (0.3) → baseline wins + result = run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(EMATracker(stepper.modules, decay=0.9999)), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(0.8, 0.75), + validate_using_ema=False, + ) + + assert result is None + + +def test_candidate_worsens(): + """Candidate worsens validation loss → baseline wins.""" + stepper = _Stepper() + train_data = _TrainData(n_batches=5) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.1, + ) + + # pre_trial=1.0, baseline->0.9 (improvement=0.1), candidate->1.1 (improvement=-0.1) + result = run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(EMATracker(stepper.modules, decay=0.9999)), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(0.9, 1.1), + validate_using_ema=False, + ) + + assert result is None + + +def test_both_worsen(): + """Both worsen → baseline wins.""" + stepper = _Stepper() + train_data = _TrainData(n_batches=5) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.1, + ) + + # pre_trial=1.0, baseline->1.2, candidate->1.3 + result = run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(EMATracker(stepper.modules, decay=0.9999)), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(1.2, 1.3), + validate_using_ema=False, + ) + + assert result is None + + +def test_baseline_worsens_candidate_improves(): + """Baseline worsens but candidate improves → still returns None + (requirement: both must improve).""" + stepper = _Stepper() + train_data = _TrainData(n_batches=5) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.1, + ) + + # pre_trial=1.0, baseline->1.1 (worsens), candidate->0.5 (improves) + result = run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(EMATracker(stepper.modules, decay=0.9999)), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(1.1, 0.5), + validate_using_ema=False, + ) + + assert result is None + + +def test_does_not_mutate_original_stepper(): + """The original stepper's parameters must not be modified by the trial.""" + stepper = _Stepper() + stepper.modules[0].weight.data.fill_(42.0) + original_weight = stepper.modules[0].weight.data.clone() + + train_data = _TrainData(n_batches=5) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.1, + ) + + run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(EMATracker(stepper.modules, decay=0.9999)), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(0.8, 0.5), + validate_using_ema=False, + ) + + assert torch.allclose(stepper.modules[0].weight.data, original_weight) + + +def test_uses_subset_loader_with_num_batches(): + """The trial should train on exactly config.num_batches batches.""" + stepper = _Stepper() + train_data = _TrainData(n_batches=10) + train_data.subset_loader = unittest.mock.MagicMock( # type: ignore + wraps=train_data.subset_loader + ) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=4, + improvement_threshold=0.1, + ) + + run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(EMATracker(stepper.modules, decay=0.9999)), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(0.9, 0.8), + validate_using_ema=False, + ) + + train_data.subset_loader.assert_called_once_with(stop_batch=4) + + +def test_with_ema_validation(): + """Trial should pass validate_using_ema through to run_validation.""" + stepper = _Stepper() + train_data = _TrainData(n_batches=5) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.1, + ) + + # This should not raise even with validate_using_ema=True + result = run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(EMATracker(stepper.modules, decay=0.9999)), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(0.8, 0.5), + validate_using_ema=True, + ) + + assert result == 0.01 * 0.5 + + +def test_trial_does_not_mutate_original_ema_num_updates(): + """The original EMA's num_updates must not be modified by the trial.""" + stepper = _Stepper() + ema = EMATracker(stepper.modules, decay=0.9999) + # Simulate some prior training updates + for _ in range(5): + ema(stepper.modules) + original_num_updates = ema.num_updates.clone() + + train_data = _TrainData(n_batches=5) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.1, + ) + + run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(ema), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(0.8, 0.5), + validate_using_ema=False, + ) + + assert torch.equal(ema.num_updates, original_num_updates) + + +def test_trial_does_not_mutate_original_ema_params(): + """The original EMA's tracked parameters must not be modified by the trial.""" + stepper = _Stepper() + ema = EMATracker(stepper.modules, decay=0.9999) + for _ in range(5): + ema(stepper.modules) + original_ema_params = { + name: param.clone() for name, param in ema._ema_params.items() + } + + train_data = _TrainData(n_batches=5) + valid_data = _TrainData(n_batches=3) + optimization = _build_optimization(stepper.modules) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.1, + ) + + run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(ema), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(0.8, 0.5), + validate_using_ema=False, + ) + + for name in original_ema_params: + assert torch.equal(ema._ema_params[name], original_ema_params[name]) + + +def test_trial_does_not_mutate_original_optimizer_state(): + """The original optimizer's momentum buffers must not be modified by the trial.""" + stepper = _Stepper() + ema = EMATracker(stepper.modules, decay=0.9999) + optimization = _build_optimization(stepper.modules) + + # Train a few batches to populate optimizer momentum buffers + train_data = _TrainData(n_batches=5) + for batch in train_data.loader: + stepper.train_on_batch(batch, optimization) + + original_state = copy.deepcopy(optimization.get_state()) + + valid_data = _TrainData(n_batches=3) + config = LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=3, + improvement_threshold=0.1, + ) + + run_lr_tuning_trial( + train_data=train_data, + valid_data=valid_data, + optimization=optimization, + copy_stepper=_make_copy_stepper(stepper), + build_optimization=_build_optimization, + copy_ema=_make_copy_ema(ema), + config=config, + current_lr=0.01, + pre_trial_val_loss=1.0, + get_validation_aggregator=_make_aggregator_factory(0.8, 0.5), + validate_using_ema=False, + ) + + current_state = optimization.get_state() + for key in original_state["optimizer_state_dict"]["state"]: + for buf_name, buf in original_state["optimizer_state_dict"]["state"][ + key + ].items(): + if isinstance(buf, torch.Tensor): + assert torch.equal( + buf, + current_state["optimizer_state_dict"]["state"][key][buf_name], + ), f"Optimizer state[{key}][{buf_name}] was mutated by the trial" + else: + assert ( + buf == current_state["optimizer_state_dict"]["state"][key][buf_name] + ), f"Optimizer state[{key}][{buf_name}] was mutated by the trial" diff --git a/fme/core/generics/test_trainer.py b/fme/core/generics/test_trainer.py index c91fc33d0..5be5343ea 100644 --- a/fme/core/generics/test_trainer.py +++ b/fme/core/generics/test_trainer.py @@ -19,6 +19,7 @@ InferenceLogs, ) from fme.core.generics.data import DataLoader, GriddedDataABC, InferenceDataABC +from fme.core.generics.lr_tuning import LRTuningConfig from fme.core.generics.optimization import OptimizationABC from fme.core.generics.trainer import ( AggregatorBuilderABC, @@ -232,6 +233,7 @@ class Config: segment_epochs: int | None = None evaluate_before_training: bool = False save_best_inference_epoch_checkpoints: bool = False + lr_tuning: LRTuningConfig | None = None def __post_init__(self): start_epoch = 0 if self.evaluate_before_training else 1 @@ -338,6 +340,7 @@ def get_trainer( scheduler_config: SchedulerConfig | None = None, n_validation_batches: int = 5, save_checkpoint: bool = True, + lr_tuning: LRTuningConfig | None = None, ) -> tuple[TrainConfigProtocol, Trainer]: if checkpoint_dir is None: checkpoint_dir = os.path.join(tmp_path, "checkpoints") @@ -413,6 +416,7 @@ def build_ema(modules: torch.nn.ModuleList) -> EMATracker: evaluate_before_training=evaluate_before_training, save_best_inference_epoch_checkpoints=save_best_inference_epoch_checkpoints, save_checkpoint=save_checkpoint, + lr_tuning=lr_tuning, ) aggregator_builder = AggregatorBuilder( train_losses=train_losses, @@ -1184,3 +1188,164 @@ def test_ema_state_preserved_after_resume(tmp_path: str): resumed_ema_state["ema_params"][key], ema_state["ema_params"][key], ) + + +def test_lr_tuning_disabled_by_default(tmp_path: str): + """When lr_tuning is None, training proceeds normally.""" + with mock_wandb(): + config, trainer = get_trainer( + tmp_path, + max_epochs=2, + lr_tuning=None, + ) + initial_lr = trainer.optimization.learning_rate + trainer.train() + # LR should only change if the scheduler changes it, not LR tuning + assert trainer.optimization.learning_rate == initial_lr + + +def test_lr_tuning_runs_and_keeps_lr(tmp_path: str): + """When the candidate doesn't win, the LR stays the same.""" + max_epochs = 2 + # epoch_frequency=1, max_epochs=2: + # Epoch 0 tune: _last_val_loss=None → validate(0.8), trial(0.7, 0.75) + # baseline improvement=0.1, candidate improvement=0.05 → baseline wins + # Epoch 0: train + validate(0.6) + # Epoch 1 tune: trial(0.5, 0.55) + # baseline improvement=0.1, candidate improvement=0.05 → baseline wins + # Epoch 1: train + validate(0.4) + validation_losses = np.array([0.8, 0.7, 0.75, 0.6, 0.5, 0.55, 0.4]) + with mock_wandb(): + config, trainer = get_trainer( + tmp_path, + max_epochs=max_epochs, + validation_losses=validation_losses, + lr_tuning=LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=2, + improvement_threshold=0.1, + ), + ) + initial_lr = trainer.optimization.learning_rate + trainer.train() + assert trainer.optimization.learning_rate == initial_lr + + +def test_lr_tuning_adopts_candidate_lr(tmp_path: str): + """When the candidate wins, the LR is updated.""" + max_epochs = 2 + # Epoch 0 tune: _last_val_loss=None → validate(1.0), trial(0.9, 0.3) + # baseline improvement=0.1, candidate improvement=0.7 → candidate wins + # Epoch 0: train + validate(0.5) + # Epoch 1 tune: trial(0.45, 0.44) + # baseline improvement=0.05, candidate improvement=0.06 + # candidate needs > 0.055 → 0.06 > 0.055 → wins again + # Epoch 1: train + validate(0.3) + validation_losses = np.array([1.0, 0.9, 0.3, 0.5, 0.45, 0.44, 0.3]) + with mock_wandb(): + config, trainer = get_trainer( + tmp_path, + max_epochs=max_epochs, + validation_losses=validation_losses, + lr_tuning=LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=2, + improvement_threshold=0.1, + ), + ) + initial_lr = trainer.optimization.learning_rate + trainer.train() + # Candidate won at both epochs + assert trainer.optimization.learning_rate == initial_lr * 0.5 * 0.5 + + +def test_lr_tuning_respects_epoch_frequency(tmp_path: str): + """LR tuning only runs on epochs matching the frequency.""" + max_epochs = 4 + # epoch_frequency=2, so tuning runs at epoch 0 and 2 + # Epoch 0 tune: needs _last_val_loss=None → runs validate_one_epoch first + # validate_one_epoch: 0.8 + # trial baseline: 0.7, trial candidate: 0.3 → candidate wins + # Epoch 0 train + validate: 0.6 + # Epoch 1: no tuning. train + validate: 0.5 + # Epoch 2 tune: + # trial baseline: 0.4, trial candidate: 0.1 → candidate wins + # Epoch 2 train + validate: 0.3 + # Epoch 3: no tuning. train + validate: 0.2 + validation_losses = np.array( + [ + 0.8, # _maybe_tune_lr validate_one_epoch at epoch 0 + 0.7, + 0.3, # trial at epoch 0 (baseline, candidate) + 0.6, # epoch 0 validate + 0.5, # epoch 1 validate + 0.4, + 0.1, # trial at epoch 2 (baseline, candidate) + 0.3, # epoch 2 validate + 0.2, # epoch 3 validate + ] + ) + with mock_wandb(): + config, trainer = get_trainer( + tmp_path, + max_epochs=max_epochs, + train_losses=np.zeros(max_epochs), + validation_losses=validation_losses, + inference_losses=np.zeros(max_epochs), + stepper_module_values=np.zeros(max_epochs), + lr_tuning=LRTuningConfig( + epoch_frequency=2, + lr_factor=0.5, + num_batches=2, + improvement_threshold=0.1, + ), + ) + initial_lr = trainer.optimization.learning_rate + trainer.train() + # Tuning ran at epoch 0 and 2, candidate won both times + assert trainer.optimization.learning_rate == initial_lr * 0.5 * 0.5 + + +def test_lr_tuning_with_evaluate_before_training(tmp_path: str): + """When evaluate_before_training=True, the pre-training validation + loss is used as _last_val_loss so _maybe_tune_lr doesn't re-validate.""" + max_epochs = 2 + # evaluate_before_training: val=0.9 + # epoch 0 tune (uses _last_val_loss=0.9): + # trial baseline: 0.8, candidate: 0.3 → candidate wins + # epoch 0 train + validate: 0.5 + # epoch 1 tune (uses _last_val_loss=0.5): + # trial baseline: 0.4 (improvement=0.1), candidate: 0.45 (improvement=0.05) + # candidate worse than baseline → baseline wins + # epoch 1 train + validate: 0.3 + validation_losses = np.array( + [ + 0.9, # evaluate_before_training + 0.8, + 0.3, # trial at epoch 0 + 0.5, # epoch 0 validate + 0.4, + 0.45, # trial at epoch 1 (baseline wins) + 0.3, # epoch 1 validate + ] + ) + with mock_wandb(): + config, trainer = get_trainer( + tmp_path, + max_epochs=max_epochs, + validation_losses=validation_losses, + inference_losses=np.zeros(max_epochs + 1), + evaluate_before_training=True, + lr_tuning=LRTuningConfig( + epoch_frequency=1, + lr_factor=0.5, + num_batches=2, + improvement_threshold=0.1, + ), + ) + initial_lr = trainer.optimization.learning_rate + trainer.train() + # Only epoch 0 candidate won + assert trainer.optimization.learning_rate == initial_lr * 0.5 diff --git a/fme/core/generics/test_validation.py b/fme/core/generics/test_validation.py new file mode 100644 index 000000000..fccfd24e1 --- /dev/null +++ b/fme/core/generics/test_validation.py @@ -0,0 +1,89 @@ +import unittest.mock + +import torch + +from fme.core.ema import EMATracker +from fme.core.generics.test_trainer import TrainData, TrainStepper, ValidationAggregator +from fme.core.generics.validation import run_validation + + +def test_run_validation_returns_logs(): + stepper = TrainStepper() + valid_data = TrainData(n_batches=3, shuffle=False) + aggregator = ValidationAggregator(validation_loss=0.5) + ema = EMATracker(stepper.modules, decay=0.9999) + + logs = run_validation( + stepper=stepper, + valid_data=valid_data, + aggregator=aggregator, + ema=ema, + validate_using_ema=False, + ) + + assert "val/mean/loss" in logs + assert logs["val/mean/loss"] == 0.5 + assert stepper.validation_batches_seen == [0, 1, 2] + + +def test_run_validation_with_ema(): + """When validate_using_ema=True, EMA params are applied during validation.""" + stepper = TrainStepper() + valid_data = TrainData(n_batches=2, shuffle=False) + aggregator = ValidationAggregator(validation_loss=0.3) + + # Set a non-zero weight so EMA differs from the initial zero weight + stepper.modules[0].weight.data.fill_(1.0) + ema = EMATracker(stepper.modules, decay=0.5) + # Update EMA with current params, then change model weight + ema(stepper.modules) + stepper.modules[0].weight.data.fill_(2.0) + + weight_before = stepper.modules[0].weight.data.clone() + + logs = run_validation( + stepper=stepper, + valid_data=valid_data, + aggregator=aggregator, + ema=ema, + validate_using_ema=True, + ) + + # After run_validation, the original weights should be restored + assert torch.allclose(stepper.modules[0].weight.data, weight_before) + assert "val/mean/loss" in logs + + +def test_run_validation_without_ema_none(): + """When ema is None and validate_using_ema=False, validation still works.""" + stepper = TrainStepper() + valid_data = TrainData(n_batches=2, shuffle=False) + aggregator = ValidationAggregator(validation_loss=0.7) + + logs = run_validation( + stepper=stepper, + valid_data=valid_data, + aggregator=aggregator, + ema=None, + validate_using_ema=False, + ) + + assert logs["val/mean/loss"] == 0.7 + + +def test_run_validation_does_not_flush_diagnostics(): + """run_validation should not call flush_diagnostics on the aggregator.""" + stepper = TrainStepper() + valid_data = TrainData(n_batches=2, shuffle=False) + aggregator = ValidationAggregator(validation_loss=0.1) + aggregator.flush_diagnostics = unittest.mock.MagicMock() # type: ignore + + run_validation( + stepper=stepper, + valid_data=valid_data, + aggregator=aggregator, + ema=None, + validate_using_ema=False, + ) + + aggregator.flush_diagnostics.assert_not_called() diff --git a/fme/core/generics/trainer.py b/fme/core/generics/trainer.py index 7bedc1fe9..9cfb35917 100644 --- a/fme/core/generics/trainer.py +++ b/fme/core/generics/trainer.py @@ -68,8 +68,10 @@ from fme.core.generics.aggregator import AggregatorABC, InferenceAggregatorABC from fme.core.generics.data import GriddedDataABC, InferenceDataABC from fme.core.generics.inference import run_inference +from fme.core.generics.lr_tuning import LRTuningConfig, run_lr_tuning_trial from fme.core.generics.metrics_aggregator import MetricsAggregator from fme.core.generics.train_stepper import TrainOutputABC, TrainStepperABC +from fme.core.generics.validation import run_validation from fme.core.optimization import NullOptimization, Optimization from fme.core.timing import GlobalTimer from fme.core.training_history import TrainingJob @@ -134,6 +136,9 @@ def evaluate_before_training(self) -> bool: ... @property def save_best_inference_epoch_checkpoints(self) -> bool: ... + @property + def lr_tuning(self) -> LRTuningConfig | None: ... + def get_inference_epochs(self) -> list[int]: ... @@ -244,12 +249,15 @@ def __init__( self.stepper = stepper self.stepper.update_training_history(TrainingJob.from_env()) + self._build_optimization = build_optimization + self._build_ema = build_ema self.optimization = build_optimization(stepper.modules) self._end_of_batch_callback = end_of_batch_callback self._end_of_epoch_callback = end_of_epoch_callback self._no_optimization = NullOptimization() self._aggregator_builder = aggregator_builder self._ema = build_ema(stepper.modules) # build before restore_checkpoint + self._last_val_loss: float | None = None resuming = os.path.isfile(self.paths.latest_checkpoint_path) if resuming: @@ -300,6 +308,53 @@ def _should_save_checkpoints(self) -> bool: dist = Distributed.get_instance() return self.config.save_checkpoint and dist.is_root() + def _copy_stepper(self) -> TrainStepperABC: + """Create a copy of the stepper via its state serialization API.""" + import copy + + new_stepper = copy.deepcopy(self.stepper) + new_stepper.load_state(copy.deepcopy(self.stepper.get_state())) + return new_stepper + + def _copy_ema(self, modules: torch.nn.ModuleList) -> EMATracker: + """Create a new EMATracker initialized from the current EMA state.""" + return EMATracker.from_state(self._ema.get_state(), modules) + + def _maybe_tune_lr(self): + cfg = self.config.lr_tuning + if cfg is None: + return + if self._current_epoch_num_batches_seen > 0: + return # resumed mid-epoch, tuning already ran (or wasn't needed) + if self._epochs_trained % cfg.epoch_frequency != 0: + return + if self._last_val_loss is None: + # No prior validation (start of training, evaluate_before_training=False). + # Run validation now to establish a baseline. + val_logs = self.validate_one_epoch() + self._last_val_loss = val_logs["val/mean/loss"] + + # set_epoch so the trial sees the same first N batches as the real epoch + self.train_data.set_epoch(self._epochs_trained + 1) + new_lr = run_lr_tuning_trial( + train_data=self.train_data, + valid_data=self.valid_data, + optimization=self.optimization, + copy_stepper=self._copy_stepper, + build_optimization=self._build_optimization, + copy_ema=self._copy_ema, + config=cfg, + current_lr=self.optimization.learning_rate, + pre_trial_val_loss=self._last_val_loss, + get_validation_aggregator=( + self._aggregator_builder.get_validation_aggregator + ), + validate_using_ema=self.config.validate_using_ema, + ) + if new_lr is not None: + logging.info(f"LR tuning: adopting candidate LR {new_lr}") + self.optimization.set_learning_rate(new_lr) + def train(self): logging.info("Starting Training Loop...") @@ -324,6 +379,7 @@ def train(self): else: inference_logs = {} valid_loss = valid_logs["val/mean/loss"] + self._last_val_loss = valid_loss logging.info(f"Validation loss before training: {valid_loss}") logging.info("Logging to wandb") all_logs = valid_logs | inference_logs | {"epoch": self._epochs_trained} @@ -338,6 +394,7 @@ def train(self): logging.info( f"Beginning epoch after {self._epochs_trained} complete epochs" ) + self._maybe_tune_lr() start_time = time.time() train_logs = self.train_one_epoch() train_end = time.time() @@ -352,6 +409,7 @@ def train(self): train_loss = train_logs.get("train/mean/loss") valid_loss = valid_logs["val/mean/loss"] + self._last_val_loss = valid_loss inference_error = inference_logs.get( "inference/time_mean_norm/rmse/channel_mean", None ) @@ -558,23 +616,19 @@ def validate_one_epoch(self): f"{self._epochs_trained} epochs" ) self.valid_data.set_epoch(self._epochs_trained) - self.stepper.set_eval() aggregator = self._aggregator_builder.get_validation_aggregator() logging.info("Starting loop over validation data") - with torch.no_grad(), self.validation_context(), GlobalTimer(): - for batch in self.valid_data.loader: - stepped = self.stepper.train_on_batch( - batch, - optimization=NullOptimization(), - compute_derived_variables=True, - ) - aggregator.record_batch( - batch=stepped, - ) + logs = run_validation( + stepper=self.stepper, + valid_data=self.valid_data, + aggregator=aggregator, + ema=self._ema, + validate_using_ema=self.config.validate_using_ema, + ) logging.info("Starting flush of reduced diagnostics to disk") aggregator.flush_diagnostics(subdir=f"epoch_{self._epochs_trained:04d}") logging.info("Getting validation aggregator logs") - return aggregator.get_logs(label="val") + return logs def inference_one_epoch( self, diff --git a/fme/core/generics/validation.py b/fme/core/generics/validation.py new file mode 100644 index 000000000..9d416f526 --- /dev/null +++ b/fme/core/generics/validation.py @@ -0,0 +1,51 @@ +import contextlib + +import torch + +from fme.core.ema import EMATracker +from fme.core.generics.aggregator import AggregatorABC +from fme.core.generics.data import GriddedDataABC +from fme.core.generics.train_stepper import TrainStepperABC +from fme.core.optimization import NullOptimization +from fme.core.timing import GlobalTimer + + +def run_validation( + stepper: TrainStepperABC, + valid_data: GriddedDataABC, + aggregator: AggregatorABC, + ema: EMATracker | None, + validate_using_ema: bool, +) -> dict[str, float]: + """ + Run validation on the given data and return logs. + + This is the core validation loop used by both the Trainer and + LR tuning trials. It does NOT call aggregator.flush_diagnostics — + the caller is responsible for flushing if needed. + + Args: + stepper: The train stepper to evaluate. + valid_data: The validation dataset. + aggregator: The aggregator to record batch results into. + ema: The EMA tracker, or None if EMA is not used. + validate_using_ema: Whether to use EMA parameters during validation. + + Returns: + Validation logs dict (e.g. {"val/mean/loss": ...}). + """ + stepper.set_eval() + ema_context: contextlib.AbstractContextManager = ( + ema.applied_params(stepper.modules) + if validate_using_ema and ema is not None + else contextlib.nullcontext() + ) + with torch.no_grad(), ema_context, GlobalTimer(): + for batch in valid_data.loader: + stepped = stepper.train_on_batch( + batch, + optimization=NullOptimization(), + compute_derived_variables=True, + ) + aggregator.record_batch(batch=stepped) + return aggregator.get_logs(label="val") diff --git a/fme/core/optimization.py b/fme/core/optimization.py index 8885633ce..0b16afdd2 100644 --- a/fme/core/optimization.py +++ b/fme/core/optimization.py @@ -204,6 +204,10 @@ def get_state(self): } return state + def set_learning_rate(self, lr: float): + for param_group in self.optimizer.param_groups: + param_group["lr"] = lr + def load_state(self, state): """ Loads state from a serializable data structure. @@ -293,6 +297,9 @@ def autocast(self): def learning_rate(self) -> float: return float("nan") + def set_learning_rate(self, lr: float): + pass + def checkpoint(self, module: nn.Module, step: int) -> nn.Module: return module diff --git a/fme/core/test_optimization.py b/fme/core/test_optimization.py index 0a1958abb..97e25b8e1 100644 --- a/fme/core/test_optimization.py +++ b/fme/core/test_optimization.py @@ -416,6 +416,150 @@ def test_sequential_scheduler_reload(): assert torch.allclose(model_first_final_state[k], model_second_final_state[k]) +def test_set_learning_rate(): + model = nn.Linear(1, 1).to(fme.get_device()) + optimization = Optimization( + parameters=model.parameters(), + optimizer_type="Adam", + lr=0.001, + max_epochs=10, + scheduler=SchedulerConfig(), + enable_automatic_mixed_precision=False, + kwargs={}, + ) + assert optimization.learning_rate == 0.001 + optimization.set_learning_rate(0.01) + assert optimization.learning_rate == 0.01 + + +def test_set_learning_rate_null(): + optimization = NullOptimization() + optimization.set_learning_rate(0.01) # should not raise + + +def test_load_state_into_different_parameters(): + """ + Test that optimizer state (including momentum) can be loaded from one + Optimization into another built with different parameter objects but + the same structure. This is the pattern used by LR tuning trials, + where we deepcopy a model and need the fork's optimizer to start + with the original's momentum. + """ + torch.manual_seed(0) + model = nn.Linear(2, 2).to(fme.get_device()) + x = torch.randn(10, 2).to(fme.get_device()) + + optimization = Optimization( + parameters=model.parameters(), + optimizer_type="Adam", + lr=0.001, + max_epochs=10, + scheduler=SchedulerConfig(), + enable_automatic_mixed_precision=False, + kwargs={}, + ) + + # Train a few steps to build up momentum state + for _ in range(3): + loss = model(x).sum() + optimization.accumulate_loss(loss) + optimization.step_weights() + + saved_state = optimization.get_state() + + # Create a new model with the same structure but different parameter objects + model2 = copy.deepcopy(model) + optimization2 = Optimization( + parameters=model2.parameters(), + optimizer_type="Adam", + lr=0.001, + max_epochs=10, + scheduler=SchedulerConfig(), + enable_automatic_mixed_precision=False, + kwargs={}, + ) + optimization2.load_state(saved_state) + + # Train both for one more step on identical data and verify identical results + x2 = x.clone() + loss1 = model(x).sum() + optimization.accumulate_loss(loss1) + optimization.step_weights() + + loss2 = model2(x2).sum() + optimization2.accumulate_loss(loss2) + optimization2.step_weights() + + for p1, p2 in zip(model.parameters(), model2.parameters()): + assert torch.allclose( + p1, p2 + ), "Parameters should match after identical training" + + +def test_load_state_then_set_learning_rate(): + """ + Test that set_learning_rate works correctly after loading state, + which is the pattern used to create a candidate fork at a different LR. + """ + torch.manual_seed(0) + model = nn.Linear(2, 2).to(fme.get_device()) + x = torch.randn(10, 2).to(fme.get_device()) + + optimization = Optimization( + parameters=model.parameters(), + optimizer_type="Adam", + lr=0.001, + max_epochs=10, + scheduler=SchedulerConfig(), + enable_automatic_mixed_precision=False, + kwargs={}, + ) + + # Train a few steps + for _ in range(3): + loss = model(x).sum() + optimization.accumulate_loss(loss) + optimization.step_weights() + + saved_state = optimization.get_state() + + # Build a new optimization, load state, then override LR + model2 = copy.deepcopy(model) + optimization2 = Optimization( + parameters=model2.parameters(), + optimizer_type="Adam", + lr=0.001, + max_epochs=10, + scheduler=SchedulerConfig(), + enable_automatic_mixed_precision=False, + kwargs={}, + ) + optimization2.load_state(saved_state) + optimization2.set_learning_rate(0.0005) + + assert optimization2.learning_rate == 0.0005 + + # Verify it actually trains at the new LR (different from original) + x2 = x.clone() + + loss1 = model(x).sum() + optimization.accumulate_loss(loss1) + optimization.step_weights() + + loss2 = model2(x2).sum() + optimization2.accumulate_loss(loss2) + optimization2.step_weights() + + # With different LRs, parameters should diverge + params_match = all( + torch.allclose(p1, p2) + for p1, p2 in zip(model.parameters(), model2.parameters()) + ) + assert ( + not params_match + ), "Parameters should differ when trained at different learning rates" + + def test_scheduler_step_timing(): """ Test that schedulers step at the correct timing based on diff --git a/fme/coupled/train/train_config.py b/fme/coupled/train/train_config.py index 64df79ede..b0ed97b3b 100644 --- a/fme/coupled/train/train_config.py +++ b/fme/coupled/train/train_config.py @@ -7,6 +7,7 @@ from fme.core.cli import ResumeResultsConfig from fme.core.distributed import Distributed from fme.core.ema import EMAConfig, EMATracker +from fme.core.generics.lr_tuning import LRTuningConfig from fme.core.generics.trainer import EndOfBatchCallback from fme.core.logging_utils import LoggingConfig from fme.core.optimization import Optimization, OptimizationConfig @@ -155,6 +156,7 @@ class TrainConfig: save_per_epoch_diagnostics: bool = False evaluate_before_training: bool = True save_best_inference_epoch_checkpoints: bool = False + lr_tuning: LRTuningConfig | None = None resume_results: ResumeResultsConfig | None = None @property diff --git a/fme/diffusion/train_config.py b/fme/diffusion/train_config.py index 917411125..21b44c596 100644 --- a/fme/diffusion/train_config.py +++ b/fme/diffusion/train_config.py @@ -17,6 +17,7 @@ from fme.core.cli import ResumeResultsConfig from fme.core.coordinates import VerticalCoordinate from fme.core.ema import EMAConfig, EMATracker +from fme.core.generics.lr_tuning import LRTuningConfig from fme.core.generics.trainer import EndOfBatchCallback, EndOfEpochCallback from fme.core.gridded_ops import GriddedOperations from fme.core.logging_utils import LoggingConfig @@ -114,6 +115,7 @@ class TrainConfig: save_per_epoch_diagnostics: bool = False evaluate_before_training: bool = False save_best_inference_epoch_checkpoints: bool = False + lr_tuning: LRTuningConfig | None = None resume_results: ResumeResultsConfig | None = None def __post_init__(self):