Skip to content

Commit 1a1f2db

Browse files
authored
Merge pull request #137 from SymbolicML/parse-without-interpolation
Parse expressions without interpolation
2 parents e46ee58 + e79b95c commit 1a1f2db

File tree

3 files changed

+111
-9
lines changed

3 files changed

+111
-9
lines changed

.github/workflows/CI.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,13 @@ jobs:
3939
julia-version: '1.10'
4040

4141
steps:
42-
- uses: actions/checkout@v2
42+
- uses: actions/checkout@v4
4343
- name: "Set up Julia"
44-
uses: julia-actions/setup-julia@v1
44+
uses: julia-actions/setup-julia@v2
4545
with:
4646
version: ${{ matrix.julia-version }}
4747
- name: "Cache artifacts"
48-
uses: julia-actions/cache@v1
48+
uses: julia-actions/cache@v2
4949
- name: "Build package"
5050
uses: julia-actions/julia-buildpkg@v1
5151
- name: "Run tests"
@@ -75,11 +75,11 @@ jobs:
7575
# - "enzyme" # flaky; seems to infinitely compile and fail the CI
7676
- "jet"
7777
steps:
78-
- uses: actions/checkout@v2
79-
- uses: julia-actions/setup-julia@v1
78+
- uses: actions/checkout@v4
79+
- uses: julia-actions/setup-julia@v2
8080
with:
8181
version: ${{ matrix.julia-version }}
82-
- uses: julia-actions/cache@v1
82+
- uses: julia-actions/cache@v2
8383
- uses: julia-actions/julia-buildpkg@v1
8484
- name: Run tests
8585
id: run-tests

src/Parse.jl

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,54 @@ end
242242
end
243243
end
244244

245+
@unstable parse_expression(ex::String; kws...) = parse_expression(Meta.parse(ex); kws...)
246+
247+
"""
248+
Find an operator function by its name in the OperatorEnum, considering the arity.
249+
Throws appropriate errors for ambiguous or missing matches.
250+
"""
251+
@unstable function _find_operator_by_name(func_symbol, degree, operators)
252+
matches = Tuple{Function,Int}[]
253+
254+
for arity in 1:length(operators.ops)
255+
for op in operators.ops[arity]
256+
if nameof(op) == func_symbol
257+
push!(matches, (op, arity))
258+
end
259+
end
260+
end
261+
262+
if isempty(matches)
263+
throw(
264+
ArgumentError(
265+
"Tried to interpolate function `$(func_symbol)` but failed. " *
266+
"Function not found in operators.",
267+
),
268+
)
269+
end
270+
271+
arity_matches = filter(m -> m[2] == degree, matches)
272+
273+
if length(arity_matches) > 1
274+
throw(
275+
ArgumentError(
276+
"Ambiguous operator `$(func_symbol)` with arity $(degree). " *
277+
"Multiple matches found: $(arity_matches)",
278+
),
279+
)
280+
elseif length(arity_matches) == 0
281+
available_arities = [m[2] for m in matches]
282+
throw(
283+
ArgumentError(
284+
"Operator `$(func_symbol)` found but not with arity $(degree). " *
285+
"Available arities: $(available_arities)",
286+
),
287+
)
288+
end
289+
290+
return arity_matches[1][1]::Function
291+
end
292+
245293
"""An empty module for evaluation without collisions."""
246294
module EmptyModule end
247295

@@ -264,9 +312,9 @@ module EmptyModule end
264312
func = try
265313
Core.eval(EmptyModule, first(ex.args))
266314
catch
267-
throw(
268-
ArgumentError("Tried to interpolate function `$(first(ex.args))` but failed."),
269-
)
315+
# Try to find the function in operators by name
316+
degree = length(args) - 1
317+
_find_operator_by_name(first(ex.args), degree, operators)
270318
end::Function
271319
return _parse_expression(
272320
func, args, operators, variable_names, N, E, evaluate_on; kws...

test/test_parse.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,57 @@ end
340340
)
341341
@test string_tree(ex) == "x"
342342
end
343+
344+
@testitem "custom operators without passing function object" begin
345+
using DynamicExpressions
346+
347+
custom_mul(x, y) = x * y
348+
custom_cos(x) = cos(x)
349+
custom_max(x, y, z) = max(x, y, z)
350+
351+
operators = OperatorEnum(
352+
1 => [custom_cos], 2 => [+, -, *, /, custom_mul], 3 => [custom_max]
353+
)
354+
355+
# Test nested custom operators
356+
ex = parse_expression(
357+
"custom_max(custom_cos(x1), custom_mul(x2, x3), x2 + 1.5)";
358+
operators=operators,
359+
node_type=Node{T,3} where {T},
360+
variable_names=["x1", "x2", "x3"],
361+
)
362+
@test typeof(ex) <: Expression{Float64}
363+
@test string_tree(ex) == "custom_max(custom_cos(x1), custom_mul(x2, x3), x2 + 1.5)"
364+
@test ex([1.0 2.0 3.0]') == [6.0]
365+
366+
# Test error cases for _find_operator_by_name
367+
@test_throws(
368+
ArgumentError(
369+
"Tried to interpolate function `unknown_func` but failed. Function not found in operators.",
370+
),
371+
parse_expression("unknown_func(x1)", operators=operators, variable_names=["x1"])
372+
)
373+
374+
# Test ambiguous operator - same name from different modules
375+
module TestMod1
376+
foo(x) = x + 1
377+
end
378+
module TestMod2
379+
foo(x) = x - 1
380+
end
381+
same_name_ops = OperatorEnum(1 => [TestMod1.foo, TestMod2.foo])
382+
@test_throws(
383+
r"Ambiguous operator `foo` with arity 1\. Multiple matches found: Tuple\{Function, Int64\}\[.*foo.*1.*foo.*1.*\]",
384+
parse_expression("foo(x1)", operators=same_name_ops, variable_names=["x1"])
385+
)
386+
387+
# Test wrong arity
388+
@test_throws(
389+
ArgumentError(
390+
"Operator `custom_cos` found but not with arity 2. Available arities: [1]"
391+
),
392+
parse_expression(
393+
"custom_cos(x1, x2)", operators=operators, variable_names=["x1", "x2"]
394+
)
395+
)
396+
end

0 commit comments

Comments
 (0)