diff --git a/TKET_VERSION b/TKET_VERSION index f674781b93..0228499b2d 100644 --- a/TKET_VERSION +++ b/TKET_VERSION @@ -1 +1 @@ -2.1.79 +2.1.80 diff --git a/pytket/binders/circuit/Circuit/add_op.cpp b/pytket/binders/circuit/Circuit/add_op.cpp index 3c49ac21d2..8d11a3bad0 100644 --- a/pytket/binders/circuit/Circuit/add_op.cpp +++ b/pytket/binders/circuit/Circuit/add_op.cpp @@ -146,6 +146,49 @@ Circuit *add_gate_method_any( } } } + + // check if there are wasm wires in the signature + + op_signature_t sig = new_op->get_signature(); + + unsigned count_wasm_sig = 0; + for (EdgeType e : sig) { + if (e == EdgeType::WASM) { + ++count_wasm_sig; + } + } + + unsigned count_wasm_args = 0; + for (UnitID uid : new_args) { + if (uid.type() == UnitType::WasmState) { + ++count_wasm_args; + } + } + + unsigned count_rng_sig = 0; + for (EdgeType e : sig) { + if (e == EdgeType::RNG) { + ++count_rng_sig; + } + } + + unsigned count_rng_args = 0; + for (UnitID uid : new_args) { + if (uid.type() == UnitType::RngState) { + ++count_rng_args; + } + } + + // potentially still effected by: + // https://github.com/Quantinuum/tket/issues/2154 + if (count_wasm_args != count_wasm_sig) { + new_args.push_back(WasmState(0)); + } + + if (count_rng_args != count_rng_sig) { + new_args.push_back(RngState(0)); + } + circ->add_op(new_op, new_args, opgroup); return circ; } @@ -577,6 +620,49 @@ void init_circuit_add_op(nb::class_ &c) { args.push_back(Bit(name, i)); } } + + // check if there are wasm wires in the signature + + op_signature_t sig = box.get_signature(); + + unsigned count_wasm_sig = 0; + for (EdgeType e : sig) { + if (e == EdgeType::WASM) { + ++count_wasm_sig; + } + } + + unsigned count_wasm_args = 0; + for (UnitID uid : args) { + if (uid.type() == UnitType::WasmState) { + ++count_wasm_args; + } + } + + unsigned count_rng_sig = 0; + for (EdgeType e : sig) { + if (e == EdgeType::RNG) { + ++count_rng_sig; + } + } + + unsigned count_rng_args = 0; + for (UnitID uid : args) { + if (uid.type() == UnitType::RngState) { + ++count_rng_args; + } + } + + // potentially still effected by: + // https://github.com/Quantinuum/tket/issues/2154 + if (count_wasm_args != count_wasm_sig) { + args.push_back(WasmState(0)); + } + + if (count_rng_args != count_rng_sig) { + args.push_back(RngState(0)); + } + return add_gate_method_any( circ, std::make_shared(box), args, kwargs); }, diff --git a/pytket/docs/changelog.md b/pytket/docs/changelog.md index 5bcc436670..a8a30435ac 100644 --- a/pytket/docs/changelog.md +++ b/pytket/docs/changelog.md @@ -1,5 +1,10 @@ # Changelog +# 2.17.0 (unreleased) + +Fixes: +- Fix handling of wasm function calls in circuit boxes + ## 2.16.0 (March 2026) Features: diff --git a/pytket/tests/classical_test.py b/pytket/tests/classical_test.py index 47f6f9850d..c3d59f09e8 100644 --- a/pytket/tests/classical_test.py +++ b/pytket/tests/classical_test.py @@ -69,7 +69,7 @@ reg_lt, reg_neq, ) -from pytket.passes import DecomposeClassicalExp +from pytket.passes import DecomposeBoxes, DecomposeClassicalExp curr_file_path = Path(__file__).resolve().parent @@ -904,6 +904,52 @@ def test_wasm_circuit_bits() -> None: DrawType = Callable[[SearchStrategy[T]], T] +def test_wasm_box() -> None: + + wasm_module = wasm.WasmFileHandler("testfile.wasm") + + A = BitRegister("A", 1) + B = BitRegister("B", 1) + + c0 = Circuit() + c0.add_c_register(A) + c0.add_c_register(B) + c0.add_wasm_to_reg("add_one", wasm_module, [A], [B]) + c0box = CircBox(c0) + + c1 = Circuit() + c1.add_c_register(A) + c1.add_c_register(B) + c1.add_circbox_regwise(c0box, qregs=[], cregs=[A, B]) + + DecomposeBoxes().apply(c1) + + assert c1.depth() == 1 + assert str(c1.get_commands()[0]) == "WASM A[0], B[0], _w[0];" + + +def test_rng_box() -> None: + + A = BitRegister("A", 32) + + c0 = Circuit() + c0.add_c_register(A) + c0.get_rng_num(A) + c0box = CircBox(c0) + + c1 = Circuit() + c1.add_c_register(A) + c1.add_circbox_regwise(c0box, qregs=[], cregs=[A]) + + DecomposeBoxes().apply(c1) + + assert c1.depth() == 1 + assert ( + str(c1.get_commands()[0]) + == "RNGNum A[0], A[1], A[2], A[3], A[4], A[5], A[6], A[7], A[8], A[9], A[10], A[11], A[12], A[13], A[14], A[15], A[16], A[17], A[18], A[19], A[20], A[21], A[22], A[23], A[24], A[25], A[26], A[27], A[28], A[29], A[30], A[31], _r[0];" + ) + + @strategies.composite def bit_register( draw: DrawType, diff --git a/pytket/tests/transform_test.py b/pytket/tests/transform_test.py index e6fae1f32d..2e71c8131a 100644 --- a/pytket/tests/transform_test.py +++ b/pytket/tests/transform_test.py @@ -1615,13 +1615,14 @@ def iregs(name: str, size: int) -> str: return "".join(f"{name}[{i}], " for i in range(size)) # both boxes should have the same args - # note that WASM and RNG states are not printed as part of the CircBox args assert c.depth() == 2 EXPECTED_BOX_ARGS = ( iregs("bound", 32) + iregs("index", 32) + iregs("num", 32) - + iregs("seed", 64)[:-2] + + iregs("seed", 64) + + iregs("_w", 1) + + iregs("_r", 1)[:-2] + ";" ) cmds = c.get_commands() diff --git a/tket/src/Circuit/Boxes.cpp b/tket/src/Circuit/Boxes.cpp index fa8151c4be..5e56037a6d 100644 --- a/tket/src/Circuit/Boxes.cpp +++ b/tket/src/Circuit/Boxes.cpp @@ -82,7 +82,11 @@ CircBox::CircBox(const Circuit &circ) : Box(OpType::CircBox) { } signature_ = op_signature_t(circ.n_qubits(), EdgeType::Quantum); op_signature_t bits(circ.n_bits(), EdgeType::Classical); + op_signature_t wasmwire(circ._number_of_wasm_wires, EdgeType::WASM); + op_signature_t rngwire(circ._number_of_rng_wires, EdgeType::RNG); signature_.insert(signature_.end(), bits.begin(), bits.end()); + signature_.insert(signature_.end(), wasmwire.begin(), wasmwire.end()); + signature_.insert(signature_.end(), rngwire.begin(), rngwire.end()); circ_ = std::make_shared(circ); } diff --git a/tket/test/src/test_wasm.cpp b/tket/test/src/test_wasm.cpp index e3f884737d..19f0d369c5 100644 --- a/tket/test/src/test_wasm.cpp +++ b/tket/test/src/test_wasm.cpp @@ -228,7 +228,12 @@ SCENARIO("generating circ with wasm") { CircBox circbox(u); Circuit major_circ(0, 1); - major_circ.add_box(circbox, {0}); + unit_vector_t new_args; + new_args.push_back(Bit(0)); + new_args.push_back(WasmState(0)); + new_args.push_back(WasmState(1)); + new_args.push_back(WasmState(2)); + major_circ.add_op(std::make_shared(circbox), new_args); REQUIRE(major_circ.depth() == 1); REQUIRE(major_circ.get_wasm_file_uid() == wasm_file);