Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test Flux 0.13 #699

Merged
merged 2 commits into from
Apr 21, 2022
Merged

test Flux 0.13 #699

merged 2 commits into from
Apr 21, 2022

Conversation

ChrisRackauckas
Copy link
Member

No description provided.

@ChrisRackauckas
Copy link
Member Author

MWE of the stuff ODE solver problem:

using Flux, Zygote, ForwardDiff
import ForwardDiff: Dual

y = Float32[1.2449092, 0.26629877]
p = Float32[0.14421135, -0.006150621, 0.0393358, 0.138404, 0.23749629, 0.06463469, -0.029445898, -0.33279192, -0.094798535, -0.257304, -0.3355695, -0.1959481, 0.12938745, 0.14058144, 0.32916018, -0.23945713, -0.18813372, -0.14978944, 0.18167028, -0.22040617, -0.16580728, -0.09962158, -0.12878253, 0.24638167, -0.03310824, 0.07440266, 0.03885393, 0.27210253, -0.053823117, -0.14623246, -0.034661364, 0.049675502, 0.16398363, -0.30591217, 0.18999895, -0.26469624, 0.28702003, 0.20897748, -0.32785562, -0.100942954, -0.32169065, 0.21481845, 0.09703442, 0.30915034, 0.09057236, -0.15546058, -0.24163458, -0.13516225, -0.06676043, -0.1966813, 0.12077151, 0.056194287, -0.16526969, -0.2222915, -0.19672059, -0.034455374, -0.24578816, 0.18768719, -0.23405759, 0.046496972, -0.258523, 0.058912445, 0.042145796, -0.13487151, -0.2644665, -0.33397835, -0.29189992, 0.13996881, -0.21306355, 0.15383047, 0.15763333, -0.27050394, 0.3312636, 0.32032087, -0.24478982, -0.1096856, 0.12329024, -0.33420125, -0.1529397, 0.013263283, -0.0321317, -0.28141057, -0.058830447, -0.033951838, -0.18657157, 0.20016932, 0.1548164, 0.028861579, 0.16291597, 0.22635445, -0.090969354, 0.1766979, -0.31983075, 0.07219792, 0.23401073, -0.07207494, 0.24587327, 0.26736307, -0.23342982, 0.08657169, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.19381283, 0.26416284, 0.15891802, -0.12095787, 0.1411279, -0.027013905, 0.30863065, 0.122840405, 0.21558835, -0.15353091, 0.29037133, 0.15146789, 0.045571294, 0.31934103, -0.013038424, 0.19338936, -0.08277726, 0.23436436, -0.25519383, 0.19290958, -0.07854137, 0.25316665, 0.068068326, -0.25551397, 0.29491913, -0.17783256, -0.28973922, -0.102145046, -0.052218888, -0.14526269, -0.20289962, -0.22463948, 0.24003603, -0.22635874, 0.22355433, -0.10727884, -0.27763215, 0.12205175, -0.33481315, 0.04747853, 0.22055429, 0.017725615, -0.14218004, -0.27020591, 0.27612484, 0.050210662, 0.041809335, -0.032814298, -0.21339102, 0.22898024, 0.0755317, -0.23465283, -0.109813884, -0.18060842, -0.066495314, 0.22580191, 0.3323758, 0.023281226, 0.07484222, 0.28912178, -0.27487472, 0.121484526, -0.2651789, 0.19090225, -0.003508792, -0.25500044, 0.05072003, -0.07643754, 0.24113968, 0.12844749, 0.24001858, 0.2613778, -0.2603248, 0.08254892, -0.111656696, 0.23785193, 0.32324004, 0.1750177, -0.09340208, -0.12355742, -0.25986317, -0.27915004, 0.07588966, 0.25872853, 0.21791716, 0.2401611, -0.2407115, -0.23268251, -0.30390444, -0.3009561, -0.02586944, 0.16676147, -0.110212825, -0.17888871, 0.33321387, -0.32094577, -0.25499186, 0.25705588, 0.15148534, -0.28999805, 0.0, 0.0]
t = 1.5f0
λ = Dual{ForwardDiff.Tag{OrdinaryDiffEq.OrdinaryDiffEqTag,Float32},Float32,12}[Dual{ForwardDiff.Tag{OrdinaryDiffEq.OrdinaryDiffEqTag,Float32}}(0.09447026, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), Dual{ForwardDiff.Tag{OrdinaryDiffEq.OrdinaryDiffEqTag,Float32}}(1.4116058, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)]

