Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Excessive allocations with VectorOfArray in broadcast #199

Open
eschnett opened this issue May 15, 2022 · 5 comments
Open

Excessive allocations with VectorOfArray in broadcast #199

eschnett opened this issue May 15, 2022 · 5 comments

Comments

@eschnett
Copy link

I am using a VectorOfArray as state vector in DifferentialEquations. This leads to too many allocations.

using DifferentialEquations
using RecursiveArrayTools

# Set up the state vector
U = VectorOfArray([zeros(100,100,100), zeros(100,100,100)]);
rhs!(U̇, U, p, t) = U̇ .= U

# Compile all code
prob = ODEProblem(rhs!, U, (0.0, 1.0));
sol = solve(prob, RK4(); adaptive=false, dt=1.0);

# Benchmark
prob = ODEProblem(rhs!, U, (0.0, 1.0));
@time sol = solve(prob, RK4(); adaptive=false, dt=1.0);

This outputs

  1.639028 seconds (42.00 M allocations: 991.828 MiB, 18.67% gc time)

Note the very large number of allocations. It seems as if some operation was performing one small allocation per array element.

(I notice that this problem disappears if I switch to the midpoint rule.)

I tried with both Julia 1.7 and the current release branch of Julia 1.8.

@ChrisRackauckas
Copy link
Member

The profiles show that it's all in broadcast. VectorOfArray just needs a better broadcast overload.

https://github.com/SciML/RecursiveArrayTools.jl/blob/v2.26.3/src/vector_of_array.jl#L288

That for some reason seems to be allocating.

@ChrisRackauckas ChrisRackauckas changed the title Excessive allocations with VectorOfArray in ODE solver Excessive allocations with VectorOfArray in broadcast May 15, 2022
@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented May 15, 2022

Here is the offending call without the ODE solver:

using RecursiveArrayTools

u = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
uprev = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₁ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₂ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₃ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₄ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
dt = 0.5

@time @. u = uprev + (dt / 6) * (2 * (k₂ + k₃) + (k₁ + k₄))

It looks like it's triggered by having a scalar in the broadcast tree.

@Moelf
Copy link

Moelf commented May 15, 2022

Does ArrayOfArrays.jl have a better specialization?

@mipals
Copy link

mipals commented May 15, 2022

Switching to

@time @. u = uprev + (dt / 6.0) * (2.0 * (k₂ + k₃) + (k₁ + k₄))

Seem to speedup things quite a bit? So there might be some type instability hiding somewhere?

Edit: Rewriting the expression as follows is even better (and does not result in any type instabilities)

@time @. u = uprev + (dt/3) * ((k₂ + k₃) + (k₁ + k₄)/2)

@chengchingwen
Copy link

chengchingwen commented May 15, 2022

It seems to be a problem when scalar appears and the broadcasted is too nested:

using RecursiveArrayTools

u = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
uprev = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₁ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₂ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₃ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
k₄ = VectorOfArray([zeros(100, 100, 100), zeros(100, 100, 100)]);
dt = 0.5

A = Broadcast.broadcasted(*, 2, Broadcast.broadcasted(+, k₂, k₃))
B = Broadcast.broadcasted(+, k₁, k₄)
C = Broadcast.broadcasted(+, A, B)
C2 = Broadcast.broadcasted(identity, C)

julia> @time Broadcast.materialize!(u, C);
  0.006494 seconds

julia> @time Broadcast.materialize!(u, C2);
  2.928890 seconds (60.00 M allocations: 1.103 GiB, 5.41% gc time)

and with Cthulhu it can be found the type instabilities are from base/broadcast.jl#403 and base/broadcast.jl#380:

julia> @descend Broadcast.materialize!(buf, C2)
[...]
(::Base.Broadcast.var"#23#24")(head, tail::Vararg{Any, N}) where N in Base.Broadcast at broadcast.jl:403
Body::Tuple{Vararg{Float64}}
404 1%1 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"()), tail)::Tuple{Vararg{Float64}}
Select a call to descend into or  to ascend. [q]uit. [b]ookmark.
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [e]ffects, [i]nlining costs, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
 • %1  = call #23(::Float64,::Float64...)::Tuple{Vararg{Float64}}

[...]

(::Base.Broadcast.var"#16#18")(args::Vararg{Any, N}) where N in Base.Broadcast at broadcast.jl:380
Body::Tuple{Float64, Vararg{Float64}}
381 1%1 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#13#14"{Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}}(Base.Broadcast.var"#13#14"{Base.Broadcast.var"#15#17"}(Base.Broadcast.var"#15#17"())), args)::Tuple{Float64, Float64, Vararg{Float64}}%3 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#23#24"{Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}}(Base.Broadcast.var"#23#24"{Base.Broadcast.var"#25#26"}(Base.Broadcast.var"#25#26"())), %1)::Tuple{Vararg{Float64}}%4 = Core._apply_iterate(Base.iterate, Base.Broadcast.var"#15#17"(), %3)::Tuple{Vararg{Float64}}           │ 
    │   %9 = Core._apply_iterate(Base.iterate, Core.tuple, %8, %4)::Tuple{Float64, Vararg{Float64}}                │ 
Select a call to descend into or  to ascend. [q]uit. [b]ookmark.
Toggles: [o]ptimize, [w]arn, [h]ide type-stable statements, [d]ebuginfo, [r]emarks, [e]ffects, [i]nlining costs, [t]ype annotations, [s]yntax highlight for Source/LLVM/Native.
Show: [S]ource code, [A]ST, [T]yped code, [L]LVM IR, [N]ative code
Actions: [E]dit source code, [R]evise and redisplay
Advanced: dump [P]arams cache.
 • %1  = call #13(::Float64,::Float64...)::Tuple{Float64, Float64, Vararg{Float64}}
   %2  = call #19(::Float64,::Float64,::Float64...)::Tuple{Float64, Float64}
   %3  = call #23(::Float64,::Float64,::Float64...)::Tuple{Vararg{Float64}}
   %4  = call #15(::Float64...)::Tuple{Vararg{Float64}}
   

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants