diff --git a/tket-py/src/passes.rs b/tket-py/src/passes.rs index c69de7947..76ef77ddb 100644 --- a/tket-py/src/passes.rs +++ b/tket-py/src/passes.rs @@ -1,6 +1,7 @@ //! Passes for optimising circuits. pub mod chunks; +mod inline_always; mod inline_funcs; mod scope; pub mod tket1; @@ -27,12 +28,14 @@ pub fn module(py: Python<'_>) -> PyResult> { m.add_function(wrap_pyfunction!(greedy_depth_reduce, &m)?)?; m.add_function(wrap_pyfunction!(badger_optimise, &m)?)?; m.add_function(wrap_pyfunction!(normalize_guppy, &m)?)?; + m.add_function(wrap_pyfunction!(self::inline_always::inline_always, &m)?)?; m.add_function(wrap_pyfunction!(self::inline_funcs::inline_functions, &m)?)?; m.add_class::()?; m.add_function(wrap_pyfunction!(self::chunks::chunks, &m)?)?; m.add_function(wrap_pyfunction!(self::tket1::tket1_pass, &m)?)?; m.add_function(wrap_pyfunction!(resolve_modifiers, &m)?)?; m.add("PullForwardError", py.get_type::())?; + m.add("InlineAlwaysError", py.get_type::())?; m.add( "InlineFunctionsError", py.get_type::(), @@ -59,6 +62,12 @@ create_py_exception!( "Errors from the modifer resolver pass." ); +create_py_exception!( + tket::passes::inline_always::InlineAlwaysError, + PyInlineAlwaysError, + "Error from `InlineAlwaysPass`" +); + create_py_exception!( tket::passes::inline_funcs::InlineFuncsError, PyInlineFunctionsError, diff --git a/tket-py/src/passes/inline_always.rs b/tket-py/src/passes/inline_always.rs new file mode 100644 index 000000000..c2fe5a6c3 --- /dev/null +++ b/tket-py/src/passes/inline_always.rs @@ -0,0 +1,22 @@ +//! Python bindings for the `InlineFunctions` pass. + +use pyo3::prelude::*; + +use tket::passes::{ComposablePass, WithScope}; + +use super::PyPassScope; +use crate::state::CompilationState; +use crate::utils::ConvertPyErr; + +/// Inline functions marked with the `inline="always"` decorator below the selected scope. +#[pyfunction] +#[pyo3(signature = (circ, *, scope = None))] +pub(super) fn inline_always( + circ: &mut CompilationState, + scope: Option, +) -> PyResult<()> { + let py_scope = scope.unwrap_or_default(); + let pass = tket::passes::InlineAlwaysPass::default_with_scope(py_scope.scope); + pass.run(&mut circ.hugr).convert_pyerrs()?; + Ok(()) +} diff --git a/tket-py/test/test_pass.py b/tket-py/test/test_pass.py index 8f4bc3ae4..75b157d11 100644 --- a/tket-py/test/test_pass.py +++ b/tket-py/test/test_pass.py @@ -7,6 +7,8 @@ from tket.passes import ( _badger_optimise, _greedy_depth_reduce, + InlineAlwaysError, + InlineAlwaysPass, InlineFunctions, inline_funcs, NormalizeGuppy, @@ -23,6 +25,7 @@ from tket.passes import PytketHugrPass from pytket.passes import CliffordSimp, SquashRzPhasedX, SequencePass from hugr.build.base import Hugr +import hugr.tys as tys import numpy as np import pytest @@ -285,6 +288,52 @@ def test_modifier_execution() -> None: np.testing.assert_allclose(computed_statevector, expected_statevector) +@pytest.mark.parametrize("annotate", [True, False]) +def test_inline_always(annotate: bool) -> None: + import hugr.ops as ops + from hugr.build.dfg import Dfg + + d = Dfg(tys.Tuple(tys.Qubit, tys.Qubit)) + + f_id = d.module_root_builder().define_function( + "id", + [tys.Qubit], + ) + f_id.set_outputs(f_id.input_node[0]) + + if annotate: + f_id.metadata["tket.inline"] = "always" + + (tup,) = d.inputs() + (q1, q2) = d.add(ops.UnpackTuple()(tup)) + call1 = d.call(f_id, q1) + call2 = d.call(f_id, q2) + (tup,) = d.add(ops.MakeTuple()(call1, call2)) + + d.set_outputs(tup) + + # validate(d.hugr) + + InlineAlwaysPass()(d.hugr) + # validate(d.hugr) + assert _count_ops(d.hugr, "Call") == 0 if annotate else 2 + + +def test_inline_always_cycle() -> None: + from hugr.build.function import Module + + mod = Module() + + f_recursive = mod.define_function("recurse", [tys.Qubit]) + f_recursive.declare_outputs([tys.Qubit]) + call = f_recursive.call(f_recursive, f_recursive.input_node[0]) + f_recursive.set_outputs(call) + + f_recursive.metadata["tket.inline"] = "always" + with pytest.raises(InlineAlwaysError): + InlineAlwaysPass()(mod.hugr) + + def test_inline_functions() -> None: hugr = _hugr_from_path("test_files/guppy_examples/fn_calls.hugr") diff --git a/tket-py/tket/_tket/passes.pyi b/tket-py/tket/_tket/passes.pyi index 699a3beea..0636577bc 100644 --- a/tket-py/tket/_tket/passes.pyi +++ b/tket-py/tket/_tket/passes.pyi @@ -18,6 +18,9 @@ class CircuitChunks: class PullForwardError(Exception): """Error from a `PullForward` operation.""" +class InlineAlwaysError(Exception): + """Error from `InlineAlwaysPass`.""" + def normalize_guppy( circ: CompilationState, *, @@ -43,6 +46,13 @@ def normalize_guppy( - remove_redundant_order_edges: Whether to remove redundant order edges. """ +def inline_always( + circ: CompilationState, + *, + scope: PassScope = GlobalScope.PRESERVE_PUBLIC, +) -> None: + """Inline functions marked with the `inline="always"` decorator below the selected scope.""" + def inline_functions( circ: CompilationState, *, diff --git a/tket-py/tket/passes/__init__.py b/tket-py/tket/passes/__init__.py index e21f78e71..c674d7e01 100644 --- a/tket-py/tket/passes/__init__.py +++ b/tket-py/tket/passes/__init__.py @@ -12,7 +12,7 @@ from tket import _state from . import inline_funcs from .._tket import passes as _passes, optimiser as _optimiser - +from .._tket.passes import InlineAlwaysError from hugr.passes.composable import ( ComposablePass, ComposedPass, @@ -25,6 +25,8 @@ __all__ = [ "PytketHugrPass", "PassResult", + "InlineAlwaysError", + "InlineAlwaysPass", "InlineFuncsHeuristic", "InlineFunctions", "NormalizeGuppy", @@ -147,6 +149,39 @@ def _run_tk(self, program: _state.CompilationState) -> _state.CompilationState: return program +@dataclass +class InlineAlwaysPass(ComposablePass): + """Inline functions marked with the `inline="always"` decorator below the selected scope.""" + + _scope: PassScope = GlobalScope.PRESERVE_PUBLIC + + def run(self, hugr: Hugr, *, inplace: bool = True) -> PassResult: + return implement_pass_run( + self, + hugr=hugr, + inplace=inplace, + copy_call=lambda h: self._inline_always(h, inplace), + ) + + def with_scope(self, scope: PassScope) -> InlineAlwaysPass: + """Set the scope of this pass and return self.""" + self._scope = scope + return self + + def _inline_always(self, hugr: Hugr, inplace: bool) -> PassResult: + tk_program = _state.CompilationState.from_python(hugr) + + _passes.inline_always( + tk_program._inner, + scope=self._scope, + ) + + package = tk_program.to_python() + return PassResult.for_pass( + self, hugr=package.modules[0], inplace=inplace, result=None + ) + + @dataclass class InlineFunctions(ComposablePass): """Inline acyclic function calls below the selected scope. diff --git a/tket/src/metadata.rs b/tket/src/metadata.rs index c60036aca..637764ccc 100644 --- a/tket/src/metadata.rs +++ b/tket/src/metadata.rs @@ -12,6 +12,24 @@ impl Metadata for MaxQubits { type Type<'hugr> = u32; } +/// Metadata that may be supplied for a function to indicate +/// that/when calls to it should be inlined. +#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +#[non_exhaustive] +pub enum InlineAnnotation { + /// Always inline calls to this function. + /// + /// If this cannot be done, an error will be raised. + Always, +} + +impl Metadata for InlineAnnotation { + type Type<'hugr> = Self; + + const KEY: &'static str = "tket.inline"; +} + /// Metadata key for traced rewrites that were applied during circuit transformation. #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct CircuitRewriteTraces; diff --git a/tket/src/passes.rs b/tket/src/passes.rs index 45c10584d..69cc066b9 100644 --- a/tket/src/passes.rs +++ b/tket/src/passes.rs @@ -42,6 +42,8 @@ pub mod inline_dfgs; pub use inline_dfgs::InlineDFGsPass; // Inline function calls. +pub mod inline_always; +pub use inline_always::InlineAlwaysPass; pub mod inline_funcs; pub use inline_funcs::InlineFunctionsPass; #[expect(deprecated)] diff --git a/tket/src/passes/inline_always.rs b/tket/src/passes/inline_always.rs new file mode 100644 index 000000000..8a9d80df8 --- /dev/null +++ b/tket/src/passes/inline_always.rs @@ -0,0 +1,376 @@ +//! Pass to inline calls to functions, controlled by [InlineAnnotation] metadata. +use std::collections::{HashSet, VecDeque}; + +use crate::metadata::InlineAnnotation; +use crate::passes::{ComposablePass, PassScope, composable::WithScope}; +use hugr::hugr::patch::inline_call::InlineCall; +use hugr_core::module_graph::{ModuleGraph, StaticNode}; +use hugr_core::{Node, hugr::hugrmut::HugrMut}; + +use itertools::Itertools; +use petgraph::algo::tarjan_scc; +use petgraph::data::DataMap; +use petgraph::visit::{ + Dfs, IntoNeighbors, IntoNodeIdentifiers, NodeFiltered, NodeIndexable, Visitable, Walker, +}; + +/// Errors that may be raised by [`InlineAlwaysPass`] +#[derive(Clone, Debug, PartialEq, Eq, derive_more::Display)] +pub enum InlineAlwaysError { + /// Functions annotated with [InlineAnnotation::Always] form a cycle + /// so inlining would produce an infinitely-big program + #[display("Cycle detected in functions marked to Always inline: {_0:?}")] + AlwaysCycle(Vec), +} + +impl std::error::Error for InlineAlwaysError {} + +/// A [ComposablePass] that inlines `Call`s to functions +/// according to [InlineAnnotation]s. +#[derive(Default, Clone, Debug)] +pub struct InlineAlwaysPass { + scope: PassScope, +} + +impl WithScope for InlineAlwaysPass { + fn with_scope(self, scope: impl Into) -> Self { + Self { + scope: scope.into(), + } + } +} + +impl ComposablePass for InlineAlwaysPass { + type Error = InlineAlwaysError; + type Result = (); + + fn run(&self, hugr: &mut H) -> Result<(), InlineAlwaysError> { + let Some(root) = self.scope.root(hugr) else { + return Ok(()); // Nothing to do + }; + let cg = ModuleGraph::new(hugr); + let always_funcs = hugr.children(hugr.module_root()).filter(|n| { + hugr.get_optype(*n).is_func_defn() + && hugr.get_metadata::(*n) == Some(InlineAnnotation::Always) + }); + // We're going to object if there's a cycle of functions marked Always, as that would + // lead to an infinitely big Hugr. However, don't object unless such a cycle is reachable + // from the entrypoint... + let reachable_always: HashSet = match &self.scope { + PassScope::Global(_) => always_funcs.collect(), + PassScope::EntrypointFlat | PassScope::EntrypointRecursive => { + let reachable = Dfs::new(cg.graph(), cg.node_index(hugr.entrypoint()).unwrap()) + .iter(&cg.graph()) + .collect::>(); + always_funcs + .filter(|n| { + let ni = cg.node_index(*n).unwrap(); + reachable.contains(&ni) + }) + .collect() + } + }; + let always_cg = + NodeFiltered::from_fn(cg.graph(), |n| match cg.graph().node_weight(n).unwrap() { + StaticNode::FuncDefn(func) => reachable_always.contains(func), + _ => false, + }); + if let Some(cycle) = cycles(&always_cg).next() { + return Err(InlineAlwaysError::AlwaysCycle(cycle)); + } + // Proceed with inlining. Do outermost first within the scope root, as we cannot + // inline into functions that are outside the scope until they themselves are inlined + // beneath the root. + let mut parents = VecDeque::from([root]); + let mut seen = HashSet::new(); + while let Some(parent) = parents.pop_front() { + if hugr.get_optype(parent).is_func_defn() { + seen.insert(parent); + } + let mut to_inline = Vec::new(); + for child in hugr.children(parent) { + if hugr.first_child(child).is_some() { + parents.push_back(child); + } else if hugr.get_optype(child).is_call() + && let Some(func) = hugr.static_source(child) + && reachable_always.contains(&func) + { + to_inline.push((child, func)); + } + } + while let Some((call, func)) = to_inline.pop() { + // We've already checked the error conditions. + hugr.apply_patch(InlineCall::new(call)).unwrap(); + if !seen.contains(&func) { + // We have not inlined everything into `func` yet, + // so there may still be some work to do in the inlined copy. + parents.push_back(call); + } + } + } + // Remove the always-inlined functions themselves, as they are now unreachable. + let funcs_to_preserve = self.scope.preserve_interface(hugr).collect::>(); + if root == hugr.module_root() { + for func in reachable_always { + debug_assert!(hugr.static_targets(func).unwrap().next().is_none()); + if !funcs_to_preserve.contains(&func) { + hugr.remove_subtree(func); + } + } + } + Ok(()) + } +} + +fn cycles<'a, N: Copy>( + g: impl Copy + + Visitable + + DataMap> + + IntoNeighbors + + IntoNodeIdentifiers + + NodeIndexable + + 'a, +) -> impl Iterator> + 'a { + tarjan_scc(g) + .into_iter() + .filter(move |ns| { + ns.iter() + .exactly_one() + .ok() + .is_none_or(|n| // multi-node, or single-node cycle + g.neighbors(*n).contains(n)) + }) + .map(move |cycle| { + cycle + .into_iter() + .map(|n| match g.node_weight(n).unwrap() { + StaticNode::FuncDefn(fd) => *fd, + _ => panic!("Expected only FuncDefns in sccs"), + }) + .collect() + }) +} + +#[cfg(test)] +mod test { + use itertools::Itertools; + use rstest::rstest; + use std::collections::HashSet; + + use crate::TketOp; + use crate::passes::{ + ComposablePass, InlineDFGsPass, PassScope, RemoveDeadFuncsPass, WithScope, + composable::Preserve, + }; + + use hugr::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + }; + use hugr::{ + HugrView, + extension::prelude::{qb_t, usize_t}, + hugr::hugrmut::HugrMut, + ops::handle::NodeHandle, + types::Signature, + }; + + use super::{InlineAlwaysError, InlineAlwaysPass, InlineAnnotation}; + + #[test] + fn test_single_cycle() { + let mut main = FunctionBuilder::new("main", Signature::new_endo([qb_t(), qb_t()])).unwrap(); + let mut mb = main.module_root_builder(); + let mut fb = mb + .define_function("self-call", Signature::new_endo([qb_t()])) + .unwrap(); + let c = fb + .call::(&fb.container_node().into(), &[], fb.input_wires()) + .unwrap(); + let fb = fb.finish_with_outputs(c.outputs()).unwrap(); + let inputs = main.input_wires(); + let mut hugr = main.finish_hugr_with_outputs(inputs).unwrap(); + hugr.set_metadata::(fb.node(), InlineAnnotation::Always); + let backup = hugr.clone(); + + // We error even though the function is not called + let e = InlineAlwaysPass::default().run(&mut hugr).unwrap_err(); + assert_eq!(e, InlineAlwaysError::AlwaysCycle(vec![fb.node()])); + assert_eq!(hugr, backup); + + RemoveDeadFuncsPass::default().run(&mut hugr).unwrap(); + assert_eq!( + hugr.children(hugr.module_root()).collect::>(), + [hugr.entrypoint()] + ); + let backup = hugr.clone(); + InlineAlwaysPass::default().run(&mut hugr).unwrap(); + assert_eq!(hugr, backup); + } + + #[test] + fn cycle() { + let mut main = FunctionBuilder::new("main", Signature::new_endo([usize_t()])).unwrap(); + let main_h = main.container_node().into(); + let mut mb = main.module_root_builder(); + let mut fb1 = mb + .define_function("f1", Signature::new_endo([usize_t()])) + .unwrap(); + let c1 = fb1.call::(&main_h, &[], fb1.input_wires()).unwrap(); + let fb1 = fb1.finish_with_outputs(c1.outputs()).unwrap(); + let c2 = main.call(fb1.handle(), &[], main.input_wires()).unwrap(); + let mut hugr = main.finish_hugr_with_outputs(c2.outputs()).unwrap(); + hugr.set_metadata::(hugr.entrypoint(), InlineAnnotation::Always); + InlineAlwaysPass::default().run(&mut hugr.clone()).unwrap(); // Ok + + hugr.set_metadata::(fb1.node(), InlineAnnotation::Always); + let e = InlineAlwaysPass::default().run(&mut hugr).unwrap_err(); + assert_eq!( + e, + InlineAlwaysError::AlwaysCycle(vec![fb1.node(), hugr.entrypoint()]) + ); + } + + #[rstest] + fn test_one_deep(#[values(1, 2, 5)] num_calls: usize) { + let mut main = + FunctionBuilder::new("main", Signature::new_endo([qb_t(), qb_t(), qb_t()])).unwrap(); + + let mut mb = main.module_root_builder(); + let swap = mb + .define_function("swap", Signature::new_endo([qb_t(), qb_t()])) + .unwrap(); + let [a, b] = swap.input_wires_arr(); + let swap = swap.finish_with_outputs([b, a]).unwrap(); + + let [mut a, mut b, c] = main.input_wires_arr(); + for _ in 0..num_calls { + [a, b] = main.call(swap.handle(), &[], [a, b]).unwrap().outputs_arr(); + } + let mut hugr = main.finish_hugr_with_outputs([a, b, c]).unwrap(); + hugr.set_metadata::(swap.node(), InlineAnnotation::Always); + + InlineAlwaysPass::default().run(&mut hugr).unwrap(); + hugr.validate().unwrap(); + + let swap_present = + hugr.contains_node(swap.node()) && hugr.get_optype(swap.node()).is_func_defn(); + assert!(!swap_present); + InlineDFGsPass::default().run(&mut hugr).unwrap(); + hugr.validate().unwrap(); + let [inp, outp] = hugr.get_io(hugr.entrypoint()).unwrap(); + assert_eq!( + HashSet::from_iter(hugr.input_neighbours(outp)), + HashSet::from([inp]) + ); + } + + #[rstest] + #[case(PassScope::EntrypointFlat)] + #[case(PassScope::EntrypointRecursive)] + #[case(Preserve::All)] + #[case(Preserve::Public)] + #[case(Preserve::Entrypoint)] + fn entrypoint_scope(#[case] ps: impl Into) { + let mut entry = FunctionBuilder::new("entry", Signature::new_endo([qb_t()])).unwrap(); + let mut mb = entry.module_root_builder(); + let mut cyclic = mb + .define_function("cyclic", Signature::new_endo([qb_t()])) + .unwrap(); + let c = cyclic + .call::(&cyclic.container_node().into(), &[], cyclic.input_wires()) + .unwrap(); + let cyclic = cyclic.finish_with_outputs(c.outputs()).unwrap(); + let mut other = mb + .define_function("other", Signature::new_endo([qb_t()])) + .unwrap(); + let c = other + .call(cyclic.handle(), &[], other.input_wires()) + .unwrap(); + other.finish_with_outputs(c.outputs()).unwrap(); + + let id = mb + .define_function("id", Signature::new_endo([qb_t()])) + .unwrap(); + let inps = id.input_wires(); + let id = id.finish_with_outputs(inps).unwrap(); + let c = entry + .call::(id.handle(), &[], entry.input_wires()) + .unwrap(); + let mut h = entry.finish_hugr_with_outputs(c.outputs()).unwrap(); + assert_eq!(h.static_targets(cyclic.node()).unwrap().count(), 2); // cyclic and entry + h.set_metadata::(cyclic.node(), InlineAnnotation::Always); + h.set_metadata::(id.node(), InlineAnnotation::Always); + let ps = ps.into(); + let e = InlineAlwaysPass::default_with_scope(ps.clone()).run(&mut h); + if let PassScope::EntrypointFlat | PassScope::EntrypointRecursive = ps { + assert_eq!(e, Ok(())); + assert_eq!(h.static_targets(cyclic.node()).unwrap().count(), 2); // cyclic and entry + assert_eq!(h.static_targets(id.node()).unwrap().collect_vec(), []); // No calls, but can't be removed as outside scope + InlineDFGsPass::default_with_scope(ps).run(&mut h).unwrap(); + let [inp, out] = h.get_io(h.entrypoint()).unwrap(); + assert_eq!(h.output_neighbours(inp).collect_vec(), [out]); + } else { + assert_eq!(e, Err(InlineAlwaysError::AlwaysCycle(vec![cyclic.node()]))); + }; + } + + #[test] + fn cycle_part_always() { + let mut main = FunctionBuilder::new("main", Signature::new_endo([qb_t()])).unwrap(); + let main_h = main.container_node().into(); + let mut mb = main.module_root_builder(); + let mut f = mb + .define_function("f", Signature::new_endo([qb_t()])) + .unwrap(); + let hada = f.add_dataflow_op(TketOp::H, f.input_wires()).unwrap(); + let c = f.call::(&main_h, &[], hada.outputs()).unwrap(); + let f = f.finish_with_outputs(c.outputs()).unwrap(); + let c = main.call(f.handle(), &[], main.input_wires()).unwrap(); + let backup = main.finish_hugr_with_outputs(c.outputs()).unwrap(); + + // 1. Mark private callee f as Always, so it can be inlined into main and removed + let mut hugr = backup.clone(); + hugr.set_metadata::(f.node(), InlineAnnotation::Always); + InlineAlwaysPass::default().run(&mut hugr).unwrap(); + assert_eq!( + hugr.children(hugr.module_root()).collect_vec(), + [main_h.node()] + ); + InlineDFGsPass::default().run(&mut hugr).unwrap(); + hugr.validate().unwrap(); + let [_in, _out, h, call] = hugr.children(main_h.node()).collect_array().unwrap(); + assert_eq!( + hugr.get_optype(h).as_extension_op(), + Some(&TketOp::H.into_extension_op()) + ); + assert!(hugr.get_optype(call).is_call()); + assert_eq!(hugr.static_source(call), Some(main_h.node())); + + // 2. Mark entrypoint function main as Always; it can be inlined into f, but not removed. + let mut hugr = backup.clone(); + hugr.set_metadata::(main_h.node(), InlineAnnotation::Always); + InlineAlwaysPass::default().run(&mut hugr).unwrap(); + // No functions can be removed + assert_eq!( + hugr.children(hugr.module_root()).collect_vec(), + backup.children(backup.module_root()).collect_vec() + ); + for n in hugr.children(hugr.module_root()) { + assert_eq!(hugr.get_optype(n), backup.get_optype(n)); + } + + // main inlined into f (inside DFG): + let [_in, _out, h, dfg] = hugr.children(f.node()).collect_array().unwrap(); + assert_eq!( + hugr.get_optype(h).as_extension_op(), + Some(&TketOp::H.into_extension_op()) + ); + assert!(hugr.get_optype(dfg).is_dfg()); + let [_in, _out, call] = hugr.children(dfg).collect_array().unwrap(); + assert!(hugr.get_optype(call).is_call()); + assert_eq!(hugr.static_source(call), Some(f.node())); + + // No calls to main: + assert_eq!(hugr.static_targets(main_h.node()).unwrap().next(), None); + } +}