Skip to content

Downscaling could re-use generic Trainer, gaining its features #941

@mcgibbon

Description

@mcgibbon

Benefits of this are summarized as:

  - Mid-epoch preemption recovery — the generic trainer saves _current_epoch_num_batches_seen and can resume mid-epoch via subset_loader
  - Signal handling (SIGTERM/SIGINT) with safe checkpoint saves
  - Per-batch logging frequency control (log_train_every_n_batches)
  - Per-batch checkpoint frequency (checkpoint_every_n_batches)
  - First-batch metric logging before training starts
  - Post-training evaluation on independent shuffle of training data
  - Epoch and EMA checkpoint saving at configurable epoch slices
  - Best-inference-error checkpoint tracking
  - Per-iteration scheduler stepping
  - Proper EMA via EMATracker.applied_params context manager (the downscaling trainer uses the older store/copy_to/restore pattern which is more
  error-prone)

Full description:

  Assessment: Can downscaling use the generic Trainer?

  Yes, it's feasible, but requires targeted work in three areas. The mismatch is smaller than it appears at first glance.

  What the generic Trainer actually needs

  Looking at the concrete method calls in Trainer, the stepper is used in exactly three ways:

  1. train_one_epoch (line 455): stepper.train_on_batch(batch, optimization) -> TO where TO.get_metrics() returns a TensorDict. This is the core training
  loop.
  2. validate_one_epoch (line 566): stepper.train_on_batch(batch, NullOptimization(), compute_derived_variables=True) -> TO, then
  aggregator.record_batch(TO). Same interface, no optimization applied.
  3. inference_one_epoch (line 591): Calls stepper.predict_paired(initial_condition, forcing) via the Looper which steps forward repeatedly.

  The generic Trainer also needs a GriddedDataABC[BD] for training/validation data (with loader, set_epoch, subset_loader, alternate_shuffle, batch_size)
  and an InferenceDataABC[PS, FD] for inference.

  Gap analysis

  Gap 1: train_on_batch signature — easily bridgeable

  The downscaling model's train_on_batch takes (batch, static_inputs, optimization) while TrainStepperABC.train_on_batch takes (data, optimization). The fix
   is straightforward: bundle static_inputs into the batch data type BD. A DownscalingTrainStepper adapter would hold the DiffusionModel + StaticInputs, and
   its train_on_batch(data: PairedBatchData, optimization) would forward to self.model.train_on_batch(data, self.static_inputs, optimization). The return
  type needs a thin wrapper implementing TrainOutputABC.get_metrics() around the existing ModelOutputs.

  Gap 2: predict_paired and inference — can be made optional

  This is the real conceptual mismatch. The generic trainer's inference_one_epoch uses predict_paired(initial_condition, forcing) -> (stepped_data,
  new_state) via a Looper that chains steps. Downscaling has no multi-step state propagation.

  However, inference_one_epoch is only called when the epoch is in config.get_inference_epochs(). The downscaling trainer currently doesn't do inline
  inference at all — its valid_one_epoch handles generation evaluation. So the simplest path is:

  - Have get_inference_epochs() return an empty list, making inference_one_epoch never execute.
  - Implement predict_paired as a NotImplementedError or a trivial single-step wrapper (the method exists on TrainStepperABC but would never be called).

  The downscaling-specific generation evaluation (calling generate_on_batch with stochastic sampling during validation) can be handled via the
  end_of_epoch_callback hook, which already exists in the generic Trainer and returns additional logs to merge.

  Gap 3: GriddedDataABC — moderate adaptation needed

  The generic Trainer expects GriddedDataABC[BD] with methods: loader, set_epoch, subset_loader, alternate_shuffle, batch_size, n_samples, n_batches,
  log_info. The downscaling PairedGriddedData currently only has loader and get_generator/get_patched_generator.

  Missing methods that need implementing:
  - set_epoch(epoch) — currently not tracked, likely a no-op or pass-through to sampler
  - subset_loader(start_batch, stop_batch) — needed for mid-epoch resume and train evaluation
  - alternate_shuffle() — needed for independent train evaluation shuffle
  - batch_size, n_samples, n_batches — simple properties
  - log_info(name) — simple logging

  This is real work but mechanical — none of these require fundamental design changes.

  Gap 4: Validation with generation — solvable via the aggregator builder

  The downscaling trainer's unique feature is that valid_one_epoch runs both train_on_batch (for loss) and generate_on_batch (for stochastic sampling
  metrics). The generic trainer's validate_one_epoch only calls train_on_batch.

  Two clean approaches:

  Option A: The validation aggregator itself calls generate_on_batch. This means the AggregatorABC[TO] implementation for downscaling would hold a reference
   to the model and call generation internally when record_batch is called. This is a bit unusual but keeps the Trainer interface unchanged.

  Option B (cleaner): Use the end_of_epoch_callback to run generation evaluation after validation completes. The callback has access to the model (via
  closure) and already returns Mapping[str, Any] that gets merged into epoch logs.

  Gap 5: Patching — solvable within the data layer

  The downscaling trainer has spatial patching logic (get_patched_generator). This can live entirely inside the GriddedDataABC adapter — the loader property
   would return patched batches when configured, transparent to the Trainer.

  What the generic Trainer gives you for free

  Things the downscaling trainer currently lacks or implements less robustly:

  - Mid-epoch preemption recovery — the generic trainer saves _current_epoch_num_batches_seen and can resume mid-epoch via subset_loader
  - Signal handling (SIGTERM/SIGINT) with safe checkpoint saves
  - Per-batch logging frequency control (log_train_every_n_batches)
  - Per-batch checkpoint frequency (checkpoint_every_n_batches)
  - First-batch metric logging before training starts
  - Post-training evaluation on independent shuffle of training data
  - Epoch and EMA checkpoint saving at configurable epoch slices
  - Best-inference-error checkpoint tracking
  - Per-iteration scheduler stepping
  - Proper EMA via EMATracker.applied_params context manager (the downscaling trainer uses the older store/copy_to/restore pattern which is more
  error-prone)

  Summary: what's needed

  ┌─────────────────────────────────────────────────────────────────────┬──────────┬─────────────────────────────────────────────────────────────────┐
  │                              Work Item                              │  Effort  │                              Notes                              │
  ├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
  │ DownscalingTrainStepper adapter wrapping DiffusionModel             │ Small    │ Bundle static_inputs, wrap ModelOutputs in TrainOutputABC       │
  ├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
  │ TrainOutputABC wrapper for ModelOutputs                             │ Trivial  │ Implement get_metrics() returning {"loss": outputs.loss}        │
  ├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
  │ GriddedDataABC adapter for PairedGriddedData                        │ Moderate │ Add set_epoch, subset_loader, alternate_shuffle, property stubs │
  ├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
  │ Generation evaluation via end_of_epoch_callback                     │ Moderate │ Move generate_on_batch loop into callback, return metrics dict  │
  ├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
  │ Stub predict_paired + empty get_inference_epochs()                  │ Trivial  │ Never called, just satisfies the interface                      │
  ├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
  │ set_train/set_eval/get_state/load_state/update_training_history     │ Small    │ Delegate to model, mostly one-liners                            │
  ├─────────────────────────────────────────────────────────────────────┼──────────┼─────────────────────────────────────────────────────────────────┤
  │ Migrate EMA from store/copy_to/restore to EMATracker.applied_params │ Small    │ Comes free with the generic trainer                             │
  └─────────────────────────────────────────────────────────────────────┴──────────┴─────────────────────────────────────────────────────────────────┘

  The total is probably 200-300 lines of adapter code, replacing ~350 lines of the current custom Trainer (lines 90-396). The net LOC change is small, but
  the benefits are significant: you get preemption recovery, signal handling, per-batch checkpointing, and all future improvements to the generic Trainer
  for free.

  Recommendation

  Do it. The conceptual mismatch ("downscaling isn't time-stepping") is real but doesn't actually block adoption. The generic Trainer's core contract is
  simpler than it looks: call train_on_batch in a loop, call an aggregator, manage checkpoints. The time-stepping inference (predict_paired / Looper) is an
  optional feature controlled by get_inference_epochs() — just return empty. The generation evaluation unique to downscaling fits cleanly into
  end_of_epoch_callback. The GriddedDataABC adaptation is the most work but it's purely mechanical and also gives downscaling mid-epoch resume capability it
   currently lacks.

