Skip to content

Commit

Permalink
use Parameter set
Browse files Browse the repository at this point in the history
  • Loading branch information
joaquimg committed Jan 14, 2025
1 parent 91865cf commit 9ec668b
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 70 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ optimize!(model)

# differentiate w.r.t. p
direction_p = 3.0
MOI.set(model, DiffOpt.ForwardConstraintSet(), ParameterRef(p), direction_p)
MOI.set(model, DiffOpt.ForwardConstraintSet(), ParameterRef(p), Parameter(direction_p))
DiffOpt.forward_differentiate!(model)
@show MOI.get(model, DiffOpt.ForwardVariablePrimal(), x) == direction_p * 3 / pc_val

Expand All @@ -82,7 +82,7 @@ optimize!(model)
DiffOpt.empty_input_sensitivities!(model)
# differentiate w.r.t. pc
direction_pc = 10.0
MOI.set(model, DiffOpt.ForwardConstraintSet(), ParameterRef(pc), direction_pc)
MOI.set(model, DiffOpt.ForwardConstraintSet(), ParameterRef(pc), Parameter(direction_pc))
DiffOpt.forward_differentiate!(model)
@show abs(MOI.get(model, DiffOpt.ForwardVariablePrimal(), x) -
-direction_pc * 3 * p_val / pc_val^2) < 1e-5
Expand All @@ -93,8 +93,8 @@ DiffOpt.empty_input_sensitivities!(model)
direction_x = 10.0
MOI.set(model, DiffOpt.ReverseVariablePrimal(), x, direction_x)
DiffOpt.reverse_differentiate!(model)
@show MOI.get(model, DiffOpt.ReverseConstraintSet(), ParameterRef(p)) == direction_x * 3 / pc_val
@show abs(MOI.get(model, DiffOpt.ReverseConstraintSet(), ParameterRef(pc)) -
@show MOI.get(model, DiffOpt.ReverseConstraintSet(), ParameterRef(p)) == MOI.Parameter(direction_x * 3 / pc_val)
@show abs(MOI.get(model, DiffOpt.ReverseConstraintSet(), ParameterRef(pc)).value -
-direction_x * 3 * p_val / pc_val^2) < 1e-5
```

Expand Down
2 changes: 1 addition & 1 deletion src/DiffOpt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ include("utils.jl")
include("product_of_sets.jl")
include("diff_opt.jl")
include("moi_wrapper.jl")
include("jump_moi_overloads.jl")
include("parameters.jl")
include("jump_moi_overloads.jl")

include("copy_dual.jl")
include("bridges.jl")
Expand Down
31 changes: 31 additions & 0 deletions src/jump_moi_overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,37 @@ function MOI.get(
return _moi_get_result(JuMP.backend(model), attr, JuMP.index(var_ref))
end

# extras to handle model_dirty

function MOI.get(
model::JuMP.Model,
attr::ReverseConstraintSet,
var_ref::JuMP.ConstraintRef,
)
JuMP.check_belongs_to_model(var_ref, model)
return _moi_get_result(JuMP.backend(model), attr, JuMP.index(var_ref))
end

function MOI.set(
model::JuMP.Model,
attr::ForwardConstraintSet,
con_ref::JuMP.ConstraintRef,
set::MOI.AbstractScalarSet,
)
JuMP.check_belongs_to_model(con_ref, model)
return MOI.set(JuMP.backend(model), attr, JuMP.index(con_ref), set)
end

function MOI.set(
model::JuMP.Model,
attr::ForwardConstraintSet,
con_ref::JuMP.ConstraintRef,
set::JuMP.AbstractScalarSet,
)
JuMP.check_belongs_to_model(con_ref, model)
return MOI.set(JuMP.backend(model), attr, JuMP.index(con_ref), JuMP.moi_set(set))
end

"""
abstract type AbstractLazyScalarFunction <: MOI.AbstractScalarFunction end
Expand Down
19 changes: 5 additions & 14 deletions src/parameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,14 +308,14 @@ function MOI.set(
model::POI.Optimizer,
::ForwardConstraintSet,
ci::MOI.ConstraintIndex{MOI.VariableIndex,MOI.Parameter{T}},
value::Number,
set::MOI.Parameter,
) where {T}
variable = MOI.VariableIndex(ci.value)
if _is_variable(model, variable)
error("Trying to set a forward parameter sensitivity for a variable")
end
sensitivity_data = _get_sensitivity_data(model)
sensitivity_data.parameter_input_forward[variable] = value
sensitivity_data.parameter_input_forward[variable] = set.value
return
end

Expand Down Expand Up @@ -573,16 +573,7 @@ function MOI.get(
error("Trying to get a backward parameter sensitivity for a variable")
end
sensitivity_data = _get_sensitivity_data(model)
return get(sensitivity_data.parameter_output_backward, variable, 0.0)
end

# extras to handle model_dirty

function MOI.get(
model::JuMP.Model,
attr::ReverseConstraintSet,
var_ref::JuMP.ConstraintRef,
)
JuMP.check_belongs_to_model(var_ref, model)
return _moi_get_result(JuMP.backend(model), attr, JuMP.index(var_ref))
return MOI.Parameter{T}(
get(sensitivity_data.parameter_output_backward, variable, 0.0),
)
end
Loading

0 comments on commit 9ec668b

Please sign in to comment.