Skip to content

Add optional LR tuning trials during training#930

Draft
mcgibbon wants to merge 13 commits intofeature/validation-two-layer-refactorfrom
feature/candidate_lr
Draft

Add optional LR tuning trials during training#930
mcgibbon wants to merge 13 commits intofeature/validation-two-layer-refactorfrom
feature/candidate_lr

Conversation

@mcgibbon
Copy link
Contributor

@mcgibbon mcgibbon commented Mar 9, 2026

At the start of configurable epochs, the trainer forks the current model into a baseline and a candidate copy, trains both for a short number of batches (candidate at a scaled LR), validates both, and adopts the candidate LR if it wins by a configurable margin. This enables automatic learning rate adaptation without manual tuning or restarts.

Changes:

  • fme.core.generics.lr_tuning.LRTuningConfig and run_lr_tuning_trial: new dataclass and function implementing isolated LR comparison trials with deep-copied stepper/optimization/EMA state

  • fme.core.generics.validation.run_validation_loop: new low-level validation loop (eval, no_grad, EMA context, iterate batches, record to aggregator) extracted from run_validation for reuse by LR tuning without flush/WandB side effects

  • fme.core.generics.validation.run_validation: now delegates to run_validation_loop and adds flush, get_logs, and WandB recording on top

  • fme.core.generics.trainer.Trainer: adds _maybe_tune_lr called at the start of each epoch; validate_one_epoch now uses run_validation; tracks _last_val_loss across epochs for tuning threshold

  • fme.core.generics.optimization.OptimizationABC.set_learning_rate: new abstract method implemented by Optimization and NullOptimization

  • fme.core.ema.EMATracker.get_state: clones tensors and copies the name mapping dict to avoid shared mutable state across forks

  • LRTuningConfig added as an optional field on ace, coupled, and diffusion TrainConfig dataclasses, and exported from fme.ace

  • Tests added

  • If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated

mcgibbon and others added 13 commits March 9, 2026 21:13
Enable reuse of the validation loop outside the Trainer (e.g. for upcoming
LR tuning trials) by extracting it into a module-level function that accepts
a stepper, data, aggregator, and optional EMA. The Trainer method retains
its epoch-boundary assertion and flush_diagnostics responsibility, delegating
only the core loop.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Introduce an isolated trial function that creates two stepper forks
(baseline at current LR, candidate at current_lr * lr_factor), trains
both on the first N batches, validates both, and compares validation loss
improvements. Returns the candidate LR only if both improve and the
candidate exceeds a configurable threshold. The original model, optimizer,
and EMA are never mutated.

The trial function accepts copy_stepper and copy_ema callables so the
caller controls how forks are created — this avoids copy.deepcopy on the
stepper (which has deeply nested tensor structures) and allows the trainer
to seed the fork EMAs from its own EMA state.

Not yet wired into the Trainer — that will follow in a subsequent commit.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Hook the LR tuning trial into the Trainer's epoch loop, running
_maybe_tune_lr before each training epoch. Extract run_validation into
its own module (fme/core/generics/validation.py) to break the circular
import between trainer.py and lr_tuning.py, replacing the previous
TYPE_CHECKING workaround. Use OptimizationABC instead of the concrete
Optimization class in lr_tuning.py to respect the generics layer's
import conventions. Add lr_tuning config field to ace, coupled, and
diffusion TrainConfig classes.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
The LR tuning trial creates baseline and candidate forks to evaluate a
candidate learning rate. These forks must be fully isolated from the
original stepper, optimizer, and EMA so that the trial has zero side
effects when the candidate does not win.

Two isolation failures were found and fixed:

1. Optimizer state (momentum buffers, Adam step counters) was shared
   between the forks and the original optimizer because
   `optimization.get_state()` returns references to live tensors.
   The forks' training incremented the original's step counter,
   corrupting Adam's bias correction and changing the effective
   learning rate for subsequent real training. Fixed by deepcopying
   the optimization state before loading into forks.

2. EMA `num_updates` tensor was shared because `get_state()` returned
   it by reference and `from_state()` assigned it directly. The forks'
   in-place `+=` mutated the original's counter, corrupting the EMA
   decay schedule. Fixed by cloning tensors in `get_state()`.

The test helpers were also updated to use the real Trainer copy patterns
(deepcopy + load_state for steppers, from_state for EMA) instead of
creating fresh objects, so these isolation bugs are now caught. Three
new tests verify that the original EMA num_updates, EMA params, and
optimizer state are not mutated by a trial.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Tests that the Trainer._copy_stepper pattern (deepcopy + load_state)
creates a fully isolated copy whose training does not affect the
original stepper's predictions.

Co-Authored-By: Claude Opus 4.6 <[email protected]>
Resolve conflicts in validation.py and test_trainer.py using a two-layer
approach: run_validation_loop (minimal loop for LR tuning) and
run_validation (high-level wrapper with flush/logging for Trainer and
inference evaluator).

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
@climate-ci-github climate-ci-github changed the title Auto tune learning rate based on validation loss Add optional LR tuning trials during training Mar 25, 2026
@mcgibbon mcgibbon changed the base branch from main to feature/validation-two-layer-refactor March 25, 2026 20:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant