From 4bec988c99151cc98221ce4b861b4d40663fdea9 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Thu, 6 Nov 2025 10:22:58 +0000 Subject: [PATCH 1/2] add basis arg to decomposition, move basis target creation to a new file Signed-off-by: Luca Mondada --- include/cudaq/Optimizer/Transforms/Passes.td | 24 ++++- lib/Optimizer/Transforms/BasisConversion.cpp | 64 +----------- lib/Optimizer/Transforms/CMakeLists.txt | 1 + lib/Optimizer/Transforms/Decomposition.cpp | 4 + .../DecompositionPatternSelection.cpp | 98 +++++++++++++++++++ .../Transforms/DecompositionPatterns.h | 18 ++++ 6 files changed, 142 insertions(+), 67 deletions(-) create mode 100644 lib/Optimizer/Transforms/DecompositionPatternSelection.cpp diff --git a/include/cudaq/Optimizer/Transforms/Passes.td b/include/cudaq/Optimizer/Transforms/Passes.td index 500b5cc53e6..3b83f1bef40 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.td +++ b/include/cudaq/Optimizer/Transforms/Passes.td @@ -309,17 +309,31 @@ def DeadStoreRemoval : Pass<"dead-store-removal"> { def DecompositionPass: Pass<"decomposition", "mlir::ModuleOp"> { let summary = "Break down quantum operations."; let description = [{ - This pass performs decomposition over a set of operations by iteratively - applying decomposition patterns until either a fixpoint is reached or the - maximum number of iterations/rewrites is exhausted. Decomposition is - best-effort and does not guarantee that the entire IR is decomposed after - running this pass. + This pass decomposes quantum operations by iteratively applying rewrite + patterns until all quantum operations are in basis, a fixpoint is reached or + the maximum number of iterations is exhausted. + + When `basis` is specified, the pass automatically selects the set of rewrite + patterns, ensuring an acyclic pattern set targeting the specified basis. + Options `enable-patterns` and `disable-patterns` can further filter the + selected rewrite patterns. + + + The `basis` option takes a comma-separated list of quantum operations with + the format: `([ | n])?` + + Examples: + - `x` — Pauli-X without controls (aka `not`) + - `x(1)` — Pauli-X with one control (aka `cx`) + - `x(n)` — Pauli-X with unbounded controls + - `x,x(1)` — Both `not` and `cx` NOTE: The current implementation is conservative w.r.t global phase, which means no decomposition will take place under the presence of controlled `quake.apply` operations in the module. }]; let options = [ + ListOption<"basis", "basis", "std::string", "Set of basis operations">, ListOption<"disabledPatterns", "disable-patterns", "std::string", "Labels of decomposition patterns that should be filtered out">, ListOption<"enabledPatterns", "enable-patterns", "std::string", diff --git a/lib/Optimizer/Transforms/BasisConversion.cpp b/lib/Optimizer/Transforms/BasisConversion.cpp index 523dcf4fc1e..03658b6962b 100644 --- a/lib/Optimizer/Transforms/BasisConversion.cpp +++ b/lib/Optimizer/Transforms/BasisConversion.cpp @@ -30,66 +30,6 @@ namespace cudaq::opt { namespace { -struct BasisTarget : public ConversionTarget { - struct OperatorInfo { - StringRef name; - size_t numControls; - }; - - BasisTarget(MLIRContext &context, ArrayRef targetBasis) - : ConversionTarget(context) { - constexpr size_t unbounded = std::numeric_limits::max(); - - // Parse the list of target operations and build a set of legal operations - for (const std::string &targetInfo : targetBasis) { - StringRef option = targetInfo; - auto nameEnd = option.find_first_of('('); - auto name = option.take_front(nameEnd); - if (nameEnd < option.size()) - option = option.drop_front(nameEnd); - - auto &info = legalOperatorSet.emplace_back(OperatorInfo{name, 0}); - if (option.consume_front("(")) { - option = option.ltrim(); - if (option.consume_front("n")) - info.numControls = unbounded; - else - option.consumeInteger(10, info.numControls); - assert(option.trim().consume_front(")")); - } - } - - addLegalDialect(); - addDynamicallyLegalDialect([&](Operation *op) { - if (auto optor = dyn_cast(op)) { - auto name = optor->getName().stripDialect(); - for (auto info : legalOperatorSet) { - if (info.name != name) - continue; - if (info.numControls == unbounded || - optor.getControls().size() == info.numControls) - return info.numControls == optor.getControls().size(); - } - return false; - } - - // Handle quake.exp_pauli. - if (isa(op)) { - // If the target defines it as a legal op, return true, else false. - return std::find_if(legalOperatorSet.begin(), legalOperatorSet.end(), - [](auto &&el) { return el.name == "exp_pauli"; }) != - legalOperatorSet.end(); - } - - return true; - }); - } - - SmallVector legalOperatorSet; -}; - //===----------------------------------------------------------------------===// // Pass implementation //===----------------------------------------------------------------------===// @@ -146,7 +86,7 @@ struct BasisConversion return; // Setup target and patterns - BasisTarget target(getContext(), basis); + auto target = cudaq::createBasisTarget(getContext(), basis); RewritePatternSet owningPatterns(&getContext()); cudaq::populateWithAllDecompositionPatterns(owningPatterns); auto patterns = FrozenRewritePatternSet(std::move(owningPatterns), @@ -155,7 +95,7 @@ struct BasisConversion // Process kernels in parallel LogicalResult rewriteResult = failableParallelForEach( module.getContext(), kernels, [&target, &patterns](Operation *op) { - return applyFullConversion(op, target, patterns); + return applyFullConversion(op, *target, patterns); }); if (failed(rewriteResult)) diff --git a/lib/Optimizer/Transforms/CMakeLists.txt b/lib/Optimizer/Transforms/CMakeLists.txt index 88d28fb584d..96b794c167f 100644 --- a/lib/Optimizer/Transforms/CMakeLists.txt +++ b/lib/Optimizer/Transforms/CMakeLists.txt @@ -25,6 +25,7 @@ add_cudaq_library(OptTransforms DeadStoreRemoval.cpp Decomposition.cpp DecompositionPatterns.cpp + DecompositionPatternSelection.cpp DelayMeasurements.cpp DeleteStates.cpp DistributedDeviceCall.cpp diff --git a/lib/Optimizer/Transforms/Decomposition.cpp b/lib/Optimizer/Transforms/Decomposition.cpp index 1768b5f3bd6..4fb5f9a0c94 100644 --- a/lib/Optimizer/Transforms/Decomposition.cpp +++ b/lib/Optimizer/Transforms/Decomposition.cpp @@ -8,12 +8,16 @@ #include "DecompositionPatterns.h" #include "cudaq/Frontend/nvqpp/AttributeNames.h" +#include "cudaq/Optimizer/Dialect/CC/CCDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" #include "cudaq/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Threading.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; diff --git a/lib/Optimizer/Transforms/DecompositionPatternSelection.cpp b/lib/Optimizer/Transforms/DecompositionPatternSelection.cpp new file mode 100644 index 00000000000..e8564e862c6 --- /dev/null +++ b/lib/Optimizer/Transforms/DecompositionPatternSelection.cpp @@ -0,0 +1,98 @@ +/******************************************************************************* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "DecompositionPatterns.h" +#include "cudaq/Optimizer/Dialect/CC/CCDialect.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include +#include + +using namespace mlir; + +namespace { + +//===----------------------------------------------------------------------===// +// ConversionTarget and OperatorInfo, parsed from target basis strings such as +// ["x", "x(1)", "z"] +//===----------------------------------------------------------------------===// + +struct OperatorInfo { + StringRef name; + size_t numControls; + + OperatorInfo(StringRef infoStr) : name(), numControls(0) { + auto nameEnd = infoStr.find_first_of('('); + name = infoStr.take_front(nameEnd); + if (nameEnd < infoStr.size()) + infoStr = infoStr.drop_front(nameEnd); + + if (infoStr.consume_front("(")) { + infoStr = infoStr.ltrim(); + if (infoStr.consume_front("n")) + numControls = std::numeric_limits::max(); + else + infoStr.consumeInteger(10, numControls); + assert(infoStr.trim().consume_front(")")); + } + } +}; + +struct BasisTarget : public ConversionTarget { + + BasisTarget(MLIRContext &context, ArrayRef targetBasis) + : ConversionTarget(context) { + constexpr size_t unbounded = std::numeric_limits::max(); + + // Parse the list of target operations and build a set of legal operations + for (const std::string &targetInfo : targetBasis) { + legalOperatorSet.emplace_back(targetInfo); + } + + addLegalDialect(); + addDynamicallyLegalDialect([&](Operation *op) { + if (auto optor = dyn_cast(op)) { + auto name = optor->getName().stripDialect(); + for (auto info : legalOperatorSet) { + if (info.name != name) + continue; + if (info.numControls == unbounded || + optor.getControls().size() == info.numControls) + return info.numControls == optor.getControls().size(); + } + return false; + } + + // Handle quake.exp_pauli. + if (isa(op)) { + // If the target defines it as a legal op, return true, else false. + return std::find_if(legalOperatorSet.begin(), legalOperatorSet.end(), + [](auto &&el) { return el.name == "exp_pauli"; }) != + legalOperatorSet.end(); + } + + return true; + }); + } + + SmallVector legalOperatorSet; +}; + +} // namespace + +std::unique_ptr +cudaq::createBasisTarget(MLIRContext &context, + ArrayRef targetBasis) { + return std::make_unique(context, targetBasis); +} diff --git a/lib/Optimizer/Transforms/DecompositionPatterns.h b/lib/Optimizer/Transforms/DecompositionPatterns.h index 9cb68e40522..27a9c61c203 100644 --- a/lib/Optimizer/Transforms/DecompositionPatterns.h +++ b/lib/Optimizer/Transforms/DecompositionPatterns.h @@ -11,6 +11,8 @@ #include "common/Registry.h" #include "llvm/ADT/ArrayRef.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/DialectConversion.h" +#include namespace mlir { class RewritePatternSet; @@ -67,4 +69,20 @@ class DecompositionPattern : public mlir::OpRewritePattern { void populateWithAllDecompositionPatterns(mlir::RewritePatternSet &patterns); +/// Create a conversion target parsed from a target basis string. +/// +/// The `targetBasis` should be made of strings of the form: +/// +/// ``` +/// (`(` [ | `n`] `)` )? +/// ``` +/// +/// The returned conversion target will accept operations in the MLIR dialects +/// arith::ArithDialect, cf::ControlFlowDialect, cudaq::cc::CCDialect, +/// func::FuncDialect, and math::MathDialect, as well as operations in the +/// quake::QuakeDialect that appear in `targetBasis`. +std::unique_ptr +createBasisTarget(mlir::MLIRContext &context, + mlir::ArrayRef targetBasis); + } // namespace cudaq From deaea768620b040ff7f6ae2a5c6b87474be79674 Mon Sep 17 00:00:00 2001 From: Luca Mondada Date: Thu, 6 Nov 2025 17:49:00 +0000 Subject: [PATCH 2/2] add DecompositionPatternSelection, use it in Decomposition and BasisConversion Signed-off-by: Luca Mondada --- include/cudaq/Optimizer/Transforms/Passes.td | 28 +- lib/Optimizer/Transforms/BasisConversion.cpp | 13 +- lib/Optimizer/Transforms/Decomposition.cpp | 27 +- .../DecompositionPatternSelection.cpp | 254 +++++++++++- .../Transforms/DecompositionPatterns.cpp | 7 +- .../Transforms/DecompositionPatterns.h | 19 + unittests/Optimizer/CMakeLists.txt | 6 + .../DecompositionPatternSelectionTest.cpp | 360 ++++++++++++++++++ .../Optimizer/DecompositionPatternsTest.cpp | 2 +- 9 files changed, 695 insertions(+), 21 deletions(-) create mode 100644 unittests/Optimizer/DecompositionPatternSelectionTest.cpp diff --git a/include/cudaq/Optimizer/Transforms/Passes.td b/include/cudaq/Optimizer/Transforms/Passes.td index 3b83f1bef40..80807ada2cf 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.td +++ b/include/cudaq/Optimizer/Transforms/Passes.td @@ -136,6 +136,15 @@ def BasisConversionPass : Pass<"basis-conversion", "mlir::ModuleOp"> { - `x(1)` means targeting pauli-x operations with one control (aka, `cx`) - `x(n)` means targeting pauli-x operation with unbounded number of controls - `x,x(1)` means targeting both `not` and `cx` operations + + The pass automatically selects the set of rewrite patterns, ensuring every + gate is decomposed to the specified basis in a unique way. Option + `disable-patterns` can be used to filter the selected rewrite patterns. + Option `enable-patterns` can be used to override the automatic pattern + selection. + + If no `basis` is specified or the pass cannot decompose all operations to + the specified basis, the pass application will fail. }]; let options = [ ListOption<"basis", "basis", "std::string", "Set of basis operations">, @@ -314,15 +323,22 @@ def DecompositionPass: Pass<"decomposition", "mlir::ModuleOp"> { the maximum number of iterations is exhausted. When `basis` is specified, the pass automatically selects the set of rewrite - patterns, ensuring an acyclic pattern set targeting the specified basis. - Options `enable-patterns` and `disable-patterns` can further filter the - selected rewrite patterns. + patterns, ensuring every gate is decomposed to the specified basis in a + unique way. + + ## Options + The following options are available and are all optional: - The `basis` option takes a comma-separated list of quantum operations with - the format: `([ | n])?` + - `disable-patterns`: used to filter out specific rewrite patterns from the + selection. + - `enable-patterns`: when set, overrides the automatic pattern selection. + - `basis`: takes a comma-separated list of quantum operations with the + format: `([ | n])?` - Examples: + If no `basis` is specified, as many patterns as possible are applied. + + Examples of valid `basis` values: - `x` — Pauli-X without controls (aka `not`) - `x(1)` — Pauli-X with one control (aka `cx`) - `x(n)` — Pauli-X with unbounded controls diff --git a/lib/Optimizer/Transforms/BasisConversion.cpp b/lib/Optimizer/Transforms/BasisConversion.cpp index 03658b6962b..3791488766f 100644 --- a/lib/Optimizer/Transforms/BasisConversion.cpp +++ b/lib/Optimizer/Transforms/BasisConversion.cpp @@ -88,9 +88,16 @@ struct BasisConversion // Setup target and patterns auto target = cudaq::createBasisTarget(getContext(), basis); RewritePatternSet owningPatterns(&getContext()); - cudaq::populateWithAllDecompositionPatterns(owningPatterns); - auto patterns = FrozenRewritePatternSet(std::move(owningPatterns), - disabledPatterns, enabledPatterns); + FrozenRewritePatternSet patterns; + if (enabledPatterns.empty()) { + cudaq::selectDecompositionPatterns(owningPatterns, basis, + disabledPatterns); + patterns = FrozenRewritePatternSet(std::move(owningPatterns)); + } else { + cudaq::populateWithAllDecompositionPatterns(owningPatterns); + patterns = FrozenRewritePatternSet(std::move(owningPatterns), + disabledPatterns, enabledPatterns); + } // Process kernels in parallel LogicalResult rewriteResult = failableParallelForEach( diff --git a/lib/Optimizer/Transforms/Decomposition.cpp b/lib/Optimizer/Transforms/Decomposition.cpp index 4fb5f9a0c94..04a21b306a2 100644 --- a/lib/Optimizer/Transforms/Decomposition.cpp +++ b/lib/Optimizer/Transforms/Decomposition.cpp @@ -8,16 +8,13 @@ #include "DecompositionPatterns.h" #include "cudaq/Frontend/nvqpp/AttributeNames.h" -#include "cudaq/Optimizer/Dialect/CC/CCDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" #include "cudaq/Optimizer/Transforms/Passes.h" -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Threading.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; @@ -44,9 +41,25 @@ struct Decomposition /// execution. LogicalResult initialize(MLIRContext *context) override { RewritePatternSet owningPatterns(context); - cudaq::populateWithAllDecompositionPatterns(owningPatterns); - patterns = FrozenRewritePatternSet(std::move(owningPatterns), - disabledPatterns, enabledPatterns); + + if (!basis.empty() && !enabledPatterns.empty()) { + mlir::emitWarning( + mlir::UnknownLoc::get(context), + "DecompositionPass: basis is ignored when enabledPatterns is " + "specified"); + } + + if (!basis.empty() && enabledPatterns.empty()) { + // Restrict to patterns useful for the target basis + cudaq::selectDecompositionPatterns(owningPatterns, basis, + disabledPatterns); + patterns = FrozenRewritePatternSet(std::move(owningPatterns)); + } else { + cudaq::populateWithAllDecompositionPatterns(owningPatterns); + patterns = FrozenRewritePatternSet(std::move(owningPatterns), + disabledPatterns, enabledPatterns); + } + return success(); } diff --git a/lib/Optimizer/Transforms/DecompositionPatternSelection.cpp b/lib/Optimizer/Transforms/DecompositionPatternSelection.cpp index e8564e862c6..b91cb4f80c2 100644 --- a/lib/Optimizer/Transforms/DecompositionPatternSelection.cpp +++ b/lib/Optimizer/Transforms/DecompositionPatternSelection.cpp @@ -14,8 +14,18 @@ #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/DialectConversion.h" +#include +#include +#include #include +#include +#include + #include +#include +#include +#include +#include using namespace mlir; @@ -45,6 +55,10 @@ struct OperatorInfo { assert(infoStr.trim().consume_front(")")); } } + + bool operator==(const OperatorInfo &other) const { + return name == other.name && numControls == other.numControls; + } }; struct BasisTarget : public ConversionTarget { @@ -69,7 +83,7 @@ struct BasisTarget : public ConversionTarget { continue; if (info.numControls == unbounded || optor.getControls().size() == info.numControls) - return info.numControls == optor.getControls().size(); + return true; } return false; } @@ -91,8 +105,246 @@ struct BasisTarget : public ConversionTarget { } // namespace +//===----------------------------------------------------------------------===// +// std::hash specialization for OperatorInfo +//===----------------------------------------------------------------------===// + +namespace std { +template <> +struct hash { + size_t operator()(const OperatorInfo &info) const { + return llvm::hash_combine(info.name, info.numControls); + } +}; +} // namespace std + +namespace { + +// Computes a hash of the given unordered set using the hashes of the elements +// in the set. +template +size_t computeSetHash(const std::unordered_set &set) { + std::vector hashes; + for (const auto &elem : set) { + hashes.push_back(std::hash()(elem)); + } + std::sort(hashes.begin(), hashes.end()); + return llvm::hash_combine_range(hashes.begin(), hashes.end()); +} + +//===----------------------------------------------------------------------===// +// Decomposition Graph for Pattern Selection +//===----------------------------------------------------------------------===// + +/// DecompositionGraph constructs a hypergraph of decomposition patterns based +/// on pattern metadata and performs backward traversal to select patterns that +/// decompose to a basis. +/// +/// Specifically, the decomposition graph is defined as a hypergraph in which +/// nodes are gate types and hyperedges are rewrite patterns connecting the +/// matched gate type to all newly inserted gate types. +class DecompositionGraph { +public: + DecompositionGraph() = default; + + /// Construct a decomposition pattern graph from a collection of pattern + /// types. + DecompositionGraph( + llvm::StringMap> + patterns) + : patternTypes(std::move(patterns)) { + // Build the graph from pattern metadata + for (const auto &pattern : patternTypes) { + auto targetGates = pattern.getValue()->getTargetOps(); + for (const auto &targetGate : targetGates) { + targetToPatterns[targetGate].push_back(pattern.getKey().str()); + } + } + } + + /// Create a DecompositionGraph from the registry entries. + static DecompositionGraph fromRegistry() { + llvm::StringMap> patterns; + for (const auto &patternType : + cudaq::DecompositionPatternType::RegistryType::entries()) { + patterns.insert({patternType.getName(), patternType.instantiate()}); + } + return DecompositionGraph(std::move(patterns)); + } + + /// Return all patterns that have the given gate as one of their targets. + /// + /// @param gate The gate to find incoming patterns for + /// @return A vector of pattern names (StringRef) whose targets include the + /// given gate + llvm::ArrayRef incomingPatterns(const OperatorInfo &gate) const { + static const llvm::SmallVector empty; + auto it = targetToPatterns.find(gate); + return it == targetToPatterns.end() ? empty : it->second; + } + + /// Select subset of patterns relevant to decomposing to the given basis + /// gates. + /// + /// The result of the pattern selection are cached, so that successive calls + /// with the same arguments will be O(1). + /// + /// @param patterns The pattern set to add the selected patterns to + /// @param basisGates The basis gates to decompose to + /// @param disabledPatterns The patterns to disable + void selectPatterns(RewritePatternSet &patterns, + const std::unordered_set &basisGates, + const std::unordered_set &disabledPatterns) { + auto hashVal = llvm::hash_combine(computeSetHash(basisGates), + computeSetHash(disabledPatterns)); + + if (!patternSelectionCache.contains(hashVal)) { + patternSelectionCache[hashVal] = + computePatternSelection(basisGates, disabledPatterns); + } + + for (const auto &patternName : patternSelectionCache[hashVal]) { + const auto &pattern = getPatternType(patternName); + patterns.add(pattern->create(patterns.getContext())); + } + } + +private: + const std::unique_ptr & + getPatternType(const std::string &patternName) const { + auto patternType = patternTypes.find(patternName); + assert(patternType != patternTypes.end() && "pattern not found"); + return patternType->getValue(); + } + + /// Use Dijkstra's algorithm to compute the shortest decomposition path from + /// every reachable gate type to the basis gates. + /// + /// This selects a unique decomposition path for each gate in the past of the + /// basis gates in the decomposition graph, such that the number of patterns + /// applied is minimized. `disabledPatterns` are ignored during the traversal + /// and hence never selected. + /// + /// @param basisGates The set of basis gates to decompose to + /// @param disabledPatterns The patterns to disable + /// @return A vector of selected pattern names + std::vector computePatternSelection( + const std::unordered_set &basisGates, + const std::unordered_set &disabledPatterns) const { + + // An element in the priority queue of the Dijkstra algorithm (ordered by + // smallest distance) + struct GateDistancePair { + OperatorInfo gate; + size_t distance; + std::optional outgoingPattern; + + bool operator<(const GateDistancePair &other) const { + // We want to order by smallest distance, so we invert the comparison + return distance > other.distance; + } + }; + + // Map: visited gate -> distance from the basis gates + std::unordered_map visitedGates; + // The set of selected patterns to return + std::vector selectedPatterns; + // Priority queue of gates to visit, sorted by smallest distance from the + // basis gates + std::priority_queue gatesToVisit; + + // Initialize the priority queue with the basis gates + for (const auto &gate : basisGates) { + gatesToVisit.push({gate, 0, std::nullopt}); + } + + /// Compute the maximum distance from a pattern's targets to the basis + /// gates. + auto getPatternDist = [&](const auto &pattern) { + auto targetGates = pattern->getTargetOps(); + std::vector targetDistances; + for (const auto &targetGate : targetGates) { + if (visitedGates.count(targetGate)) { + targetDistances.push_back(visitedGates.at(targetGate)); + } else { + targetDistances.push_back(std::numeric_limits::max()); + } + } + return *std::max_element(targetDistances.begin(), targetDistances.end()); + }; + + while (!gatesToVisit.empty()) { + auto [gate, dist, outgoingPattern] = gatesToVisit.top(); + gatesToVisit.pop(); + + auto [_, success] = visitedGates.insert({gate, dist}); + if (!success) { + // Gate already visited + continue; + } + + if (outgoingPattern.has_value()) { + selectedPatterns.push_back(*outgoingPattern); + } + + for (const auto &patternName : incomingPatterns(gate)) { + if (disabledPatterns.contains(patternName)) { + // Ignore disabled patterns + continue; + } + const auto &pattern = getPatternType(patternName); + size_t dist = getPatternDist(pattern); + if (dist < std::numeric_limits::max()) { + gatesToVisit.push({pattern->getSourceOp(), dist + 1, patternName}); + } + } + } + + return selectedPatterns; + } + + //===--------------------------------------------------------------------===// + // Data structures for the graph definition + //===--------------------------------------------------------------------===// + + /// All pattern types in the graph, keyed by pattern name. + llvm::StringMap> + patternTypes; + + /// Map: target gate -> patterns that produce it + std::unordered_map> targetToPatterns; + + //===--------------------------------------------------------------------===// + // Other data (cache) + //===--------------------------------------------------------------------===// + + /// Cache for `selectPatterns`: hash of basis gates, disabled patterns, + /// enabled patterns -> selected patterns + std::unordered_map> patternSelectionCache; +}; + +} // namespace + std::unique_ptr cudaq::createBasisTarget(MLIRContext &context, ArrayRef targetBasis) { return std::make_unique(context, targetBasis); } + +void cudaq::selectDecompositionPatterns( + RewritePatternSet &patterns, ArrayRef targetBasis, + ArrayRef disabledPatterns) { + // Static local graph - constructed once and reused + static DecompositionGraph graph = DecompositionGraph::fromRegistry(); + + BasisTarget target(*patterns.getContext(), targetBasis); + + // Convert targetBasis, disabledPatterns and enabledPatterns to sets for O(1) + // lookup + std::unordered_set basisGatesSet( + target.legalOperatorSet.begin(), target.legalOperatorSet.end()); + std::unordered_set disabledPatternsSet(disabledPatterns.begin(), + disabledPatterns.end()); + + return graph.selectPatterns(patterns, basisGatesSet, disabledPatternsSet); +} diff --git a/lib/Optimizer/Transforms/DecompositionPatterns.cpp b/lib/Optimizer/Transforms/DecompositionPatterns.cpp index 957cf6fc27b..bc4f21d8297 100644 --- a/lib/Optimizer/Transforms/DecompositionPatterns.cpp +++ b/lib/Optimizer/Transforms/DecompositionPatterns.cpp @@ -336,7 +336,8 @@ LogicalResult checkAndExtractControls(quake::OperatorInterface op, // TODO: The decomposition patterns "SToR1", "TToR1", "R1ToU3", "U3ToRotations" // can handle arbitrary number of controls, but currently metadata cannot -// capture this. The pattern types therefore only advertise them for 0 controls. +// capture this. The pattern types therefore only advertise them for a fixed +// number of controls (1 for "SToR1" and "TToR1", 0 for the rest). //===----------------------------------------------------------------------===// // HOp decompositions @@ -794,7 +795,7 @@ struct SToR1 : public cudaq::DecompositionPattern { return success(); } }; -REGISTER_DECOMPOSITION_PATTERN(SToR1, "s", "r1"); +REGISTER_DECOMPOSITION_PATTERN(SToR1, "s(1)", "r1(1)"); //===----------------------------------------------------------------------===// // TOp decompositions @@ -875,7 +876,7 @@ struct TToR1 : public cudaq::DecompositionPattern { return success(); } }; -REGISTER_DECOMPOSITION_PATTERN(TToR1, "t", "r1"); +REGISTER_DECOMPOSITION_PATTERN(TToR1, "t(1)", "r1(1)"); //===----------------------------------------------------------------------===// // XOp decompositions diff --git a/lib/Optimizer/Transforms/DecompositionPatterns.h b/lib/Optimizer/Transforms/DecompositionPatterns.h index 27a9c61c203..c951abe04d2 100644 --- a/lib/Optimizer/Transforms/DecompositionPatterns.h +++ b/lib/Optimizer/Transforms/DecompositionPatterns.h @@ -67,6 +67,25 @@ class DecompositionPattern : public mlir::OpRewritePattern { void initialize() { this->setDebugName(PatternType().getPatternName()); } }; +/// Select subset of patterns relevant to decomposing to the given target basis. +/// +/// The result of the pattern selection are cached, so that successive calls +/// with the same arguments will be O(1). +/// +/// @param patterns The pattern set to add the selected patterns to +/// @param basisGates The basis gates to decompose to +/// @param disabledPatterns The patterns to disable +/// +/// A subset of the decomposition patterns is selected such that: +/// - for every gate that can be decomposed to the target basis, the sequence of +/// decomposition to the target basis is unique. +/// - when more than one decomposition would exist, the one that requires the +/// fewest applications of patterns is chosen. +/// - `disabledPatterns` are never selected +void selectDecompositionPatterns(mlir::RewritePatternSet &patterns, + llvm::ArrayRef targetBasis, + llvm::ArrayRef disabledPatterns); + void populateWithAllDecompositionPatterns(mlir::RewritePatternSet &patterns); /// Create a conversion target parsed from a target basis string. diff --git a/unittests/Optimizer/CMakeLists.txt b/unittests/Optimizer/CMakeLists.txt index 3ff19b74813..379622f5f17 100644 --- a/unittests/Optimizer/CMakeLists.txt +++ b/unittests/Optimizer/CMakeLists.txt @@ -12,6 +12,7 @@ add_executable(OptimizerUnitTests HermitianTrait.cpp FactoryMergeModuleTest.cpp DecompositionPatternsTest.cpp + DecompositionPatternSelectionTest.cpp ) target_link_libraries(OptimizerUnitTests @@ -30,8 +31,13 @@ target_link_libraries(OptimizerUnitTests target_include_directories(OptimizerUnitTests PRIVATE ${CMAKE_SOURCE_DIR}/runtime + PRIVATE ${CMAKE_SOURCE_DIR}/lib/Optimizer/Transforms ) +# MLIR/LLVM is built without RTTI, so we must disable it here too +# to avoid linker errors (undefined typeinfo for base classes) +target_compile_options(OptimizerUnitTests PRIVATE -fno-rtti) + gtest_discover_tests(OptimizerUnitTests) add_executable(test_quake_synth QuakeSynthTester.cpp) diff --git a/unittests/Optimizer/DecompositionPatternSelectionTest.cpp b/unittests/Optimizer/DecompositionPatternSelectionTest.cpp new file mode 100644 index 00000000000..d201c109adf --- /dev/null +++ b/unittests/Optimizer/DecompositionPatternSelectionTest.cpp @@ -0,0 +1,360 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +// Include the implementation file that we are testing +#include "DecompositionPatternSelection.cpp" + +#include "DecompositionPatterns.h" +#include "cudaq/Optimizer/Builder/Factory.h" +#include "cudaq/Optimizer/Dialect/CC/CCDialect.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h" +#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h" +#include "cudaq/Optimizer/Transforms/Passes.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/PassManager.h" + +#include +#include +#include +#include + +using namespace mlir; + +namespace { +/// A mock pattern class +class PatternTest : public mlir::RewritePattern { +public: + PatternTest(llvm::StringRef patternName, MLIRContext *context) + : mlir::RewritePattern(patternName, 0, context, {}) { + setDebugName(patternName); + } +}; + +/// A mock pattern type for testing. +class PatternTypeTest : public cudaq::DecompositionPatternType { +public: + PatternTypeTest(llvm::StringRef patternName, llvm::StringRef sourceOp, + std::vector targetOps) + : patternName(patternName), sourceOp(sourceOp), targetOps(targetOps) {} + + llvm::StringRef getSourceOp() const override { return sourceOp; } + + llvm::ArrayRef getTargetOps() const override { + return targetOps; + } + + llvm::StringRef getPatternName() const override { return patternName; } + + std::unique_ptr + create(mlir::MLIRContext *context, + mlir::PatternBenefit benefit = 1) const override { + return std::make_unique(patternName, context); + }; + +private: + llvm::StringRef patternName; + llvm::StringRef sourceOp; + std::vector targetOps; +}; + +/// Create a test decomposition graph with the following patterns. The arrow +// "->" should be read as "decomposes to". +/// x -> x(1) -> x(2) -> x(3) +/// y -> y(1) -> y(2) -> y(3) +/// z -> z(1)+x(1) +/// z(1) -> z(2)+x(2) +/// z(2) -> z(3)+x(3) +/// z -> h -> z(1) +DecompositionGraph createTestGraph() { + // Decompose x -> x(1) -> x(2) -> x(3) + auto pattern_x1 = std::make_unique( + "pattern_x1", "x", std::vector{"x(1)"}); + auto pattern_x2 = std::make_unique( + "pattern_x2", "x(1)", std::vector{"x(2)"}); + auto pattern_x3 = std::make_unique( + "pattern_x3", "x(2)", std::vector{"x(3)"}); + + // Decompose y -> y(1) -> y(2) -> y(3) + auto pattern_y1 = std::make_unique( + "pattern_y1", "y", std::vector{"y(1)"}); + auto pattern_y2 = std::make_unique( + "pattern_y2", "y(1)", std::vector{"y(2)"}); + auto pattern_y3 = std::make_unique( + "pattern_y3", "y(2)", std::vector{"y(3)"}); + + // Decompose z similarly to x and y, but it creates "side effects" in the form + // of extra x gates. + auto pattern_z1 = std::make_unique( + "pattern_z1", "z", std::vector{"z(1)", "x(1)"}); + auto pattern_z2 = std::make_unique( + "pattern_z2", "z(1)", std::vector{"z(2)", "x(2)"}); + auto pattern_z3 = std::make_unique( + "pattern_z3", "z(2)", std::vector{"z(3)", "x(3)"}); + + // Another way to decompose z -> z(1), is side-effect free, but requires an + // extra pattern. + // z -> h -> z(1) + auto pattern_zh1 = std::make_unique( + "pattern_zh1", "z", std::vector{"h"}); + auto pattern_zh2 = std::make_unique( + "pattern_zh2", "h", std::vector{"z(1)"}); + + llvm::StringMap> patterns; + patterns.insert({pattern_x1->getPatternName(), std::move(pattern_x1)}); + patterns.insert({pattern_x2->getPatternName(), std::move(pattern_x2)}); + patterns.insert({pattern_x3->getPatternName(), std::move(pattern_x3)}); + patterns.insert({pattern_y1->getPatternName(), std::move(pattern_y1)}); + patterns.insert({pattern_y2->getPatternName(), std::move(pattern_y2)}); + patterns.insert({pattern_y3->getPatternName(), std::move(pattern_y3)}); + patterns.insert({pattern_z1->getPatternName(), std::move(pattern_z1)}); + patterns.insert({pattern_z2->getPatternName(), std::move(pattern_z2)}); + patterns.insert({pattern_z3->getPatternName(), std::move(pattern_z3)}); + patterns.insert({pattern_zh1->getPatternName(), std::move(pattern_zh1)}); + patterns.insert({pattern_zh2->getPatternName(), std::move(pattern_zh2)}); + return DecompositionGraph(std::move(patterns)); +} + +class BaseDecompositionPatternSelectionTest : public ::testing::Test { +protected: + void SetUp() override { + context = std::make_unique(); + context->loadDialect(); + // set up graph in children classes + } + + /// Whether an operation of type Op with nCtrls control qubits is legal on + /// the target. + template + bool isLegal(const std::unique_ptr &target, + unsigned nCtrls = 0) { + // Create a module with a single operation of type Op + auto loc = UnknownLoc::get(context.get()); + auto module = ModuleOp::create(loc); + OpBuilder builder(module.getBodyRegion()); + + // Create a function to hold the operation + auto funcType = builder.getFunctionType({}, {}); + auto func = builder.create(loc, "test_func", funcType); + auto *entryBlock = func.addEntryBlock(); + builder.setInsertionPointToStart(entryBlock); + + // Create n_qubits qubit wires + SmallVector controls; + auto wireType = quake::WireType::get(context.get()); + for (unsigned i = 0; i < nCtrls; ++i) { + auto qubit = builder.create(loc, wireType); + controls.push_back(qubit.getResult()); + } + auto targetQubit = builder.create(loc, wireType); + SmallVector targets{targetQubit}; + + // Create the operation of type Op with the qubits + auto op = builder.create(loc, controls, targets); + + // Get the operation pointer and check if it is legal + Operation *operation_ptr = op.getOperation(); + return target->isLegal(operation_ptr).has_value(); + } + + /// Run `selectPatterns` on the current decomposition graph and return the + /// selected patterns as a vector of sorted pattern names. + std::vector + selectPatterns(const std::vector &targetBasis, + const std::unordered_set &disabledPatterns = {}) { + auto convertToOperatorInfoSet = + [](const std::vector &targetBasis) { + std::unordered_set operatorInfoSet; + for (const auto &target : targetBasis) { + operatorInfoSet.insert(OperatorInfo(target)); + } + return operatorInfoSet; + }; + + RewritePatternSet patterns(context.get()); + graph.selectPatterns(patterns, convertToOperatorInfoSet(targetBasis), + disabledPatterns); + + std::vector selectedPatterns; + for (const auto &pattern : patterns.getNativePatterns()) { + selectedPatterns.push_back(pattern->getDebugName().str()); + } + std::sort(selectedPatterns.begin(), selectedPatterns.end()); + return selectedPatterns; + } + + std::unique_ptr context; + DecompositionGraph graph; +}; + +/// Run pattern selection tests on a dummy graph. +class DummyDecompositionPatternSelectionTest + : public BaseDecompositionPatternSelectionTest { +protected: + void SetUp() override { + BaseDecompositionPatternSelectionTest::SetUp(); + graph = createTestGraph(); + } +}; + +/// Run pattern selection tests on the full decomposition graph. +class FullDecompositionPatternSelectionTest + : public BaseDecompositionPatternSelectionTest { +protected: + void SetUp() override { + BaseDecompositionPatternSelectionTest::SetUp(); + graph = DecompositionGraph::fromRegistry(); + } +}; + +//===----------------------------------------------------------------------===// +// Test BasisTarget +//===----------------------------------------------------------------------===// + +TEST_F(BaseDecompositionPatternSelectionTest, BasisTargetParsesSimpleGates) { + std::vector basis{"h", "t", "x"}; + auto target = cudaq::createBasisTarget(*context, basis); + EXPECT_TRUE(isLegal(target)); + EXPECT_TRUE(isLegal(target)); + EXPECT_TRUE(isLegal(target)); + + EXPECT_FALSE(isLegal(target, 1)); + EXPECT_FALSE(isLegal(target, 1)); + EXPECT_FALSE(isLegal(target, 1)); + EXPECT_FALSE(isLegal(target)); +} + +TEST_F(BaseDecompositionPatternSelectionTest, + BasisTargetParsesControlledGates) { + std::vector basis{"x(1)", "z(2)"}; + auto target = cudaq::createBasisTarget(*context, basis); + EXPECT_TRUE(isLegal(target, 1)); + EXPECT_TRUE(isLegal(target, 2)); + + EXPECT_FALSE(isLegal(target)); + EXPECT_FALSE(isLegal(target, 2)); + EXPECT_FALSE(isLegal(target)); + EXPECT_FALSE(isLegal(target, 1)); + EXPECT_FALSE(isLegal(target, 3)); +} + +TEST_F(BaseDecompositionPatternSelectionTest, + BasisTargetParsesArbitraryControls) { + std::vector basis{"x(n)"}; + auto target = cudaq::createBasisTarget(*context, basis); + + EXPECT_TRUE(isLegal(target, 0)); + EXPECT_TRUE(isLegal(target, 1)); + EXPECT_TRUE(isLegal(target, 2)); + EXPECT_TRUE(isLegal(target, 10)); +} + +//===----------------------------------------------------------------------===// +// Test selectDecompositionPatterns on dummy graph +//===----------------------------------------------------------------------===// + +// Reminder: here are the fictional decompositions that we allow: +// y -> y(1) -> y(2) -> y(3) +// z -> z(1)+x(1) +// z(1) -> z(2)+x(2) +// z(2) -> z(3)+x(3) +// z -> h -> z(1) + +TEST_F(DummyDecompositionPatternSelectionTest, SelectXPatterns) { + std::vector targetBasis{"x(3)"}; + auto selectedPatterns = selectPatterns(targetBasis); + + // gates x, x(1) and x(2) can be decomposed to x(3), using the pattern_x* + // decomposition patterns: + // - pattern_x1: decompose x into x(1) + // - pattern_x2: decompose x(1) into x(2) + // - pattern_x3: decompose x(2) into x(3) + std::vector exp{"pattern_x1", "pattern_x2", "pattern_x3"}; + EXPECT_EQ(selectedPatterns, exp); +} + +TEST_F(DummyDecompositionPatternSelectionTest, SelectYPatterns) { + std::vector targetBasis{"y(2)"}; + auto selectedPatterns = selectPatterns(targetBasis); + + // gates y, y(1) can be decomposed to y(2), using the pattern_y* + // decomposition patterns: + // - pattern_y1: decompose y into y(1) + // - pattern_y2: decompose y(1) into y(2) + // pattern_y3 cannot be used, as it decomposes to y(3) + std::vector exp{"pattern_y1", "pattern_y2"}; + EXPECT_EQ(selectedPatterns, exp); +} + +TEST_F(DummyDecompositionPatternSelectionTest, SelectZOverXPatterns) { + std::vector targetBasis{"z(2)", "x(3)"}; + auto selectedPatterns = selectPatterns(targetBasis); + + // The decomposition patterns for z also introduce x gates. As we allow both + // x and z in the target basis, we can use the following z decomposition + // patterns: + // - pattern_x1: decompose x into x(1) + // - pattern_x2: decompose x(1) into x(2) + // - pattern_x3: decompose x(2) into x(3) + // - pattern_z1: decompose z into z(1)+x(1) + // - pattern_z2: decompose z(1) into z(2)+x(2) + // - pattern_zh2: decompose h into z(1) + // Pattern pattern_zh1 cannot be used, as z is already decomposed by + // pattern_z1. + std::vector exp{"pattern_x1", "pattern_x2", "pattern_x3", + "pattern_z1", "pattern_z2", "pattern_zh2"}; + EXPECT_EQ(selectedPatterns, exp); +} + +TEST_F(DummyDecompositionPatternSelectionTest, SelectZOverHPatterns) { + std::vector targetBasis{"z(1)"}; + auto selectedPatterns = selectPatterns(targetBasis); + + // The decomposition patterns for z also introduce x gates, but we do not + // accept x gates here. We can therefore only use the z-over-h decomposition + // patterns: + // - pattern_zh1: decompose z into h + // - pattern_zh2: decompose h into z(1) + std::vector exp{"pattern_zh1", "pattern_zh2"}; + EXPECT_EQ(selectedPatterns, exp); +} + +TEST_F(DummyDecompositionPatternSelectionTest, + SelectZOverHPatternsWithDisabledPatterns) { + std::vector targetBasis{"z(1)", "x(1)"}; + std::unordered_set disabledPatterns{"pattern_z1"}; + auto selectedPatterns = selectPatterns(targetBasis, disabledPatterns); + + // If we only consider the target basis, then pattern_z1: + // z -> z(1)+x(1) + // would be selected. However, by disabling it we force the selection of the + // pattern_zh1 instead. + std::vector exp{"pattern_x1", "pattern_zh1", "pattern_zh2"}; + EXPECT_EQ(selectedPatterns, exp); +} + +//===----------------------------------------------------------------------===// +// Test selectDecompositionPatterns on the registered decomposition graph +//===----------------------------------------------------------------------===// + +TEST_F(FullDecompositionPatternSelectionTest, DecomposeCCXToCZ) { + std::vector targetBasis{"h", "t", "z(1)"}; + auto selectedPatterns = selectPatterns(targetBasis); + + std::vector exp{"CCXToCCZ", "CCZToCX", "CXToCZ", "SwapToCX"}; + EXPECT_EQ(selectedPatterns, exp); +} + +} // namespace diff --git a/unittests/Optimizer/DecompositionPatternsTest.cpp b/unittests/Optimizer/DecompositionPatternsTest.cpp index a20a665ab84..020e041c254 100644 --- a/unittests/Optimizer/DecompositionPatternsTest.cpp +++ b/unittests/Optimizer/DecompositionPatternsTest.cpp @@ -6,7 +6,7 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ -#include "../../lib/Optimizer/Transforms/DecompositionPatterns.h" +#include "DecompositionPatterns.h" #include "cudaq/Optimizer/Builder/Factory.h" #include "cudaq/Optimizer/Dialect/CC/CCDialect.h" #include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"