diff --git a/include/cudaq/Optimizer/CodeGen/QIRFunctionNames.h b/include/cudaq/Optimizer/CodeGen/QIRFunctionNames.h index 6fc350d6d39..db32636a82f 100644 --- a/include/cudaq/Optimizer/CodeGen/QIRFunctionNames.h +++ b/include/cudaq/Optimizer/CodeGen/QIRFunctionNames.h @@ -118,6 +118,8 @@ static constexpr const char QISApplyKrausChannel[] = static constexpr const char QISTrap[] = "__quantum__qis__trap"; +static constexpr const char QISSaveState[] = "__quantum__qis__save_state"; + /// Since apply noise is actually a call back to `C++` code, the `QIR` data type /// `Array` of `Qubit*` must be converted into a `cudaq::qvector`, which is /// presently a `std::vector` but with an extremely restricted diff --git a/include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td b/include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td index 6fcd053f80c..1842a6fcfd9 100644 --- a/include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td +++ b/include/cudaq/Optimizer/Dialect/Quake/QuakeOps.td @@ -530,6 +530,17 @@ def quake_ApplyNoiseOp : QuakeOp<"apply_noise", [AttrSizedOperandSegments]> { }]; } +def quake_SaveStateOp : QuakeOp<"save_state"> { + let summary = "Save quantum state representation compatible with the simulator."; + let description = [{ + This operation provides support for the `cudaq::save_state` + function. This function is only valid in simulation contexts where the + simulator is part of the same process as the C++ host executable itself. + }]; + + // No arguments are needed. +} + //===----------------------------------------------------------------------===// // Memory and register conversion instructions: These operations are useful for // intermediate conversions between memory-SSA and value-SSA semantics and vice diff --git a/include/cudaq/Optimizer/Transforms/Passes.td b/include/cudaq/Optimizer/Transforms/Passes.td index e5fbe75a96b..b25e4e76158 100644 --- a/include/cudaq/Optimizer/Transforms/Passes.td +++ b/include/cudaq/Optimizer/Transforms/Passes.td @@ -459,6 +459,15 @@ def EraseNoise : Pass<"erase-noise"> { }]; } +def EraseSaveState : Pass<"erase-save-state"> { + let summary = "Erase the injection of save_state."; + let description = [{ + Although CUDA-Q allows the user to save state representations, + these are not needed and must be removed if the code is to + run on quantum hardware, for example. + }]; +} + def EraseNopCalls : Pass<"erase-nop-calls"> { let summary = "Erase calls to any builtin intrinsics that are NOPs."; let description = [{ diff --git a/lib/Frontend/nvqpp/ConvertExpr.cpp b/lib/Frontend/nvqpp/ConvertExpr.cpp index e1f70b84839..c1ea277738f 100644 --- a/lib/Frontend/nvqpp/ConvertExpr.cpp +++ b/lib/Frontend/nvqpp/ConvertExpr.cpp @@ -1634,6 +1634,11 @@ bool QuakeBridgeVisitor::VisitCallExpr(clang::CallExpr *x) { return false; } + if (funcName == "save_state") { + builder.create(loc); + return true; + } + if (funcName == "mx" || funcName == "my" || funcName == "mz") { // Measurements always return a bool or a std::vector. bool useStdvec = diff --git a/lib/Optimizer/Builder/Intrinsics.cpp b/lib/Optimizer/Builder/Intrinsics.cpp index fd7622981fb..5f698e47c2e 100644 --- a/lib/Optimizer/Builder/Intrinsics.cpp +++ b/lib/Optimizer/Builder/Intrinsics.cpp @@ -538,6 +538,7 @@ static constexpr IntrinsicCode intrinsicTable[] = { func.func private @__quantum__qis__convert_array_to_stdvector(!qir_array) -> !qir_array func.func private @__quantum__qis__free_converted_stdvector(!qir_array) + func.func private @__quantum__qis__save_state() llvm.func @generalizedInvokeWithRotationsControlsTargets(i64, i64, i64, i64, !qir_llvmptr, ...) attributes {sym_visibility = "private"} llvm.func @__quantum__qis__apply_kraus_channel_generalized(i64, i64, i64, i64, i64, ...) attributes {sym_visibility = "private"} )#"}, diff --git a/lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp b/lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp index 7ddb4aae51e..c06186f5c72 100644 --- a/lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp +++ b/lib/Optimizer/CodeGen/ConvertToQIRAPI.cpp @@ -1528,6 +1528,18 @@ struct ReturnOpPattern : public OpConversionPattern { } }; +struct SaveStateOpRewrite : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(quake::SaveStateOp saveState, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + saveState, TypeRange{}, cudaq::opt::QISSaveState, ValueRange{}); + return success(); + } +}; + /// Convert the quake types in `func::FuncOp` signatures. struct FuncSignaturePattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1792,7 +1804,7 @@ struct FullQIR { /* Irregular quantum operators. */ CustomUnitaryOpPattern, ExpPauliOpPattern, MeasurementOpPattern, ResetOpPattern, - ApplyNoiseOpRewrite, + ApplyNoiseOpRewrite, SaveStateOpRewrite, /* Regular quantum operators. */ QuantumGatePattern, diff --git a/lib/Optimizer/Transforms/CMakeLists.txt b/lib/Optimizer/Transforms/CMakeLists.txt index 9e505a1e138..8c41108d424 100644 --- a/lib/Optimizer/Transforms/CMakeLists.txt +++ b/lib/Optimizer/Transforms/CMakeLists.txt @@ -30,6 +30,7 @@ add_cudaq_library(OptTransforms DistributedDeviceCall.cpp EraseNoise.cpp EraseNopCalls.cpp + EraseSaveState.cpp EraseVectorCopyCtor.cpp ExpandControlVeqs.cpp ExpandMeasurements.cpp diff --git a/lib/Optimizer/Transforms/EraseSaveState.cpp b/lib/Optimizer/Transforms/EraseSaveState.cpp new file mode 100644 index 00000000000..4351d85cab1 --- /dev/null +++ b/lib/Optimizer/Transforms/EraseSaveState.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "PassDetails.h" +#include "cudaq/Optimizer/Builder/Intrinsics.h" +#include "cudaq/Optimizer/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" + +namespace cudaq::opt { +#define GEN_PASS_DEF_ERASESAVESTATE +#include "cudaq/Optimizer/Transforms/Passes.h.inc" +} // namespace cudaq::opt + +#define DEBUG_TYPE "erase-save-state" + +using namespace mlir; + +/// \file +/// This pass exists simply to remove all the quake.save_state (and related) +/// Ops from the IR. + +namespace { +template +class EraseSaveStatePattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(Op saveState, + PatternRewriter &rewriter) const override { + rewriter.eraseOp(saveState); + return success(); + } +}; + +class EraseSaveStatePass + : public cudaq::opt::impl::EraseSaveStateBase { +public: + using EraseSaveStateBase::EraseSaveStateBase; + + void runOnOperation() override { + auto *op = getOperation(); + LLVM_DEBUG(llvm::dbgs() << "Before erasure:\n" << *op << "\n\n"); + auto *ctx = &getContext(); + RewritePatternSet patterns(ctx); + patterns.insert>(ctx); + if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) + signalPassFailure(); + LLVM_DEBUG(llvm::dbgs() << "After erasure:\n" << *op << "\n\n"); + } +}; +} // namespace diff --git a/lib/Optimizer/Transforms/Pipelines.cpp b/lib/Optimizer/Transforms/Pipelines.cpp index dd84cfb3c33..ee5247d297d 100644 --- a/lib/Optimizer/Transforms/Pipelines.cpp +++ b/lib/Optimizer/Transforms/Pipelines.cpp @@ -18,6 +18,9 @@ struct TargetPrepPipelineOptions PassOptions::Option eraseNoise{ *this, "erase-noise", llvm::cl::desc("Erase apply noise calls."), llvm::cl::init(true)}; + PassOptions::Option eraseSaveState{ + *this, "erase-save-state", llvm::cl::desc("Erase save state calls."), + llvm::cl::init(true)}; PassOptions::Option applyConstProp{ *this, "apply-const-prop", llvm::cl::desc("Enable constant propagation in apply specialization."), diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index aceab7ec75a..1b4f52d38e3 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -2491,6 +2491,543 @@ def bodyBuilder(iterVal): # Handled in the Attribute visit, # since `numpy` arrays have a size attribute self.visit(node.func) + namedArgs = {} + for keyword in node.keywords: + self.visit(keyword.value) + namedArgs[keyword.arg] = self.popValue() + + value = self.popValue() + + if node.func.attr == 'array': + # `np.array(vec, )` + arrayType = value.type + if cc.PointerType.isinstance(value.type): + arrayType = cc.PointerType.getElementType(value.type) + + if cc.StdvecType.isinstance(arrayType): + eleTy = cc.StdvecType.getElementType(arrayType) + dTy = eleTy + if len(namedArgs) > 0: + dTy = namedArgs['dtype'] + + # Convert the vector to the provided data type if needed. + self.pushValue( + self.__copyVectorAndCastElements(value, dTy)) + return + + raise self.emitFatalError( + f"unexpected numpy array initializer type: {value.type}", + node) + + value = self.ifPointerThenLoad(value) + + if node.func.attr in ['complex128', 'complex64']: + if node.func.attr == 'complex128': + ty = self.getComplexType() + eleTy = self.getFloatType() + if node.func.attr == 'complex64': + ty = self.getComplexType(width=32) + eleTy = self.getFloatType(width=32) + + value = self.changeOperandToType(ty, value) + if (ty == value.type): + self.pushValue(value) + return + + real = complex.ReOp(value).result + imag = complex.ImOp(value).result + real = self.changeOperandToType(eleTy, real) + imag = self.changeOperandToType(eleTy, imag) + + self.pushValue(complex.CreateOp(ty, real, imag).result) + return + + if node.func.attr in ['float64', 'float32']: + if node.func.attr == 'float64': + ty = self.getFloatType() + if node.func.attr == 'float32': + ty = self.getFloatType(width=32) + + value = self.changeOperandToType(ty, value) + self.pushValue(value) + return + + # Promote argument's types for `numpy.func` calls to match python's semantics + if node.func.attr in ['sin', 'cos', 'sqrt', 'ceil', 'exp']: + if ComplexType.isinstance(value.type): + value = self.changeOperandToType( + self.getComplexType(), value) + if IntegerType.isinstance(value.type): + value = self.changeOperandToType( + self.getFloatType(), value) + + if node.func.attr == 'cos': + if ComplexType.isinstance(value.type): + self.pushValue(complex.CosOp(value).result) + return + self.pushValue(math.CosOp(value).result) + return + if node.func.attr == 'sin': + if ComplexType.isinstance(value.type): + self.pushValue(complex.SinOp(value).result) + return + self.pushValue(math.SinOp(value).result) + return + if node.func.attr == 'sqrt': + if ComplexType.isinstance(value.type): + self.pushValue(complex.SqrtOp(value).result) + return + self.pushValue(math.SqrtOp(value).result) + return + if node.func.attr == 'exp': + if ComplexType.isinstance(value.type): + # Note: using `complex.ExpOp` results in a + # "can't legalize `complex.exp`" error. + # Using Euler's' formula instead: + # + # "e^(x+i*y) = (e^x) * (cos(y)+i*sin(y))" + complexType = ComplexType(value.type) + floatType = complexType.element_type + real = complex.ReOp(value).result + imag = complex.ImOp(value).result + left = self.changeOperandToType(complexType, + math.ExpOp(real).result) + re2 = math.CosOp(imag).result + im2 = math.SinOp(imag).result + right = complex.CreateOp(ComplexType.get(floatType), + re2, im2).result + res = complex.MulOp(left, right).result + self.pushValue(res) + return + self.pushValue(math.ExpOp(value).result) + return + if node.func.attr == 'ceil': + if ComplexType.isinstance(value.type): + self.emitFatalError( + f"numpy call ({node.func.attr}) is not supported for complex numbers", + node) + return + self.pushValue(math.CeilOp(value).result) + return + + self.emitFatalError( + f"unsupported NumPy call ({node.func.attr})", node) + + self.generic_visit(node) + + if node.func.value.id == 'cudaq': + if node.func.attr == 'complex': + self.pushValue(self.simulationDType()) + return + + if node.func.attr == 'amplitudes': + value = self.popValue() + arrayType = value.type + if cc.PointerType.isinstance(value.type): + arrayType = cc.PointerType.getElementType(value.type) + if cc.StdvecType.isinstance(arrayType): + self.pushValue(value) + return + + self.emitFatalError( + f"unsupported amplitudes argument type: {value.type}", + node) + + if node.func.attr == 'qvector': + if len(self.valueStack) == 0: + self.emitFatalError( + 'qvector does not have default constructor. Init from size or existing state.', + node) + + valueOrPtr = self.popValue() + initializerTy = valueOrPtr.type + + if cc.PointerType.isinstance(initializerTy): + initializerTy = cc.PointerType.getElementType( + initializerTy) + + if (IntegerType.isinstance(initializerTy)): + # handle `cudaq.qvector(n)` + value = self.ifPointerThenLoad(valueOrPtr) + ty = self.getVeqType() + qubits = quake.AllocaOp(ty, size=value).result + self.pushValue(qubits) + return + if cc.StdvecType.isinstance(initializerTy): + # handle `cudaq.qvector(initState)` + + # Validate the length in case of a constant initializer: + # `cudaq.qvector([1., 0., ...])` + # `cudaq.qvector(np.array([1., 0., ...]))` + value = self.ifPointerThenLoad(valueOrPtr) + listScalar = None + arrNode = node.args[0] + if isinstance(arrNode, ast.List): + listScalar = arrNode.elts + + if isinstance(arrNode, ast.Call) and isinstance( + arrNode.func, ast.Attribute): + if arrNode.func.value.id in [ + 'numpy', 'np' + ] and arrNode.func.attr == 'array': + lst = node.args[0].args[0] + if isinstance(lst, ast.List): + listScalar = lst.elts + + if listScalar != None: + size = len(listScalar) + numQubits = np.log2(size) + if not numQubits.is_integer(): + self.emitFatalError( + "Invalid input state size for qvector init (not a power of 2)", + node) + + eleTy = cc.StdvecType.getElementType(value.type) + size = cc.StdvecSizeOp(self.getIntegerType(), + value).result + numQubits = math.CountTrailingZerosOp(size).result + + # TODO: Dynamically check if number of qubits is power of 2 + # and if the state is normalized + + ptrTy = cc.PointerType.get(eleTy) + arrTy = cc.ArrayType.get(eleTy) + ptrArrTy = cc.PointerType.get(arrTy) + veqTy = quake.VeqType.get() + + qubits = quake.AllocaOp(veqTy, size=numQubits).result + data = cc.StdvecDataOp(ptrArrTy, value).result + init = quake.InitializeStateOp(veqTy, qubits, + data).result + self.pushValue(init) + return + + if cc.StateType.isinstance(initializerTy): + # handle `cudaq.qvector(state)` + statePtr = self.ifNotPointerThenStore(valueOrPtr) + + i64Ty = self.getIntegerType() + numQubits = quake.GetNumberOfQubitsOp(i64Ty, + statePtr).result + + veqTy = quake.VeqType.get() + qubits = quake.AllocaOp(veqTy, size=numQubits).result + init = quake.InitializeStateOp(veqTy, qubits, + statePtr).result + + self.pushValue(init) + return + + self.emitFatalError( + f"unsupported qvector argument type: {value.type}", + node) + return + + if node.func.attr == "qubit": + if len(self.valueStack) == 1 and IntegerType.isinstance( + self.valueStack[0].type): + self.emitFatalError( + 'cudaq.qubit() constructor does not take any arguments. To construct a vector of qubits, use `cudaq.qvector(N)`.' + ) + self.pushValue(quake.AllocaOp(self.getRefType()).result) + return + + if node.func.attr == 'adjoint': + # Handle cudaq.adjoint(kernel, ...) + otherFuncName = node.args[0].id + if otherFuncName in self.symbolTable: + # This is a callable block argument + values = [ + self.popValue() + for _ in range(len(self.valueStack) - 2) + ] + quake.ApplyOp([], [self.popValue()], [], values) + return + + if otherFuncName not in globalKernelRegistry: + self.emitFatalError( + f"{otherFuncName} is not a known quantum kernel (was it annotated?)." + ) + + values = [ + self.popValue() for _ in range(len(self.valueStack)) + ] + values.reverse() + if len(values) != len( + globalKernelRegistry[otherFuncName].arguments): + self.emitFatalError( + f"incorrect number of runtime arguments for cudaq.control({otherFuncName},..) call.", + node) + quake.ApplyOp([], [], [], + values, + callee=FlatSymbolRefAttr.get(nvqppPrefix + + otherFuncName), + is_adj=True) + return + + if node.func.attr == 'control': + # Handle cudaq.control(kernel, ...) + otherFuncName = node.args[0].id + if otherFuncName in self.symbolTable: + # This is a callable argument + values = [ + self.popValue() + for _ in range(len(self.valueStack) - 2) + ] + controls = self.popValue() + a = quake.ApplyOp([], [self.popValue()], [controls], + values) + return + + if otherFuncName not in globalKernelRegistry: + self.emitFatalError( + f"{otherFuncName} is not a known quantum kernel (was it annotated?).", + node) + values = [ + self.popValue() + for _ in range(len(self.valueStack) - 1) + ] + values.reverse() + if len(values) != len( + globalKernelRegistry[otherFuncName].arguments): + self.emitFatalError( + f"incorrect number of runtime arguments for cudaq.control({otherFuncName},..) call.", + node) + controls = self.popValue() + self.checkControlAndTargetTypes([controls], []) + quake.ApplyOp([], [], [controls], + values, + callee=FlatSymbolRefAttr.get(nvqppPrefix + + otherFuncName)) + return + + if node.func.attr == 'apply_noise': + # Pop off all the arguments we need + values = [ + self.popValue() for _ in range(len(self.valueStack)) + ] + # They are in reverse order + values.reverse() + # First one should be the number of Kraus channel parameters + numParamsVal = values[0] + # Shrink the arguments down + values = values[1:] + + # Need to get the number of parameters as an integer + concreteIntAttr = IntegerAttr( + numParamsVal.owner.attributes['value']) + numParams = concreteIntAttr.value + + # Next Value is our generated key for the channel + # Get it and shrink the list + key = values[0] + values = values[1:] + + # Now we know the next `numParams` arguments are + # our Kraus channel parameters + params = values[:numParams] + for i, p in enumerate(params): + # If we have a F64 value, we want to + # store it to a pointer + if F64Type.isinstance(p.type): + alloca = cc.AllocaOp(cc.PointerType.get(p.type), + TypeAttr.get(p.type)).result + cc.StoreOp(p, alloca) + params[i] = alloca + + # The remaining arguments are the qubits + asVeq = quake.ConcatOp(self.getVeqType(), + values[numParams:]).result + quake.ApplyNoiseOp(params, [asVeq], key=key) + return + + if node.func.attr == 'save_state': + quake.SaveStateOp() + return + + if node.func.attr == 'compute_action': + # There can only be 2 arguments here. + action = None + compute = None + actionArg = node.args[1] + if isinstance(actionArg, ast.Name): + actionName = actionArg.id + if actionName in self.symbolTable: + action = self.symbolTable[actionName] + else: + self.emitFatalError( + "could not find action lambda / function in the symbol table.", + node) + else: + action = self.popValue() + + computeArg = node.args[0] + if isinstance(computeArg, ast.Name): + computeName = computeArg.id + if computeName in self.symbolTable: + compute = self.symbolTable[computeName] + else: + self.emitFatalError( + "could not find compute lambda / function in the symbol table.", + node) + else: + compute = self.popValue() + + quake.ComputeActionOp(compute, action) + return + + self.emitFatalError( + f'Invalid function or class type requested from the cudaq module ({node.func.attr})', + node) + + if node.func.value.id in self.symbolTable: + # Method call on one of our variables + var = self.symbolTable[node.func.value.id] + if quake.VeqType.isinstance(var.type): + if node.func.attr == 'size': + # Handled already in the Attribute visit + return + + # `qreg` or `qview` method call + if node.func.attr == 'back': + qrSize = quake.VeqSizeOp(self.getIntegerType(), + var).result + one = self.getConstantInt(1) + endOff = arith.SubIOp(qrSize, one) + if len(node.args): + # extract the `subveq` + startOff = arith.SubIOp(qrSize, self.popValue()) + dyna = IntegerAttr.get(self.getIntegerType(), -1) + self.pushValue( + quake.SubVeqOp(self.getVeqType(), + var, + dyna, + dyna, + lower=startOff, + upper=endOff).result) + else: + # extract the qubit... + self.pushValue( + quake.ExtractRefOp(self.getRefType(), + var, + -1, + index=endOff).result) + return + if node.func.attr == 'front': + zero = self.getConstantInt(0) + if len(node.args): + # extract the `subveq` + qrSize = self.popValue() + one = self.getConstantInt(1) + offset = arith.SubIOp(qrSize, one) + dyna = IntegerAttr.get(self.getIntegerType(), -1) + self.pushValue( + quake.SubVeqOp(self.getVeqType(), + var, + dyna, + dyna, + lower=zero, + upper=offset).result) + else: + # extract the qubit... + self.pushValue( + quake.ExtractRefOp(self.getRefType(), + var, + -1, + index=zero).result) + return + + def maybeProposeOpAttrFix(opName, attrName): + """ + Check the quantum operation attribute name and + propose a smart fix message if possible. For example, + if we have `x.control(...)` then remind the programmer the + correct attribute is `x.ctrl(...)`. + """ + # TODO Add more possibilities in the future... + if attrName in ['control' + ] or 'control' in attrName or 'ctrl' in attrName: + return f'Did you mean {opName}.ctrl(...)?' + + if attrName in ['adjoint' + ] or 'adjoint' in attrName or 'adj' in attrName: + return f'Did you mean {opName}.adj(...)?' + + return '' + + # We have a `func_name.ctrl` + if node.func.value.id in ['h', 'x', 'y', 'z', 's', 't']: + if node.func.attr == 'ctrl': + target = self.popValue() + # Should be number of arguments minus one for the controls + controls = [ + self.popValue() for i in range(len(node.args) - 1) + ] + if not controls: + self.emitFatalError( + 'controlled operation requested without any control argument(s).', + node) + negatedControlQubits = None + if len(self.controlNegations): + negCtrlBools = [None] * len(controls) + for i, c in enumerate(controls): + negCtrlBools[i] = c in self.controlNegations + negatedControlQubits = DenseBoolArrayAttr.get( + negCtrlBools) + self.controlNegations.clear() + + opCtor = getattr(quake, + '{}Op'.format(node.func.value.id.title())) + self.checkControlAndTargetTypes(controls, [target]) + opCtor([], [], + controls, [target], + negated_qubit_controls=negatedControlQubits) + return + if node.func.attr == 'adj': + target = self.popValue() + self.checkControlAndTargetTypes([], [target]) + opCtor = getattr(quake, + '{}Op'.format(node.func.value.id.title())) + if quake.VeqType.isinstance(target.type): + + def bodyBuilder(iterVal): + q = quake.ExtractRefOp(self.getRefType(), + target, + -1, + index=iterVal).result + opCtor([], [], [], [q], is_adj=True) + + veqSize = quake.VeqSizeOp(self.getIntegerType(), + target).result + self.createInvariantForLoop(veqSize, bodyBuilder) + return + elif quake.RefType.isinstance(target.type): + opCtor([], [], [], [target], is_adj=True) + return + else: + self.emitFatalError( + 'adj quantum operation on incorrect type {}.'. + format(target.type), node) + + self.emitFatalError( + f'Unknown attribute on quantum operation {node.func.value.id} ({node.func.attr}). {maybeProposeOpAttrFix(node.func.value.id, node.func.attr)}' + ) + + # We have a `func_name.ctrl` + if node.func.value.id == 'swap' and node.func.attr == 'ctrl': + targetB = self.popValue() + targetA = self.popValue() + controls = [ + self.popValue() for i in range(len(self.valueStack)) + ] + if not controls: + self.emitFatalError( + 'controlled operation requested without any control argument(s).', + node) + opCtor = getattr(quake, + '{}Op'.format(node.func.value.id.title())) + self.checkControlAndTargetTypes(controls, [targetA, targetB]) + opCtor([], [], controls, [targetA, targetB]) return if self.__isSupportedVectorFunction(node.func.attr): diff --git a/python/tests/builder/test_save_state.py b/python/tests/builder/test_save_state.py new file mode 100644 index 00000000000..e59bd82a5d3 --- /dev/null +++ b/python/tests/builder/test_save_state.py @@ -0,0 +1,39 @@ +# ============================================================================ # +# Copyright (c) 2025 NVIDIA Corporation & Affiliates. # +# All rights reserved. # +# # +# This source code and the accompanying materials are made available under # +# the terms of the Apache License 2.0 which accompanies this distribution. # +# ============================================================================ # + +import os + +import pytest +import numpy as np + +import cudaq + + +@pytest.mark.parametrize('target', ['density-matrix-cpu', 'stim']) +def test_save_state_builtin(target: str): + cudaq.set_target(target) + + noise = cudaq.NoiseModel() + + @cudaq.kernel + def bell_depol2(d: float, flag: bool): + q, r = cudaq.qubit(), cudaq.qubit() + h(q) + cudaq.save_state() + + x.ctrl(q, r) + cudaq.save_state() + + if flag: + cudaq.apply_noise(cudaq.Depolarization2, d, q, r) + else: + cudaq.apply_noise(cudaq.Depolarization2, [d], q, r) + + counts = cudaq.sample(bell_depol2, 0.2, True, noise_model=noise) + assert len(counts) == 4 + print(counts) diff --git a/runtime/common/ExecutionContext.h b/runtime/common/ExecutionContext.h index 69776848a96..6ade8eabdcf 100644 --- a/runtime/common/ExecutionContext.h +++ b/runtime/common/ExecutionContext.h @@ -15,14 +15,167 @@ #include "Trace.h" #include "cudaq/algorithms/optimizer.h" #include "cudaq/operators.h" +#include #include #include +#include "nvqir/stim/StimState.h" + namespace cudaq { +using ErrorByShotLogEntry = std::pair>, + std::vector>>; +using ErrorLogType = std::vector>; + +struct RecordStorage { + + size_t memory_limit; + size_t current_memory; + ErrorLogType error_data; + RecordStorage(size_t limit = 1e9) : memory_limit(limit), current_memory(0) {} + + std::vector> recordedStates; + + void save_state(SimulationState *state) { + recordedStates.push_back(clone_state(state)); + } + const std::vector> & + get_recorded_states() const { + return recordedStates; + } + + void clear() { recordedStates.clear(); } + void dump_recorded_states() const { + for (std::size_t i = 0; i < recordedStates.size(); i++) { + recordedStates[i]->dump(std::cout); + } + } + + void dump_error_data() const { + printf("=== Error Data Dump ===\n"); + if (error_data.empty()) { + printf("(no error data)\n"); + return; + } + + for (const auto &[index, entry] : error_data) { + const auto &[x_errors, z_errors] = + entry; // both are vector> + + printf("\n---------------------------------------\n"); + printf(" Error Index: %zu\n", index); + printf("---------------------------------------\n"); + + // X error Shots + printf(" X Error Shots (%zu):\n", x_errors.size()); + if (x_errors.empty()) { + printf(" (none)\n"); + } else { + for (std::size_t i = 0; i < x_errors.size(); ++i) { + printf(" - Set %zu (%zu elements): ", i, x_errors[i].size()); + for (const auto &q : x_errors[i]) + printf("%zu ", q); + printf("\n"); + } + } + + // Z error Shots + printf(" Z Error Shots (%zu):\n", z_errors.size()); + if (z_errors.empty()) { + printf(" (none)\n"); + } else { + for (std::size_t i = 0; i < z_errors.size(); ++i) { + printf(" - Set %zu (%zu elements): ", i, z_errors[i].size()); + for (const auto &q : z_errors[i]) + printf("%zu ", q); + printf("\n"); + } + } + } + + printf("\n=== End of Error Data ===\n"); + } + + void record_error_data(const size_t index, const ErrorByShotLogEntry &entry) { + error_data.emplace_back(index, entry); + } + ~RecordStorage() { + std::cout << "Destroying RecordStorage with " << recordedStates.size() + << " recorded states.\n"; + } + +private: + std::unique_ptr clone_state(SimulationState *state) { + if (state->isArrayLike()) { + // Handle array-like states (CusvState, etc.) + return clone_array_like_state(state); + } else { + // Handle specialized states (CuDensityMatState, StimState, etc.) + return clone_specialized_state(state); + } + } + + std::unique_ptr + clone_array_like_state(SimulationState *state) { + auto numQubits = state->getNumQubits(); + if (numQubits > 20) { // Prevent exponential explosion + throw std::runtime_error("State too large to clone via amplitudes"); + } + + // Generate all basis states + auto totalStates = 1ULL << numQubits; + std::vector> basisStates; + for (size_t i = 0; i < totalStates; ++i) { + std::vector basis(numQubits); + for (size_t j = 0; j < numQubits; ++j) { + basis[j] = (i >> j) & 1; + } + basisStates.push_back(basis); + } + + auto amplitudes = state->getAmplitudes(basisStates); + + // Create new state with appropriate precision + if (state->getPrecision() == SimulationState::precision::fp32) { + std::vector> floatAmps; + for (const auto & : amplitudes) { + floatAmps.emplace_back(static_cast(amp.real()), + static_cast(amp.imag())); + } + return state->createFromData(floatAmps); + } else { + return state->createFromData(amplitudes); + } + } + + std::unique_ptr + clone_specialized_state(SimulationState *state) { + // Try dynamic_cast to known types that have clone methods + // this triggerd fatal error: library_types.h: No such file or directory + // if (auto* densityState = dynamic_cast(state)) { + // return CuDensityMatState::clone(*densityState); + //} + + if (auto *cloneable = dynamic_cast(state)) { + return cloneable->clone(); + } + + // Fallback for non-cloneable specialized states + throw std::runtime_error("Specialized state type does not support cloning"); + // For unknown specialized types, try createFromSizeAndPtr as fallback + // This might work for some specialized states + // auto tensor = state->getTensor(0); + // return state->createFromSizeAndPtr(tensor.get_num_elements(), + // tensor.data, 1); + } +}; /// The ExecutionContext is an abstraction to indicate how a CUDA-Q kernel /// should be executed. class ExecutionContext { + + ///@brief record storage for the states saved during execution + RecordStorage recordStorage; + public: /// @brief The Constructor, takes the name of the context /// @param n The name of the context @@ -142,5 +295,39 @@ class ExecutionContext { /// Note: Measurement Syndrome Matrix is defined in /// https://arxiv.org/pdf/2407.13826. std::optional> msm_dimensions; + + std::size_t randomSeed = 0; + + std::size_t replay_columns = 0; + + /// @brief Save the current simulation state in the recorded states storage. + void save_state(SimulationState *state) { recordStorage.save_state(state); } + + /// @brief Get the recorded states saved during execution. + const std::vector> & + get_recorded_states() const { + return recordStorage.get_recorded_states(); + } + + /// @brief Clear the recorded states saved during execution. + void clear_recorded_states() { recordStorage.clear(); } + + /// @brief Dump the recorded states saved during execution. + void dump_recorded_states() const { recordStorage.dump_recorded_states(); } + + void dump_error_data() const { recordStorage.dump_error_data(); } + + void record_error_data(const size_t index, const ErrorByShotLogEntry &entry) { + recordStorage.record_error_data(index, entry); + } + + const auto &get_error_data() const { return recordStorage.error_data; } + void set_error_data(const ErrorLogType &data) { + recordStorage.error_data = data; + } + void update_replay_columns(std::size_t cols) { replay_columns = cols; } + std::size_t get_replay_columns() const { return replay_columns; } + + void set_seed(std::size_t seed) { randomSeed = seed; } }; } // namespace cudaq diff --git a/runtime/common/SimulationState.h b/runtime/common/SimulationState.h index 2855d765aaa..e55c04ef43c 100644 --- a/runtime/common/SimulationState.h +++ b/runtime/common/SimulationState.h @@ -17,10 +17,19 @@ namespace cudaq { class SimulationState; +class ClonableState; /// Enum to specify the initial quantum state. enum class InitialState { ZERO, UNIFORM }; +/// @brief StimData now stores a list of (pointer, size) pairs +/// according to convention: +/// 0: pointer to num_qubits, size = 1 +/// 1: pointer to msm_err_count, size = 1 +/// 2: x_output array, coming from the frame simulator, size = x_output_size +/// 3: z_output array, coming from the frame simulator, size = z_output_size +using StimData = std::vector>; + /// @brief Encapsulates a list of tensors (data pointer and dimensions). // Note: tensor data is expected in column-major. using TensorStateData = @@ -31,7 +40,7 @@ using TensorStateData = using state_data = std::variant< std::vector>, std::vector>, std::pair *, std::size_t>, - std::pair *, std::size_t>, TensorStateData>; + std::pair *, std::size_t>, TensorStateData, StimData>; /// @brief The `SimulationState` interface provides and extension point /// for concrete circuit simulation sub-types to describe their @@ -131,6 +140,15 @@ class SimulationState { const_cast(dataCasted.data()), data.index()); } + if (std::holds_alternative(data)) { + if (isArrayLike()) + throw std::runtime_error( + "Cannot initialize state vector/density matrix state by stim " + "data. Please use stabilizer simulator backends."); + auto &dataCasted = std::get(data); + return createFromSizeAndPtr( + dataCasted.size(), const_cast(&dataCasted), data.index()); + } // Flat array state data // Check the precision first. Get the size and // data pointer from the input data. @@ -249,4 +267,12 @@ class SimulationState { /// @brief Destructor virtual ~SimulationState() {} }; + +/// @brief Interface for SimulationState subtypes that support cloning. +class ClonableState { +public: + virtual ~ClonableState() = default; + virtual std::unique_ptr clone() const = 0; +}; + } // namespace cudaq diff --git a/runtime/cudaq/builder/kernels.h b/runtime/cudaq/builder/kernels.h index 7673fae7587..e99109f9dc1 100644 --- a/runtime/cudaq/builder/kernels.h +++ b/runtime/cudaq/builder/kernels.h @@ -59,9 +59,9 @@ std::vector getAlphaY(const std::span data, /// to its internal representation. This implementation follows the algorithm /// defined in `https://arxiv.org/pdf/quant-ph/0407010.pdf`. template -void from_state(Kernel &&kernel, QuakeValue &qubits, - const std::span> data, - std::size_t inNumQubits = 0) { +inline void from_state(Kernel &&kernel, QuakeValue &qubits, + const std::span> data, + std::size_t inNumQubits = 0) { std::make_signed_t numQubits = qubits.constantSize().value_or(inNumQubits); if (numQubits <= 0) @@ -113,7 +113,7 @@ void from_state(Kernel &&kernel, QuakeValue &qubits, /// @brief Construct a CUDA-Q kernel that produces the /// given state. This overload will return the `kernel_builder` as a /// `unique_ptr`. -auto from_state(const std::span> data) { +inline auto from_state(const std::span> data) { auto numQubits = std::log2(data.size()); std::vector empty; auto kernel = std::make_unique>(empty); diff --git a/runtime/cudaq/platform/default/DefaultQuantumPlatform.cpp b/runtime/cudaq/platform/default/DefaultQuantumPlatform.cpp index d29ade10e96..938581df31c 100644 --- a/runtime/cudaq/platform/default/DefaultQuantumPlatform.cpp +++ b/runtime/cudaq/platform/default/DefaultQuantumPlatform.cpp @@ -56,6 +56,9 @@ class DefaultQPU : public cudaq::QPU { /// Overrides resetExecutionContext to forward to /// the ExecutionManager. Also handles observe post-processing void resetExecutionContext() override { + // check if it was reset by other preceding routines + if (!executionContext) + return; ScopedTraceWithContext( executionContext->name == "observe" ? cudaq::TIMING_OBSERVE : 0, "DefaultPlatform::resetExecutionContext", executionContext->name); diff --git a/runtime/cudaq/qis/execution_manager.h b/runtime/cudaq/qis/execution_manager.h index 2413305b016..6193a25e295 100644 --- a/runtime/cudaq/qis/execution_manager.h +++ b/runtime/cudaq/qis/execution_manager.h @@ -101,6 +101,10 @@ class ExecutionManager { /// Checker for qudits that were not deallocated bool memoryLeaked() { return !tracker.allDeallocated(); } + void save_state() { + printf("ExecutionManager::save_state() not implemented.\n"); + return; + } /// Provide an ExecutionContext for the current cudaq kernel virtual void setExecutionContext(cudaq::ExecutionContext *ctx) = 0; diff --git a/runtime/cudaq/qis/managers/BasicExecutionManager.h b/runtime/cudaq/qis/managers/BasicExecutionManager.h index ca33e7fab24..d2b7a0f30ec 100644 --- a/runtime/cudaq/qis/managers/BasicExecutionManager.h +++ b/runtime/cudaq/qis/managers/BasicExecutionManager.h @@ -252,6 +252,11 @@ class BasicExecutionManager : public cudaq::ExecutionManager { return; } + void save_state() { + printf("BasicExecutionManager::save_state() not implemented.\n"); + return; + } + void synchronize() override { for (auto &instruction : instructionQueue) { if (!isInTracerMode()) { diff --git a/runtime/cudaq/qis/managers/default/DefaultExecutionManager.cpp b/runtime/cudaq/qis/managers/default/DefaultExecutionManager.cpp index 84750453c45..f2ee146badf 100644 --- a/runtime/cudaq/qis/managers/default/DefaultExecutionManager.cpp +++ b/runtime/cudaq/qis/managers/default/DefaultExecutionManager.cpp @@ -221,6 +221,16 @@ class DefaultExecutionManager : public cudaq::BasicExecutionManager { })(); } + void save_state() { + auto *ctx = nvqir::getCircuitSimulatorInternal()->getExecutionContext(); + if (!ctx) + return; + + std::unique_ptr state = + nvqir::getCircuitSimulatorInternal()->getCurrentSimulationState(); + ctx->save_state(state.get()); + } + void applyNoise(const kraus_channel &channel, const std::vector &targets) override { if (isInTracerMode()) diff --git a/runtime/cudaq/qis/qubit_qis.h b/runtime/cudaq/qis/qubit_qis.h index 0b1835e6fa8..4ea8c155fb6 100644 --- a/runtime/cudaq/qis/qubit_qis.h +++ b/runtime/cudaq/qis/qubit_qis.h @@ -1386,6 +1386,8 @@ void apply_noise(Args &&...args) { details::tuple_slice_last(std::forward_as_tuple(args...))); } +inline void save_state() { getExecutionManager()->save_state(); } + } // namespace cudaq #define __qop__ __attribute__((annotate("user_custom_quantum_operation"))) diff --git a/runtime/nvqir/CircuitSimulator.h b/runtime/nvqir/CircuitSimulator.h index 057109336ce..163dd7a153a 100644 --- a/runtime/nvqir/CircuitSimulator.h +++ b/runtime/nvqir/CircuitSimulator.h @@ -151,6 +151,16 @@ class CircuitSimulator { /// https://arxiv.org/pdf/2407.13826. virtual void generateMSM() {} + /// @brief Return the internal state representation. This + /// is meant for subtypes to override + virtual std::unique_ptr getSimulationState() = 0; + + /// @brief Get the current simulation state. + /// The method returns the current state of the simulation without flushing + /// the gate queue. + virtual std::unique_ptr + getCurrentSimulationState() = 0; + /// @brief Apply exp(-i theta PauliTensorProd) to the underlying state. /// This must be provided by subclasses. virtual void applyExpPauli(double theta, @@ -542,6 +552,14 @@ class CircuitSimulatorBase : public CircuitSimulator { "Simulation data not available for this simulator backend."); } + /// @brief Get the current simulation state. + /// The method returns the current state of the simulation without flushing + /// the gate queue. + virtual std::unique_ptr getCurrentSimulationState() { + throw std::runtime_error( + "Simulation data not available for this simulator backend."); + } + /// @brief Handle basic sampling tasks by storing the qubit index for /// processing in resetExecutionContext. Return true to indicate this is /// sampling and to exit early. False otherwise. diff --git a/runtime/nvqir/NVQIR.cpp b/runtime/nvqir/NVQIR.cpp index 16f8d6d2045..0fbe3e1f1b5 100644 --- a/runtime/nvqir/NVQIR.cpp +++ b/runtime/nvqir/NVQIR.cpp @@ -840,6 +840,25 @@ void __quantum__qis__apply_kraus_channel_generalized( va_end(args); } +extern "C" void __quantum__qis__save_state() { + CUDAQ_INFO("NVQIR:: saving state"); + auto *ctx = nvqir::getCircuitSimulatorInternal()->getExecutionContext(); + if (!ctx) { + CUDAQ_INFO("NVQIR::No execution context, cannot save state"); + return; + } + + CUDAQ_INFO("NVQIR::Context name: {}", ctx->name.c_str()); + + std::unique_ptr state = + nvqir::getCircuitSimulatorInternal()->getCurrentSimulationState(); + + CUDAQ_INFO("NVQIR::simulator name : {}", + nvqir::getCircuitSimulatorInternal()->name().c_str()); + + ctx->save_state(state.get()); +} + namespace details { struct FakeQubit { std::int8_t *id; diff --git a/runtime/nvqir/cudensitymat/CuDensityMatState.cpp b/runtime/nvqir/cudensitymat/CuDensityMatState.cpp index 6b55e010251..a73b651841c 100644 --- a/runtime/nvqir/cudensitymat/CuDensityMatState.cpp +++ b/runtime/nvqir/cudensitymat/CuDensityMatState.cpp @@ -387,6 +387,10 @@ CuDensityMatState::clone(const CuDensityMatState &other) { return std::unique_ptr(state); } +std::unique_ptr CuDensityMatState::clone() const { + return CuDensityMatState::clone(*this); +} + CuDensityMatState::CuDensityMatState(CuDensityMatState &&other) noexcept : isDensityMatrix(other.isDensityMatrix), dimension(other.dimension), devicePtr(other.devicePtr), cudmState(other.cudmState), diff --git a/runtime/nvqir/cudensitymat/CuDensityMatState.h b/runtime/nvqir/cudensitymat/CuDensityMatState.h index 14a12db61b7..37d0cf87c8f 100644 --- a/runtime/nvqir/cudensitymat/CuDensityMatState.h +++ b/runtime/nvqir/cudensitymat/CuDensityMatState.h @@ -15,7 +15,7 @@ namespace cudaq { /// @cond // This is an internal class, no API documentation. // Simulation state implementation for `CuDensityMatState`. -class CuDensityMatState : public cudaq::SimulationState { +class CuDensityMatState : public cudaq::SimulationState, public ClonableState { private: bool isDensityMatrix = false; std::size_t dimension = 0; @@ -114,6 +114,10 @@ class CuDensityMatState : public cudaq::SimulationState { // Clone a state static std::unique_ptr clone(const CuDensityMatState &other); + + // Clone a state + std::unique_ptr clone() const override; + // Prevent copies (avoids double free issues) CuDensityMatState(const CuDensityMatState &) = delete; CuDensityMatState &operator=(const CuDensityMatState &) = delete; diff --git a/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp b/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp index efe081701ee..77fb3566d3b 100644 --- a/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp +++ b/runtime/nvqir/custatevec/CuStateVecCircuitSimulator.cpp @@ -736,6 +736,18 @@ class CuStateVecCircuitSimulator deviceStateVector); } + std::unique_ptr getCurrentSimulationState() override { + void *copiedDeviceVector; + HANDLE_CUDA_ERROR( + cudaMalloc(&copiedDeviceVector, stateDimension * sizeof(CudaDataType))); + HANDLE_CUDA_ERROR(cudaMemcpy(copiedDeviceVector, deviceStateVector, + stateDimension * sizeof(CudaDataType), + cudaMemcpyDeviceToDevice)); + + return std::make_unique>(stateDimension, + copiedDeviceVector); + } + bool isStateVectorSimulator() const override { return true; } std::string name() const override; diff --git a/runtime/nvqir/stim/StimCircuitSimulator.cpp b/runtime/nvqir/stim/StimCircuitSimulator.cpp index 3e7390a43a2..069effa257d 100644 --- a/runtime/nvqir/stim/StimCircuitSimulator.cpp +++ b/runtime/nvqir/stim/StimCircuitSimulator.cpp @@ -6,6 +6,7 @@ * the terms of the Apache License 2.0 which accompanies this distribution. * ******************************************************************************/ +#include "StimState.h" #include "nvqir/CircuitSimulator.h" #include "stim.h" @@ -60,6 +61,21 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { /// for speed) bool is_msm_mode = false; + /// @brief Whether or not the execution context name is "generate_data" (value + /// is cached for speed) + bool is_generate_data_mode = false; + std::size_t replay_columns = 0; + + bool is_replay_errors_mode = false; + + size_t error_log_vec_index = 0; + size_t noise_application_index = 0; + size_t last_column_touched = 0; + + ExecutionContext *exe_ctx; + + ErrorLogType error_log; + std::optional isValidStimNoiseChannel(const kraus_channel &channel) const { @@ -147,6 +163,108 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { return std::nullopt; } + std::pair>, + std::vector>> + get_x_z_tables_as_vectors(stim::FrameSimulator *sampleSim) { + auto *executionContext = getExecutionContext(); + auto batch_size = executionContext->shots; + auto num_qubits = sampleSim->num_qubits; + + std::vector> x_run; + std::vector> z_run; + + x_run.reserve(batch_size); + z_run.reserve(batch_size); + + for (size_t shot = 0; shot < batch_size; shot++) { + std::vector x_shot(num_qubits); + std::vector z_shot(num_qubits); + + for (std::size_t q = 0; q < num_qubits; q++) { + x_shot[q] = sampleSim->x_table[q][shot] ? 1 : 0; + z_shot[q] = sampleSim->z_table[q][shot] ? 1 : 0; + } + + x_run.push_back(std::move(x_shot)); + z_run.push_back(std::move(z_shot)); + } + + return std::make_pair(std::move(x_run), std::move(z_run)); + } + + std::vector xor_vectors(const std::vector &a, + const std::vector &b) { + // Determine the length of the longer vector + size_t max_size = std::max(a.size(), b.size()); + + // Create copies of the vectors, padded with zeros if necessary + std::vector a_padded = a; + std::vector b_padded = b; + + a_padded.resize(max_size, 0); + b_padded.resize(max_size, 0); + + // XOR element-wise + std::vector result(max_size); + for (size_t i = 0; i < max_size; ++i) { + result[i] = a_padded[i] ^ b_padded[i]; + } + + return result; + } + + /// @brief Find indices where vector elements equal 1 + std::vector find_one_indices(const std::vector &vec) const { + std::vector indices; + indices.reserve(vec.size()); // Optimize allocation + for (size_t i = 0; i < vec.size(); i++) { + if (vec[i] == 1) { + indices.push_back(i); + } + } + return indices; + } + + StimData serialize_frame_simulator(stim::FrameSimulator *sampleSim) { + StimData data; + CUDAQ_INFO("Serializing Stim Frame Simulator data"); + + auto *executionContext = getExecutionContext(); + auto batch_size = executionContext->shots; + CUDAQ_INFO("batch_size: {}", batch_size); + std::size_t num_qubits = sampleSim->num_qubits; + + // 0: num_qubits + std::size_t *num_qubits_ptr = new std::size_t(num_qubits); + CUDAQ_INFO("num_qubits: {}", *num_qubits_ptr); + data.push_back({num_qubits_ptr, 1}); + + // 1: msm_err_count + std::size_t *msm_err_count_ptr = new std::size_t(msm_err_count); + CUDAQ_INFO("msm_err_count: {}", *msm_err_count_ptr); + data.push_back({msm_err_count_ptr, 1}); + + // 2,3: x_output and z_output + uint8_t *x_output = new uint8_t[num_qubits * batch_size]; + uint8_t *z_output = new uint8_t[num_qubits * batch_size]; + + for (int shot = 0; shot < batch_size; shot++) { + for (std::size_t q = 0; q < num_qubits; q++) { + CUDAQ_INFO("q {}: x = {}, z = {}", q, + static_cast(sampleSim->x_table[q][shot]), + static_cast(sampleSim->z_table[q][shot])); + + x_output[q + shot * num_qubits] = sampleSim->x_table[q][shot] ? 1 : 0; + z_output[q + shot * num_qubits] = sampleSim->z_table[q][shot] ? 1 : 0; + } + } + + data.push_back({x_output, num_qubits * batch_size}); + data.push_back({z_output, num_qubits * batch_size}); + + return data; + } + /// @brief Grow the state vector by one qubit. void addQubitToState() override { addQubitsToState(1); } @@ -162,6 +280,10 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { batch_size = executionContext->msm_dimensions.value_or(std::make_pair(1, 1)) .second; + else if (executionContext && executionContext->name == "generate_data") + batch_size = executionContext->shots; + else if (executionContext && executionContext->name == "replay_errors") + batch_size = executionContext->replay_columns; return batch_size; } @@ -228,6 +350,24 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { executionContext->msm_prob_err_id.emplace(); executionContext->msm_prob_err_id->reserve(num_msm_cols); } + is_generate_data_mode = + executionContext && executionContext->name == "generate_data"; + is_replay_errors_mode = executionContext && + executionContext->name == "replay_errors" && + !executionContext->get_error_data().empty(); + + if (is_generate_data_mode) { + CUDAQ_INFO("Generate data mode enabled"); + noise_application_index = 0; // Reset for generation + } + + if (is_replay_errors_mode) { + CUDAQ_INFO("Replay errors mode enabled with {} logged errors", + executionContext->get_error_data().size()); + noise_application_index = 0; // Reset for replay matching + error_log_vec_index = 0; // Reset replay cursor to start + last_column_touched = 0; + } // If possible, provide a non-empty stim::CircuitStats in order to avoid // reallocations during execution. @@ -241,15 +381,19 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { // simulator. randomEngine.discard( std::uniform_int_distribution(1, 30)(randomEngine)); + sampleSim = std::make_unique>( circuit_stats, stim::FrameSimulatorMode::STORE_MEASUREMENTS_TO_MEMORY, batch_size, std::mt19937_64(randomEngine)); - if (is_msm_mode) { + if (is_msm_mode || is_generate_data_mode || is_replay_errors_mode) { sampleSim->guarantee_anticommutation_via_frame_randomization = false; } sampleSim->reset_all(); msm_err_count = 0; msm_id_counter = 0; + error_log_vec_index = 0; + noise_application_index = 0; + last_column_touched = 0; } } @@ -265,6 +409,9 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { msm_err_count = 0; msm_id_counter = 0; is_msm_mode = false; + error_log_vec_index = 0; + noise_application_index = 0; + last_column_touched = 0; } /// @brief Apply operation to all Stim simulators. @@ -338,13 +485,13 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { // If we have a valid operation, apply it if (auto res = isValidStimNoiseChannel(channel)) { + auto max_qubit = *std::max_element(qubits.begin(), qubits.end()); if (is_msm_mode) { // If the noise operation is the first operation done to a qubit, the // x_table and z_table may not be sized for the qubits. If that is the // case, then we simply perform a reset on the qubit to essentially // allocate it, which ensures the tables are resized to the correct // size. - auto max_qubit = *std::max_element(qubits.begin(), qubits.end()); if (sampleSim->num_qubits < max_qubit + 1) applyOpToSims("R", std::vector{max_qubit}); @@ -368,10 +515,163 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { } } msm_id_counter++; + } else if (is_generate_data_mode) { + CUDAQ_INFO("Generating data for noise operation ID {}", + error_log_vec_index); + + // allocate the qubits if needed + if (sampleSim->num_qubits < max_qubit + 1) + applyOpToSims("R", std::vector{max_qubit}); + + CUDAQ_INFO("Applying noise operation {} to qubits {}", res->stim_name, + fmt::join(qubits, ", ")); + + stim::Circuit noiseOps; + noiseOps.safe_append_u(res.value().stim_name, qubits, + channel.parameters); + + auto tables = get_x_z_tables_as_vectors(sampleSim.get()); + // Only apply the noise operations to the sample simulator (not the + // Tableau simulator). + sampleSim->safe_do_circuit(noiseOps); + + auto tables_after = get_x_z_tables_as_vectors(sampleSim.get()); + + // Initialize the log entry for this noise operation + ErrorByShotLogEntry error_log_entry; + + // Get batch size from the execution context + auto batch_size = executionContext->shots; + + // tables[0] = all X data, tables[1] = all Z data + // Each is a vector of shots, each shot is a vector of qubits + const auto &x_before = tables.first; // vector of shots + const auto &z_before = tables.second; + const auto &x_after = tables_after.first; + const auto &z_after = tables_after.second; + + // Process each shot + for (size_t shot = 0; shot < batch_size; shot++) { + // XOR to find which qubits flipped in this shot + auto x_diff = xor_vectors(x_before[shot], x_after[shot]); + auto z_diff = xor_vectors(z_before[shot], z_after[shot]); + + // Extract the qubit indices where flips occurred + auto x_flipped_qubits = find_one_indices(x_diff); + auto z_flipped_qubits = find_one_indices(z_diff); + + // Debug output + if (!x_flipped_qubits.empty() || !z_flipped_qubits.empty()) { + CUDAQ_INFO( + "Shot {}: X errors on qubits [{}], Z errors on qubits [{}]", + shot, fmt::join(x_flipped_qubits, ", "), + fmt::join(z_flipped_qubits, ", ")); + } + + // Store the indices for this shot (even if empty - maintains shot + // alignment) + if (!x_flipped_qubits.empty() && !z_flipped_qubits.empty()) { + error_log_entry.first.push_back(x_flipped_qubits); + error_log_entry.second.push_back(z_flipped_qubits); + replay_columns += 2; + } + } + + // Record the entire batch of errors under one ID + if (error_log_entry.first.empty() && error_log_entry.second.empty()) { + CUDAQ_INFO("No errors occurred for noise operation ID {}", + error_log_vec_index); + } else { + CUDAQ_INFO("Recording errors for noise operation ID {}", + error_log_vec_index); + executionContext->record_error_data(error_log_vec_index, + error_log_entry); + executionContext->update_replay_columns(replay_columns); + } + + error_log_vec_index++; + } else if (is_replay_errors_mode) { + CUDAQ_INFO("In replay mode: Noise application index: {}", + noise_application_index); + CUDAQ_INFO("Replaying errors for noise operation ID {}", + error_log_vec_index); + + if (sampleSim->num_qubits < max_qubit + 1) + applyOpToSims("R", std::vector{max_qubit}); + + // Ensure we have errors to replay + const auto &errors_to_replay = executionContext->get_error_data(); + if (errors_to_replay.empty()) { + CUDAQ_INFO("No errors to replay"); + return; + } + + // Check if we've exhausted the error log + if (noise_application_index >= errors_to_replay.size()) { + CUDAQ_INFO("All logged errors have been replayed"); + return; + } + + // wrong: CUDAQ_INFO("Replaying error entry {} of {}", + // noise_application_index, errors_to_replay.size()); + + // Fetch the error entry for this noise operation + const auto &error_entry = errors_to_replay[error_log_vec_index]; + const size_t logged_error_id = std::get<0>(error_entry); + const auto &error_log_entry = std::get<1>(error_entry); + + // Verify this is the correct error to replay + if (noise_application_index != logged_error_id) { + CUDAQ_INFO("Skipping - current noise application index {} doesn't " + "match logged ID {}", + noise_application_index, logged_error_id); + noise_application_index++; + return; + } + + // Extract X and Z error data + const auto &x_errors_per_shot = error_log_entry.first; + const auto &z_errors_per_shot = error_log_entry.second; + + // Get the number of shots to replay + size_t num_shots = x_errors_per_shot.size(); + CUDAQ_INFO("Replaying {} shots for error ID {}", num_shots, + error_log_vec_index); + + if (last_column_touched + num_shots > getBatchSize()) { + throw std::runtime_error(fmt::format( + "Not enough columns in Stim FrameSimulator to replay errors. " + "Needed {}, but only have {}.", + last_column_touched + num_shots, getBatchSize())); + } + // Apply the logged errors to each shot + for (size_t shot = 0; shot < num_shots; shot++) { + const auto &x_qubits = x_errors_per_shot[shot]; + const auto &z_qubits = z_errors_per_shot[shot]; + + // Apply X errors + for (uint8_t qubit : x_qubits) { + sampleSim->x_table[qubit][last_column_touched + shot] ^= 1; + CUDAQ_INFO(" Shot {}: Applied X error to qubit {}", shot, qubit); + } + + // Apply Z errors + for (uint8_t qubit : z_qubits) { + sampleSim->z_table[qubit][last_column_touched + shot] ^= 1; + CUDAQ_INFO(" Shot {}: Applied Z error to qubit {}", shot, qubit); + } + } + last_column_touched += num_shots; + // Move to the next error entry + noise_application_index++; + error_log_vec_index++; + CUDAQ_INFO("Finished replaying errors for noise operation ID {}", + error_log_vec_index - 1); } else { stim::Circuit noiseOps; noiseOps.safe_append_u(res.value().stim_name, qubits, channel.parameters); + // Only apply the noise operations to the sample simulator (not the // Tableau simulator). sampleSim->safe_do_circuit(noiseOps); @@ -460,6 +760,13 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { // simulator knows how to buffer the results across multiple sample() // invocations. supportsBufferedSample = true; + + auto exe_ctx = getExecutionContext(); + if (exe_ctx && exe_ctx->randomSeed != 0) { + setRandomSeed(exe_ctx->randomSeed); + CUDAQ_INFO("Setting random seed to {} in Stim simulator", + exe_ctx->randomSeed); + } } virtual ~StimCircuitSimulator() = default; @@ -478,6 +785,19 @@ class StimCircuitSimulator : public nvqir::CircuitSimulatorBase { "R", std::vector{static_cast(index)}); } + std::unique_ptr getSimulationState() override { + flushGateQueue(); + StimData data = serialize_frame_simulator(sampleSim.get()); + return std::make_unique(data); + } + + std::unique_ptr getCurrentSimulationState() override { + CUDAQ_INFO("Getting current simulation state from stim simulator"); + flushGateQueue(); + StimData data = serialize_frame_simulator(sampleSim.get()); + return std::make_unique(data); + } + /// @brief Sample the multi-qubit state. If \p qubits is empty and /// explicitMeasurements is set, this returns all previously saved /// measurements. diff --git a/runtime/nvqir/stim/StimState.h b/runtime/nvqir/stim/StimState.h new file mode 100644 index 00000000000..a2f6507aebf --- /dev/null +++ b/runtime/nvqir/stim/StimState.h @@ -0,0 +1,227 @@ +/****************************************************************-*- C++ -*-**** + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ +#pragma once + +#include "common/SimulationState.h" +#include +#include +#include +#include +#include +#include + +namespace cudaq { + +/// @brief Provides stabilizer simulation state representation using StimData. +class StimState : public SimulationState, public ClonableState { +private: + StimData data_; + + template + struct variant_type_index; + + template + struct variant_type_index> { + static constexpr std::size_t value = []() { + std::size_t index = 0; + bool found = + ((std::is_same::value ? true : (++index, false)) || ...); + return found ? index : throw "Type not in variant"; + }(); + }; + + // Disable copying to prevent shallow copy issues + StimState(const StimState &) = delete; + StimState &operator=(const StimState &) = delete; + + // Allow moving + StimState(StimState &&) = delete; + StimState &operator=(StimState &&) = delete; + +public: + /// @brief Construct from StimData (may copy). + explicit StimState(const StimData &d) : data_(d) {} + + /// @brief Construct from an rvalue StimData + explicit StimState(StimData &&d) : data_(std::move(d)) {} + + /// @brief Factory for this type from state_data. + std::unique_ptr + createFromData(const state_data &d) override { + if (!std::holds_alternative(d)) + throw std::runtime_error( + "[StimState] only supports StimData for initialization."); + return std::make_unique(std::get(d)); + } + +protected: + /// @brief Create from data pointer. + std::unique_ptr + createFromSizeAndPtr(std::size_t, void *ptr, std::size_t dataType) override { + if (dataType != variant_type_index::value) + throw std::runtime_error( + "[StimState] only supports StimData for initialization."); + + auto stim_data = static_cast(ptr); + return std::make_unique(*stim_data); + } + +public: + std::unique_ptr clone() const override { + // Note: This performs shallow copy of pointers in StimData + // Caller must ensure proper lifetime management of underlying data + return std::make_unique(data_); + } + + /// @brief This simulator is not array-like (must use Pauli frame APIs). + bool isArrayLike() const override { return false; } + + /// @brief Return the number of qubits. + std::size_t getNumQubits() const override { + if (data_.empty() || !data_[0].first || data_[0].second == 0) + throw std::runtime_error( + "[StimState] Invalid StimData: missing num_qubits."); + return *static_cast(data_[0].first); + } + + /// @brief Get MSM error count. + std::size_t getMsmErrorCount() const { + if (data_.size() < 2 || !data_[1].first || data_[1].second == 0) + throw std::runtime_error( + "[StimState] Invalid StimData: missing msm_err_count."); + return *static_cast(data_[1].first); + } + + /// @brief Get batch size (number of shots) + std::size_t getBatchSize() const { + std::size_t nq = getNumQubits(); + if (data_.size() > 2 && data_[2].first && data_[2].second > 0) { + return data_[2].second / nq; + } + return 0; + } + + /// @brief Get X value for a specific shot and qubit + uint8_t getXValue(std::size_t shot, std::size_t qubit) const { + if (data_.size() <= 2 || !data_[2].first) { + throw std::runtime_error("[StimState] No X output data"); + } + std::size_t nq = getNumQubits(); + std::size_t idx = qubit + shot * nq; + if (idx >= data_[2].second) { + throw std::out_of_range("[StimState] Index out of bounds"); + } + const auto *x_data = static_cast(data_[2].first); + return x_data[idx]; + } + + /// @brief Get Z value for a specific shot and qubit + uint8_t getZValue(std::size_t shot, std::size_t qubit) const { + if (data_.size() <= 3 || !data_[3].first) { + throw std::runtime_error("[StimState] No Z output data"); + } + std::size_t nq = getNumQubits(); + std::size_t idx = qubit + shot * nq; + if (idx >= data_[3].second) { + throw std::out_of_range("[StimState] Index out of bounds"); + } + const auto *z_data = static_cast(data_[3].first); + return z_data[idx]; + } + + /// @brief Tensor interface not supported for StimState. + Tensor getTensor(std::size_t idx = 0) const override { + throw std::runtime_error("[StimState] Tensor interface not supported."); + } + + std::vector getTensors() const override { return {}; } + std::size_t getNumTensors() const override { return 0; } + + /// @brief Overlap is not implemented for stabilizer states. + std::complex overlap(const SimulationState &other) override { + throw std::runtime_error( + "[StimState] overlap not implemented for stabilizer data."); + } + + /// @brief Amplitude access not supported for StimState. + std::complex getAmplitude(const std::vector &) override { + throw std::runtime_error( + "[StimState] amplitudes not supported for stabilizer states."); + } + + void dump(std::ostream &os) const override { + if (data_.size() < 2) { + os << "StimState { Invalid/empty data }"; + return; + } + + os << "StimState { qubits=" << getNumQubits() + << ", msm_err_count=" << getMsmErrorCount(); + + // Display batch size if data is available + std::size_t batch_size = getBatchSize(); + if (batch_size > 0) { + os << ", batch_size=" << batch_size; + } + os << " }\n"; + + // Display X and Z output sizes + if (data_.size() > 2 && data_[2].first) { + os << "X output size: " << data_[2].second << "\n"; + } + if (data_.size() > 3 && data_[3].first) { + os << "Z output size: " << data_[3].second << "\n"; + } + + // Display sample data for first few shots (if available) + if (batch_size > 0) { + std::size_t nq = getNumQubits(); + std::size_t shots_to_display = std::min(batch_size, std::size_t(5)); + + os << "\nSample data (first " << shots_to_display << " shots):\n"; + for (std::size_t shot = 0; shot < shots_to_display; ++shot) { + os << "Shot " << shot << ": X=["; + for (std::size_t qubit = 0; qubit < nq; ++qubit) { + if (qubit > 0) + os << " "; + os << static_cast(getXValue(shot, qubit)); + } + os << "] Z=["; + for (std::size_t qubit = 0; qubit < nq; ++qubit) { + if (qubit > 0) + os << " "; + os << static_cast(getZValue(shot, qubit)); + } + os << "]\n\n\n"; + } + } + } + + /// @brief Precision is always double for stabilizer/Stim data. + precision getPrecision() const override { return precision::fp64; } + + /// @brief Destroy any resources. + void destroyState() override { // Free all allocated memory in StimData + for (auto &[ptr, size] : data_) { + if (ptr) { + // Elements 0 and 1 are single std::size_t values + if (&ptr == &data_[0].first || &ptr == &data_[1].first) { + delete static_cast(ptr); + } + // Elements 2 and 3 are uint8_t arrays + else if ((&ptr == &data_[2].first || &ptr == &data_[3].first) && + size > 0) { + delete[] static_cast(ptr); + } + } + } + std::cout << "StimState destroyed and memory freed.\n"; + } +}; + +} // namespace cudaq diff --git a/unittests/CMakeLists.txt b/unittests/CMakeLists.txt index ce38a584d8e..a50489fe58a 100644 --- a/unittests/CMakeLists.txt +++ b/unittests/CMakeLists.txt @@ -42,6 +42,7 @@ set(CUDAQ_RUNTIME_TEST_SOURCES integration/kernels_tester.cpp common/MeasureCountsTester.cpp common/NoiseModelTester.cpp + integration/save_state_tester.cpp integration/tracer_tester.cpp integration/gate_library_tester.cpp ) diff --git a/unittests/integration/save_state_tester.cpp b/unittests/integration/save_state_tester.cpp new file mode 100644 index 00000000000..fa7fff4e1db --- /dev/null +++ b/unittests/integration/save_state_tester.cpp @@ -0,0 +1,101 @@ +/******************************************************************************* + * Copyright (c) 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ + +#include "CUDAQTestUtils.h" +#include "cudaq/builder/kernels.h" +#include +#include + +#if defined(CUDAQ_BACKEND_STIM) +CUDAQ_TEST(SaveStateTester, checkStimState) { + + struct multi_round_kernel { + void operator()(int num_qubits, int num_rounds, + double noise_probability) __qpu__ { + cudaq::qvector q(num_qubits); + for (int round = 0; round < num_rounds; round++) { + h(q[0]); + for (int qi = 0; qi < num_qubits; qi++) + cudaq::apply_noise(noise_probability, + q[qi]); + + cudaq::save_state(); + } + mz(q); + for (int qi = 0; qi < num_qubits; qi++) + reset(q[qi]); + } + }; + + int num_qubits = 5; + int num_rounds = 3; + double noise_bf_prob = 0.0625; + + cudaq::noise_model noise; + cudaq::bit_flip_channel bf(noise_bf_prob); + for (std::size_t i = 0; i < num_qubits; i++) + noise.add_channel("mz", {i}, bf); + cudaq::set_noise(noise); + + cudaq::ExecutionContext ctx_msm_size("msm_size"); + auto &platform = cudaq::get_platform(); + platform.set_exec_ctx(&ctx_msm_size); + multi_round_kernel{}(num_qubits, num_rounds, noise_bf_prob); + platform.reset_exec_ctx(); + + cudaq::ExecutionContext ctx_msm("msm"); + ctx_msm.noiseModel = &noise; + ctx_msm.msm_dimensions = ctx_msm_size.msm_dimensions; + platform.set_exec_ctx(&ctx_msm); + multi_round_kernel{}(num_qubits, num_rounds, noise_bf_prob); + platform.reset_exec_ctx(); + + // accessing the execution context to get the recorded states + assert(ctx_msm.get_recorded_states().size() == num_rounds && + "Expected 3 state snapshots"); + + for (std::size_t round = 0; round < ctx_msm.get_recorded_states().size(); + ++round) { + const auto &state = ctx_msm.get_recorded_states()[round]; + + assert(state.getNumQubits() == num_qubits && + "Number of qubits must equal num_qubits"); + /* + for (std::size_t qi = 0; qi < state.getNumQubits(); ++qi) { + const auto &x_row = state.getTableau().x_output[qi]; + const auto &z_row = state.getTableau().z_output[qi]; + + // Each row should have the same number of qubits. + assert(x_row.num_qubits == state.getNumQubits()); + assert(z_row.num_qubits == state.getNumQubits()); + + // Check that each character in the row is a valid Pauli symbol. + for (char c : x_row.str()) { + assert((c == 'I' || c == 'X' || c == 'Y' || c == 'Z') && + "Invalid symbol in tableau X-output row"); + } + for (char c : z_row.str()) { + assert((c == 'I' || c == 'X' || c == 'Y' || c == 'Z') && + "Invalid symbol in tableau Z-output row"); + } + } + const auto &frame = state.getPauliFrame(); + + // Frame size should match number of qubits. + assert(frame.size() == state.getNumQubits()); + + // Each frame entry must be a valid Pauli operator. + for (char c : frame) { + assert((c == 'I' || c == 'X' || c == 'Y' || c == 'Z') && + "Invalid symbol in Pauli frame"); + } + } + */ + } +} +#endif diff --git a/unittests/integration/test.cpp b/unittests/integration/test.cpp new file mode 100644 index 00000000000..0ad04d7b67f --- /dev/null +++ b/unittests/integration/test.cpp @@ -0,0 +1,239 @@ +/******************************************************************************* + * Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. * + * All rights reserved. * + * * + * This source code and the accompanying materials are made available under * + * the terms of the Apache License 2.0 which accompanies this distribution. * + ******************************************************************************/ +#include + +/* + +std::vector> detection_matrix( + operation statePrep, + bool run_mz_circuit, bool keep_x_stabilizers, bool keep_z_stabilizers) { + + if (!code.contains_operation(statePrep)) + throw std::runtime_error("prep kernel not found."); + + if (!keep_x_stabilizers && !keep_z_stabilizers) + throw std::runtime_error(" no stabilizers to keep."); + + std::vector> detection_matrix; // to return + + std::size_t numCols = numAncx + numAncz; + + cudaq::ExecutionContext ctx_msm_size("msm_size"); + ctx_msm_size.noiseModel = &noise; + auto &platform = cudaq::get_platform(); + platform.set_exec_ctx(&ctx_msm_size); + + // Run the memory circuit experiment + if (run_mz_circuit) { + memory_circuit_mz(stabRound, prep, numData, numAncx, numAncz, numRounds, + xVec, zVec); + } else { + memory_circuit_mx(stabRound, prep, numData, numAncx, numAncz, numRounds, + xVec, zVec); + } + + platform.reset_exec_ctx(); + + if (!ctx_msm_size.msm_dimensions.has_value()) { + throw std::runtime_error("dem_from_memory_circuit error: no MSM dimensions " + "found. One reason could be missing a target."); + } + if (ctx_msm_size.msm_dimensions.value().second == 0) { + throw std::runtime_error( + "dem_from_memory_circuit error: no noise mechanisms found in circuit. " + "Cannot generate a DEM. Did you forget to enable noise?"); + } + + cudaq::ExecutionContext ctx_msm("msm"); + ctx_msm.noiseModel = &noise; + ctx_msm.msm_dimensions = ctx_msm_size.msm_dimensions; + platform.set_exec_ctx(&ctx_msm); + + // Run the memory circuit experiment + if (run_mz_circuit) { + memory_circuit_mz(stabRound, prep, numData, numAncx, numAncz, numRounds, + xVec, zVec); + } else { + memory_circuit_mx(stabRound, prep, numData, numAncx, numAncz, numRounds, + xVec, zVec); + } + + platform.reset_exec_ctx(); + + // Populate error rates and error IDs + dem.error_rates = std::move(ctx_msm.msm_probabilities.value()); + dem.error_ids = std::move(ctx_msm.msm_prob_err_id.value()); + + auto msm_as_strings = ctx_msm.result.sequential_data(); + cudaqx::tensor msm_data( + std::vector({ctx_msm_size.msm_dimensions->first, + ctx_msm_size.msm_dimensions->second})); + cudaqx::tensor mzTable(msm_as_strings); + mzTable = mzTable.transpose(); + std::size_t numNoiseMechs = mzTable.shape()[1]; + + std::size_t numSyndromesPerRound = numXStabs + numZStabs; + + // Populate dem.detector_error_matrix by XORing consecutive rounds. Generally + // speaking, this is calculating H = D*Ω, where H is the Detector Error + // Matrix, D is the Detector Matrix, and Ω is Measurement Syndrome Matrix. + // However, D is very sparse, and is it represents simple XORs of a syndrome + // with the prior round's syndrome. + // Reference: https://arxiv.org/pdf/2407.13826 + + auto numReturnSynPerRound = numSyndromesPerRound; + + if (keep_x_stabilizers && !keep_z_stabilizers) { + numReturnSynPerRound = numXStabs; + } else if (!keep_x_stabilizers && keep_z_stabilizers) { + numReturnSynPerRound = numZStabs; + } + + // If we are returning only x-stabilizers, we need to offset the syndrome + // indices of mzTable by numSyndromesPerRound / 2. + auto offset = keep_x_stabilizers && !keep_z_stabilizers ? numZStabs : 0; + dem.detector_error_matrix = cudaqx::tensor( + {numRounds * numReturnSynPerRound, numNoiseMechs}); + for (std::size_t round = 0; round < numRounds; round++) { + if (round == 0) { + for (std::size_t syndrome = 0; syndrome < numReturnSynPerRound; + syndrome++) { + for (std::size_t noise_mech = 0; noise_mech < numNoiseMechs; + noise_mech++) { + dem.detector_error_matrix.at( + {round * numReturnSynPerRound + syndrome, noise_mech}) = + mzTable.at({round * numSyndromesPerRound + syndrome + offset, + noise_mech}); + } + } + } else { + for (std::size_t syndrome = 0; syndrome < numReturnSynPerRound; + syndrome++) { + for (std::size_t noise_mech = 0; noise_mech < numNoiseMechs; + noise_mech++) { + dem.detector_error_matrix.at( + {round * numReturnSynPerRound + syndrome, noise_mech}) = + mzTable.at({round * numSyndromesPerRound + syndrome + offset, + noise_mech}) ^ + mzTable.at( + {(round - 1) * numSyndromesPerRound + syndrome + offset, + noise_mech}); + } + } + } + } + + // Uncomment for debugging: + // printf("dem.detector_error_matrix:\n"); + // dem.detector_error_matrix.dump_bits(); + + // Populate dem.observables_flips_matrix by converting the physical data qubit + // measurements to logical observables. + auto first_data_row = numRounds * numSyndromesPerRound; + assert(first_data_row < mzTable.shape()[0]); + + cudaqx::tensor msm_obs( + {mzTable.shape()[0] - first_data_row, numNoiseMechs}); + for (std::size_t row = first_data_row; row < mzTable.shape()[0]; row++) + for (std::size_t col = 0; col < numNoiseMechs; col++) + msm_obs.at({row - first_data_row, col}) = mzTable.at({row, col}); + + // Populate dem.observables_flips_matrix by converting the physical data qubit + // measurements to logical observables. + dem.observables_flips_matrix = obs_matrix.dot(msm_obs) % 2; + + // printf("getting obs_matrix : \n"); + // obs_matrix.dump_bits(); + + // printf("getting msm_obs : \n"); + // msm_obs.dump_bits(); + + // Uncomment print statements for debugging: + // printf("dem.detector_error_matrix Before canonicalization:\n"); + // dem.detector_error_matrix.dump_bits(); + // printf("dem.observables_flips_matrix Before canonicalization:\n"); + // dem.observables_flips_matrix.dump_bits(); + dem.canonicalize_for_rounds(numReturnSynPerRound); + // printf("dem.detector_error_matrix After canonicalization:\n"); + // dem.detector_error_matrix.dump_bits(); + // printf("dem.observables_flips_matrix After canonicalization:\n"); + // dem.observables_flips_matrix.dump_bits(); + + return dem; +} +*/ + +__qpu__ int sx(cudaq::qview<> qubits, cudaq::qview<> ancillas) {} + +__qpu__ int sz(cudaq::qview<> qubits, cudaq::qview<> ancillas) {} + +__qpu__ auto kernel(int num_qubits, int num_rounds) { + cudaq::qvector q(num_qubits); + + for (int i = 0; i < num_qubits; i++) { + cudaq::apply_noise(0.1, 0.1, 0.1, q[i]); + } + + for (int i = 0; i < num_rounds; i++) { + h(q[1]); + cudaq::save_state(); + x(q[1], q[0]); + cudaq::save_state(); + } + return cudaq::to_integer(mz(q)); +} + +int main() { + + int num_qubits = 6; + int num_rounds = 2; + + double noise_bf_prob = 1.; + + cudaq::noise_model noise; + cudaq::depolarization_channel depolarization(noise_bf_prob); + + noise.add_channel({0}, depolarization); + noise.add_channel({1}, depolarization); + noise.add_channel({0}, depolarization); + + cudaq::set_noise(noise); + + cudaq::ExecutionContext ctx_gen("generate_data", 2); + auto &platform2 = cudaq::get_platform(); + ctx_gen.set_seed(41); + + platform2.set_exec_ctx(&ctx_gen); + + auto m = kernel(num_qubits, num_rounds); + printf(" ------------------ Result: %d\n", m); + + ctx_gen.dump_recorded_states(); + ctx_gen.dump_error_data(); + auto errors = ctx_gen.get_error_data(); + auto replay_cols = ctx_gen.get_replay_columns(); + printf("Total replay columns: %zu\n", replay_cols); + platform2.reset_exec_ctx(); + printf("Total errors recorded: %zu\n", errors.size()); + printf("+++++++++++++++++++++++++++++++++\n"); + printf("+++++++++++++++++++++++++++++++++\n"); + + cudaq::ExecutionContext ctx_rep("replay_errors", 2); + platform2.set_exec_ctx(&ctx_rep); + ctx_rep.set_error_data(errors); + ctx_rep.update_replay_columns(replay_cols); + ctx_rep.set_seed(41); + + auto m2 = kernel(num_qubits, num_rounds); + printf(" ------------------ Result after replaying errors: %d\n", m2); + + ctx_rep.dump_recorded_states(); + ctx_rep.dump_error_data(); + platform2.reset_exec_ctx(); + return 0; +}