Skip to content

Commit e584a21

Browse files
fix: update InfiniteOpt extension to new semantics
1 parent 7b165fa commit e584a21

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

ext/MTKInfiniteOptExt.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -150,29 +150,28 @@ function set_jump_bounds!(model, sys, pmap)
150150
for (i, u) in enumerate(unknowns(sys))
151151
if MTK.hasbounds(u)
152152
lo, hi = MTK.getbounds(u)
153-
set_lower_bound(U[i], Symbolics.fixpoint_sub(lo, pmap))
154-
set_upper_bound(U[i], Symbolics.fixpoint_sub(hi, pmap))
153+
set_lower_bound(U[i], Symbolics.fast_substitute(lo, pmap))
154+
set_upper_bound(U[i], Symbolics.fast_substitute(hi, pmap))
155155
end
156156
end
157157

158158
V = model[:V]
159159
for (i, v) in enumerate(MTK.unbound_inputs(sys))
160160
if MTK.hasbounds(v)
161161
lo, hi = MTK.getbounds(v)
162-
set_lower_bound(V[i], Symbolics.fixpoint_sub(lo, pmap))
163-
set_upper_bound(V[i], Symbolics.fixpoint_sub(hi, pmap))
162+
set_lower_bound(V[i], Symbolics.fast_substitute(lo, pmap))
163+
set_upper_bound(V[i], Symbolics.fast_substitute(hi, pmap))
164164
end
165165
end
166166
end
167167

168168
function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free_t = false)
169-
jcosts = MTK.get_costs(sys)
170-
consolidate = MTK.get_consolidate(sys)
171-
if isnothing(jcosts) || isempty(jcosts)
169+
jcosts = cost(sys)
170+
if Symbolics._iszero(jcosts)
172171
@objective(model, Min, 0)
173172
return
174173
end
175-
jcosts = substitute_jump_vars(model, sys, pmap, jcosts; is_free_t)
174+
jcosts = substitute_jump_vars(model, sys, pmap, [jcosts]; is_free_t)[1]
176175
tₛ = is_free_t ? model[:tf] : 1
177176

178177
# Substitute integral
@@ -187,17 +186,18 @@ function add_jump_cost_function!(model::InfiniteModel, sys, tspan, pmap; is_free
187186
hi = haskey(pmap, hi) ? 1 : MTK.value(hi)
188187
intmap[int] = tₛ * InfiniteOpt.(arg, model[:t], lo, hi)
189188
end
190-
jcosts = map(c -> Symbolics.substitute(c, intmap), jcosts)
191-
@objective(model, Min, consolidate(jcosts))
189+
jcosts = Symbolics.substitute(jcosts, intmap)
190+
@objective(model, Min, MTK.value(jcosts))
192191
end
193192

194193
function add_user_constraints!(model::InfiniteModel, sys, pmap; is_free_t = false)
195-
conssys = MTK.get_constraintsystem(sys)
196-
jconstraints = isnothing(conssys) ? nothing : MTK.get_constraints(conssys)
194+
jconstraints = MTK.get_constraints(sys)
197195
(isnothing(jconstraints) || isempty(jconstraints)) && return nothing
196+
cons_dvs, cons_ps = MTK.process_constraint_system(
197+
jconstraints, Set(unknowns(sys)), parameters(sys), MTK.get_iv(sys); validate = false)
198198

199199
if is_free_t
200-
for u in MTK.get_unknowns(conssys)
200+
for u in cons_dvs
201201
x = MTK.operation(u)
202202
t = only(arguments(u))
203203
if (MTK.symbolic_type(t) === MTK.NotSymbolic())
@@ -206,7 +206,7 @@ function add_user_constraints!(model::InfiniteModel, sys, pmap; is_free_t = fals
206206
end
207207
end
208208

209-
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in unknowns(conssys)])
209+
auxmap = Dict([u => MTK.default_toterm(MTK.value(u)) for u in cons_dvs])
210210
jconstraints = substitute_jump_vars(model, sys, pmap, jconstraints; auxmap, is_free_t)
211211

212212
# Substitute to-term'd variables
@@ -235,25 +235,25 @@ function substitute_jump_vars(model, sys, pmap, exprs; auxmap = Dict(), is_free_
235235
x_ops = [MTK.operation(MTK.unwrap(st)) for st in sts]
236236
c_ops = [MTK.operation(MTK.unwrap(ct)) for ct in cts]
237237

238-
exprs = map(c -> Symbolics.fixpoint_sub(c, auxmap), exprs)
239-
exprs = map(c -> Symbolics.fixpoint_sub(c, Dict(pmap)), exprs)
238+
exprs = map(c -> Symbolics.fast_substitute(c, auxmap), exprs)
239+
exprs = map(c -> Symbolics.fast_substitute(c, Dict(pmap)), exprs)
240240
if is_free_t
241241
tf = model[:tf]
242242
free_t_map = Dict([[x(tf) => U[i](1) for (i, x) in enumerate(x_ops)];
243243
[c(tf) => V[i](1) for (i, c) in enumerate(c_ops)]])
244-
exprs = map(c -> Symbolics.fixpoint_sub(c, free_t_map), exprs)
244+
exprs = map(c -> Symbolics.fast_substitute(c, free_t_map), exprs)
245245
end
246246

247247
# for variables like x(t)
248248
whole_interval_map = Dict([[v => U[i] for (i, v) in enumerate(sts)];
249249
[v => V[i] for (i, v) in enumerate(cts)]])
250-
exprs = map(c -> Symbolics.fixpoint_sub(c, whole_interval_map), exprs)
250+
exprs = map(c -> Symbolics.fast_substitute(c, whole_interval_map), exprs)
251251

252252
# for variables like x(1.0)
253253
fixed_t_map = Dict([[x_ops[i] => U[i] for i in 1:length(U)];
254254
[c_ops[i] => V[i] for i in 1:length(V)]])
255255

256-
exprs = map(c -> Symbolics.fixpoint_sub(c, fixed_t_map), exprs)
256+
exprs = map(c -> Symbolics.fast_substitute(c, fixed_t_map), exprs)
257257
exprs
258258
end
259259

0 commit comments

Comments
 (0)