Kondo Gate backward skip, and some other changes.#1
Kondo Gate backward skip, and some other changes.#1plugyawn wants to merge 5 commits intogoogle-deepmind:mainfrom
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
Note that the compacted version is an approximation to the "true" gradient of the original implementation. On second thoughts... I think it's maybe better to think of it as a delight-based MoE-like router, systems-wise? Edit: Yep, that's definitely cleaner. |
|
Hi! I agree that the current kondo loss is not actually skipping the backwards step, but I guess it's meant to be a more clean/clear mathematical simulation, as opposed to a large scale optimized training framework. This looks like a great implementation of the kind of scalable infra, and awesome to see you implement that. I'd suggest that perhaps you create a fork to show this work, and maybe I can add a link to that fork with a clear comment explaining this in the existing kondo.py ... wdyt? |
|
Sounds great! I'd be happy to maintain the fork. I have also been experimenting with backward-skip infra in NeMO-RL for the last month, but haven't ported them into egg yet, so maybe they could into the fork as well. And, and some questions, if you don't mind! This is a bit speculative, but in the real world, I assume, a "mature" model would find most actions not-very-delightful; I wonder if it's related to the recent observations in pretraining by Bytedance Nexus / Gan, Isola et al's Neural Thickets, etc. The cumulative delight, (or maybe the decay of delight) over a run could then be a sign of how mature the model is for certain tasks, the same way the emergence of closeby "experts" seems to signal a mature pretrain. Also, I could perhaps look to a smaller implementation of the compute-skip too! I realized afterwards looking at the code footprint that this was probably just supposed to be a mathematical demonstration, as you said. If I could find a more principled, small 200-line implementation, I'd try to put up a separate PR; that said, I think without the KV Caching, it'd already be much smaller. Also, for LLMs, I suppose the kernel-friendly implementation would be to define token-chunks over which we skip backwards, instead of a gating network over batches, since that affects the expected gradient too much (a lot of delightful tokens in otherwise undelightful rows get removed). Also thanks again! This was a delight (haha) to work with! |
I noticed the implementation of the Kondo Gate doesn't actually skip the backward, as mentioned in the paper (it instead masks the loss and still pays the dense backward). In addition, the actual benefits on large-scale training wasn't becoming apparent due to lack of caching.
So, this PR adds a few changes:
Base Kondo at 70%/50% goes through all of the backward tokens,
but algorithmically roughly does the same.Edit: I'll add the ablations. The row-compaction does lead to an approximation to the "true" gradient of the original egg implementation, but I think it's closer to the paper's spirit?
Edit 2: On second thoughts, I might write this out as a MoE-like router over the training items. That should be cleaner.
The plots look a little too good, but they seem reproducible. I'll try with more configs to check.
The total step-time across 5000 step-runs drops by from ~54ms to ~38ms on average on my M3 Pro, for the default transformer config, due to reduced backward cost. However, across the run, this amortizes to ~21% reduction wallclock for 50% Kondo go over the same amount data (including logging costs; estimated logging step is 19ms per step, excluding which the 50% gate speedup goes to ~27%).
Across a 5000 step run, timings were: