Skip to content

Commit 9a6800d

Browse files
committed
add DecompositionPatternSelection, use it in Decomposition and BasisConversion
Signed-off-by: Luca Mondada <[email protected]>
1 parent 4bec988 commit 9a6800d

File tree

8 files changed

+802
-16
lines changed

8 files changed

+802
-16
lines changed

include/cudaq/Optimizer/Transforms/Passes.td

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def BasisConversionPass : Pass<"basis-conversion", "mlir::ModuleOp"> {
136136
- `x(1)` means targeting pauli-x operations with one control (aka, `cx`)
137137
- `x(n)` means targeting pauli-x operation with unbounded number of controls
138138
- `x,x(1)` means targeting both `not` and `cx` operations
139+
140+
The pass automatically selects the set of rewrite patterns, ensuring every
141+
gate is decomposed to the specified basis in a unique way. Options
142+
`enable-patterns` and `disable-patterns` can be used to specify or further
143+
filter the selected rewrite patterns.
144+
145+
If no `basis` is specified or the pass cannot decompose all operations to
146+
the specified basis, the pass application will fail.
139147
}];
140148
let options = [
141149
ListOption<"basis", "basis", "std::string", "Set of basis operations">,
@@ -314,10 +322,9 @@ def DecompositionPass: Pass<"decomposition", "mlir::ModuleOp"> {
314322
the maximum number of iterations is exhausted.
315323

316324
When `basis` is specified, the pass automatically selects the set of rewrite
317-
patterns, ensuring an acyclic pattern set targeting the specified basis.
318-
Options `enable-patterns` and `disable-patterns` can further filter the
319-
selected rewrite patterns.
320-
325+
patterns, ensuring every gate is decomposed to the specified basis in a
326+
unique way. Options `enable-patterns` and `disable-patterns` can be used to
327+
specify or further filter the selected rewrite patterns.
321328

322329
The `basis` option takes a comma-separated list of quantum operations with
323330
the format: `<op-name>([<number-of-controls> | n])?`
@@ -328,6 +335,8 @@ def DecompositionPass: Pass<"decomposition", "mlir::ModuleOp"> {
328335
- `x(n)` — Pauli-X with unbounded controls
329336
- `x,x(1)` — Both `not` and `cx`
330337

338+
If no `basis` is specified, as many patterns as possible are applied.
339+
331340
NOTE: The current implementation is conservative w.r.t global phase, which
332341
means no decomposition will take place under the presence of controlled
333342
`quake.apply` operations in the module.

lib/Optimizer/Transforms/BasisConversion.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ struct BasisConversion
8888
// Setup target and patterns
8989
auto target = cudaq::createBasisTarget(getContext(), basis);
9090
RewritePatternSet owningPatterns(&getContext());
91-
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
92-
auto patterns = FrozenRewritePatternSet(std::move(owningPatterns),
93-
disabledPatterns, enabledPatterns);
91+
cudaq::selectDecompositionPatterns(owningPatterns, basis, disabledPatterns,
92+
enabledPatterns);
93+
auto patterns = FrozenRewritePatternSet(std::move(owningPatterns));
9494

9595
// Process kernels in parallel
9696
LogicalResult rewriteResult = failableParallelForEach(

lib/Optimizer/Transforms/Decomposition.cpp

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,12 @@
88

99
#include "DecompositionPatterns.h"
1010
#include "cudaq/Frontend/nvqpp/AttributeNames.h"
11-
#include "cudaq/Optimizer/Dialect/CC/CCDialect.h"
1211
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
1312
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
1413
#include "cudaq/Optimizer/Transforms/Passes.h"
15-
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
16-
#include "mlir/Dialect/Math/IR/Math.h"
1714
#include "mlir/IR/PatternMatch.h"
1815
#include "mlir/IR/Threading.h"
1916
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
20-
#include "mlir/Transforms/DialectConversion.h"
2117
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2218

2319
using namespace mlir;
@@ -43,10 +39,19 @@ struct Decomposition
4339
/// Initialize the decomposer by building the set of patterns used during
4440
/// execution.
4541
LogicalResult initialize(MLIRContext *context) override {
42+
4643
RewritePatternSet owningPatterns(context);
47-
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
48-
patterns = FrozenRewritePatternSet(std::move(owningPatterns),
49-
disabledPatterns, enabledPatterns);
44+
if (!basis.empty()) {
45+
// Restrict to patterns useful for the target basis
46+
cudaq::selectDecompositionPatterns(owningPatterns, basis,
47+
disabledPatterns, enabledPatterns);
48+
patterns = FrozenRewritePatternSet(std::move(owningPatterns));
49+
} else {
50+
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
51+
patterns = FrozenRewritePatternSet(std::move(owningPatterns),
52+
disabledPatterns, enabledPatterns);
53+
}
54+
5055
return success();
5156
}
5257

0 commit comments

Comments
 (0)