diff --git a/tket/src/metadata.rs b/tket/src/metadata.rs index b5bf44bbb..4f298b61a 100644 --- a/tket/src/metadata.rs +++ b/tket/src/metadata.rs @@ -12,6 +12,24 @@ impl Metadata for MaxQubits { type Type<'hugr> = u32; } +/// Metadata hinting the compiler that a function declaration should be inlined at its call sites. +/// +/// When a function is not annotated, we use an heuristic to determine whether to inline. +#[derive( + Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize, +)] +#[serde(rename_all = "lowercase")] +pub enum InlineHint { + /// Always inline the function, even if it is large or called from many places. + Always, + /// Never inline the function. + Never, +} +impl Metadata for InlineHint { + const KEY: &'static str = "tket.hint.inline"; + type Type<'hugr> = Self; +} + /// Metadata key for traced rewrites that were applied during circuit transformation. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct CircuitRewriteTraces; diff --git a/tket/src/passes/inline_funcs.rs b/tket/src/passes/inline_funcs.rs index 1de45d1bb..638f866e5 100644 --- a/tket/src/passes/inline_funcs.rs +++ b/tket/src/passes/inline_funcs.rs @@ -8,6 +8,7 @@ use petgraph::algo::tarjan_scc; use hugr_core::hugr::{hugrmut::HugrMut, patch::inline_call::InlineCall}; use hugr_core::module_graph::{ModuleGraph, StaticNode}; +use crate::metadata::InlineHint; use crate::passes::{ComposablePass, PassScope, WithScope}; /// Error raised by [inline_acyclic] @@ -54,10 +55,9 @@ impl Default for InlineFuncsHeuristic { pub struct InlineFunctionsPass { /// Heuristic for deciding which functions to inline. heuristic: InlineFuncsHeuristic, - /// Whether to follow compiler hints for inlining functions. - // - // Note that the inline hint metadata has not been defined yet, so this is currently unused. - // TODO: + /// Whether to follow compiler hints for inlining additional functions. + /// + /// See the [InlineHint] metadata entry. follow_inline_hints: bool, scope: PassScope, @@ -80,7 +80,10 @@ impl InlineFunctionsPass { self } - /// Sets whether to follow compiler hints for inlining functions. + /// Sets whether to follow compiler hints for inlining additional functions. + /// + /// Note that functions annotated with [InlineHint::Never] will never + /// be inlined, even if this option is disabled. pub fn follow_inline_hints(mut self, follow_inline_hints: bool) -> Self { self.follow_inline_hints = follow_inline_hints; self @@ -97,9 +100,13 @@ impl ComposablePass for InlineFunctionsPass { let Some(func) = h.static_source(call) else { return false; }; - *should_inline_cache - .entry(func) - .or_insert_with(|| self.heuristic.should_inline(func, h)) + *should_inline_cache.entry(func).or_insert_with(|| { + match h.get_metadata::(func) { + Some(InlineHint::Never) => false, + Some(InlineHint::Always) if self.follow_inline_hints => true, + _ => self.heuristic.should_inline(func, h), + } + }) }) } } @@ -366,7 +373,7 @@ mod test { #[case::size_zero(InlineFuncsHeuristic::MaxSize(0), vec!["f", "b"])] #[case::size_unlimited(InlineFuncsHeuristic::MaxSize(usize::MAX), vec!["f"])] #[case::all(InlineFuncsHeuristic::All, vec!["f"])] - fn inline_functions_pass_respects_max_inline_size( + fn inline_functions_pass_heuristic( #[case] heuristic: InlineFuncsHeuristic, #[case] g_targets: Vec<&'static str>, ) { @@ -387,4 +394,45 @@ mod test { HashSet::from_iter(g_targets), ); } + + #[rstest] + #[case::follow_hints(true, vec!["f", "c"])] + #[case::ignore_hints(false, vec!["f", "b"])] + fn inline_functions_pass_hints( + #[case] follow_hints: bool, + #[case] g_targets: Vec<&'static str>, + ) { + use hugr::hugr::hugrmut::HugrMut; + + use crate::metadata::InlineHint; + + let mut h = make_test_hugr(); + let b = find_func(&h, "b"); + let c = find_func(&h, "c"); + let f = find_func(&h, "f"); + // This should be inlined + h.set_metadata::(b, InlineHint::Always); + // This should never be inlined, even if `follow_hints` is false. + h.set_metadata::(c, InlineHint::Never); + // This should be ignored, as `f` is in a double-recursive loop with `g`. + h.set_metadata::(f, InlineHint::Always); + + run_validating( + InlineFunctionsPass::default() + .with_heuristic(InlineFuncsHeuristic::MaxSize(0)) + .follow_inline_hints(follow_hints), + &mut h, + ) + .unwrap(); + + let cg = ModuleGraph::new(&h); + let g = find_func(&h, "g"); + assert_eq!( + outgoing_calls(&cg, g) + .into_iter() + .map(|n| func_name(&h, n).as_str()) + .collect::>(), + HashSet::from_iter(g_targets), + ); + } }