Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions include/cudaq/Optimizer/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ def BasisConversionPass : Pass<"basis-conversion", "mlir::ModuleOp"> {
- `x(1)` means targeting pauli-x operations with one control (aka, `cx`)
- `x(n)` means targeting pauli-x operation with unbounded number of controls
- `x,x(1)` means targeting both `not` and `cx` operations

The pass automatically selects the set of rewrite patterns, ensuring every
gate is decomposed to the specified basis in a unique way. Option
`disable-patterns` can be used to filter the selected rewrite patterns.
Option `enable-patterns` can be used to override the automatic pattern
selection.

If no `basis` is specified or the pass cannot decompose all operations to
the specified basis, the pass application will fail.
}];
let options = [
ListOption<"basis", "basis", "std::string", "Set of basis operations">,
Expand Down Expand Up @@ -309,17 +318,38 @@ def DeadStoreRemoval : Pass<"dead-store-removal"> {
def DecompositionPass: Pass<"decomposition", "mlir::ModuleOp"> {
let summary = "Break down quantum operations.";
let description = [{
This pass performs decomposition over a set of operations by iteratively
applying decomposition patterns until either a fixpoint is reached or the
maximum number of iterations/rewrites is exhausted. Decomposition is
best-effort and does not guarantee that the entire IR is decomposed after
running this pass.
This pass decomposes quantum operations by iteratively applying rewrite
patterns until all quantum operations are in basis, a fixpoint is reached or
the maximum number of iterations is exhausted.

When `basis` is specified, the pass automatically selects the set of rewrite
patterns, ensuring every gate is decomposed to the specified basis in a
unique way.

## Options

The following options are available and are all optional:

- `disable-patterns`: used to filter out specific rewrite patterns from the
selection.
- `enable-patterns`: when set, overrides the automatic pattern selection.
- `basis`: takes a comma-separated list of quantum operations with the
format: `<op-name>([<number-of-controls> | n])?`

If no `basis` is specified, as many patterns as possible are applied.

Examples of valid `basis` values:
- `x` — Pauli-X without controls (aka `not`)
- `x(1)` — Pauli-X with one control (aka `cx`)
- `x(n)` — Pauli-X with unbounded controls
- `x,x(1)` — Both `not` and `cx`

NOTE: The current implementation is conservative w.r.t global phase, which
means no decomposition will take place under the presence of controlled
`quake.apply` operations in the module.
}];
let options = [
ListOption<"basis", "basis", "std::string", "Set of basis operations">,
ListOption<"disabledPatterns", "disable-patterns", "std::string",
"Labels of decomposition patterns that should be filtered out">,
ListOption<"enabledPatterns", "enable-patterns", "std::string",
Expand Down
77 changes: 12 additions & 65 deletions lib/Optimizer/Transforms/BasisConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,66 +30,6 @@ namespace cudaq::opt {

namespace {

struct BasisTarget : public ConversionTarget {
struct OperatorInfo {
StringRef name;
size_t numControls;
};

BasisTarget(MLIRContext &context, ArrayRef<std::string> targetBasis)
: ConversionTarget(context) {
constexpr size_t unbounded = std::numeric_limits<size_t>::max();

// Parse the list of target operations and build a set of legal operations
for (const std::string &targetInfo : targetBasis) {
StringRef option = targetInfo;
auto nameEnd = option.find_first_of('(');
auto name = option.take_front(nameEnd);
if (nameEnd < option.size())
option = option.drop_front(nameEnd);

auto &info = legalOperatorSet.emplace_back(OperatorInfo{name, 0});
if (option.consume_front("(")) {
option = option.ltrim();
if (option.consume_front("n"))
info.numControls = unbounded;
else
option.consumeInteger(10, info.numControls);
assert(option.trim().consume_front(")"));
}
}

addLegalDialect<arith::ArithDialect, cf::ControlFlowDialect,
cudaq::cc::CCDialect, func::FuncDialect,
math::MathDialect>();
addDynamicallyLegalDialect<quake::QuakeDialect>([&](Operation *op) {
if (auto optor = dyn_cast<quake::OperatorInterface>(op)) {
auto name = optor->getName().stripDialect();
for (auto info : legalOperatorSet) {
if (info.name != name)
continue;
if (info.numControls == unbounded ||
optor.getControls().size() == info.numControls)
return info.numControls == optor.getControls().size();
}
return false;
}

// Handle quake.exp_pauli.
if (isa<quake::ExpPauliOp>(op)) {
// If the target defines it as a legal op, return true, else false.
return std::find_if(legalOperatorSet.begin(), legalOperatorSet.end(),
[](auto &&el) { return el.name == "exp_pauli"; }) !=
legalOperatorSet.end();
}

return true;
});
}

SmallVector<OperatorInfo, 8> legalOperatorSet;
};

//===----------------------------------------------------------------------===//
// Pass implementation
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -146,16 +86,23 @@ struct BasisConversion
return;

// Setup target and patterns
BasisTarget target(getContext(), basis);
auto target = cudaq::createBasisTarget(getContext(), basis);
RewritePatternSet owningPatterns(&getContext());
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
auto patterns = FrozenRewritePatternSet(std::move(owningPatterns),
disabledPatterns, enabledPatterns);
FrozenRewritePatternSet patterns;
if (enabledPatterns.empty()) {
cudaq::selectDecompositionPatterns(owningPatterns, basis,
disabledPatterns);
patterns = FrozenRewritePatternSet(std::move(owningPatterns));
} else {
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
patterns = FrozenRewritePatternSet(std::move(owningPatterns),
disabledPatterns, enabledPatterns);
}

// Process kernels in parallel
LogicalResult rewriteResult = failableParallelForEach(
module.getContext(), kernels, [&target, &patterns](Operation *op) {
return applyFullConversion(op, target, patterns);
return applyFullConversion(op, *target, patterns);
});

if (failed(rewriteResult))
Expand Down
1 change: 1 addition & 0 deletions lib/Optimizer/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ add_cudaq_library(OptTransforms
DeadStoreRemoval.cpp
Decomposition.cpp
DecompositionPatterns.cpp
DecompositionPatternSelection.cpp
DelayMeasurements.cpp
DeleteStates.cpp
DistributedDeviceCall.cpp
Expand Down
23 changes: 20 additions & 3 deletions lib/Optimizer/Transforms/Decomposition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "cudaq/Optimizer/Dialect/Quake/QuakeDialect.h"
#include "cudaq/Optimizer/Dialect/Quake/QuakeOps.h"
#include "cudaq/Optimizer/Transforms/Passes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Threading.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
Expand Down Expand Up @@ -40,9 +41,25 @@ struct Decomposition
/// execution.
LogicalResult initialize(MLIRContext *context) override {
RewritePatternSet owningPatterns(context);
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
patterns = FrozenRewritePatternSet(std::move(owningPatterns),
disabledPatterns, enabledPatterns);

if (!basis.empty() && !enabledPatterns.empty()) {
mlir::emitWarning(
mlir::UnknownLoc::get(context),
"DecompositionPass: basis is ignored when enabledPatterns is "
"specified");
}

if (!basis.empty() && enabledPatterns.empty()) {
// Restrict to patterns useful for the target basis
cudaq::selectDecompositionPatterns(owningPatterns, basis,
disabledPatterns);
patterns = FrozenRewritePatternSet(std::move(owningPatterns));
} else {
cudaq::populateWithAllDecompositionPatterns(owningPatterns);
patterns = FrozenRewritePatternSet(std::move(owningPatterns),
disabledPatterns, enabledPatterns);
}

return success();
}

Expand Down
Loading
Loading