Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tket/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@ impl Metadata for MaxQubits {
type Type<'hugr> = u32;
}

/// Metadata hinting the compiler that a function declaration should be inlined at its call sites.
///
/// When a function is not annotated, we use an heuristic to determine whether to inline.
#[derive(
Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, serde::Serialize, serde::Deserialize,
)]
#[serde(rename_all = "lowercase")]
pub enum InlineHint {
/// Always inline the function, even if it is large or called from many places.
Always,
/// Never inline the function.
Never,
}
impl Metadata for InlineHint {
const KEY: &'static str = "tket.hint.inline";
type Type<'hugr> = Self;
}

/// Metadata key for traced rewrites that were applied during circuit transformation.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct CircuitRewriteTraces;
Expand Down
66 changes: 57 additions & 9 deletions tket/src/passes/inline_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use petgraph::algo::tarjan_scc;
use hugr_core::hugr::{hugrmut::HugrMut, patch::inline_call::InlineCall};
use hugr_core::module_graph::{ModuleGraph, StaticNode};

use crate::metadata::InlineHint;
use crate::passes::{ComposablePass, PassScope, WithScope};

/// Error raised by [inline_acyclic]
Expand Down Expand Up @@ -54,10 +55,9 @@ impl Default for InlineFuncsHeuristic {
pub struct InlineFunctionsPass {
/// Heuristic for deciding which functions to inline.
heuristic: InlineFuncsHeuristic,
/// Whether to follow compiler hints for inlining functions.
//
// Note that the inline hint metadata has not been defined yet, so this is currently unused.
// TODO: <https://github.com/Quantinuum/hugr/issues/2328>
/// Whether to follow compiler hints for inlining additional functions.
///
/// See the [InlineHint] metadata entry.
follow_inline_hints: bool,

scope: PassScope,
Expand All @@ -80,7 +80,10 @@ impl InlineFunctionsPass {
self
}

/// Sets whether to follow compiler hints for inlining functions.
/// Sets whether to follow compiler hints for inlining additional functions.
///
/// Note that functions annotated with [InlineHint::Never] will never
/// be inlined, even if this option is disabled.
pub fn follow_inline_hints(mut self, follow_inline_hints: bool) -> Self {
self.follow_inline_hints = follow_inline_hints;
self
Expand All @@ -97,9 +100,13 @@ impl<H: HugrMut> ComposablePass<H> for InlineFunctionsPass {
let Some(func) = h.static_source(call) else {
return false;
};
*should_inline_cache
.entry(func)
.or_insert_with(|| self.heuristic.should_inline(func, h))
*should_inline_cache.entry(func).or_insert_with(|| {
match h.get_metadata::<InlineHint>(func) {
Some(InlineHint::Never) => false,
Some(InlineHint::Always) if self.follow_inline_hints => true,
_ => self.heuristic.should_inline(func, h),
}
})
})
}
}
Expand Down Expand Up @@ -366,7 +373,7 @@ mod test {
#[case::size_zero(InlineFuncsHeuristic::MaxSize(0), vec!["f", "b"])]
#[case::size_unlimited(InlineFuncsHeuristic::MaxSize(usize::MAX), vec!["f"])]
#[case::all(InlineFuncsHeuristic::All, vec!["f"])]
fn inline_functions_pass_respects_max_inline_size(
fn inline_functions_pass_heuristic(
#[case] heuristic: InlineFuncsHeuristic,
#[case] g_targets: Vec<&'static str>,
) {
Expand All @@ -387,4 +394,45 @@ mod test {
HashSet::from_iter(g_targets),
);
}

#[rstest]
#[case::follow_hints(true, vec!["f", "c"])]
#[case::ignore_hints(false, vec!["f", "b"])]
fn inline_functions_pass_hints(
#[case] follow_hints: bool,
#[case] g_targets: Vec<&'static str>,
) {
use hugr::hugr::hugrmut::HugrMut;

use crate::metadata::InlineHint;

let mut h = make_test_hugr();
let b = find_func(&h, "b");
let c = find_func(&h, "c");
let f = find_func(&h, "f");
// This should be inlined
h.set_metadata::<InlineHint>(b, InlineHint::Always);
// This should never be inlined, even if `follow_hints` is false.
h.set_metadata::<InlineHint>(c, InlineHint::Never);
// This should be ignored, as `f` is in a double-recursive loop with `g`.
h.set_metadata::<InlineHint>(f, InlineHint::Always);

run_validating(
InlineFunctionsPass::default()
.with_heuristic(InlineFuncsHeuristic::MaxSize(0))
.follow_inline_hints(follow_hints),
&mut h,
)
.unwrap();

let cg = ModuleGraph::new(&h);
let g = find_func(&h, "g");
assert_eq!(
outgoing_calls(&cg, g)
.into_iter()
.map(|n| func_name(&h, n).as_str())
.collect::<HashSet<_>>(),
HashSet::from_iter(g_targets),
);
}
}
Loading