diff --git a/diffsl/src/ast/mod.rs b/diffsl/src/ast/mod.rs index e369b4f..d1c7149 100644 --- a/diffsl/src/ast/mod.rs +++ b/diffsl/src/ast/mod.rs @@ -815,10 +815,13 @@ impl<'a> Ast<'a> { AstKind::Name(Name { name: found_name, indices, - indice: _, + indice, is_tangent: _, }) => { deps.insert((*found_name, indices.clone())); + if let Some(indice) = indice { + indice.collect_deps(deps); + } } AstKind::NamedGradient(gradient) => { gradient.gradient_of.collect_deps(deps); @@ -837,6 +840,9 @@ impl<'a> Ast<'a> { } } AstKind::TensorElmt(elmt) => { + if let Some(indices) = &elmt.indices { + indices.collect_deps(deps); + } elmt.expr.collect_deps(deps); } AstKind::DsModel(_m) => (), @@ -850,8 +856,17 @@ impl<'a> Ast<'a> { AstKind::Domain(_) => (), AstKind::IntRange(_) => (), AstKind::Assignment(_) => (), - AstKind::Vector(_) => (), - AstKind::Indice(_) => (), + AstKind::Vector(vector) => { + for item in &vector.data { + item.collect_deps(deps); + } + } + AstKind::Indice(indice) => { + indice.first.collect_deps(deps); + if let Some(last) = &indice.last { + last.collect_deps(deps); + } + } } } @@ -935,6 +950,48 @@ impl<'a> Ast<'a> { } } +#[cfg(test)] +mod tests { + use super::{Ast, AstKind, Name}; + + #[test] + fn test_get_dependents_includes_name_index_dependencies() { + let expr = Ast { + kind: AstKind::Name(Name { + name: "pace", + indices: vec!['i'], + indice: Some(Box::new(Ast { + kind: AstKind::new_indice( + Ast { + kind: AstKind::new_binop( + '%', + Ast { + kind: AstKind::new_name("N"), + span: None, + }, + Ast { + kind: AstKind::new_integer(2), + span: None, + }, + ), + span: None, + }, + None, + None, + ), + span: None, + })), + is_tangent: false, + }), + span: None, + }; + + let deps = expr.get_dependents(); + assert!(deps.contains("pace")); + assert!(deps.contains("N")); + } +} + impl fmt::Display for Ast<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match &self.kind { diff --git a/diffsl/src/discretise/discrete_model.rs b/diffsl/src/discretise/discrete_model.rs index d71565c..fab1b09 100644 --- a/diffsl/src/discretise/discrete_model.rs +++ b/diffsl/src/discretise/discrete_model.rs @@ -1436,7 +1436,7 @@ mod tests { let model = parse_ds_string(model_text.as_str()).unwrap(); match DiscreteModel::build("$name", &model) { Ok(model) => { - let tensor = model.constant_defns().iter().chain(model.time_dep_defns.iter()).find(|t| t.name() == $tensor_name).unwrap(); + let tensor = model.constant_defns().iter().chain(model.input_dep_defns().iter()).chain(model.time_dep_defns().iter()).find(|t| t.name() == $tensor_name).unwrap(); let tensor_string = format!("{}", tensor).chars().filter(|c| !c.is_whitespace()).collect::(); let tensor_check_string = $tensor_string.chars().filter(|c| !c.is_whitespace()).collect::(); assert_eq!(tensor_string, tensor_check_string); diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index f7a2c5f..fab7d8b 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -1993,6 +1993,82 @@ mod tests { assert_relative_eq!(rr1[1], T::from_f64(19.0).unwrap()); } + #[cfg(feature = "llvm")] + #[test] + fn test_model_indexed_scalar_tensor_does_not_panic_llvm() { + let full_text = " + pace_i { 0, 10 } + rate { pace_i[N % 2] } + u { x = 1 } + F { rate - x } + "; + let model = parse_ds_string(full_text).unwrap(); + let rate_expr = model + .tensors + .iter() + .find_map(|tensor| { + let tensor = tensor.kind.as_array()?; + if tensor.name() == "rate" { + Some( + tensor.elmts()[0] + .kind + .as_tensor_elmt() + .unwrap() + .expr + .as_ref(), + ) + } else { + None + } + }) + .unwrap(); + assert!(rate_expr.get_dependents().contains("N")); + + let discrete_model = + DiscreteModel::build("test_model_indexed_scalar_tensor", &model).unwrap(); + + assert_eq!( + discrete_model + .constant_defns() + .iter() + .map(|tensor| tensor.name()) + .collect::>(), + ["pace"] + ); + assert_eq!( + discrete_model + .input_dep_defns() + .iter() + .map(|tensor| tensor.name()) + .collect::>(), + ["rate"] + ); + + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut u0 = vec![0.0; 1]; + let mut rr0 = vec![0.0; 1]; + let mut rr1 = vec![0.0; 1]; + let mut data = compiler.get_new_data(); + + compiler.set_inputs(&[], data.as_mut_slice(), 0); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs(0.0, u0.as_slice(), data.as_mut_slice(), rr0.as_mut_slice()); + + compiler.set_inputs(&[], data.as_mut_slice(), 1); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.rhs(0.0, u0.as_slice(), data.as_mut_slice(), rr1.as_mut_slice()); + + assert_relative_eq!(u0.as_slice(), vec![1.0].as_slice()); + assert_relative_eq!(rr0.as_slice(), vec![-1.0].as_slice()); + assert_relative_eq!(rr1.as_slice(), vec![9.0].as_slice()); + } + #[allow(dead_code)] fn test_out_depends_on_internal_tensor< M: CodegenModuleCompile + CodegenModuleJit,