Skip to content

Commit 77eed07

Browse files
committed
add DecompositionPatternSelection, use it in Decomposition and BasisConversion
Signed-off-by: Luca Mondada <[email protected]>
1 parent b3701de commit 77eed07

File tree

8 files changed

+732
-12
lines changed

8 files changed

+732
-12
lines changed

include/cudaq/Optimizer/Transforms/Passes.td

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,14 @@ def BasisConversionPass : Pass<"basis-conversion", "mlir::ModuleOp"> {
136136
- `x(1)` means targeting pauli-x operations with one control (aka, `cx`)
137137
- `x(n)` means targeting pauli-x operation with unbounded number of controls
138138
- `x,x(1)` means targeting both `not` and `cx` operations
139+
140+
The pass automatically selects the set of rewrite patterns, ensuring every
141+
gate is decomposed to the specified basis in a unique way. Options
142+
`enable-patterns` and `disable-patterns` can be used to specify or further
143+
filter the selected rewrite patterns.
144+
145+
If no `basis` is specified or the pass cannot decompose all operations to
146+
the specified basis, the pass application will fail.
139147
}];
140148
let options = [
141149
ListOption<"basis", "basis", "std::string", "Set of basis operations">,
@@ -314,10 +322,9 @@ def DecompositionPass: Pass<"decomposition", "mlir::ModuleOp"> {
314322
the maximum number of iterations is exhausted.
315323

316324
When `basis` is specified, the pass automatically selects the set of rewrite
317-
patterns, ensuring an acyclic pattern set targeting the specified basis.
318-
Options `enable-patterns` and `disable-patterns` can further filter the
319-
selected rewrite patterns.
320-
325+
patterns, ensuring every gate is decomposed to the specified basis in a
326+
unique way. Options `enable-patterns` and `disable-patterns` can be used to
327+
specify or further filter the selected rewrite patterns.
321328

322329
The `basis` option takes a comma-separated list of quantum operations with
323330
the format: `<op-name>([<number-of-controls> | n])?`
@@ -328,6 +335,8 @@ def DecompositionPass: Pass<"decomposition", "mlir::ModuleOp"> {
328335
- `x(n)` — Pauli-X with unbounded controls
329336
- `x,x(1)` — Both `not` and `cx`
330337

338+
If no `basis` is specified, as many patterns as possible are applied.
339+
331340
NOTE: The current implementation is conservative w.r.t global phase, which
332341
means no decomposition will take place under the presence of controlled
333342
`quake.apply` operations in the module.

lib/Optimizer/Transforms/BasisConversion.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ struct BasisConversion
8888
// Setup target and patterns
8989
auto target = cudaq::createBasisTarget(getContext(), basis);
9090
RewritePatternSet owningPatterns(&getContext());
91-
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
92-
auto patterns = FrozenRewritePatternSet(std::move(owningPatterns),
93-
disabledPatterns, enabledPatterns);
91+
cudaq::selectDecompositionPatterns(owningPatterns, basis, disabledPatterns,
92+
enabledPatterns);
93+
auto patterns = FrozenRewritePatternSet(std::move(owningPatterns));
9494

9595
// Process kernels in parallel
9696
LogicalResult rewriteResult = failableParallelForEach(

lib/Optimizer/Transforms/Decomposition.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
2020
#include "mlir/Transforms/DialectConversion.h"
2121
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
#include <memory>
2223

2324
using namespace mlir;
2425

@@ -43,10 +44,20 @@ struct Decomposition
4344
/// Initialize the decomposer by building the set of patterns used during
4445
/// execution.
4546
LogicalResult initialize(MLIRContext *context) override {
47+
4648
RewritePatternSet owningPatterns(context);
47-
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
48-
patterns = FrozenRewritePatternSet(std::move(owningPatterns),
49-
disabledPatterns, enabledPatterns);
49+
FrozenRewritePatternSet patterns;
50+
if (!basis.empty()) {
51+
// Restrict to patterns useful for the target basis
52+
cudaq::selectDecompositionPatterns(owningPatterns, basis,
53+
disabledPatterns, enabledPatterns);
54+
patterns = FrozenRewritePatternSet(std::move(owningPatterns));
55+
} else {
56+
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
57+
patterns = FrozenRewritePatternSet(std::move(owningPatterns),
58+
disabledPatterns, enabledPatterns);
59+
}
60+
5061
return success();
5162
}
5263

0 commit comments

Comments
 (0)