-
Notifications
You must be signed in to change notification settings - Fork 38
Open
Description
Benefits of this are summarized as:
- Mid-epoch preemption recovery — the generic trainer saves _current_epoch_num_batches_seen and can resume mid-epoch via subset_loader
- Signal handling (SIGTERM/SIGINT) with safe checkpoint saves
- Per-batch logging frequency control (log_train_every_n_batches)
- Per-batch checkpoint frequency (checkpoint_every_n_batches)
- First-batch metric logging before training starts
- Post-training evaluation on independent shuffle of training data
- Epoch and EMA checkpoint saving at configurable epoch slices
- Best-inference-error checkpoint tracking
- Per-iteration scheduler stepping
- Proper EMA via EMATracker.applied_params context manager (the downscaling trainer uses the older store/copy_to/restore pattern which is more
error-prone)
Full description:
Assessment: Can downscaling use the generic Trainer?
Yes, it's feasible, but requires targeted work in three areas. The mismatch is smaller than it appears at first glance.
What the generic Trainer actually needs
Looking at the concrete method calls in Trainer, the stepper is used in exactly three ways:
1. train_one_epoch (line 455): stepper.train_on_batch(batch, optimization) -> TO where TO.get_metrics() returns a TensorDict. This is the core training
loop.
2. validate_one_epoch (line 566): stepper.train_on_batch(batch, NullOptimization(), compute_derived_variables=True) -> TO, then
aggregator.record_batch(TO). Same interface, no optimization applied.
3. inference_one_epoch (line 591): Calls stepper.predict_paired(initial_condition, forcing) via the Looper which steps forward repeatedly.
The generic Trainer also needs a GriddedDataABC[BD] for training/validation data (with loader, set_epoch, subset_loader, alternate_shuffle, batch_size)
and an InferenceDataABC[PS, FD] for inference.
Gap analysis
Gap 1: train_on_batch signature — easily bridgeable
The downscaling model's train_on_batch takes (batch, static_inputs, optimization) while TrainStepperABC.train_on_batch takes (data, optimization). The fix
is straightforward: bundle static_inputs into the batch data type BD. A DownscalingTrainStepper adapter would hold the DiffusionModel + StaticInputs, and
its train_on_batch(data: PairedBatchData, optimization) would forward to self.model.train_on_batch(data, self.static_inputs, optimization). The return
type needs a thin wrapper implementing TrainOutputABC.get_metrics() around the existing ModelOutputs.
Gap 2: predict_paired and inference — can be made optional
This is the real conceptual mismatch. The generic trainer's inference_one_epoch uses predict_paired(initial_condition, forcing) -> (stepped_data,
new_state) via a Looper that chains steps. Downscaling has no multi-step state propagation.
However, inference_one_epoch is only called when the epoch is in config.get_inference_epochs(). The downscaling trainer currently doesn't do inline
inference at all — its valid_one_epoch handles generation evaluation. So the simplest path is:
- Have get_inference_epochs() return an empty list, making inference_one_epoch never execute.
- Implement predict_paired as a NotImplementedError or a trivial single-step wrapper (the method exists on TrainStepperABC but would never be called).
The downscaling-specific generation evaluation (calling generate_on_batch with stochastic sampling during validation) can be handled via the
end_of_epoch_callback hook, which already exists in the generic Trainer and returns additional logs to merge.
Gap 3: GriddedDataABC — moderate adaptation needed
The generic Trainer expects GriddedDataABC[BD] with methods: loader, set_epoch, subset_loader, alternate_shuffle, batch_size, n_samples, n_batches,
log_info. The downscaling PairedGriddedData currently only has loader and get_generator/get_patched_generator.
Missing methods that need implementing:
- set_epoch(epoch) — currently not tracked, likely a no-op or pass-through to sampler
- subset_loader(start_batch, stop_batch) — needed for mid-epoch resume and train evaluation
- alternate_shuffle() — needed for independent train evaluation shuffle
- batch_size, n_samples, n_batches — simple properties
- log_info(name) — simple logging
This is real work but mechanical — none of these require fundamental design changes.
Gap 4: Validation with generation — solvable via the aggregator builder
The downscaling trainer's unique feature is that valid_one_epoch runs both train_on_batch (for loss) and generate_on_batch (for stochastic sampling
metrics). The generic trainer's validate_one_epoch only calls train_on_batch.
Two clean approaches:
Option A: The validation aggregator itself calls generate_on_batch. This means the AggregatorABC[TO] implementation for downscaling would hold a reference
to the model and call generation internally when record_batch is called. This is a bit unusual but keeps the Trainer interface unchanged.
Option B (cleaner): Use the end_of_epoch_callback to run generation evaluation after validation completes. The callback has access to the model (via
closure) and already returns Mapping[str, Any] that gets merged into epoch logs.
Gap 5: Patching — solvable within the data layer
The downscaling trainer has spatial patching logic (get_patched_generator). This can live entirely inside the GriddedDataABC adapter — the loader property
would return patched batches when configured, transparent to the Trainer.
What the generic Trainer gives you for free
Things the downscaling trainer currently lacks or implements less robustly:
- Mid-epoch preemption recovery — the generic trainer saves _current_epoch_num_batches_seen and can resume mid-epoch via subset_loader
- Signal handling (SIGTERM/SIGINT) with safe checkpoint saves
- Per-batch logging frequency control (log_train_every_n_batches)
- Per-batch checkpoint frequency (checkpoint_every_n_batches)
- First-batch metric logging before training starts
- Post-training evaluation on independent shuffle of training data
- Epoch and EMA checkpoint saving at configurable epoch slices
- Best-inference-error checkpoint tracking
- Per-iteration scheduler stepping
- Proper EMA via EMATracker.applied_params context manager (the downscaling trainer uses the older store/copy_to/restore pattern which is more
error-prone)
Summary: what's needed
┌─────────────────────────────────────────────────────────────────────┬──────────┬─────────────────────────────────────────────────────────────────┐
│ Work Item │ Effort │ Notes │
├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ DownscalingTrainStepper adapter wrapping DiffusionModel │ Small │ Bundle static_inputs, wrap ModelOutputs in TrainOutputABC │
├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ TrainOutputABC wrapper for ModelOutputs │ Trivial │ Implement get_metrics() returning {"loss": outputs.loss} │
├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ GriddedDataABC adapter for PairedGriddedData │ Moderate │ Add set_epoch, subset_loader, alternate_shuffle, property stubs │
├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ Generation evaluation via end_of_epoch_callback │ Moderate │ Move generate_on_batch loop into callback, return metrics dict │
├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ Stub predict_paired + empty get_inference_epochs() │ Trivial │ Never called, just satisfies the interface │
├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ set_train/set_eval/get_state/load_state/update_training_history │ Small │ Delegate to model, mostly one-liners │
├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
│ Migrate EMA from store/copy_to/restore to EMATracker.applied_params │ Small │ Comes free with the generic trainer │
└─────────────────────────────────────────────────────────────────────┴──────────┴─────────────────────────────────────────────────────────────────┘
The total is probably 200-300 lines of adapter code, replacing ~350 lines of the current custom Trainer (lines 90-396). The net LOC change is small, but
the benefits are significant: you get preemption recovery, signal handling, per-batch checkpointing, and all future improvements to the generic Trainer
for free.
Recommendation
Do it. The conceptual mismatch ("downscaling isn't time-stepping") is real but doesn't actually block adoption. The generic Trainer's core contract is
simpler than it looks: call train_on_batch in a loop, call an aggregator, manage checkpoints. The time-stepping inference (predict_paired / Looper) is an
optional feature controlled by get_inference_epochs() — just return empty. The generation evaluation unique to downscaling fits cleanly into
end_of_epoch_callback. The GriddedDataABC adaptation is the most work but it's purely mechanical and also gives downscaling mid-epoch resume capability it
currently lacks.
✻ Cooked for 2m 38s
❯ The main purpose of "inference" in ace is to run "the problem we really care about", in that case long-term stepping. In diffusion, there's an analog of
this - while we train and validate on a single output of the model, what we care about is the multi-step autoregressive iteration of the diffusion model
which "generates" an output sample (we call it generation in that code base). Is it possible to have "inference" in the case of diffusion training perform
generation instead, using the Trainer?
● Explore(Diffusion generation vs inference)
⎿ Done (25 tool uses · 40.8k tokens · 55s)
(ctrl+o to expand)
● The answer is: the diffusion module already does this. Let me verify one detail to be sure about the downscaling case.
● Read 1 file (ctrl+o to expand)
● This confirms it. Here's the answer:
---
The diffusion module already does exactly this. The architecture is:
- Training (train_on_batch): trains a single denoising step — the network learns to denoise one noisy sample. This is fast and is what the loss is
computed on.
- Inference (predict_paired → called by the Trainer's Looper): each call to step() runs the full EDM sampling loop (edm_sampler, ~18 denoising iterations,
~36 network calls) to produce one clean output timestep. The Looper then chains these across multiple timesteps autoregressively, feeding each output as
input to the next.
So when the generic Trainer calls inference_one_epoch(), it already runs multi-step generation — the predict_paired interface naturally accommodates this
because:
1. step() (line 739) creates random latents and calls edm_sampler() — this is the generation process
2. _predict() (line 786) loops step() across n_forward_steps timesteps autoregressively
3. predict_paired() (line 896) wraps this with reference data pairing
4. The Looper chains predict_paired calls across data windows
The config controls this via inference_n_forward_steps and get_inference_epochs() in fme/diffusion/train_config.py.
For downscaling, the same pattern could work if you define "inference" as "run the full EDM generation process and evaluate the result." The difference is
that downscaling doesn't chain steps autoregressively — it's one shot per sample. But you can still use the inference machinery:
- predict_paired would run edm_sampler to generate one output (one full generation), pair it with the target, and return
- The Looper would iterate over evaluation samples (not timesteps), calling predict_paired once per sample
- InferenceDataABC would provide evaluation batches as "forcing" windows, with a dummy initial condition
- The InferenceAggregatorABC would collect generation-quality metrics (CRPS, histogram tails, etc.)
The n_forward_steps=1 case works fine with the Looper — it just means each window produces one step of output. The Looper doesn't care whether the steps
are autoregressive or independent; it just calls predict_paired and records the result.
✻ Crunched for 1m 32s
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels