Skip to content

Commit 9f39581

Browse files
Merge pull request #50 from SciML/as/inplace-getp
feat: add support for inplace getp
2 parents 46070b2 + 8befb01 commit 9f39581

File tree

2 files changed

+32
-5
lines changed

2 files changed

+32
-5
lines changed

src/parameter_indexing.jl

+20-5
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,15 @@ set_parameter!(sys, val, idx) = set_parameter!(parameter_values(sys), val, idx)
3535
"""
3636
getp(sys, p)
3737
38-
Return a function that takes an array representing the parameter vector or an integrator
38+
Return a function that takes an array representing the parameter object or an integrator
3939
or solution of `sys`, and returns the value of the parameter `p`. Note that `p` can be a
4040
direct index or a symbolic value, or an array/tuple of the aforementioned.
4141
42+
If `p` is an array/tuple of parameters, then the returned function can also be used
43+
as an in-place getter function. The first argument is the buffer to which the parameter
44+
values should be written, and the second argument is the parameter object/integrator/
45+
solution from which the values are obtained.
46+
4247
Requires that the integrator or solution implement [`parameter_values`](@ref). This function
4348
typically does not need to be implemented, and has a default implementation relying on
4449
[`parameter_values`](@ref).
@@ -57,8 +62,10 @@ end
5762

5863
function _getp(sys, ::ScalarSymbolic, ::SymbolicTypeTrait, p)
5964
idx = parameter_index(sys, p)
60-
return function getter(sol)
61-
return parameter_values(sol, idx)
65+
return let idx = idx
66+
function getter(sol)
67+
return parameter_values(sol, idx)
68+
end
6269
end
6370
end
6471

@@ -70,8 +77,16 @@ for (t1, t2) in [
7077
@eval function _getp(sys, ::NotSymbolic, ::$t1, p::$t2)
7178
getters = getp.((sys,), p)
7279

73-
return function getter(sol)
74-
map(g -> g(sol), getters)
80+
return let getters = getters
81+
function getter(sol)
82+
map(g -> g(sol), getters)
83+
end
84+
function getter(buffer, sol)
85+
for (i, g) in zip(eachindex(buffer), getters)
86+
buffer[i] = g(sol)
87+
end
88+
buffer
89+
end
7590
end
7691
end
7792
end

test/parameter_indexing_test.jl

+12
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,15 @@ for (sym, oldval, newval, check_inference) in [
5656
set!(p, oldval)
5757
@test get(p) == oldval
5858
end
59+
60+
for (sym, val) in [
61+
([:a, :b, :c], p),
62+
([:c, :a], p[[3, 1]]),
63+
((:b, :a), p[[2, 1]]),
64+
((1, :c), p[[1, 3]])
65+
]
66+
buffer = zeros(length(sym))
67+
get = getp(sys, sym)
68+
@inferred get(buffer, fi)
69+
@test buffer == val
70+
end

0 commit comments

Comments
 (0)