model = Chain(x -> x .^ 3,
    Dense(2, 50, tanh),
    Dense(50, 2))
p, re = Flux.destructure(model)
f(u, p, t) = re(p)(u)

_dy, back = Zygote.pullback(y, p) do u, p
    vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)

Found via:

using DiffEqFlux, OrdinaryDiffEq, Test

u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end
t = range(tspan[1], tspan[2], length=datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))

model = Chain(x -> x .^ 3,
    Dense(2, 50, tanh),
    Dense(50, 2))
neuralde = NeuralODE(model, tspan, Rodas5(), saveat=t, reltol=1e-7, abstol=1e-9)

function predict_n_ode()
    neuralde(u0)
end
loss_n_ode() = sum(abs2, ode_data .- predict_n_ode())

data = Iterators.repeated((), 10)
opt = ADAM(0.1)
cb = function () #callback function to observe training
    display(loss_n_ode())
end

# Display the ODE with the initial parameter values.
cb()

neuralde = NeuralODE(model, tspan, Rodas5(), saveat=t, reltol=1e-7, abstol=1e-9)
ps = Flux.params(neuralde)
loss1 = loss_n_ode()

xx = Ref{Any}()

Flux.train!(loss_n_ode, ps, data, opt, cb=cb)

with the change:

function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::ZygoteVJP, dgrad, dy, W) where TS<:SensitivityFunction
  @unpack sensealg, f = S
  prob = getprob(S)

  isautojacvec = get_jacvec(sensealg)
  if inplace_sensitivity(S)
    if W===nothing
      _dy, back = Zygote.pullback(y, p) do u, p
        out_ = Zygote.Buffer(similar(u))
        f(out_, u, p, t)
        vec(copy(out_))
      end
    else
      _dy, back = Zygote.pullback(y, p) do u, p
        out_ = Zygote.Buffer(similar(u))
        f(out_, u, p, t, W)
        vec(copy(out_))
      end
    end
    tmp1,tmp2 = back(λ)
    dλ[:] .= vec(tmp1)
    dgrad !== nothing && tmp2 !== nothing && (dgrad[:] .= vec(tmp2))
    dy !== nothing && (dy[:] .= vec(_dy))
  else
    if W===nothing
      _dy, back = Zygote.pullback(y, p) do u, p
        vec(f(u, p, t))
      end
    else
      _dy, back = Zygote.pullback(y, p) do u, p
        vec(f(u, p, t, W))
      end
    end
    Main.xx[] = y,p,t,λ
    tmp1, tmp2 = back(λ)
    tmp1 !== nothing && (dλ[:] .= vec(tmp1))
    dy !== nothing && (dy[:] .= vec(_dy))
    dgrad !== nothing && tmp2 !== nothing && (dgrad[:] .= vec(tmp2))
  end
  return
end

ProjectManifest.zip

@ChrisRackauckas
Copy link
Member Author

