Skip to content

Re-allocate Buffer if the incoming eltypes don't match preallocated one #65

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Apr 8, 2022

Currently, the flat vector is stored inside the Restructure struct. It therefore assumes that incoming parameters also match the same eltype of the model. This will fail in mixed-mode AD, using struct-of-arrays, etc. To combat that, try to reallocate a buffer that can hold in the actual new parameters properly. Note that with mixed-precision, we also pay for conversion (and therefore allocation) with every operation. cc @ChrisRackauckas MWE:

using DiffEqFlux, OrdinaryDiffEq, Test

u0 = Float32[2.0; 0.0]
             datasize = 30
             tspan = (0.0f0, 1.5f0)

function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u .^ 3)'true_A)'
end
t = range(tspan[1], tspan[2], length=datasize)
prob = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob, Tsit5(), saveat=t))

model = Chain(x -> x .^ 3,
              Dense(2, 50, tanh),
              Dense(50, 2))
neuralde = NeuralODE(model, tspan, Rodas5(), saveat=t, reltol=1e-7, abstol=1e-9)

function predict_n_ode()
  neuralde(u0)
end
loss_n_ode() = sum(abs2, ode_data .- predict_n_ode())

data = Iterators.repeated((), 10)
opt = ADAM(0.1)
cb = function () #callback function to observe training
   display(loss_n_ode())
end

# Display the ODE with the initial parameter values.
cb()

neuralde = NeuralODE(model, tspan, Rodas5(), saveat=t, reltol=1e-7, abstol=1e-9)
ps = Flux.params(neuralde)
loss1 = loss_n_ode()
Flux.train!(loss_n_ode, ps, data, opt, cb=cb)

It might be good to update the actual struct else it would lie about the actual contents of the parameters. This would mean making Restructure mutable.

This still needs tests before merging; and some test failures are expected since we have to accumulate the gradients properly still

@DhairyaLGandhi
Copy link
Member Author

Of course doing it out of place is less efficient, but seems like at least some of the tests show that there was indeed some cases of implicit conversion going on. It might also be the more correct implementation since we can't assume that the types of the primal and the pullback would match always.

@ChrisRackauckas
Copy link
Member

Fixes SciML/DiffEqFlux.jl#699

@CarloLucibello
Copy link
Member

Solved in #66

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants