Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap Boris solution into struct #124

Merged
merged 3 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/examples/basics/demo_boris.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# ---
# title: Boris method
# id: demo_boris
# date: 2023-11-13
# date: 2024-01-23
# author: "[Hongyang Zhou](https://github.com/henry2004y)"
# julia: 1.10.0
# description: Simple electron trajectory under uniform B and zero E
Expand Down Expand Up @@ -46,7 +46,7 @@ sol2 = solve(prob, Tsit5());

f = Figure(size=(700, 600))
ax = Axis(f[1, 1], aspect=1)
@views lines!(ax, traj[1][1,:], traj[1][2,:], label="Boris")
@views lines!(ax, traj[1].u[1,:], traj[1].u[2,:], label="Boris")
lines!(ax, sol1, linestyle=:dashdot, label="Tsit5 fixed")
lines!(ax, sol2, linestyle=:dot, label="Tsit5 adaptive")
axislegend(position=:lt, framevisible=false)
Expand Down
6 changes: 4 additions & 2 deletions docs/examples/basics/demo_ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ ax = Axis3(f[1, 1],
)

for i in eachindex(trajs)
@views lines!(ax, trajs[i][1,:], trajs[i][2,:], trajs[i][3,:], label="$i")
@views lines!(ax, trajs[i].u[1,:], trajs[i].u[2,:], trajs[i].u[3,:], label="$i")
end

f = DisplayAs.PNG(f) #hide
f = DisplayAs.PNG(f) #hide

# You may notice that the Boris outputs are more "discontinuous" than `Tsit5`. This is because algorithms in OrdinaryDiffEq.jl come with "free" interpolation schemes automatically applied for visualization, while we have not yet implemented this for the native Boris method.
2 changes: 1 addition & 1 deletion docs/examples/basics/demo_proton_dipole.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ dt = 1e-4
paramBoris = BorisMethod(param)
prob = TraceProblem(stateinit, tspan, dt, paramBoris)
traj = trace_trajectory(prob)
get_energy_ratio(traj[1])
get_energy_ratio(traj[1].u)

# The Boris method requires a fixed time step. It takes about 0.05s and consumes 53 MiB memory. In this specific case, the time step is determined empirically. If we increase the time step to `1e-2` seconds, the trajectory becomes completely off (but the energy is still conserved).
# Therefore, as a rule of thumb, we should not use the default `Tsit5()` scheme without decreasing `reltol`. Use adaptive `Vern9()` for an unfamiliar field configuration, then switch to more accurate schemes if needed. A more thorough test can be found [here](https://github.com/henry2004y/TestParticle.jl/issues/73).
2 changes: 1 addition & 1 deletion src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@
dt = 0.5
paramBoris = BorisMethod(param)
prob = TraceProblem(stateinit, tspan, dt, paramBoris)
traj = trace_trajectory(prob; savestepinterval=100)
sol = trace_trajectory(prob; savestepinterval=100)
end
end
32 changes: 20 additions & 12 deletions src/pusher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
u0::TS
tspan::Tuple{T, T}
dt::T
param::TP
p::TP
prob_func::PF
end

struct TraceSolution{TU<:Array, T<:AbstractVector}
u::TU
t::T
end

DEFAULT_PROB_FUNC(prob, i, repeat) = prob

function TraceProblem(u0, tspan, dt, param; prob_func=DEFAULT_PROB_FUNC)
TraceProblem(u0, tspan, dt, param, prob_func)
function TraceProblem(u0, tspan, dt, p; prob_func=DEFAULT_PROB_FUNC)
TraceProblem(u0, tspan, dt, p, prob_func)
end

struct BorisMethod{T, TV}
Expand Down Expand Up @@ -105,11 +110,11 @@
function trace_trajectory(prob::TraceProblem; trajectories::Int=1,
savestepinterval::Int=1, isoutofdomain::Function=ODE_DEFAULT_ISOUTOFDOMAIN)

trajs = Vector{Array{eltype(prob.u0),2}}(undef, trajectories)
sols = Vector{TraceSolution}(undef, trajectories)

for i in 1:trajectories
new_prob = prob.prob_func(prob, i, false)
(; u0, tspan, dt, param) = new_prob
(; u0, tspan, dt, p) = new_prob
xv = copy(u0)
# prepare advancing
ttotal = tspan[2] - tspan[1]
Expand All @@ -120,10 +125,10 @@
traj[:,1] = xv

# push velocity back in time by 1/2 dt
update_velocity!(xv, param, -0.5*dt)
update_velocity!(xv, p, -0.5*dt)

for it in 1:nt
update_velocity!(xv, param, dt)
update_velocity!(xv, p, dt)
update_location!(xv, dt)
if it % savestepinterval == 0
iout += 1
Expand All @@ -133,18 +138,21 @@
end

if iout == nout # regular termination
# final step if needed
dtfinal = ttotal - nt*dt
if dtfinal > 1e-3
update_velocity!(xv, param, dtfinal)
if dtfinal > 1e-3 # final step if needed
update_velocity!(xv, p, dtfinal)

Check warning on line 143 in src/pusher.jl

View check run for this annotation

Codecov / codecov/patch

src/pusher.jl#L143

Added line #L143 was not covered by tests
update_location!(xv, dtfinal)
traj = hcat(traj, xv)
t = [collect(tspan[1]:dt:tspan[2])..., tspan[2]]

Check warning on line 146 in src/pusher.jl

View check run for this annotation

Codecov / codecov/patch

src/pusher.jl#L146

Added line #L146 was not covered by tests
else
t = collect(tspan[1]:dt:tspan[2])
end
else # early termination
traj = traj[:, 1:iout]
t = collect(range(tspan[1], tspan[1]+iout*dt, step=dt))

Check warning on line 152 in src/pusher.jl

View check run for this annotation

Codecov / codecov/patch

src/pusher.jl#L152

Added line #L152 was not covered by tests
end
trajs[i] = traj
sols[i] = TraceSolution(traj, t)
end

trajs
sols
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ end

traj = trace_trajectory(prob; savestepinterval=10)

@test traj[1][:, end] == [-7.84237771267459e-5, 5.263661571564935e-5, 0.0,
@test traj[1].u[:, end] == [-7.84237771267459e-5, 5.263661571564935e-5, 0.0,
-93512.6374393526, -35431.43574759836, 0.0]
end
end
Expand Down
Loading