using ForwardDiff, Zygote, Flux
using ForwardDiff: Dual
y = Float32[0.8564646, 0.21083355]
p = Float32[-0.2548858, -0.264061, 0.06902494, -0.23288882, -0.13166176, 0.25982612, -0.26543534, -0.29349443, 0.31963557, 0.21243489, -0.2755482, -0.04317024, 0.2678376, -0.32618907, -0.11215708, -0.20082082, -0.075056225, -0.3250112, -0.20113565, -0.2580761, 0.03797583, -0.1354496, 0.18161258, 0.3180589, 0.283674, 0.05116003, -0.07082515, 0.12914972, 0.09830813, 0.29125124, 0.32423735, 0.045021717, 0.09604585, 0.007445923, 0.12431481, 0.063025564, 0.30161184, 0.23123802, 0.30304855, -0.18616274, 0.06983177, 0.13229537, 0.26679033, 0.29119095, 0.2044387, -0.1310391, 0.06418764, -0.05145624, 0.28958446, 0.08143681, -0.26594874, 0.258198, -0.16387275, 0.23627394, -0.0025739619, 0.12877232, 0.28468516, 0.14945742, -0.09824067, 0.22391124, 0.2722607, 0.034997866, 0.021131594, -0.058169674, -0.20168333, 0.3310362, 0.29977754, 0.27228144, 0.088294245, 0.17472656, 0.030819716, 0.27218765, 0.042448767, 0.25967237, 0.18181679, 0.2810931, -0.16689181, 0.17927635, 0.32586476, -0.25481033, 0.009913104, 0.20943141, -0.13506782, -0.30059853, -0.084571846, -0.31261674, 0.11608189, 0.084546946, -0.21448077, -0.19288287, -0.22511461, 0.27675447, 0.26279518, 0.061226156, -0.2828123, -0.1394083, -0.16996919, 0.2784961, -0.0039018209, -0.1362619, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.048874624, 0.11889865, 0.01040518, 0.12694769, 0.32327807, 0.13581258, 0.10043003, -0.12258695, 0.32029858, -0.05385616, -0.28262973, -0.29426816, 0.11472986, 0.014853499, -0.055616893, -0.24432188, -0.23522359, -0.07780609, 0.16605335, 0.29451388, -0.32305816, 0.03262463, -0.28862894, 0.054972157, 0.2411704, 0.31518432, 0.2221482, -0.12357236, 0.25466782, 0.03921116, -0.087710164, 0.1594814, -0.33685195, -0.13411506, 0.04239876, 0.260748, 0.15104404, 0.24697773, -0.06698533, -0.039195247, 0.29528958, -0.19330974, -0.32768622, 0.07959501, -0.11285911, -0.031941384, -0.108291335, -0.24830729, -0.08987814, -0.04234308, 0.255426, 0.3337179, 0.18690939, -0.32503495, -0.06603645, -0.17818044, 0.10007081, -0.22569874, 0.030490262, -0.014429291, 0.13864784, 0.100892544, -0.28683808, 0.05345175, -0.12727126, 0.31637886, 0.27381366, 0.026415939, 0.20263642, 0.33452004, -0.3351626, 0.0063842274, -0.26546854, -0.24439275, -0.19636214, 0.3032137, 0.13219267, 0.20853092, -0.05988348, -0.30968776, -0.1278926, 0.33035672, -0.32249796, 0.14322737, -0.29625347, -0.17458698, -0.0010983021, 0.14215776, -0.07308902, -0.19241002, 0.1702171, 0.32165667, 0.27042934, 0.068846, 0.19114906, 0.06528145, -0.31603774, 0.049985882, -0.05847536, 0.04034526, 0.0, 0.0]
t = 1.5f0
λ = ForwardDiff.Dual{ForwardDiff.Tag{Nothing,Float32},Float32,12}[Dual{ForwardDiff.Tag{Nothing,Float32}}(0.87135935, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0), Dual{ForwardDiff.Tag{Nothing,Float32}}(1.5225363, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0)]

model = Chain(x -> x .^ 3,
    Dense(2, 50, tanh),
    Dense(50, 2))

p,re = Flux.destructure(model)
f(u, p, t) = re(p)(u)
_dy, back = Zygote.pullback(y, p) do u, p
    vec(f(u, p, t))
end
tmp1, tmp2 = back(λ)

@ChrisRackauckas
Copy link
Member Author

Should work once FluxML/Optimisers.jl#65

@ChrisRackauckas
Copy link
Member Author

SciML/NeuralPDE.jl#508 and SciML/DeepEquilibriumNetworks.jl#44 are dependent on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant