Skip to content

Kondo Gate backward skip, and some other changes.#1

Open
plugyawn wants to merge 5 commits intogoogle-deepmind:mainfrom
plugyawn:main
Open

Kondo Gate backward skip, and some other changes.#1
plugyawn wants to merge 5 commits intogoogle-deepmind:mainfrom
plugyawn:main

Conversation

@plugyawn
Copy link
Copy Markdown

@plugyawn plugyawn commented Apr 16, 2026

I noticed the implementation of the Kondo Gate doesn't actually skip the backward, as mentioned in the paper (it instead masks the loss and still pays the dense backward). In addition, the actual benefits on large-scale training wasn't becoming apparent due to lack of caching.

So, this PR adds a few changes:

  • A backward-skipping Kondo Gate implementation, that skips backward cost by screening the batch, compacting, and then only diffing through the kept subset. This incurs a small screening cost; I have attached the timings below. I'm sure it can be further amortized.
  • In the case that the learner and actor policies are the same, and hence training is on-policy, a second forward pass is unnecessary and can be skipped. This reduces the otherwise $F + k ( F + B)$ cost to just $F + kB$ (for cases where we're (at least nearly) on-policy, from what I understand).
  • The wallclock savings from the Kondo Gate alone were hard to notice at longer prompt_lengths because caching wasn't implemented. With caching, the trainer is a much larger chunk of the per-step timing, and hence skipping the backward is even more rewarding, Amdahl-wise.
image This is on a 4-vocab reversal task with prompt_length 12. Averaged across 3 seeds, 5000 steps each. Base Kondo is the current implementation, with k=1.0; note that the wallclock includes logging-at-every-step overhead, the actual difference will probably be bigger.
  • Base Kondo 100% / base PG: ~364s, 19.2M backward tokens
  • Backward-skip Kondo 50%: ~287s, 9.6M backward tokens
  • Backward-skip Kondo 70%: ~324s, 13.45M backward tokens

Base Kondo at 70%/50% goes through all of the backward tokens, but algorithmically roughly does the same.

Edit: I'll add the ablations. The row-compaction does lead to an approximation to the "true" gradient of the original egg implementation, but I think it's closer to the paper's spirit?

Edit 2: On second thoughts, I might write this out as a MoE-like router over the training items. That should be cleaner.

The plots look a little too good, but they seem reproducible. I'll try with more configs to check.


The total step-time across 5000 step-runs drops by from ~54ms to ~38ms on average on my M3 Pro, for the default transformer config, due to reduced backward cost. However, across the run, this amortizes to ~21% reduction wallclock for 50% Kondo go over the same amount data (including logging costs; estimated logging step is 19ms per step, excluding which the 50% gate speedup goes to ~27%).

Across a 5000 step run, timings were:

  • Base Kondo 100% / base PG
    • total step: 53.93 ms
    • sample: 17.93 ms
    • screen: 0 ms
    • compact: 0 ms
    • train: 36.00 ms
    • backward tokens / step: 3840
    • total wall clock: 364.32 s
  • Backward-skip Kondo 70%
    • total step: 46.17 ms
    • sample: 18.42 ms
    • screen: 0.489 ms
    • compact: 0.078 ms
    • train: 27.19 ms
    • backward tokens / step: 2690
    • backward fraction: 0.7005
    • total wall clock: 323.89 s
  • Backward-skip Kondo 50%
    • total step: 37.76 ms
    • sample: 17.90 ms
    • screen: 0.467 ms
    • compact: 0.063 ms
    • train: 19.33 ms
    • backward tokens / step: 1920
    • backward fraction: 0.5
    • total wall clock: 287.17 s

@google-cla
Copy link
Copy Markdown

google-cla Bot commented Apr 16, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Apr 16, 2026

Note that the compacted version is an approximation to the "true" gradient of the original implementation.

On second thoughts... I think it's maybe better to think of it as a delight-based MoE-like router, systems-wise?

Edit: Yep, that's definitely cleaner.

@iosband
Copy link
Copy Markdown
Collaborator

iosband commented Apr 21, 2026

Hi!
Having a look here, my belief is that this is working as intended...

I agree that the current kondo loss is not actually skipping the backwards step, but I guess it's meant to be a more clean/clear mathematical simulation, as opposed to a large scale optimized training framework.

This looks like a great implementation of the kind of scalable infra, and awesome to see you implement that.
However, at an additional 1k lines of code I don't think that implementing this in the main branch (and extensive review) is on the pipeline right now.

I'd suggest that perhaps you create a fork to show this work, and maybe I can add a link to that fork with a clear comment explaining this in the existing kondo.py ... wdyt?

@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented Apr 21, 2026

Sounds great! I'd be happy to maintain the fork. I have also been experimenting with backward-skip infra in NeMO-RL for the last month, but haven't ported them into egg yet, so maybe they could into the fork as well.

And, and some questions, if you don't mind!
Do you think there's value to annealing the gate threshold... or maybe a learned gate, so as training proceeds more and more backwards can be skipped (since I assume that's when the large benefits could come)?

This is a bit speculative, but in the real world, I assume, a "mature" model would find most actions not-very-delightful; I wonder if it's related to the recent observations in pretraining by Bytedance Nexus / Gan, Isola et al's Neural Thickets, etc. The cumulative delight, (or maybe the decay of delight) over a run could then be a sign of how mature the model is for certain tasks, the same way the emergence of closeby "experts" seems to signal a mature pretrain.

Also, I could perhaps look to a smaller implementation of the compute-skip too! I realized afterwards looking at the code footprint that this was probably just supposed to be a mathematical demonstration, as you said. If I could find a more principled, small 200-line implementation, I'd try to put up a separate PR; that said, I think without the KV Caching, it'd already be much smaller.

Also, for LLMs, I suppose the kernel-friendly implementation would be to define token-chunks over which we skip backwards, instead of a gating network over batches, since that affects the expected gradient too much (a lot of delightful tokens in otherwise undelightful rows get removed).

Also thanks again! This was a delight (haha) to work with!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants