diff --git a/test_files/guppy_examples/modifiers.hugr b/test_files/guppy_examples/modifiers.hugr index 15487e87f..3184b437d 100644 Binary files a/test_files/guppy_examples/modifiers.hugr and b/test_files/guppy_examples/modifiers.hugr differ diff --git a/test_files/guppy_examples/modifiers.py b/test_files/guppy_examples/modifiers.py index e1b968fd7..97b0d07bd 100644 --- a/test_files/guppy_examples/modifiers.py +++ b/test_files/guppy_examples/modifiers.py @@ -1,7 +1,7 @@ # /// script # requires-python = ">=3.13" # dependencies = [ -# "guppylang ==0.21.13", +# "guppylang @ git+https://github.com/Quantinuum/guppylang.git@main#subdirectory=guppylang", # ] # /// """A simple controlled gate using modifiers""" diff --git a/test_files/modified_hugrs/ctrl_on_call2_solved.hugr b/test_files/modified_hugrs/ctrl_on_call2_solved.hugr deleted file mode 100644 index 82ffbb076..000000000 Binary files a/test_files/modified_hugrs/ctrl_on_call2_solved.hugr and /dev/null differ diff --git a/test_files/modified_hugrs/ctrl_on_call_solved.hugr b/test_files/modified_hugrs/ctrl_on_call_solved.hugr deleted file mode 100644 index e871bceb2..000000000 Binary files a/test_files/modified_hugrs/ctrl_on_call_solved.hugr and /dev/null differ diff --git a/test_files/modified_hugrs/ctrl_on_x_solved.hugr b/test_files/modified_hugrs/ctrl_on_x_solved.hugr deleted file mode 100644 index 8e5b671b6..000000000 Binary files a/test_files/modified_hugrs/ctrl_on_x_solved.hugr and /dev/null differ diff --git a/test_files/modified_hugrs/dagger_on_call_solved.hugr b/test_files/modified_hugrs/dagger_on_call_solved.hugr deleted file mode 100644 index 412e27e30..000000000 Binary files a/test_files/modified_hugrs/dagger_on_call_solved.hugr and /dev/null differ diff --git a/test_files/modifier_examples/classical_function1.hugr b/test_files/modifier_examples/classical_function1.hugr new file mode 100644 index 000000000..a5a670239 Binary files /dev/null and b/test_files/modifier_examples/classical_function1.hugr differ diff --git a/test_files/modifier_examples/classical_function1.py b/test_files/modifier_examples/classical_function1.py new file mode 100644 index 000000000..7435e12ff --- /dev/null +++ b/test_files/modifier_examples/classical_function1.py @@ -0,0 +1,42 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "guppylang ==0.21.13", +# ] +# /// +"""A simple controlled gate using modifiers""" + +from pathlib import Path +from sys import argv +import sys + +from guppylang import guppy +from guppylang.std.builtins import dagger +from guppylang.std.debug import state_result +from guppylang.std.quantum import discard, qubit, angle +from guppylang.std.quantum import rx + +sys.path.append(str(Path(__file__).resolve().parents[1])) + +from guppylang.experimental import enable_experimental_features + +enable_experimental_features() + + +@guppy +def fuu(i: int) -> int: + return i + 1 + + +@guppy +def main() -> None: + q = qubit() + with dagger: + rx(q, angle(1 / fuu(2))) + + state_result("r", q) + discard(q) + + +program = main.compile() +Path(argv[0]).with_suffix(".hugr").write_bytes(program.to_bytes()) diff --git a/test_files/modifier_examples/classical_function2.hugr b/test_files/modifier_examples/classical_function2.hugr new file mode 100644 index 000000000..cc63667fe Binary files /dev/null and b/test_files/modifier_examples/classical_function2.hugr differ diff --git a/test_files/modifier_examples/classical_function2.py b/test_files/modifier_examples/classical_function2.py new file mode 100644 index 000000000..d7ed90a0e --- /dev/null +++ b/test_files/modifier_examples/classical_function2.py @@ -0,0 +1,51 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "guppylang @ git+https://github.com/Quantinuum/guppylang.git@main#subdirectory=guppylang", +# ] +# /// +"""A simple controlled gate using modifiers""" + +from pathlib import Path +from sys import argv +import sys + +from guppylang import guppy +from guppylang.std.builtins import control, dagger +from guppylang.std.debug import state_result +from guppylang.std.quantum import discard, qubit, angle, measure +from guppylang.std.quantum import h, rx, x + +sys.path.append(str(Path(__file__).resolve().parents[1])) + +from guppylang.experimental import enable_experimental_features + +enable_experimental_features() + + +@guppy +def fuu(i: int) -> int: + q = qubit() + x(q) + if measure(q): + i = i + 1 + return i + + +@guppy +def main() -> None: + t = qubit() + c1 = qubit() + h(c1) + with control(c1): + d = fuu(2) + with dagger: + rx(t, angle(1 / d)) + + state_result("r", c1, t) + discard(c1) + discard(t) + + +program = main.compile() +Path(argv[0]).with_suffix(".hugr").write_bytes(program.to_bytes()) diff --git a/test_files/modifier_examples/classical_function3.hugr b/test_files/modifier_examples/classical_function3.hugr new file mode 100644 index 000000000..0f72c6b8b Binary files /dev/null and b/test_files/modifier_examples/classical_function3.hugr differ diff --git a/test_files/modifier_examples/classical_function3.py b/test_files/modifier_examples/classical_function3.py new file mode 100644 index 000000000..2b51e1343 --- /dev/null +++ b/test_files/modifier_examples/classical_function3.py @@ -0,0 +1,55 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "guppylang @ git+https://github.com/Quantinuum/guppylang.git@main#subdirectory=guppylang", +# ] +# /// +"""A simple controlled gate using modifiers""" + +from pathlib import Path +from sys import argv +import sys + +from guppylang import guppy +from guppylang.std.builtins import control, dagger +from guppylang.std.debug import state_result +from guppylang.std.quantum import discard, qubit, angle, measure +from guppylang.std.quantum import h, rx, x + +sys.path.append(str(Path(__file__).resolve().parents[1])) + +from guppylang.experimental import enable_experimental_features + +enable_experimental_features() + + +@guppy +def fuu(i: int) -> int: + q = qubit() + x(q) + if measure(q): + i = i + 1 + return i + + +@guppy +def main() -> None: + t = qubit() + c1 = qubit() + c2 = qubit() + h(c1) + h(c2) + with control(c1): + d = fuu(2) + with control(c2): + with dagger: + rx(t, angle(1 / d)) + + state_result("r", c1, c2, t) + discard(c1) + discard(c2) + discard(t) + + +program = main.compile() +Path(argv[0]).with_suffix(".hugr").write_bytes(program.to_bytes()) diff --git a/test_files/modifier_examples/ctrl_array_controller.hugr b/test_files/modifier_examples/ctrl_array_controller.hugr new file mode 100644 index 000000000..5702d2ed2 Binary files /dev/null and b/test_files/modifier_examples/ctrl_array_controller.hugr differ diff --git a/test_files/modifier_examples/ctrl_array_controller.py b/test_files/modifier_examples/ctrl_array_controller.py new file mode 100644 index 000000000..f48021e9e --- /dev/null +++ b/test_files/modifier_examples/ctrl_array_controller.py @@ -0,0 +1,50 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "guppylang ==0.21.13", +# ] +# /// +"""A controlled gate where the controller is an array of qubits""" + +from pathlib import Path +from sys import argv +import sys + +from guppylang import guppy +from guppylang.std.builtins import array, control +from guppylang.std.debug import state_result +from guppylang.std.quantum import discard, discard_array, qubit +from guppylang.std.quantum import h, x + +sys.path.append(str(Path(__file__).resolve().parents[1])) + +from guppylang.experimental import enable_experimental_features + +enable_experimental_features() + + +@guppy(unitary=True) +def bar(q: qubit) -> None: + x(q) + + +@guppy +def main() -> None: + controllers: array[qubit, 3] = array(qubit(), qubit(), qubit()) + t = qubit() + + h(controllers[0]) + h(controllers[1]) + h(controllers[2]) + + with control(controllers): + bar(t) + + state_result("r", controllers[0], controllers[1], controllers[2], t) + + discard_array(controllers) + discard(t) + + +program = main.compile() +Path(argv[0]).with_suffix(".hugr").write_bytes(program.to_bytes()) diff --git a/test_files/modifier_examples/ctrl_on_call1.hugr b/test_files/modifier_examples/ctrl_on_call1.hugr new file mode 100644 index 000000000..343692010 Binary files /dev/null and b/test_files/modifier_examples/ctrl_on_call1.hugr differ diff --git a/test_files/modifier_examples/ctrl_on_call1.py b/test_files/modifier_examples/ctrl_on_call1.py new file mode 100644 index 000000000..9b8f1a17e --- /dev/null +++ b/test_files/modifier_examples/ctrl_on_call1.py @@ -0,0 +1,45 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "guppylang ==0.21.13", +# ] +# /// +"""A simple controlled gate using modifiers""" + +from pathlib import Path +from sys import argv +import sys + +from guppylang import guppy +from guppylang.std.builtins import control +from guppylang.std.debug import state_result +from guppylang.std.quantum import discard, qubit +from guppylang.std.quantum import h, x + +sys.path.append(str(Path(__file__).resolve().parents[1])) + +from guppylang.experimental import enable_experimental_features + +enable_experimental_features() + + +@guppy(unitary=True) +def bar(q: qubit) -> None: + x(q) + + +@guppy +def main() -> None: + q1 = qubit() + q2 = qubit() + h(q1) + with control(q1): + bar(q2) + + state_result("r", q1, q2) + discard(q1) + discard(q2) + + +program = main.compile() +Path(argv[0]).with_suffix(".hugr").write_bytes(program.to_bytes()) diff --git a/test_files/modifier_examples/double_modifier.hugr b/test_files/modifier_examples/double_modifier.hugr new file mode 100644 index 000000000..b94054dd8 Binary files /dev/null and b/test_files/modifier_examples/double_modifier.hugr differ diff --git a/test_files/modifier_examples/double_modifier.py b/test_files/modifier_examples/double_modifier.py new file mode 100644 index 000000000..c60d75604 --- /dev/null +++ b/test_files/modifier_examples/double_modifier.py @@ -0,0 +1,41 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "guppylang @ git+https://github.com/Quantinuum/guppylang.git@main#subdirectory=guppylang", +# ] +# /// +"""A simple controlled gate using modifiers""" + +from pathlib import Path +from sys import argv +import sys + +from guppylang import guppy +from guppylang.std.builtins import control, dagger +from guppylang.std.debug import state_result +from guppylang.std.quantum import discard, qubit, angle +from guppylang.std.quantum import h, rx + +sys.path.append(str(Path(__file__).resolve().parents[1])) + +from guppylang.experimental import enable_experimental_features + +enable_experimental_features() + + +@guppy +def main() -> None: + c1 = qubit() + t = qubit() + h(c1) + with control(c1): + with dagger: + rx(t, angle(1 / 3)) + + state_result("r", c1, t) + discard(c1) + discard(t) + + +program = main.compile() +Path(argv[0]).with_suffix(".hugr").write_bytes(program.to_bytes()) diff --git a/test_files/modifier_examples/multiple_dagger.hugr b/test_files/modifier_examples/multiple_dagger.hugr new file mode 100644 index 000000000..7f257c909 Binary files /dev/null and b/test_files/modifier_examples/multiple_dagger.hugr differ diff --git a/test_files/modifier_examples/multiple_dagger.py b/test_files/modifier_examples/multiple_dagger.py new file mode 100644 index 000000000..9cfb4d07c --- /dev/null +++ b/test_files/modifier_examples/multiple_dagger.py @@ -0,0 +1,47 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "guppylang ==0.21.13", +# ] +# /// +"""An example with an even number of daggers, which should cancel out""" + +from pathlib import Path +from sys import argv +import sys + +from guppylang import guppy +from guppylang.std.builtins import dagger +from guppylang.std.debug import state_result +from guppylang.std.quantum import discard, qubit, angle +from guppylang.std.quantum import rx + +sys.path.append(str(Path(__file__).resolve().parents[1])) + +from guppylang.experimental import enable_experimental_features + +enable_experimental_features() + + +@guppy(unitary=True) +def rotation(q: qubit) -> None: + rx(q, angle(1 / 4)) + + +@guppy +def main() -> None: + t = qubit() + + with dagger: + with dagger: + rotation(t) + + with dagger, dagger, dagger: + rotation(t) + + state_result("r", t) + discard(t) + + +program = main.compile() +Path(argv[0]).with_suffix(".hugr").write_bytes(program.to_bytes()) diff --git a/test_files/modifier_examples/nested_ctrl_dagger1.hugr b/test_files/modifier_examples/nested_ctrl_dagger1.hugr new file mode 100644 index 000000000..a5f6512ba Binary files /dev/null and b/test_files/modifier_examples/nested_ctrl_dagger1.hugr differ diff --git a/test_files/modifier_examples/nested_ctrl_dagger1.py b/test_files/modifier_examples/nested_ctrl_dagger1.py new file mode 100644 index 000000000..e01a23247 --- /dev/null +++ b/test_files/modifier_examples/nested_ctrl_dagger1.py @@ -0,0 +1,62 @@ +# /// script +# requires-python = ">=3.13" +# dependencies = [ +# "guppylang @ git+https://github.com/Quantinuum/guppylang.git@main#subdirectory=guppylang", +# ] +# /// +"""Nested control and dagger modifiers in various combinations""" + +from pathlib import Path +from sys import argv +import sys + +from guppylang import guppy +from guppylang.std.builtins import control, dagger +from guppylang.std.debug import state_result +from guppylang.std.quantum import discard, qubit, angle +from guppylang.std.quantum import h, rx, x + +sys.path.append(str(Path(__file__).resolve().parents[1])) + +from guppylang.experimental import enable_experimental_features + +enable_experimental_features() + + +@guppy(unitary=True) +def rotation(q: qubit) -> None: + rx(q, angle(-1 / 3)) + + +@guppy(unitary=True) +def flip(q: qubit) -> None: + x(q) + + +@guppy +def main() -> None: + c1 = qubit() + c2 = qubit() + t1 = qubit() + t2 = qubit() + + h(c1) + h(c2) + + with control(c1): + with dagger: + rotation(t2) + + with dagger: + with control(c2): + rotation(t1) + + state_result("r", c1, c2, t1, t2) + discard(c1) + discard(c2) + discard(t1) + discard(t2) + + +program = main.compile() +Path(argv[0]).with_suffix(".hugr").write_bytes(program.to_bytes()) diff --git a/test_files/run_modifier_examples/apply_passes.py b/test_files/run_modifier_examples/apply_passes.py index a278289ed..82d52f33f 100644 --- a/test_files/run_modifier_examples/apply_passes.py +++ b/test_files/run_modifier_examples/apply_passes.py @@ -1,17 +1,12 @@ +import sys +from pathlib import Path + +from hugr.build.base import Hugr from tket.passes import ( - NormalizeGuppy, ModifierResolverPass, ) - -from hugr.build.base import Hugr - - -from pathlib import Path -import sys - - -normalize = NormalizeGuppy() +mr_pass = ModifierResolverPass() def _hugr_from_path(str_path: str) -> Hugr: @@ -21,23 +16,23 @@ def _hugr_from_path(str_path: str) -> Hugr: return h -mr_pass = ModifierResolverPass() -modifier_examples_dir = Path(__file__).resolve().parents[1] / "modifier_examples" -modified_hugrs_dir = Path(__file__).resolve().parents[1] / "modified_hugrs" -modified_hugrs_dir.mkdir(parents=True, exist_ok=True) - +def apply_passes(input_paths: list[Path], output_dir: Path) -> None: + for input_path in input_paths: + print(f"Processing {input_path.name}") + hugr = _hugr_from_path(str(input_path)) + resolved: Hugr = mr_pass(hugr) -input_paths = ( - [modifier_examples_dir / (sys.argv[1] + ".hugr")] - if len(sys.argv) > 1 - else modifier_examples_dir.glob("*.hugr") -) + output_path = output_dir / f"{input_path.stem}_solved.hugr" + output_path.write_bytes(resolved.to_bytes()) -for input_path in input_paths: - print(f"Processing {input_path.name}") - modifier_hugr = _hugr_from_path(str(input_path)) - normalized = normalize(modifier_hugr) - resolved: Hugr = mr_pass(normalized) - output_path = modified_hugrs_dir / f"{input_path.stem}_solved.hugr" - output_path.write_bytes(resolved.to_bytes()) +if __name__ == "__main__": + modifier_examples_dir = Path(__file__).resolve().parents[1] / "modifier_examples" + modified_hugrs_dir = Path(__file__).resolve().parent / "modified_hugrs" + modified_hugrs_dir.mkdir(parents=True, exist_ok=True) + input_paths = ( + [modifier_examples_dir / (sys.argv[1] + ".hugr")] + if len(sys.argv) > 1 + else modifier_examples_dir.glob("*.hugr") + ) + apply_passes(input_paths, modified_hugrs_dir) diff --git a/test_files/run_modifier_examples/hugr_results.txt b/test_files/run_modifier_examples/hugr_results.txt index e4b55ef34..1777a0cc3 100644 --- a/test_files/run_modifier_examples/hugr_results.txt +++ b/test_files/run_modifier_examples/hugr_results.txt @@ -1,3 +1,33 @@ +classical_function1_solved: + 0 -> 0.866+0j + 1 -> 0+0.5j +----- +classical_function2_solved: + 00 -> 0.7071+0j + 10 -> 0.6124+0j + 11 -> 0+0.3536j +----- +classical_function3_solved: + 000 -> 0.5+0j + 010 -> 0.5+0j + 100 -> 0.5+0j + 110 -> 0.433+0j + 111 -> 0+0.25j +----- +ctrl_array_controller_solved: + 0000 -> 0.3536+0j + 0010 -> 0.3536+0j + 0100 -> 0.3536+0j + 0110 -> 0.3536+0j + 1000 -> 0.3536+0j + 1010 -> 0.3536+0j + 1100 -> 0.3536+0j + 1111 -> 0.3536+0j +----- +ctrl_on_call1_solved: + 00 -> 0.7071+0j + 11 -> 0.7071+0j +----- ctrl_on_call2_solved: 0110 -> 0.7071+0j 1110 -> 0.6124+0j @@ -14,3 +44,22 @@ ctrl_on_x_solved: dagger_on_call_solved: 0 -> 0.7071+0j 1 -> 0-0.7071j +----- +double_modifier_solved: + 00 -> 0.7071+0j + 10 -> 0.6124+0j + 11 -> 0+0.3536j +----- +multiple_dagger_solved: + 0 -> 1+0j +----- +nested_ctrl_dagger1_solved: + 0000 -> 0.5+0j + 0100 -> 0.433+0j + 0110 -> 0-0.25j + 1000 -> 0.433+0j + 1001 -> 0-0.25j + 1100 -> 0.375+0j + 1101 -> 0-0.2165j + 1110 -> 0-0.2165j + 1111 -> -0.125+0j diff --git a/test_files/run_modifier_examples/hugr_results/classical_function1_solved.npy b/test_files/run_modifier_examples/hugr_results/classical_function1_solved.npy new file mode 100644 index 000000000..19aab2972 Binary files /dev/null and b/test_files/run_modifier_examples/hugr_results/classical_function1_solved.npy differ diff --git a/test_files/run_modifier_examples/hugr_results/classical_function2_solved.npy b/test_files/run_modifier_examples/hugr_results/classical_function2_solved.npy new file mode 100644 index 000000000..aceda8cdf Binary files /dev/null and b/test_files/run_modifier_examples/hugr_results/classical_function2_solved.npy differ diff --git a/test_files/run_modifier_examples/hugr_results/classical_function3_solved.npy b/test_files/run_modifier_examples/hugr_results/classical_function3_solved.npy new file mode 100644 index 000000000..88939d912 Binary files /dev/null and b/test_files/run_modifier_examples/hugr_results/classical_function3_solved.npy differ diff --git a/test_files/run_modifier_examples/hugr_results/ctrl_array_controller_solved.npy b/test_files/run_modifier_examples/hugr_results/ctrl_array_controller_solved.npy new file mode 100644 index 000000000..7a83ca43d Binary files /dev/null and b/test_files/run_modifier_examples/hugr_results/ctrl_array_controller_solved.npy differ diff --git a/test_files/run_modifier_examples/hugr_results/ctrl_on_call1_solved.npy b/test_files/run_modifier_examples/hugr_results/ctrl_on_call1_solved.npy new file mode 100644 index 000000000..1b4e66782 Binary files /dev/null and b/test_files/run_modifier_examples/hugr_results/ctrl_on_call1_solved.npy differ diff --git a/test_files/run_modifier_examples/hugr_results/double_modifier_solved.npy b/test_files/run_modifier_examples/hugr_results/double_modifier_solved.npy new file mode 100644 index 000000000..aceda8cdf Binary files /dev/null and b/test_files/run_modifier_examples/hugr_results/double_modifier_solved.npy differ diff --git a/test_files/run_modifier_examples/hugr_results/multiple_dagger_solved.npy b/test_files/run_modifier_examples/hugr_results/multiple_dagger_solved.npy new file mode 100644 index 000000000..3ecbf3051 Binary files /dev/null and b/test_files/run_modifier_examples/hugr_results/multiple_dagger_solved.npy differ diff --git a/test_files/run_modifier_examples/hugr_results/nested_ctrl_dagger1_solved.npy b/test_files/run_modifier_examples/hugr_results/nested_ctrl_dagger1_solved.npy new file mode 100644 index 000000000..ec7489319 Binary files /dev/null and b/test_files/run_modifier_examples/hugr_results/nested_ctrl_dagger1_solved.npy differ diff --git a/test_files/run_modifier_examples/justfile b/test_files/run_modifier_examples/justfile index 743dc98ac..86f90713a 100644 --- a/test_files/run_modifier_examples/justfile +++ b/test_files/run_modifier_examples/justfile @@ -11,7 +11,6 @@ run-hugrs: uv run apply_passes.py @echo "---- Running hugrs ----" uv run --no-project "run_hugrs.py" - # Re-generate a single hugr file. rh example: diff --git a/test_files/run_modifier_examples/run_hugrs.py b/test_files/run_modifier_examples/run_hugrs.py index e46b4bc0c..92aebad7d 100644 --- a/test_files/run_modifier_examples/run_hugrs.py +++ b/test_files/run_modifier_examples/run_hugrs.py @@ -7,6 +7,7 @@ """Run on selene the passed hugrs""" from pathlib import Path +import shutil import sys import numpy as np import numpy.typing as npt @@ -14,6 +15,7 @@ from hugr import Hugr from guppylang.emulator import EmulatorBuilder + sys.path.append(str(Path(__file__).resolve().parents[1])) @@ -35,7 +37,7 @@ def format_statevector( return "\n".join(parts) if parts else "all amplitudes below threshold" -modifier_examples_dir = Path(__file__).resolve().parents[1] / "modified_hugrs" +modifier_examples_dir = Path(__file__).resolve().parent / "modified_hugrs" result_execution_dir = Path(__file__).resolve().parent / "hugr_results" print(modifier_examples_dir) @@ -79,3 +81,5 @@ def format_statevector( result_path = Path(__file__).resolve().parent / "hugr_results.txt" result_path.parent.mkdir(parents=True, exist_ok=True) result_path.write_text("\n-----\n".join(all_results) + "\n") + +shutil.rmtree(modifier_examples_dir) diff --git a/tket-py/test/test_pass.py b/tket-py/test/test_pass.py index 8f4bc3ae4..f5481923f 100644 --- a/tket-py/test/test_pass.py +++ b/tket-py/test/test_pass.py @@ -1,3 +1,4 @@ +import importlib.util import tempfile from pytket import Circuit, OpType @@ -239,32 +240,45 @@ def test_modifier_resolver() -> None: mr_pass = ModifierResolverPass() modifier_hugr: Hugr = _hugr_from_path("test_files/guppy_examples/modifiers.hugr") - normalized = normalize(modifier_hugr) + modifier_hugr = normalize(modifier_hugr) - assert _count_ops(normalized, "tket.modifier.ControlModifier") == 1 - assert _count_ops(normalized, "tket.modifier.DaggerModifier") == 1 + assert _count_ops(modifier_hugr, "tket.modifier.ControlModifier") == 1 + assert _count_ops(modifier_hugr, "tket.modifier.DaggerModifier") == 1 - resolved: Hugr = mr_pass(normalized) + resolved: Hugr = mr_pass(modifier_hugr) assert _count_ops(resolved, "tket.modifier.ControlModifier") == 0 assert _count_ops(resolved, "tket.modifier.DaggerModifier") == 0 def test_modifier_execution() -> None: - modified_hugrs_dir = Path("test_files/modified_hugrs") + modifier_examples_dir = Path("test_files/modifier_examples") hugr_results_dir = Path("test_files/run_modifier_examples/hugr_results") run_hugrs_dir = Path("test_files/run_modifier_examples") + apply_passes_path = run_hugrs_dir / "apply_passes.py" + spec = importlib.util.spec_from_file_location( + "run_modifier_examples_apply_passes", apply_passes_path + ) + assert spec is not None + assert spec.loader is not None + apply_passes_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(apply_passes_module) + apply_passes = apply_passes_module.apply_passes expected_results = { expected_path.stem: np.load(expected_path).copy() for expected_path in sorted(hugr_results_dir.glob("*.npy")) } - - for hugr_path in sorted(modified_hugrs_dir.glob("*.hugr")): - hugr_name = hugr_path.stem.removesuffix("_solved") - expected_statevector = expected_results[hugr_path.stem] + for hugr_path in sorted(modifier_examples_dir.glob("*.hugr")): + hugr_name = hugr_path.stem + expected_statevector = expected_results[f"{hugr_name}_solved"] with tempfile.TemporaryDirectory() as tmp_dir: + generated_hugrs_dir = Path(tmp_dir) / "modified_hugrs" + generated_hugrs_dir.mkdir() + apply_passes([hugr_path], generated_hugrs_dir) + + (run_hugrs_dir / "modified_hugrs").mkdir(exist_ok=True) tmp_path = Path(tmp_dir) / f"{hugr_name}.npy" subprocess.run( [ @@ -274,7 +288,7 @@ def test_modifier_execution() -> None: "--python", "3.13", "run_hugrs.py", - hugr_name, + str((generated_hugrs_dir / hugr_name).resolve()), str(tmp_path), ], cwd=run_hugrs_dir, diff --git a/tket/src/modifier.rs b/tket/src/modifier.rs index 07a0186ff..aada78cc7 100644 --- a/tket/src/modifier.rs +++ b/tket/src/modifier.rs @@ -64,8 +64,8 @@ impl ModifierFlags { fn from_metadata(h: &impl HugrView, n: N) -> Option { h.get_metadata::(n) .map(|num| ModifierFlags { - dagger: (num & 1) != 0, - control: (num & 2) != 0, + control: (num & 1) != 0, + dagger: (num & 2) != 0, power: (num & 4) != 0, }) } diff --git a/tket/src/modifier/control.rs b/tket/src/modifier/control.rs index 02f495588..e67c9653a 100644 --- a/tket/src/modifier/control.rs +++ b/tket/src/modifier/control.rs @@ -13,7 +13,7 @@ pub struct ModifierControl(usize); impl ModifierControl { /// Create a new ModifierControl with a specific number of controls. - pub fn new(num: usize) -> Self { + fn new(num: usize) -> Self { ModifierControl(num) } } @@ -24,7 +24,7 @@ impl Default for ModifierControl { } impl ModifierControl { /// Signature for the control modifier. - pub fn signature() -> SignatureFunc { + pub(crate) fn signature() -> SignatureFunc { PolyFuncTypeRV::new( [ TypeParam::max_nat_type(), diff --git a/tket/src/modifier/dagger.rs b/tket/src/modifier/dagger.rs index d4e27b425..f226ad51f 100644 --- a/tket/src/modifier/dagger.rs +++ b/tket/src/modifier/dagger.rs @@ -12,7 +12,7 @@ pub struct ModifierDagger; impl ModifierDagger { /// Create a new ModifierDagger. - pub fn new() -> Self { + fn new() -> Self { ModifierDagger } } @@ -34,7 +34,7 @@ impl FromStr for ModifierDagger { } impl ModifierDagger { /// Signature for the dagger modifier. - pub fn signature() -> SignatureFunc { + pub(crate) fn signature() -> SignatureFunc { PolyFuncTypeRV::new( [ TypeParam::new_list_type(TypeBound::Linear), diff --git a/tket/src/modifier/modifier_resolver.rs b/tket/src/modifier/modifier_resolver.rs index 2392f7ac6..6f31c8f8c 100644 --- a/tket/src/modifier/modifier_resolver.rs +++ b/tket/src/modifier/modifier_resolver.rs @@ -42,7 +42,7 @@ //! When dagger is applied, the order of nodes to be processed is reversed, //! since the control qubits are passed in the reverse order. //! After visiting all children, `modify_dfg_body` calls -//! [`connect_all`](ModifierResolver::connect_all) to connect all wires that are registered +//! ModifierResolver::connect_all to connect all wires that are registered //! in the correspondence map. //! //! Importantly, when dagger is applied, not only the order of nodes is reversed, @@ -112,7 +112,10 @@ pub mod global_phase_modify; pub mod tket_op_modify; use super::{CombinedModifier, ModifierFlags}; -use crate::passes::utils::unpack_container::TypeUnpacker; +use crate::passes::{ + ComposablePass, RemoveDeadFuncsPass, WithScope, composable::Preserve, + utils::unpack_container::TypeUnpacker, +}; use crate::{TketOp, extension::global_phase::GlobalPhase, modifier::Modifier}; use global_phase_modify::delete_phase; @@ -144,12 +147,12 @@ impl std::fmt::Display for DirWire { impl DirWire { /// Create a new DirWire. - pub fn new(node: N, port: Port) -> Self { + fn new(node: N, port: Port) -> Self { DirWire(node, port) } /// Reverse the direction of the wire. - pub fn reverse(self) -> Self { + pub(crate) fn reverse(self) -> Self { let index = self.1.index(); let port = match self.1.as_directed() { Either::Left(_in) => OutgoingPort::from(index).into(), @@ -328,7 +331,7 @@ pub struct ModifierResolver { impl ModifierResolver { /// Create a new modifier resolver. - pub fn new() -> Self { + fn new() -> Self { ModifierResolver { modifiers: CombinedModifier::default(), corresp_map: HashMap::default(), @@ -414,12 +417,12 @@ pub enum ModifierResolverErrors { impl ModifierResolverErrors { /// Create an unreachable error. - pub fn unreachable(msg: impl Into) -> Self { + fn unreachable(msg: impl Into) -> Self { Self::Unreachable { msg: msg.into() } } /// Create an unresolvable error. - pub fn unresolvable(node: N, msg: impl Into, optype: OpType) -> Self { + fn unresolvable(node: N, msg: impl Into, optype: OpType) -> Self { Self::UnResolvable { node, msg: msg.into(), @@ -461,6 +464,7 @@ impl ModifierResolver { *self.worklist() = worklist; r } + fn with_modifiers( &mut self, modifiers: CombinedModifier, @@ -471,6 +475,7 @@ impl ModifierResolver { *self.modifiers_mut() = modifiers; r } + fn with_ancilla( &mut self, wire: &mut Wire, @@ -602,7 +607,7 @@ impl ModifierResolver { } /// connects all the wires in the builder. - pub fn connect_all( + fn connect_all( &mut self, h: &impl HugrView, new_dfg: &mut impl Container, @@ -668,7 +673,7 @@ impl ModifierResolver { // Verify that the rewrite can be applied. self.verify(hugr, modifier_node)?; - // the ports that takes inputs from the modified function. + // The ports that takes inputs from the modified function to the IndirectCall node. let modified_fn_loader: Vec<(_, Vec<_>)> = hugr .node_outputs(modifier_node) .map(|p| (p, hugr.linked_inputs(modifier_node, p).collect())) @@ -680,7 +685,7 @@ impl ModifierResolver { let new_load = self.with_modifiers(modifiers, |this| { this.apply_modifier_chain_to_loaded_fn(hugr, modifier_node) })?; - + // NICOLA: the fail is before here! // Connect the modified function to the inputs for (out_port, inputs) in modified_fn_loader { for (recv, recv_port) in inputs { @@ -688,7 +693,6 @@ impl ModifierResolver { hugr.connect(new_load, out_port, recv, recv_port); } } - Ok(()) } @@ -696,7 +700,7 @@ impl ModifierResolver { /// flatten = true means that control qubits are represented as individual wires, /// while false means that they are packed to some arrays. /// This false mode is used for function definitions, - pub fn modify_signature(&self, signature: &mut Signature, flatten: bool) { + fn modify_signature(&self, signature: &mut Signature, flatten: bool) { let FuncTypeBase { input, output } = signature; if flatten { @@ -724,39 +728,42 @@ impl ModifierResolver { match optype { // Skip input/output nodes: it should be handled by its parent as it sets control qubits. OpType::Input(_) | OpType::Output(_) => {} - // CFG OpType::CFG(cfg) => self.modify_cfg(h, target_node, cfg, new_dfg)?, - // DFGs OpType::DFG(dfg) => self.modify_dfg(h, target_node, dfg, new_dfg)?, + // TailLoop OpType::TailLoop(tail_loop) => { self.modify_tail_loop(h, target_node, tail_loop, new_dfg)? } + // Conditional OpType::Conditional(conditional) => { self.modify_conditional(h, target_node, conditional, new_dfg)? } - // Function calls OpType::Call(_) => self.modify_call(h, target_node, optype, new_dfg)?, + // Indirect call OpType::CallIndirect(indir_call) => { self.modify_indirect_call(h, target_node, indir_call, new_dfg)? } + // Load function OpType::LoadFunction(load) => { self.modify_load_function(h, target_node, load, new_dfg)? } - // Operations OpType::ExtensionOp(_) => { self.modify_extension_op(h, target_node, optype, new_dfg)?; } + // Constants OpType::Const(constant) => { self.modify_constant(target_node, constant, new_dfg)?; } + // Load constant OpType::LoadConstant(_) | OpType::OpaqueOp(_) | OpType::Tag(_) => { self.add_node_no_modification(h, target_node, optype.clone(), new_dfg)?; } + // Invalid nodes OpType::FuncDefn(_) | OpType::FuncDecl(_) | OpType::Module(_) => { return Err(ModifierResolverErrors::unreachable(format!( "Invalid node found inside modified function (OpType = {})", @@ -796,8 +803,8 @@ impl ModifierResolver { /// If the dagger is not applied, the ports are mapped directly. /// If the dagger is applied, the quantum input/output ports are swapped. /// Inputs: - /// * `n`: the old node - /// * `node`: the new node + /// * `old_node`: the old node + /// * `new_node`: the new node /// * `inputs`/`outputs`: the types of the input/output ports of the old node /// * `input_offset`/`output_offset`: the offset of the ports of the old and new node /// - e.g., for IndirectCall, the first input port is the loaded function, which we want to ignore here. @@ -822,8 +829,8 @@ impl ModifierResolver { /// TODO: Handle state order edges. fn wire_node_inout<'a>( &mut self, - n: N, - node: Node, + old_node: N, + new_node: Node, (inputs, outputs): ( impl Iterator, impl Iterator, @@ -831,8 +838,8 @@ impl ModifierResolver { (input_offset, output_offset, new_offset): (usize, usize, usize), ) -> Result<(), ModifierResolverErrors> { self.wire_inout( - (n, n), - (node, node), + (old_node, old_node), + (new_node, new_node), (inputs, outputs), (input_offset, output_offset, new_offset), ) @@ -878,7 +885,8 @@ impl ModifierResolver { out_ty = outputs.next(); } - // If both are quantum types, wire them in the opposite direction until the next non-quantum type + // If both are quantum types, wire them in the opposite direction (if dagger is applied) + // until the next non-quantum type while let Some(ty) = in_ty { if !self.qubit_finder.contains_element_type(ty) { break; @@ -1017,19 +1025,19 @@ impl ModifierResolver { fn modify_cfg( &mut self, h: &mut impl HugrMut, - n: N, + cfg_node: N, cfg: &CFG, new_dfg: &mut impl Container, ) -> Result<(), ModifierResolverErrors> { // Check if the CFG contains only one block. let children: Vec = h - .children(n) + .children(cfg_node) .filter(|child| h.get_optype(*child).is_dataflow_block()) .collect(); // NOTE: this check prevents breaking modifier application to branching or loops if children.len() != 1 { return Err(ModifierResolverErrors::unresolvable( - n, + cfg_node, "CFG with more than one node found.".to_string(), cfg.clone().into(), )); @@ -1038,24 +1046,30 @@ impl ModifierResolver { let mut signature = cfg.signature.clone(); self.modify_signature(&mut signature, true); + let mut new_cfg = CFGBuilder::new(signature.clone())?; let mut new_bb = new_cfg.entry_builder([type_row![]], signature.output.clone())?; self.modify_dfg_body(h, old_bb, &mut new_bb)?; + let bb_id = new_bb.finish_sub_container()?; new_cfg.branch(&bb_id, 0, &new_cfg.exit_block())?; - let new = self.insert_sub_dfg(new_dfg, new_cfg)?; + let new_node = self.insert_sub_dfg(new_dfg, new_cfg)?; // connect the controls and register the IOs for (i, c) in self.controls().iter_mut().enumerate() { - new_dfg.hugr_mut().connect(c.node(), c.source(), new, i); - *c = Wire::new(new, i); + new_dfg + .hugr_mut() + .connect(c.node(), c.source(), new_node, i); + *c = Wire::new(new_node, i); } + let offset = self.control_num(); + self.wire_node_inout( - n, - new, - (signature.input.iter(), signature.output.iter()), + cfg_node, + new_node, + (cfg.signature.input.iter(), cfg.signature.output.iter()), (0, 0, offset), )?; // self.wire_others(n, cfg.into(), new, new_dfg.hugr().get_optype(new))?; @@ -1076,6 +1090,7 @@ pub fn resolve_modifier_with_entrypoints( ) -> Result<(), ModifierResolverErrors> { use ModifierResolverErrors::*; + // Collect entry points into a deque so they can be cloned for later cleanup passes. let entry_points: VecDeque<_> = entry_points.into_iter().collect(); // Walk all nodes reachable from the entry points (children and neighbours) @@ -1083,6 +1098,7 @@ pub fn resolve_modifier_with_entrypoints( let mut resolver = ModifierResolver::new(); let mut worklist = entry_points.clone(); let mut visited = vec![]; + while let Some(node) = worklist.pop_front() { // Skip nodes that have been removed during previous rewrites or already visited. if !h.contains_node(node) || visited.contains(&node) { @@ -1153,6 +1169,13 @@ pub fn resolve_modifier_with_entrypoints( // were produced or left behind by the resolution passes above. delete_phase(h, entry_points)?; + // At end we delete dead code: i.e. old function blocks that have been replaced by modified + // versions but are still present as unreachable code. + RemoveDeadFuncsPass::default() + .with_scope(Preserve::Public) + .run(h) + .unwrap(); + h.validate() .map_err(|e| ModifierResolverErrors::BuildError(e.into()))?; @@ -1163,6 +1186,8 @@ pub fn resolve_modifier_with_entrypoints( #[cfg(test)] mod tests { + use std::{fs, io::BufReader, path::Path}; + use cool_asserts::assert_matches; use hugr::{ Hugr, @@ -1198,12 +1223,20 @@ mod tests { /// ``` /// where `foo` is supplied by the caller. /// + /// Parameters: + /// * `target_num` – number of plain qubit (target) arguments that `foo` accepts. + /// * `ctrl_num` – number of control qubits to wrap around `foo`. + /// * `foo` – closure that inserts the function-under-test into the module and + /// returns its `FuncID`. + /// * `dagger` – if `true`, a `Dagger` modifier is inserted before the `Control` + /// modifier, so the full chain is `Dagger → Control`. pub(crate) fn test_modifier_resolver( target_num: usize, ctrl_num: u64, foo: impl FnOnce(&mut ModuleBuilder, usize) -> FuncID, dagger: bool, ) { + // --- Build the module --- let mut module = ModuleBuilder::new(); // Signature used by the CallIndirect node: @@ -1262,6 +1295,8 @@ mod tests { // Build the "main" function body --- let _main = { let mut func = module.define_function("main", main_sig).unwrap(); + + // Load the function value; this is the wire that will be passed through modifiers. let mut call = func.load_func(&foo, &[]).unwrap(); if dagger { @@ -1272,13 +1307,13 @@ mod tests { .out_wire(0); } - // Wrap with the Control modifier. + // Wrap the (possibly daggered) function reference with the Control modifier. call = func .add_dataflow_op(control_op, vec![call]) .unwrap() .out_wire(0); - // Allocate ctrl_num qubits + // Allocate ctrl_num fresh qubits to serve as control qubits. let mut controls = Vec::new(); for _ in 0..ctrl_num { controls.push( @@ -1287,7 +1322,8 @@ mod tests { .out_wire(0), ); } - // Allocate target_num qubits + + // Allocate target_num fresh qubits to serve as target qubits. let mut targ = Vec::new(); for _ in 0..target_num { targ.push( @@ -1320,6 +1356,55 @@ mod tests { let entrypoint = h.entrypoint(); resolve_modifier_with_entrypoints(&mut h, [entrypoint]).unwrap(); + // The resolved hugr must still be structurally valid. + assert_matches!(h.validate(), Ok(())); + } + + const GUPPY_EXAMPLES_DIR: &str = "../test_files/modifier_examples"; + + fn load_guppy_example(name: &str) -> std::io::Result { + let file = Path::new(GUPPY_EXAMPLES_DIR).join(format!("{name}.hugr")); + let reader = fs::File::open(file)?; + let reader = BufReader::new(reader); + Ok(Hugr::load(reader, None).unwrap()) + } + + fn load_guppy_examples() -> std::io::Result> { + let mut files = fs::read_dir(GUPPY_EXAMPLES_DIR)? + .filter_map(|entry| { + let path = entry.ok()?.path(); + path.extension() + .is_some_and(|ext| ext == "hugr") + .then_some(path) + }) + .collect::>(); + files.sort_unstable(); + + files + .into_iter() + .map(|file| { + let name = file.file_stem().unwrap().to_string_lossy().into_owned(); + let h = load_guppy_example(&name)?; + Ok((name, h)) + }) + .collect() + } + + /// Resolve modifiers in `h` + fn test_resolve(h: &mut Hugr) { + assert_matches!(h.validate(), Ok(())); + + let entrypoint = h.entrypoint(); + resolve_modifier_with_entrypoints(h, [entrypoint]).unwrap(); + assert_matches!(h.validate(), Ok(())); } + + #[rstest::rstest] + fn test_saved_hugr() { + for (name, mut h) in load_guppy_examples().unwrap() { + println!("Resolving example: {name}"); + test_resolve(&mut h); + } + } } diff --git a/tket/src/modifier/modifier_resolver/call_modify.rs b/tket/src/modifier/modifier_resolver/call_modify.rs index e88d291c9..3af237932 100644 --- a/tket/src/modifier/modifier_resolver/call_modify.rs +++ b/tket/src/modifier/modifier_resolver/call_modify.rs @@ -1,5 +1,4 @@ //! Modify nodes related to function calls. - use hugr::{ IncomingPort, Wire, builder::{BuildError, Dataflow}, @@ -31,7 +30,7 @@ impl ModifierResolver { .unwrap(); // wire the callee - let Some(new_callee) = self.modify_fn_if_needed(h, callee.0, &call.signature())? else { + let Some(new_callee) = self.modify_fn_if_needed(h, callee.0)? else { // If the function need not be modified, just copy the Call node as is. let new = self.add_node_no_modification(h, call_node, call.clone(), new_dfg)?; self.call_map() @@ -81,14 +80,17 @@ impl ModifierResolver { // The final target of modifiers to apply. // Collection of modifiers to apply. let modifiers_and_targ = self.trace_modifiers_chain(h, modifier_node)?; + let targ = modifiers_and_targ .last() .cloned() .ok_or(ModifierError::NoTarget(modifier_node))?; - // The function to apply the modifier to. + // The function to apply the modifier to. This is expected to be a LoadFunction node let (func, load) = Self::get_loaded_function(h, modifier_node, targ, h.get_optype(targ))?; + h.remove_node(targ); + // Modify the function let modified_fn = self.modify_fn(h, func)?; @@ -136,7 +138,7 @@ impl ModifierResolver { /// Given a target node `targ` which is expected to be a `LoadFunction`, retrieve the function node it loads. pub(super) fn get_loaded_function( h: &impl HugrMut, - n: N, + modifier_node: N, targ: N, optype: &OpType, ) -> Result<(N, LoadFunction), ModifierError> { @@ -145,19 +147,25 @@ impl ModifierResolver { let (fn_node, _) = h.single_linked_output(targ, 0).unwrap(); let fn_optype = h.get_optype(fn_node); let OpType::FuncDefn(_) = fn_optype else { - return Err(ModifierError::ModifierNotApplicable(n, fn_optype.clone())); + return Err(ModifierError::ModifierNotApplicable( + modifier_node, + fn_optype.clone(), + )); }; // TODO: We want some machinery to prevent generating a lot of copies of modified functions // from the same function. Ok((fn_node, load.clone())) } - OpType::Input(_) => Err(ModifierError::NoTarget(n)), + OpType::Input(_) => Err(ModifierError::NoTarget(modifier_node)), // If the target is a function, we need to create a new dataflow block of it. _ => { // TODO: // In the future, we might want to handle modifiers provided from other nodes. // For example, conditionals? - Err(ModifierError::ModifierNotApplicable(n, optype.clone())) + Err(ModifierError::ModifierNotApplicable( + modifier_node, + optype.clone(), + )) } } } @@ -191,7 +199,7 @@ impl ModifierResolver { Self::get_loaded_function(h, n, targ, h.get_optype(targ)).map_err(wrap_err)?; // Modify the function - let modified_fn = match self.modify_fn_if_needed(h, func, &load.signature())? { + let modified_fn = match self.modify_fn_if_needed(h, func)? { Some(node) => node, None => self.wrap_fn_with_controls(h, func, &load.type_args)?, }; @@ -293,8 +301,49 @@ mod tests { .call(callee.handle(), &[], vec![inputs[0]]) .unwrap() .out_wire(0); + // inputs[0] = func + // .add_dataflow_op(TketOp::X, vec![inputs[0]]) + // .unwrap() + // .out_wire(0); + *func.finish_with_outputs(inputs).unwrap().handle() + } + + /// Nested call pattern: `foo(q) = foo1(q)`, `foo1(q) = bar(q)`, `bar(q) = X(q)`. + /// Tests that the resolver correctly propagates modifiers through a three-level call chain. + fn foo_modifier_on_function(module: &mut ModuleBuilder, t_num: usize) -> FuncID { + // bar: applies X to its single qubit argument. + let bar = { + let bar_sig = Signature::new_endo(vec![qb_t()]); + let mut bar_builder = module.define_function("inner", bar_sig).unwrap(); + bar_builder.set_unitary(); + let mut inputs: Vec = bar_builder.input_wires().collect(); + inputs[0] = bar_builder + .add_dataflow_op(TketOp::X, vec![inputs[0]]) + .unwrap() + .out_wire(0); + bar_builder.finish_with_outputs(inputs).unwrap() + }; + + // foo1: delegates entirely to bar. + let foo1 = { + let foo1_sig = Signature::new_endo(vec![qb_t()]); + let mut foo1_builder = module.define_function("outer", foo1_sig).unwrap(); + foo1_builder.set_unitary(); + let mut inputs: Vec = foo1_builder.input_wires().collect(); + inputs[0] = foo1_builder + .call(bar.handle(), &[], vec![inputs[0]]) + .unwrap() + .out_wire(0); + foo1_builder.finish_with_outputs(inputs).unwrap() + }; + + // foo: delegates entirely to foo1. + let foo_sig = Signature::new_endo(iter::repeat_n(qb_t(), t_num).collect::>()); + let mut func = module.define_function("foo", foo_sig).unwrap(); + func.set_unitary(); + let mut inputs: Vec<_> = func.input_wires().collect(); inputs[0] = func - .add_dataflow_op(TketOp::X, vec![inputs[0]]) + .call(foo1.handle(), &[], vec![inputs[0]]) .unwrap() .out_wire(0); *func.finish_with_outputs(inputs).unwrap().handle() @@ -405,13 +454,14 @@ mod tests { } #[rstest::rstest] + #[case::call_twice(1, 1, foo_modifier_on_function, false)] #[case::call(1, 1, foo_call, false)] #[case::call_dagger(1, 1, foo_call, true)] #[case::indir_call(1, 1, foo_indir_call, false)] #[case::indir_call_dagger(1, 1, foo_indir_call, true)] #[case::load_fn(1, 1, foo_load_fn, false)] #[case::nested_modifier(2, 2, foo_nested_modifier, false)] - pub fn test_call_modify( + fn test_call_modify( #[case] target_num: usize, #[case] ctrl_num: u64, #[case] foo: fn(&mut ModuleBuilder, usize) -> FuncID, diff --git a/tket/src/modifier/modifier_resolver/dfg_modify.rs b/tket/src/modifier/modifier_resolver/dfg_modify.rs index fd016d272..47dabc6a5 100644 --- a/tket/src/modifier/modifier_resolver/dfg_modify.rs +++ b/tket/src/modifier/modifier_resolver/dfg_modify.rs @@ -15,7 +15,7 @@ use hugr::{ hugr::hugrmut::HugrMut, ops::{Call, Conditional, DFG, DataflowBlock, DataflowOpTrait, OpType, TailLoop}, std_extensions::collections::array::ArrayOpBuilder, - types::{FuncTypeBase, Signature, TypeArg, TypeRow}, + types::{FuncTypeBase, TypeArg, TypeRow}, }; use hugr_core::hugr::internal::PortgraphNodeMap; use petgraph::visit::{Topo, Walker}; @@ -38,11 +38,14 @@ impl ModifierResolver { // Modify the input/output nodes beforehand. self.modify_in_out_node(h, parent_node, new_dfg)?; + // Modify the children nodes. self.modify_dfg_children(h, parent_node, new_dfg)?; self.wire_control_to_output(h, parent_node, new_dfg)?; + self.connect_all(h, new_dfg, parent_node)?; + mem::swap(self.controls(), &mut controls); mem::swap(self.corresp_map(), &mut corresp_map); @@ -68,6 +71,7 @@ impl ModifierResolver { worklist.push_back(node_map.from_portgraph(old_n_id)); } } + self.with_worklist(worklist, |this| { while let Some(working_node) = this.worklist().pop_front() { this.modify_op(h, working_node, new_dfg)?; @@ -295,26 +299,22 @@ impl ModifierResolver { // if only dagger, just check signature // // Also, it may be better to check with the usage (how it is instantiated). - pub fn modify_fn_if_needed( + pub(crate) fn modify_fn_if_needed( &mut self, h: &mut impl HugrMut, func: N, - signature: &Signature, ) -> Result, ModifierResolverErrors> { let satisfies = ModifierFlags::from_metadata(h, func) .is_some_and(|flags| flags.satisfies(&self.modifiers)); + if !satisfies { - let in_out_match = signature.input == signature.output; - if in_out_match { - // If the flag is not set and the signature does not show an evident problem, skip the modification. - return Ok(None); - } + return Ok(None); } Ok(Some(self.modify_fn(h, func)?)) } /// Generates a new function modified by the combined modifier. - pub fn modify_fn( + pub(crate) fn modify_fn( &mut self, h: &mut impl HugrMut, func: N, @@ -356,6 +356,7 @@ impl ModifierResolver { Ok(new_function_node) } + /// NICOLA: seemes to be useless /// Generates a new function that does not essentially modify the function itself /// but changes the signature to match the modified calls. /// The generated function just calls the original function. @@ -393,7 +394,10 @@ impl ModifierResolver { let call_port = call.called_function_port(); let call_node = builder.add_child_node(call); - // connect wires + // connect wires: + // - first `offset` inputs are control arrays, passed through directly to output + // - remaining inputs are forwarded to the inner call + // - call outputs are forwarded to the remaining output ports for i in 0..offset { builder.hugr_mut().connect(in_node, i, out_node, i); } @@ -801,7 +805,7 @@ mod test { #[case::conditional_dagger(1, 1, foo_conditional, true)] #[case::cfg(1, 1, foo_cfg, false)] #[case::cfg_dagger(1, 1, foo_cfg, true)] - pub fn test_dfg_modify( + fn test_dfg_modify( #[case] t_num: usize, #[case] c_num: u64, #[case] foo: fn(&mut ModuleBuilder, usize) -> FuncID, diff --git a/tket/src/modifier/modifier_resolver/global_phase_modify.rs b/tket/src/modifier/modifier_resolver/global_phase_modify.rs index 4963952a1..2026b3acd 100644 --- a/tket/src/modifier/modifier_resolver/global_phase_modify.rs +++ b/tket/src/modifier/modifier_resolver/global_phase_modify.rs @@ -19,7 +19,7 @@ use crate::{ impl ModifierResolver { /// Modify a global phase operation. /// This returns the incoming port for the rotation of the modified operation. - pub fn modify_global_phase( + pub(crate) fn modify_global_phase( &mut self, n: N, new_fn: &mut impl Dataflow, @@ -103,7 +103,7 @@ impl ModifierResolver { } /// Delete all global phase operations in the subgraph reachable from the given entry points. -pub fn delete_phase( +pub(crate) fn delete_phase( h: &mut impl HugrMut, entry_points: impl IntoIterator, ) -> Result<(), ModifierResolverErrors> { @@ -151,7 +151,7 @@ mod tests { #[case(1, foo, true)] #[case(5, foo, false)] #[case(5, foo, true)] - pub fn test_global_phase_modify( + fn test_global_phase_modify( #[case] c_num: u64, #[case] foo: fn(&mut ModuleBuilder, usize) -> FuncID, #[case] dagger: bool, diff --git a/tket/src/modifier/modifier_resolver/tket_op_modify.rs b/tket/src/modifier/modifier_resolver/tket_op_modify.rs index f766069c6..14e3b8acb 100644 --- a/tket/src/modifier/modifier_resolver/tket_op_modify.rs +++ b/tket/src/modifier/modifier_resolver/tket_op_modify.rs @@ -14,7 +14,7 @@ use crate::{ impl ModifierResolver { /// Modify a TketOp operation. The returned `PortVector` contains the incoming and outgoing ports of the modified operation. /// Ancilla qubits are dirty qubits that are used to store intermediate results. - pub fn modify_tket_op( + pub(crate) fn modify_tket_op( &mut self, op_node: N, tket_op: TketOp, @@ -602,7 +602,7 @@ mod test { #[case(3, false)] #[case(3, true)] #[case(7, false)] - pub fn test_single_tket_op(#[case] c_num: u64, #[case] dagger: bool) { + fn test_single_tket_op(#[case] c_num: u64, #[case] dagger: bool) { for op in TketOp::iter() { let Some((size, has_angle)) = size(op) else { continue; diff --git a/tket/src/modifier/power.rs b/tket/src/modifier/power.rs index c16d1ecfa..753a575e4 100644 --- a/tket/src/modifier/power.rs +++ b/tket/src/modifier/power.rs @@ -13,7 +13,7 @@ pub struct ModifierPower; impl ModifierPower { /// Create a new ModifierPower. - pub fn new() -> Self { + fn new() -> Self { ModifierPower } } @@ -36,7 +36,7 @@ impl FromStr for ModifierPower { impl ModifierPower { /// signature for the power modifier. /// The Copyable bound of the second parameter is needed while constructing `TailLoop`. - pub fn signature() -> SignatureFunc { + pub(crate) fn signature() -> SignatureFunc { PolyFuncTypeRV::new( [ TypeParam::new_list_type(TypeBound::Linear),