From 8de486671599d844337e56cd5166d27cf339661f Mon Sep 17 00:00:00 2001 From: Mateusz Kaduk Date: Thu, 14 Nov 2024 13:32:44 +0100 Subject: [PATCH 1/5] src/rules.jl: Add Moun optimiser. --- src/Optimisers.jl | 5 ++- src/rules.jl | 95 +++++++++++++++++++++++++++++++++++++++++++---- test/rules.jl | 12 +++--- 3 files changed, 97 insertions(+), 15 deletions(-) diff --git a/src/Optimisers.jl b/src/Optimisers.jl index 99fc162f..0bce300d 100644 --- a/src/Optimisers.jl +++ b/src/Optimisers.jl @@ -1,9 +1,10 @@ module Optimisers -using Functors: functor, fmap, fmap_with_path, +using Functors: functor, fmap, fmap_with_path, KeyPath, haskeypath, getkeypath, isleaf, @functor, fmapstructure, children, AbstractWalk using LinearAlgebra +import LinearAlgebra: norm include("interface.jl") export AbstractRule @@ -23,7 +24,7 @@ include("rules.jl") export Descent, Adam, Momentum, Nesterov, Rprop, RMSProp, AdaGrad, AdaMax, AdaDelta, AMSGrad, NAdam, AdamW, RAdam, OAdam, AdaBelief, WeightDecay, SignDecay, ClipGrad, ClipNorm, OptimiserChain, Lion, - AccumGrad + AccumGrad, Muon VERSION >= v"1.11.0-DEV.469" && eval(Meta.parse("public apply!, init, setup, update, update!")) diff --git a/src/rules.jl b/src/rules.jl index 0cd8d30c..93149d67 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -181,7 +181,7 @@ init(o::Rprop, x::AbstractArray) = (zero(x), onevalue(o.eta, x)) function apply!(o::Rprop, state, x::AbstractArray{T}, dx) where T ℓ, Γ = T.(o.ell), T.(o.gamma) g, η = state - + η = broadcast(g, η, dx) do g, η, dx g * dx > 0 ? min(η * ℓ[2], Γ[2]) : g * dx < 0 ? max(η * ℓ[1], Γ[1]) : η end @@ -256,6 +256,87 @@ function apply!(o::Lion, state, x::AbstractArray{T}, dx) where T return state, dx′ end +""" + Muon(η = 0.02, ρ = 0.95; steps = 5) + Muon(; [eta, rho, steps]) + +Muon - MomentUm Orthogonalized by Newton-schulz + +Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, +in which each 2D parameter's update is replaced with the nearest orthogonal matrix using Newton-Schulz iteration. + +# Parameters +- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights +- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the prominent direction +- Steps: Number of Newton-Schulz iteration steps for orthogonalization + +Note: This optimizer only acts on arrays with 2 or more dimensions (matrices, tensors). +Parameters with fewer dimensions are silently ignored. Works best with large batch sizes +and may not be suitable for fine-tuning. +""" +@def struct Muon <: AbstractRule + eta = 0.02 + rho = 0.95 + steps = 5 +end + +function init(o::Muon, x::AbstractArray) + ndims(x) < 2 ? nothing : zero(x) +end + +function apply!(o::Muon, state, x::AbstractArray{T}, dx) where T + # Silently pass through 1D arrays unchanged + if ndims(x) < 2 + return nothing, dx + end + + η, ρ = T(o.eta), T(o.rho) + + # Update momentum buffer + @.. state = ρ * state + dx + + # For higher dimensional tensors, reshape to matrix, orthogonalize, then reshape back + original_size = size(state) + if ndims(state) > 2 + state_mat = reshape(state, size(state,1), :) + dx_orth = _newton_schulz_orthogonalize(state_mat, o.steps) + dx_orth = reshape(dx_orth, original_size) + else + dx_orth = _newton_schulz_orthogonalize(state, o.steps) + end + + # Scale based on matrix dimensions + scale = sqrt(max(1, size(dx_orth,1)/size(dx_orth,2))) + dx′ = @lazy η * scale * dx_orth + + return state, dx′ +end + +# _newton_schulz_orthogonalize remains unchanged +function _newton_schulz_orthogonalize(G::AbstractMatrix, steps::Int) + a, b, c = (3.4445f0, -4.7750f0, 2.0315f0) + + X = G + X = X / (norm(X) + eps()) + + transposed = size(G, 1) > size(G, 2) + if transposed + X = X' + end + + for _ in 1:steps + A = X * X' + B = b * A + c * A * A + X = a * X + B * X + end + + if transposed + X = X' + end + + X +end + """ RAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8) RAdam(; [eta, beta, epsilon]) @@ -603,10 +684,10 @@ end WeightDecay(λ = 5e-4) WeightDecay(; [lambda]) -Implements ``L_2`` regularisation, also known as ridge regression, +Implements ``L_2`` regularisation, also known as ridge regression, when composed with other rules as the first transformation in an [`OptimiserChain`](@ref). -It does this by adding `λ .* x` to the gradient. This is equivalent to adding +It does this by adding `λ .* x` to the gradient. This is equivalent to adding `λ/2 * sum(abs2, x) == λ/2 * norm(x)^2` to the loss. See also [`SignDecay`] for ``L_1`` normalisation. @@ -644,7 +725,7 @@ function adjust(r::WeightDecay; gamma = nothing, kw...) Implements ``L_1`` regularisation, also known as LASSO regression, when composed with other rules as the first transformation in an [`OptimiserChain`](@ref). -It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding +It does this by adding `λ .* sign(x)` to the gradient. This is equivalent to adding `λ * sum(abs, x) == λ * norm(x, 1)` to the loss. See also [`WeightDecay`] for ``L_2`` normalisation. @@ -783,7 +864,7 @@ function apply!(o::OptimiserChain, states, x, dx, dxs...) foldl(tuple.(o.opts, states); init = ((), dx)) do (states′, dx′), (opt, state) if dx′ isa Zero return (states′..., state), dx′ - else + else state′, dx′ = apply!(opt, state, x, dx′, dxs...) return (states′..., state′), dx′ end @@ -831,10 +912,10 @@ julia> m # n=2 gradients applied at once """ struct AccumGrad <: AbstractRule n::Int - + function AccumGrad(n::Int) n > 0 || throw(ArgumentError("AccumGrad must accumulate at least one gradient")) - return new(n) + return new(n) end end diff --git a/test/rules.jl b/test/rules.jl index 499902ca..2cbc3efc 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -8,7 +8,7 @@ RULES = [ # All the rules at default settings: Descent(), Adam(), Momentum(), Nesterov(), Rprop(), RMSProp(), AdaGrad(), AdaMax(), AdaDelta(), AMSGrad(), NAdam(), - AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), + AdamW(), RAdam(), OAdam(), AdaBelief(), Lion(), Muon(), # A few chained combinations: OptimiserChain(SignDecay(0.001), Adam(0.001)), OptimiserChain(ClipNorm(), Adam(0.001)), @@ -183,7 +183,7 @@ end # The Flux PR had 1e-2 for all. But AdaDelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too: Adam(1e-2), RMSProp(1e-2), RAdam(1e-2), OAdam(1e-2), AdaGrad(1e-2), AdaDelta(0.9, 1e-5), NAdam(1e-2), AdaBelief(1e-2), # These weren't in Flux PR: - Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2), + Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), AdamW(1e-2), ] # Our "model" is just a complex number model = (w = zeros(ComplexF64, 1),) @@ -226,7 +226,7 @@ end @test static_loss(static_model) < last_loss last_loss = static_loss(static_model) end - @test static_loss(static_model) < 1.9 + @test static_loss(static_model) < 1.9 end end @@ -254,16 +254,16 @@ end g1 = rand(5) tree, x1 = Optimisers.update(tree, x, g1) @test x1 ≈ x - @test x1 ≈ x0 + @test x1 ≈ x0 g2 = rand(5) tree, x2 = Optimisers.update(tree, x1, g2) @test x2 ≈ x - @test x2 ≈ x0 + @test x2 ≈ x0 g3 = rand(5) tree, x3 = Optimisers.update(tree, x2, g3) @test x3 ≈ x0 .- lr .* (g1 .+ g2 .+ g3) ./ 3 g4 = rand(5) - + tree, x4 = Optimisers.update(tree, x3, g4) @test x4 ≈ x3 end From 58a9bef5c89db670f78f68186fdf6c0fd8a3f6b3 Mon Sep 17 00:00:00 2001 From: murrellb Date: Sat, 21 Dec 2024 15:38:05 +0100 Subject: [PATCH 2/5] Muon with fallback --- src/rules.jl | 153 ++++++++++++++++++++++++--------------------------- 1 file changed, 73 insertions(+), 80 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 93149d67..d3531483 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -256,86 +256,6 @@ function apply!(o::Lion, state, x::AbstractArray{T}, dx) where T return state, dx′ end -""" - Muon(η = 0.02, ρ = 0.95; steps = 5) - Muon(; [eta, rho, steps]) - -Muon - MomentUm Orthogonalized by Newton-schulz - -Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, -in which each 2D parameter's update is replaced with the nearest orthogonal matrix using Newton-Schulz iteration. - -# Parameters -- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights -- Momentum (`ρ == rho`): Controls the acceleration of gradient descent in the prominent direction -- Steps: Number of Newton-Schulz iteration steps for orthogonalization - -Note: This optimizer only acts on arrays with 2 or more dimensions (matrices, tensors). -Parameters with fewer dimensions are silently ignored. Works best with large batch sizes -and may not be suitable for fine-tuning. -""" -@def struct Muon <: AbstractRule - eta = 0.02 - rho = 0.95 - steps = 5 -end - -function init(o::Muon, x::AbstractArray) - ndims(x) < 2 ? nothing : zero(x) -end - -function apply!(o::Muon, state, x::AbstractArray{T}, dx) where T - # Silently pass through 1D arrays unchanged - if ndims(x) < 2 - return nothing, dx - end - - η, ρ = T(o.eta), T(o.rho) - - # Update momentum buffer - @.. state = ρ * state + dx - - # For higher dimensional tensors, reshape to matrix, orthogonalize, then reshape back - original_size = size(state) - if ndims(state) > 2 - state_mat = reshape(state, size(state,1), :) - dx_orth = _newton_schulz_orthogonalize(state_mat, o.steps) - dx_orth = reshape(dx_orth, original_size) - else - dx_orth = _newton_schulz_orthogonalize(state, o.steps) - end - - # Scale based on matrix dimensions - scale = sqrt(max(1, size(dx_orth,1)/size(dx_orth,2))) - dx′ = @lazy η * scale * dx_orth - - return state, dx′ -end - -# _newton_schulz_orthogonalize remains unchanged -function _newton_schulz_orthogonalize(G::AbstractMatrix, steps::Int) - a, b, c = (3.4445f0, -4.7750f0, 2.0315f0) - - X = G - X = X / (norm(X) + eps()) - - transposed = size(G, 1) > size(G, 2) - if transposed - X = X' - end - - for _ in 1:steps - A = X * X' - B = b * A + c * A * A - X = a * X + B * X - end - - if transposed - X = X' - end - - X -end """ RAdam(η = 0.001, β = (0.9, 0.999), ϵ = 1e-8) @@ -680,6 +600,79 @@ function apply!(o::AdaBelief, state, x::AbstractArray{T}, dx) where T return (mt, st, βt .* β), dx′ end +nonfirstdims(x) = prod(size(x)[2:end]) + +""" + Muon(opt = AdamW(eta = 0.0003, beta = (0.9,0.95), lambda = 0.01), η = 0.02, μ = 0.95, λ = 0.01, fallback = Returns(false)) + Muon(; [opt, eta, mu, lambda, fallback]) + +Muon - MomentUm Orthogonalized by Newton-schulz (https://github.com/KellerJordan/Muon) + +Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-processing step, +in which each 2D parameter's update is replaced with the nearest orthogonal matrix using Newton-Schulz iteration. + +# Parameters +- Fallback optimizer (`opt`): Optimizer to use for 1D parameters or when the `fallback` function returns true +- Learning rate (`η == eta`): Amount by which gradients are discounted before updating the weights +- Momentum (`μ == mu`): Controls the acceleration of gradient descent in the prominent direction +- Weight decay (`λ == lambda`): Controls the strength of ``L_2`` regularisation. +- Fallback function (`fallback`): Function to control when, in addition to 1D arrays, the fallback optimizer should be used. Will be passed the parameter array and must return a boolean. + +Note: Works best with large batch sizes and may not be suitable for fine-tuning. +In nanoGPT speedrun experiments, Muon is used for the internal layer >2D weights, and AdamW is used for the 1D weights, embeddings, and heads. + +`Optimisers.adjust!(optimiser_state, η::Real)` will adjust the fallback optimizer's `eta` to `η * (opt.eta / eta)`, and Muon's `eta` to `η`, preserving their ratio, +but `Optimisers.adjust!(optimiser, eta = η)` will only adjust Muon's learning rate (allowing you to adjust the fallback optimizer's learning rate separately). +""" +@def struct Muon <: AbstractRule + opt = AdamW(eta = 0.0003, beta = (0.9,0.95), lambda = 0.01) + eta = 0.02 + mu = 0.95 + lambda = 0.01 + fallback = Returns(false) +end + +function init(o::Muon, x::AbstractArray) + if nonfirstdims(x) == 1 || o.fallback(x) + return init(o.opt, x) + else + return zero(x) + end +end + +function apply!(o::Muon, state, x::AbstractArray{T}, dx) where T + if nonfirstdims(x) == 1 || o.fallback(x) + return apply!(o.opt, state, x, dx) + else + η, μ, λ = T(o.eta), T(o.mu), T(o.lambda) + @.. state = μ * state + dx + Ot = _newton_schulz5(μ .* state .+ dx) + dx′ = @lazy η * (Ot + λ * x) + return state, dx′ + end +end + +function _newton_schulz5(G::AbstractMatrix{T}) where T + a, b, c = (T(3.4445f0), T(-4.7750f0), T(2.0315f0)) + X = G / (norm(G) + eps(T)) + transposed = size(G, 1) > size(G, 2) + if transposed + X = X' + end + for _ in 1:5 + A = X * X' + B = b * A + c * A * A + X = a * X + B * X + end + if transposed + X = X' + end + X +end +_newton_schulz5(G::AbstractArray) = reshape(_newton_schulz5(reshape(G, size(G,1), :)), size(G)) + +adjust(r::Muon, η::Real) = adjust(r, eta = η, opt = adjust(r.opt, eta = (r.opt.eta / r.eta) * η)) + """ WeightDecay(λ = 5e-4) WeightDecay(; [lambda]) From dec3d762c0f068c5888d74790fd7718884ebd506 Mon Sep 17 00:00:00 2001 From: murrellb Date: Sat, 21 Dec 2024 16:54:21 +0100 Subject: [PATCH 3/5] Adding in a scaling factor that was missed --- src/rules.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rules.jl b/src/rules.jl index d3531483..6258d791 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -646,7 +646,7 @@ function apply!(o::Muon, state, x::AbstractArray{T}, dx) where T else η, μ, λ = T(o.eta), T(o.mu), T(o.lambda) @.. state = μ * state + dx - Ot = _newton_schulz5(μ .* state .+ dx) + Ot = _newton_schulz5(μ .* state .+ dx) * T(sqrt(max(1, size(x,1)/nonfirstdims(x)))) dx′ = @lazy η * (Ot + λ * x) return state, dx′ end From ca00d5f1070f7f4f5ab2dfe53eae60c6aa0b5249 Mon Sep 17 00:00:00 2001 From: murrellb Date: Sat, 21 Dec 2024 18:58:40 +0100 Subject: [PATCH 4/5] Fixing fallback+constructor --- src/rules.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index 6258d791..d9aecdb2 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -624,14 +624,16 @@ In nanoGPT speedrun experiments, Muon is used for the internal layer >2D weights `Optimisers.adjust!(optimiser_state, η::Real)` will adjust the fallback optimizer's `eta` to `η * (opt.eta / eta)`, and Muon's `eta` to `η`, preserving their ratio, but `Optimisers.adjust!(optimiser, eta = η)` will only adjust Muon's learning rate (allowing you to adjust the fallback optimizer's learning rate separately). """ -@def struct Muon <: AbstractRule - opt = AdamW(eta = 0.0003, beta = (0.9,0.95), lambda = 0.01) - eta = 0.02 - mu = 0.95 - lambda = 0.01 - fallback = Returns(false) +struct Muon <: AbstractRule + opt::AbstractRule + eta::Float64 + mu::Float64 + lambda::Float64 + fallback::Function end +Muon(;opt = AdamW(eta = 0.0003, beta = (0.9,0.95), lambda = 0.01), eta = 0.02, mu = 0.95, lambda = 0.01, fallback = x -> false) = Muon(opt, eta, mu, lambda, fallback) + function init(o::Muon, x::AbstractArray) if nonfirstdims(x) == 1 || o.fallback(x) return init(o.opt, x) From cbb02be34999a8dac6f1addfc6dd3762f6e7827e Mon Sep 17 00:00:00 2001 From: murrellb Date: Sat, 21 Dec 2024 21:54:57 +0100 Subject: [PATCH 5/5] Tweaking Newton Schulz --- src/rules.jl | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/src/rules.jl b/src/rules.jl index d9aecdb2..db74e8ed 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -654,22 +654,22 @@ function apply!(o::Muon, state, x::AbstractArray{T}, dx) where T end end +function _inner_newton_schulz5(X::AbstractMatrix{T}) where T + a, b, c = (T(3.4445f0), T(-4.7750f0), T(2.0315f0)) + for _ in 1:5 + A = X * X' + B = b * A + c * A * A + X = a * X + B * X + end + X +end function _newton_schulz5(G::AbstractMatrix{T}) where T - a, b, c = (T(3.4445f0), T(-4.7750f0), T(2.0315f0)) X = G / (norm(G) + eps(T)) - transposed = size(G, 1) > size(G, 2) - if transposed - X = X' - end - for _ in 1:5 - A = X * X' - B = b * A + c * A * A - X = a * X + B * X - end - if transposed - X = X' + if size(G, 1) > size(G, 2) + transpose(_inner_newton_schulz5(transpose(X))) + else + _inner_newton_schulz5(X) end - X end _newton_schulz5(G::AbstractArray) = reshape(_newton_schulz5(reshape(G, size(G,1), :)), size(G))