Add optional LR tuning trials during training#930
Draft
mcgibbon wants to merge 13 commits intofeature/validation-two-layer-refactorfrom
Draft
Add optional LR tuning trials during training#930mcgibbon wants to merge 13 commits intofeature/validation-two-layer-refactorfrom
mcgibbon wants to merge 13 commits intofeature/validation-two-layer-refactorfrom
Conversation
Enable reuse of the validation loop outside the Trainer (e.g. for upcoming LR tuning trials) by extracting it into a module-level function that accepts a stepper, data, aggregator, and optional EMA. The Trainer method retains its epoch-boundary assertion and flush_diagnostics responsibility, delegating only the core loop. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Introduce an isolated trial function that creates two stepper forks (baseline at current LR, candidate at current_lr * lr_factor), trains both on the first N batches, validates both, and compares validation loss improvements. Returns the candidate LR only if both improve and the candidate exceeds a configurable threshold. The original model, optimizer, and EMA are never mutated. The trial function accepts copy_stepper and copy_ema callables so the caller controls how forks are created — this avoids copy.deepcopy on the stepper (which has deeply nested tensor structures) and allows the trainer to seed the fork EMAs from its own EMA state. Not yet wired into the Trainer — that will follow in a subsequent commit. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Hook the LR tuning trial into the Trainer's epoch loop, running _maybe_tune_lr before each training epoch. Extract run_validation into its own module (fme/core/generics/validation.py) to break the circular import between trainer.py and lr_tuning.py, replacing the previous TYPE_CHECKING workaround. Use OptimizationABC instead of the concrete Optimization class in lr_tuning.py to respect the generics layer's import conventions. Add lr_tuning config field to ace, coupled, and diffusion TrainConfig classes. Co-Authored-By: Claude Opus 4.6 <[email protected]>
The LR tuning trial creates baseline and candidate forks to evaluate a candidate learning rate. These forks must be fully isolated from the original stepper, optimizer, and EMA so that the trial has zero side effects when the candidate does not win. Two isolation failures were found and fixed: 1. Optimizer state (momentum buffers, Adam step counters) was shared between the forks and the original optimizer because `optimization.get_state()` returns references to live tensors. The forks' training incremented the original's step counter, corrupting Adam's bias correction and changing the effective learning rate for subsequent real training. Fixed by deepcopying the optimization state before loading into forks. 2. EMA `num_updates` tensor was shared because `get_state()` returned it by reference and `from_state()` assigned it directly. The forks' in-place `+=` mutated the original's counter, corrupting the EMA decay schedule. Fixed by cloning tensors in `get_state()`. The test helpers were also updated to use the real Trainer copy patterns (deepcopy + load_state for steppers, from_state for EMA) instead of creating fresh objects, so these isolation bugs are now caught. Three new tests verify that the original EMA num_updates, EMA params, and optimizer state are not mutated by a trial. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Tests that the Trainer._copy_stepper pattern (deepcopy + load_state) creates a fully isolated copy whose training does not affect the original stepper's predictions. Co-Authored-By: Claude Opus 4.6 <[email protected]>
Resolve conflicts in validation.py and test_trainer.py using a two-layer approach: run_validation_loop (minimal loop for LR tuning) and run_validation (high-level wrapper with flush/logging for Trainer and inference evaluator). Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
At the start of configurable epochs, the trainer forks the current model into a baseline and a candidate copy, trains both for a short number of batches (candidate at a scaled LR), validates both, and adopts the candidate LR if it wins by a configurable margin. This enables automatic learning rate adaptation without manual tuning or restarts.
Changes:
fme.core.generics.lr_tuning.LRTuningConfigandrun_lr_tuning_trial: new dataclass and function implementing isolated LR comparison trials with deep-copied stepper/optimization/EMA statefme.core.generics.validation.run_validation_loop: new low-level validation loop (eval, no_grad, EMA context, iterate batches, record to aggregator) extracted fromrun_validationfor reuse by LR tuning without flush/WandB side effectsfme.core.generics.validation.run_validation: now delegates torun_validation_loopand adds flush, get_logs, and WandB recording on topfme.core.generics.trainer.Trainer: adds_maybe_tune_lrcalled at the start of each epoch;validate_one_epochnow usesrun_validation; tracks_last_val_lossacross epochs for tuning thresholdfme.core.generics.optimization.OptimizationABC.set_learning_rate: new abstract method implemented byOptimizationandNullOptimizationfme.core.ema.EMATracker.get_state: clones tensors and copies the name mapping dict to avoid shared mutable state across forksLRTuningConfigadded as an optional field onace,coupled, anddiffusionTrainConfigdataclasses, and exported fromfme.aceTests added
If dependencies changed, "deps only" image rebuilt and "latest_deps_only_image.txt" file updated