diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index dafc3b5f..780019e2 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -39,13 +39,13 @@ jobs: julia-version: '1.10' steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: "Set up Julia" - uses: julia-actions/setup-julia@v1 + uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.julia-version }} - name: "Cache artifacts" - uses: julia-actions/cache@v1 + uses: julia-actions/cache@v2 - name: "Build package" uses: julia-actions/julia-buildpkg@v1 - name: "Run tests" @@ -75,11 +75,11 @@ jobs: # - "enzyme" # flaky; seems to infinitely compile and fail the CI - "jet" steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 with: version: ${{ matrix.julia-version }} - - uses: julia-actions/cache@v1 + - uses: julia-actions/cache@v2 - uses: julia-actions/julia-buildpkg@v1 - name: Run tests id: run-tests diff --git a/src/Parse.jl b/src/Parse.jl index d1aac2b6..e6a3805a 100644 --- a/src/Parse.jl +++ b/src/Parse.jl @@ -242,6 +242,54 @@ end end end +@unstable parse_expression(ex::String; kws...) = parse_expression(Meta.parse(ex); kws...) + +""" +Find an operator function by its name in the OperatorEnum, considering the arity. +Throws appropriate errors for ambiguous or missing matches. +""" +@unstable function _find_operator_by_name(func_symbol, degree, operators) + matches = Tuple{Function,Int}[] + + for arity in 1:length(operators.ops) + for op in operators.ops[arity] + if nameof(op) == func_symbol + push!(matches, (op, arity)) + end + end + end + + if isempty(matches) + throw( + ArgumentError( + "Tried to interpolate function `$(func_symbol)` but failed. " * + "Function not found in operators.", + ), + ) + end + + arity_matches = filter(m -> m[2] == degree, matches) + + if length(arity_matches) > 1 + throw( + ArgumentError( + "Ambiguous operator `$(func_symbol)` with arity $(degree). " * + "Multiple matches found: $(arity_matches)", + ), + ) + elseif length(arity_matches) == 0 + available_arities = [m[2] for m in matches] + throw( + ArgumentError( + "Operator `$(func_symbol)` found but not with arity $(degree). " * + "Available arities: $(available_arities)", + ), + ) + end + + return arity_matches[1][1]::Function +end + """An empty module for evaluation without collisions.""" module EmptyModule end @@ -264,9 +312,9 @@ module EmptyModule end func = try Core.eval(EmptyModule, first(ex.args)) catch - throw( - ArgumentError("Tried to interpolate function `$(first(ex.args))` but failed."), - ) + # Try to find the function in operators by name + degree = length(args) - 1 + _find_operator_by_name(first(ex.args), degree, operators) end::Function return _parse_expression( func, args, operators, variable_names, N, E, evaluate_on; kws... diff --git a/test/test_parse.jl b/test/test_parse.jl index f801a740..b57d953a 100644 --- a/test/test_parse.jl +++ b/test/test_parse.jl @@ -340,3 +340,57 @@ end ) @test string_tree(ex) == "x" end + +@testitem "custom operators without passing function object" begin + using DynamicExpressions + + custom_mul(x, y) = x * y + custom_cos(x) = cos(x) + custom_max(x, y, z) = max(x, y, z) + + operators = OperatorEnum( + 1 => [custom_cos], 2 => [+, -, *, /, custom_mul], 3 => [custom_max] + ) + + # Test nested custom operators + ex = parse_expression( + "custom_max(custom_cos(x1), custom_mul(x2, x3), x2 + 1.5)"; + operators=operators, + node_type=Node{T,3} where {T}, + variable_names=["x1", "x2", "x3"], + ) + @test typeof(ex) <: Expression{Float64} + @test string_tree(ex) == "custom_max(custom_cos(x1), custom_mul(x2, x3), x2 + 1.5)" + @test ex([1.0 2.0 3.0]') == [6.0] + + # Test error cases for _find_operator_by_name + @test_throws( + ArgumentError( + "Tried to interpolate function `unknown_func` but failed. Function not found in operators.", + ), + parse_expression("unknown_func(x1)", operators=operators, variable_names=["x1"]) + ) + + # Test ambiguous operator - same name from different modules + module TestMod1 + foo(x) = x + 1 + end + module TestMod2 + foo(x) = x - 1 + end + same_name_ops = OperatorEnum(1 => [TestMod1.foo, TestMod2.foo]) + @test_throws( + r"Ambiguous operator `foo` with arity 1\. Multiple matches found: Tuple\{Function, Int64\}\[.*foo.*1.*foo.*1.*\]", + parse_expression("foo(x1)", operators=same_name_ops, variable_names=["x1"]) + ) + + # Test wrong arity + @test_throws( + ArgumentError( + "Operator `custom_cos` found but not with arity 2. Available arities: [1]" + ), + parse_expression( + "custom_cos(x1, x2)", operators=operators, variable_names=["x1", "x2"] + ) + ) +end