Skip to content

Commit 57a4587

Browse files
Merge pull request #2514 from AayushSabharwal/as/big-bad-bug-fix
fix: fix variable_index
2 parents 26960c8 + 46af08d commit 57a4587

9 files changed

+186
-143
lines changed

src/systems/abstractsystem.jl

+42-53
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
165165
if !iscomplete(sys)
166166
error("A completed system is required. Call `complete` or `structural_simplify` on the system.")
167167
end
168-
p = reorder_parameters(sys, ps)
168+
p = reorder_parameters(sys, unwrap.(ps))
169169
isscalar = !(exprs isa AbstractArray)
170170
if wrap_code === nothing
171171
wrap_code = isscalar ? identity : (identity, identity)
@@ -347,25 +347,23 @@ end
347347

348348
#Treat the result as a vector of symbols always
349349
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
350-
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
351-
return unwrap(sym) in 1:length(variable_symbols(sys))
350+
sym = unwrap(sym)
351+
if sym isa Int # [x, 1] coerces 1 to a Num
352+
return sym in 1:length(variable_symbols(sys))
352353
end
353-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
354-
ic = get_index_cache(sys)
355-
h = getsymbolhash(sym)
356-
return haskey(ic.unknown_idx, h) ||
357-
haskey(ic.unknown_idx, getsymbolhash(default_toterm(sym))) ||
358-
(istree(sym) && operation(sym) === getindex &&
359-
is_variable(sys, first(arguments(sym))))
360-
else
361-
return any(isequal(sym), variable_symbols(sys)) ||
362-
hasname(sym) && is_variable(sys, getname(sym))
354+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
355+
return is_variable(ic, sym) ||
356+
istree(sym) && operation(sym) === getindex &&
357+
is_variable(ic, first(arguments(sym)))
363358
end
359+
return any(isequal(sym), variable_symbols(sys)) ||
360+
hasname(sym) && is_variable(sys, getname(sym))
364361
end
365362

366363
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
364+
sym = unwrap(sym)
367365
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
368-
return haskey(ic.unknown_idx, hash(sym))
366+
return is_variable(ic, sym)
369367
end
370368
return any(isequal(sym), getname.(variable_symbols(sys))) ||
371369
count('', string(sym)) == 1 &&
@@ -374,21 +372,19 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
374372
end
375373

376374
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
377-
if unwrap(sym) isa Int
378-
return unwrap(sym)
375+
sym = unwrap(sym)
376+
if sym isa Int
377+
return sym
379378
end
380-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
381-
ic = get_index_cache(sys)
382-
h = getsymbolhash(sym)
383-
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]
384-
385-
h = getsymbolhash(default_toterm(sym))
386-
haskey(ic.unknown_idx, h) && return ic.unknown_idx[h]
387-
sym = unwrap(sym)
388-
istree(sym) && operation(sym) === getindex || return nothing
389-
idx = variable_index(sys, first(arguments(sym)))
390-
idx === nothing && return nothing
391-
return idx[arguments(sym)[(begin + 1):end]...]
379+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
380+
return if (idx = variable_index(ic, sym)) !== nothing
381+
idx
382+
elseif istree(sym) && operation(sym) === getindex &&
383+
(idx = variable_index(ic, first(arguments(sym)))) !== nothing
384+
idx[arguments(sym)[(begin + 1):end]...]
385+
else
386+
nothing
387+
end
392388
end
393389
idx = findfirst(isequal(sym), variable_symbols(sys))
394390
if idx === nothing && hasname(sym)
@@ -399,7 +395,7 @@ end
399395

400396
function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
401397
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
402-
return get(ic.unknown_idx, h, nothing)
398+
return variable_index(ic, sym)
403399
end
404400
idx = findfirst(isequal(sym), getname.(variable_symbols(sys)))
405401
if idx !== nothing
@@ -418,30 +414,22 @@ function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem)
418414
end
419415

420416
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
417+
sym = unwrap(sym)
418+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
419+
return is_parameter(ic, sym) ||
420+
istree(sym) && operation(sym) === getindex &&
421+
is_parameter(ic, first(arguments(sym)))
422+
end
421423
if unwrap(sym) isa Int
422424
return unwrap(sym) in 1:length(parameter_symbols(sys))
423425
end
424-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
425-
ic = get_index_cache(sys)
426-
h = getsymbolhash(sym)
427-
return if haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
428-
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
429-
haskey(ic.nonnumeric_idx, h)
430-
true
431-
else
432-
h = getsymbolhash(default_toterm(sym))
433-
haskey(ic.param_idx, h) || haskey(ic.discrete_idx, h) ||
434-
haskey(ic.constant_idx, h) || haskey(ic.dependent_idx, h) ||
435-
haskey(ic.nonnumeric_idx, h)
436-
end
437-
end
438426
return any(isequal(sym), parameter_symbols(sys)) ||
439427
hasname(sym) && is_parameter(sys, getname(sym))
440428
end
441429

442430
function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
443431
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
444-
return ParameterIndex(ic, sym) !== nothing
432+
return is_parameter(ic, sym)
445433
end
446434
return any(isequal(sym), getname.(parameter_symbols(sys))) ||
447435
count('', string(sym)) == 1 &&
@@ -450,20 +438,21 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol
450438
end
451439

452440
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
453-
if unwrap(sym) isa Int
454-
return unwrap(sym)
455-
end
456-
if has_index_cache(sys) && get_index_cache(sys) !== nothing
457-
ic = get_index_cache(sys)
458-
return if (idx = ParameterIndex(ic, sym)) !== nothing
459-
idx
460-
elseif (idx = ParameterIndex(ic, default_toterm(sym))) !== nothing
441+
sym = unwrap(sym)
442+
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
443+
return if (idx = parameter_index(ic, sym)) !== nothing
461444
idx
445+
elseif istree(sym) && operation(sym) === getindex &&
446+
(idx = parameter_index(ic, first(arguments(sym)))) !== nothing
447+
ParameterIndex(idx.portion, (idx.idx..., arguments(sym)[(begin + 1):end]...))
462448
else
463449
nothing
464450
end
465451
end
466452

453+
if sym isa Int
454+
return sym
455+
end
467456
idx = findfirst(isequal(sym), parameter_symbols(sys))
468457
if idx === nothing && hasname(sym)
469458
idx = parameter_index(sys, getname(sym))
@@ -473,7 +462,7 @@ end
473462

474463
function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
475464
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
476-
return ParameterIndex(ic, sym)
465+
return parameter_index(ic, sym)
477466
end
478467
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
479468
if idx !== nothing

src/systems/diffeqs/abstractodesystem.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -786,7 +786,7 @@ function get_u0_p(sys,
786786
allobs = Set(getproperty.(observed(sys), :lhs))
787787
if any(in(allobs), keys(u0map))
788788
u0s_in_obs = filter(in(allobs), keys(u0map))
789-
@warn "Observed variables cannot assigned initial values. Initial values for $u0s_in_obs will be ignored."
789+
@warn "Observed variables cannot be assigned initial values. Initial values for $u0s_in_obs will be ignored."
790790
end
791791
end
792792
defs = mergedefaults(defs, u0map, dvs)

0 commit comments

Comments
 (0)