✻ Cooked for 2m 38s

❯ The main purpose of "inference" in ace is to run "the problem we really care about", in that case long-term stepping. In diffusion, there's an analog of
this - while we train and validate on a single output of the model, what we care about is the multi-step autoregressive iteration of the diffusion model
which "generates" an output sample (we call it generation in that code base). Is it possible to have "inference" in the case of diffusion training perform
generation instead, using the Trainer?

● Explore(Diffusion generation vs inference)
  ⎿  Done (25 tool uses · 40.8k tokens · 55s)
  (ctrl+o to expand)

● The answer is: the diffusion module already does this. Let me verify one detail to be sure about the downscaling case.

● Read 1 file (ctrl+o to expand)

● This confirms it. Here's the answer:

  ---
  The diffusion module already does exactly this. The architecture is:

  - Training (train_on_batch): trains a single denoising step — the network learns to denoise one noisy sample. This is fast and is what the loss is
  computed on.
  - Inference (predict_paired → called by the Trainer's Looper): each call to step() runs the full EDM sampling loop (edm_sampler, ~18 denoising iterations,
   ~36 network calls) to produce one clean output timestep. The Looper then chains these across multiple timesteps autoregressively, feeding each output as
  input to the next.

  So when the generic Trainer calls inference_one_epoch(), it already runs multi-step generation — the predict_paired interface naturally accommodates this
  because:

  1. step() (line 739) creates random latents and calls edm_sampler() — this is the generation process
  2. _predict() (line 786) loops step() across n_forward_steps timesteps autoregressively
  3. predict_paired() (line 896) wraps this with reference data pairing
  4. The Looper chains predict_paired calls across data windows

  The config controls this via inference_n_forward_steps and get_inference_epochs() in fme/diffusion/train_config.py.

  For downscaling, the same pattern could work if you define "inference" as "run the full EDM generation process and evaluate the result." The difference is
   that downscaling doesn't chain steps autoregressively — it's one shot per sample. But you can still use the inference machinery:

  - predict_paired would run edm_sampler to generate one output (one full generation), pair it with the target, and return
  - The Looper would iterate over evaluation samples (not timesteps), calling predict_paired once per sample
  - InferenceDataABC would provide evaluation batches as "forcing" windows, with a dummy initial condition
  - The InferenceAggregatorABC would collect generation-quality metrics (CRPS, histogram tails, etc.)

  The n_forward_steps=1 case works fine with the Looper — it just means each window produces one step of output. The Looper doesn't care whether the steps
  are autoregressive or independent; it just calls predict_paired and records the result.

✻ Crunched for 1m 32s

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions