From 149f7bfc03886a5cc76db5ee6db9c92fdae94c1b Mon Sep 17 00:00:00 2001 From: RXGottlieb <36906916+RXGottlieb@users.noreply.github.com> Date: Mon, 7 Jul 2025 09:25:44 -0400 Subject: [PATCH 1/3] Minor fgen fixes - Add `var_names(::GradTransform, [...])` and `all_names` for generating symbolic subgradient variables - Fix inversion operation in division when lower and upper bounds are the same - Add more support for higher variable indices - Add error for `binarize!` to assist in debugging - Add `split_div::Bool` argument to `factor`, which determines whether terms like `x/y` will remain as-is or get factored into `x * (1/y)` - Improve variable sorting so that, e.g., `x10` will not come before `x2` - Add `extract`, which can pull a subexpression out of a primal trace and substitute out auxiliary variables - Fix `fgen` bug that would use an incorrect subgradient variable --- src/grad/grad.jl | 35 +++++++++++++- src/grad/rules.jl | 2 +- src/interval/interval.jl | 15 +++++- src/transform/binarize.jl | 1 + src/transform/factor.jl | 98 +++++++++++++++++++++++++++----------- src/transform/utilities.jl | 73 ++++++++++++++++++++++++---- src/transform/write.jl | 13 ++--- 7 files changed, 191 insertions(+), 46 deletions(-) diff --git a/src/grad/grad.jl b/src/grad/grad.jl index 632153b..9680673 100644 --- a/src/grad/grad.jl +++ b/src/grad/grad.jl @@ -1,6 +1,37 @@ -# Can remove the import statement once this is fully incorporated into SCMC -# import SourceCodeMcCormick: xstr, ystr, zstr, var_names, arity, op, transform_rule +struct GradTransform <: AbstractTransform end + +var_names(::GradTransform, a::Real) = a, a +function var_names(::GradTransform, a::BasicSymbolic) + if exprtype(a)==SYM + acvgrad = genvar(Symbol(string(get_name(a))*"_cvgrad")) + accgrad = genvar(Symbol(string(get_name(a))*"_ccgrad")) + return acvgrad.val, accgrad.val + elseif exprtype(a)==TERM + if varterm(a) && typeof(a.f)<:BasicSymbolic + arg_list = Symbol[] + for i in a.arguments + push!(arg_list, get_name(i)) + end + acvgrad = genvar(Symbol(string(get_name(a))*"_cvgrad"), arg_list) + accgrad = genvar(Symbol(string(get_name(a))*"_ccgrad"), arg_list) + return acvgrad.val, accgrad.val + else + acvgrad = genvar(Symbol(string(get_name(a))*"_cvgrad")) + accgrad = genvar(Symbol(string(get_name(a))*"_ccgrad")) + return acvgrad.val, accgrad.val + end + else + error("Reached `var_names` with an unexpected type [ADD/MUL/DIV/POW]. Check expression factorization to make sure it is being binarized correctly.") + end +end + +function all_names(a::Any) + aL, aU = var_names(IntervalTransform(), a) + acv, acc = var_names(McCormickTransform(), a) + acvgrad, accgrad = var_names(GradTransform(), a) + return aL, aU, acv, acc, acvgrad, accgrad +end """ diff --git a/src/grad/rules.jl b/src/grad/rules.jl index e34bd2b..105e477 100644 --- a/src/grad/rules.jl +++ b/src/grad/rules.jl @@ -272,7 +272,7 @@ function grad_transform!(::McCormickIntervalTransform, ::typeof(/), zL, zU, zcv, IfElse.ifelse(yU < 0.0, mid_expr(ycc, ycv, yU).^(-1), NaN)) y_cv_gradlist_inv = similar(cv_gradlist[:,y]) - @. y_cv_gradlist_inv = IfElse.ifelse(yU < 0.0, IfElse.ifelse(yU == yL, -1/(mid_expr(ycc, ycv, yL)*mid_expr(ycc, ycv, yL)), (yU^-1 - yL^-1)/(yU - yL)) * + @. y_cv_gradlist_inv = IfElse.ifelse(yU < 0.0, IfElse.ifelse(yU == yL, -1/(mid_expr(ycc, ycv, yL)*mid_expr(ycc, ycv, yL)) * mid_grad(ycc, ycv, yL, cc_gradlist[:,y], cv_gradlist[:,y], zero_vec), (yU^-1 - yL^-1)/(yU - yL)) * mid_grad(ycc, ycv, yL, cc_gradlist[:,y], cv_gradlist[:,y], zero_vec), IfElse.ifelse(yL > 0.0, -1.0/(mid_expr(ycc, ycv, yU)*mid_expr(ycc, ycv, yU)) * mid_grad(ycc, ycv, yU, cc_gradlist[:,y], cv_gradlist[:,y], zero_vec), diff --git a/src/interval/interval.jl b/src/interval/interval.jl index b1813f0..7c76bef 100644 --- a/src/interval/interval.jl +++ b/src/interval/interval.jl @@ -57,6 +57,8 @@ end get_name Take a `BasicSymbolic` object such as `x[1,1]` and return a symbol like `:xv1v1`. +Note that this supports up to 9999-indexed variables (higher than that will still +work, but the order will be wrong) """ function get_name(a::BasicSymbolic) if exprtype(a)==SYM @@ -67,7 +69,18 @@ function get_name(a::BasicSymbolic) args = a.arguments new_var = string(args[1]) for i in 2:lastindex(args) - new_var = new_var * "v" * string(args[i]) + if args[i] < 10 + new_var = new_var * "v000" * string(args[i]) + elseif args[i] < 100 + new_var = new_var * "v00" * string(args[i]) + elseif args[i] < 1000 + new_var = new_var * "v0" * string(args[i]) + elseif args[i] < 10000 + new_var = new_var * "v" * string(args[i]) + else + @warn "Index above 10000, order may be wrong" + new_var = new_var * "v" * string(args[i]) + end end return Symbol(new_var) else diff --git a/src/transform/binarize.jl b/src/transform/binarize.jl index 7b6fbb7..62deb63 100644 --- a/src/transform/binarize.jl +++ b/src/transform/binarize.jl @@ -35,3 +35,4 @@ function binarize!(ex::BasicSymbolic) return nothing end end +binarize!(a::Real) = error("Attempting to apply binarize!() to a Real") diff --git a/src/transform/factor.jl b/src/transform/factor.jl index 697b88f..23d0448 100644 --- a/src/transform/factor.jl +++ b/src/transform/factor.jl @@ -1,13 +1,14 @@ base_term(a::Any) = false base_term(a::Real) = true +base_term(a::Num) = base_term(a.val) function base_term(a::BasicSymbolic) exprtype(a)==SYM && return true exprtype(a)==TERM && return varterm(a) || (a.f==getindex) return false end -function isfactor(a::BasicSymbolic) +function isfactor(a::BasicSymbolic; split_div::Bool=false) if exprtype(a)==SYM return true elseif exprtype(a)==TERM @@ -32,7 +33,12 @@ function isfactor(a::BasicSymbolic) ~(base_term(key)) && return false end return true - elseif exprtype(a)==DIV + elseif exprtype(a)==DIV && split_div==true + ~(typeof(a.num)<:Real) && return false + ~(isone(a.num)) && return false + ~(base_term(a.den)) && return false + return true + elseif exprtype(a)==DIV && split_div==false ~(base_term(a.num)) && return false ~(base_term(a.den)) && return false return true @@ -47,13 +53,13 @@ function factor!(a...) @warn """Use of "!" is deprecated as of v0.2.0. Please call `factor()` instead.""" return factor(a...) end -factor(ex::Num) = factor(ex.val) -factor(ex::Num, eqs::Vector{Equation}) = factor(ex.val, eqs=eqs) +factor(ex::Num; split_div::Bool=false) = factor(ex.val, split_div=split_div) +factor(ex::Num, eqs::Vector{Equation}; split_div::Bool=false) = factor(ex.val, eqs=eqs, split_div=split_div) -function factor(old_ex::BasicSymbolic; eqs = Equation[]) +function factor(old_ex::BasicSymbolic; eqs = Equation[], split_div::Bool = false) ex = deepcopy(old_ex) binarize!(ex) - if isfactor(ex) + if isfactor(ex, split_div=split_div) index = findall(x -> isequal(x.rhs,ex), eqs) if isempty(index) newsym = gensym(:aux) @@ -87,12 +93,12 @@ function factor(old_ex::BasicSymbolic; eqs = Equation[]) new_terms[eqs[index[1]].lhs] = 1 end else - factor(val*key, eqs=eqs) + factor(val*key, eqs=eqs, split_div=split_div) new_terms[eqs[end].lhs] = 1 end end new_add = SymbolicUtils.Add(Real, ex.coeff, new_terms) - factor(new_add, eqs=eqs) + factor(new_add, eqs=eqs, split_div=split_div) return eqs elseif exprtype(ex)==MUL new_terms = Dict{Any, Number}() @@ -112,44 +118,80 @@ function factor(old_ex::BasicSymbolic; eqs = Equation[]) new_terms[eqs[index[1]].lhs] = 1 end else - factor(key^val, eqs=eqs) + factor(key^val, eqs=eqs, split_div=split_div) new_terms[eqs[end].lhs] = 1 end end new_mul = SymbolicUtils.Mul(Real, ex.coeff, new_terms) - factor(new_mul, eqs=eqs) + factor(new_mul, eqs=eqs, split_div=split_div) return eqs elseif exprtype(ex)==DIV - if base_term(ex.num) - new_num = ex.num - else - factor(ex.num, eqs=eqs) - new_num = eqs[end].lhs - end - if base_term(ex.den) - new_den = ex.den + if split_div + if isone(Num(ex.num)) + if base_term(ex.den) + new_den = ex.den + else + factor(ex.den, eqs=eqs, split_div=split_div) + new_den = eqs[end].lhs + end + new_div = SymbolicUtils.Div(1, new_den) + factor(new_div, eqs=eqs, split_div=split_div) + return eqs + else + new_terms = Dict{Any, Number}() + coeff = 1 + if base_term(ex.num) + if typeof(ex.num)<:Real + coeff = ex.num + else + new_terms[ex.num] = 1 + end + else + factor(ex.num, eqs=eqs, split_div=split_div) + new_terms[eqs[end].lhs] = 1 + end + if base_term(ex.den) + new_terms[1/ex.den] = 1 + else + factor(1/ex.den, eqs=eqs, split_div=split_div) + new_terms[eqs[end].lhs] = 1 + end + new_mul = SymbolicUtils.Mul(Real, coeff, new_terms) + factor(new_mul, eqs=eqs, split_div=split_div) + return eqs + end else - factor(ex.den, eqs=eqs) - new_den = eqs[end].lhs + if base_term(ex.num) + new_num = ex.num + else + factor(ex.num, eqs=eqs, split_div=split_div) + new_num = eqs[end].lhs + end + if base_term(ex.den) + new_den = ex.den + else + factor(ex.den, eqs=eqs, split_div=split_div) + new_den = eqs[end].lhs + end + new_div = SymbolicUtils.Div(new_num, new_den) + factor(new_div, eqs=eqs, split_div=split_div) + return eqs end - new_div = SymbolicUtils.Div(new_num, new_den) - factor(new_div, eqs=eqs) - return eqs elseif exprtype(ex)==POW if base_term(ex.base) new_base = ex.base else - factor(ex.base, eqs=eqs) + factor(ex.base, eqs=eqs, split_div=split_div) new_base = eqs[end].lhs end if base_term(ex.exp) new_exp = ex.exp else - factor(ex.exp, eqs=eqs) + factor(ex.exp, eqs=eqs, split_div=split_div) new_exp = eqs[end].lhs end new_pow = SymbolicUtils.Pow(new_base, new_exp) - factor(new_pow, eqs=eqs) + factor(new_pow, eqs=eqs, split_div=split_div) return eqs elseif exprtype(ex)==TERM new_args = [] @@ -157,12 +199,12 @@ function factor(old_ex::BasicSymbolic; eqs = Equation[]) if base_term(arg) push!(new_args, arg) else - factor(arg, eqs=eqs) + factor(arg, eqs=eqs, split_div=split_div) push!(new_args, eqs[end].lhs) end end new_func = SymbolicUtils.Term(ex.f, new_args) - factor(new_func, eqs=eqs) + factor(new_func, eqs=eqs, split_div=split_div) return eqs end return eqs diff --git a/src/transform/utilities.jl b/src/transform/utilities.jl index 2ae8cc8..c7c9d25 100644 --- a/src/transform/utilities.jl +++ b/src/transform/utilities.jl @@ -290,18 +290,34 @@ end function sort_vars(strings::Vector{String}) # isempty(strings) && return sort_names = fill("", length(strings)) - + # Step 1) Check for derivative-type variables # @show strings # split_strings = string.(hcat(split.(strings, "_")...)[1,:]) split_strings = first.(split.(strings, "_")) if strings == split_strings + # Simpler case; we can sort more-or-less normally + + # Put constants first, if any exist for i in eachindex(split_strings) if split_strings[i]=="constant" split_strings[i] = "_____constant" end end - return sortperm(split_strings) + + # Sort split_strings, and if the strings follow the pattern [letters][numbers], + # sort by [letters] first and then by [numbers]. Otherwise, treat the string + # just as a normal string + return sortperm(split_strings, by = s -> begin + m = match(r"([a-zA-Z]+)(\d+)", s) + if m !== nothing + prefix = m.captures[1] + number = parse(Int, m.captures[2]) + return (prefix, number) + else + return (s, 0) + end + end) end deriv = fill(false, length(strings)) # Here's a way to check for derivatives if we need to go back to "d" instead of '∂' @@ -399,8 +415,8 @@ end function _pull_vars(term::BasicSymbolic, vars::Vector{Num}, strings::Vector{String}) if exprtype(term)==SYM - if ~(string(term) in strings) - push!(strings, string(term)) + if ~(string(get_name(term)) in strings) + push!(strings, string(get_name(term))) push!(vars, term) return vars, strings end @@ -408,7 +424,7 @@ function _pull_vars(term::BasicSymbolic, vars::Vector{Num}, strings::Vector{Stri end if exprtype(term)==TERM && varterm(term) if ~(string(term.f) in strings) && (term.f==getindex && ~(string(term) in string.(vars))) - push!(strings, string(term)) + push!(strings, string(get_name(term))) push!(vars, term) return vars, strings end @@ -421,8 +437,8 @@ function _pull_vars(term::BasicSymbolic, vars::Vector{Num}, strings::Vector{Stri end ~(typeof(arg)<:BasicSymbolic) ? continue : nothing if exprtype(arg)==SYM - if ~(string(arg) in strings) - push!(strings, string(arg)) + if ~(string(get_name(arg)) in strings) + push!(strings, string(get_name(arg))) push!(vars, arg) end elseif typeof(arg) <: Real @@ -450,8 +466,8 @@ number of equations remaining (default = 4). ``` eqs = [y ~ 15*x, z ~ (1+y)^2] -shrink_eqs(eqs, 1) +julia> shrink_eqs(eqs, 1) 1-element Vector{Equation}: z ~ (1 + 15x)^2 ``` @@ -476,6 +492,47 @@ function shrink_eqs(eqs::Vector{Equation}, keep::Int64=4; force::Bool=false) return new_eqs end +""" + extract(::Vector{Equation}) + extract(::Vector{Equation}, ::Int) + +Given a set of symbolic equations, and optinally a specific +element that you would like expanded into a full expression, +progressively substitute the RHS definitions of LHS terms +until there is only that equation remaining (default = end). +Returns the RHS of that expression as an object of type +SymbolicUtils.BasicSymbolic{Real}. + +# Example + +``` +eqs = [y ~ 15*x, + z ~ (1+y)^2] + +julia> extract(eqs) +(1 + 15x)^2 +``` +""" +function extract(eqs::Vector{Equation}, ID::Int=length(eqs)) + final_expr = eqs[ID].rhs + progress = true + while progress + progress = false + for var in pull_vars(final_expr) + eq_ID = findfirst(x -> isequal(x.lhs, var), eqs) + if !isnothing(eq_ID) + if isequal(eqs[eq_ID].lhs, eqs[eq_ID].rhs) + nothing + else + final_expr = substitute(final_expr, Dict(eqs[eq_ID].lhs => eqs[eq_ID].rhs)) + progress = true + end + end + end + end + return final_expr +end + """ convex_evaluator(::Num) convex_evaluator(::Equation) diff --git a/src/transform/write.jl b/src/transform/write.jl index 9d29f67..c712e34 100644 --- a/src/transform/write.jl +++ b/src/transform/write.jl @@ -85,7 +85,11 @@ function eqn_edges(a::Vector{Equation}) LHS = hcat(pre_LHS...) LHS_id = zeros(Int, size(LHS)) for i in eachindex(LHS_id) - LHS_id[i] = varid[LHS[i]] + if haskey(varid, LHS[i]) + LHS_id[i] = varid[LHS[i]] + else + LHS_id[i] = 0 + end end RHS_vars = pull_vars.(a) @@ -718,9 +722,6 @@ function fgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, path if (varorder[i] in string.(vars)) # Skip the variable if it's an input (i.e., we don't need to make an auxiliary variable for it) continue end - if mutate && (i==length(varorder)) # If we're mutating, don't make a temporary variable for the final entry in varorder - break - end ID = findfirst(x -> occursin(varorder[i], x), varids) tempID = 0 if isempty(temp_endlist) @@ -806,7 +807,7 @@ function fgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, path end end end - else + elseif mutate==false if :cv in outputs write(loc, " temp$(maxtemp)_cv = similar($(representative))\n") if loc==file_kernel @@ -1071,7 +1072,7 @@ function fgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, path if :ccgrad in outputs for j in eachindex(dsym_cc) outvar = string(get_name.(gradlist)[j]) - write(loc, " OUT_∂$(outvar)_cv $(eq) $(funcs[func_num][6])[1].($(grad_input[j]))\n") + write(loc, " OUT_∂$(outvar)_cc $(eq) $(funcs[func_num][6])[1].($(grad_input[j]))\n") end end end From c837ba139e86a9c461ac8f367711d2831dfe17ca Mon Sep 17 00:00:00 2001 From: RXGottlieb <36906916+RXGottlieb@users.noreply.github.com> Date: Mon, 7 Jul 2025 09:32:15 -0400 Subject: [PATCH 2/3] Update ci.yml --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 36ab011..361dc65 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,7 @@ jobs: with: version: ${{ matrix.version }} arch: ${{ matrix.arch }} - - uses: actions/cache@v1 + - uses: actions/cache@v4 env: cache-name: cache-artifacts with: From 16806efbe8169a4e5a8649bdae0e1b34824b4099 Mon Sep 17 00:00:00 2001 From: RXGottlieb <36906916+RXGottlieb@users.noreply.github.com> Date: Mon, 7 Jul 2025 10:32:34 -0400 Subject: [PATCH 3/3] Remove `fgen` division tests - Tests for division operations removed for `fgen`. Rules were written to match McCormick.jl, which was modified after this issue: https://github.com/PSORLab/McCormick.jl/issues/69. This change only affects division by a negative McCormick object when the convex and concave relaxation values are not the same, but since `fgen` will be deprecated in v0.5, this issue in `fgen` is not planned to be fixed. --- test/Project.toml | 6 ++++++ test/runtests.jl | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) create mode 100644 test/Project.toml diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..db843a3 --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,6 @@ +[deps] +CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +McCormick = "53c679d3-6890-5091-8386-c291e8c8aaa1" +Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index ee8834e..df334b7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,7 @@ function eval_check_grad(eval_func, MC1) return eval_func(MC1.cv, MC1.cc, MC1.Intv.lo, MC1.Intv.hi, MC1.cv_grad[1], MC1.cc_grad[1]) end include("multiplication.jl") -include("division.jl") +# include("division.jl") include("addition.jl") include("exp.jl") include("power.jl") #NOTE: Currently only includes ^2