diff --git a/include/cudaq/Optimizer/Dialect/CC/CCTypes.h b/include/cudaq/Optimizer/Dialect/CC/CCTypes.h index 3e7a17f36b7..33d476e3305 100644 --- a/include/cudaq/Optimizer/Dialect/CC/CCTypes.h +++ b/include/cudaq/Optimizer/Dialect/CC/CCTypes.h @@ -38,10 +38,16 @@ inline bool SpanLikeType::classof(mlir::Type type) { return mlir::isa(type); } -/// Return true if and only if \p ty has dynamic extent. This is a recursive +/// Returns true if and only if \p ty has dynamic extent. This is a recursive /// test on composable types. bool isDynamicType(mlir::Type ty); +/// Returns true if and only if the memory needed to store a value of type +/// \p ty is not known at compile time. This is a recursive test on composable +/// types. In contrast to `isDynamicType`, the size of the type is statically +/// known even if it contains pointers that may point to memory of dynamic size. +bool isDynamicallySizedType(mlir::Type ty); + /// Determine the number of hidden arguments, which is 0, 1, or 2. inline unsigned numberOfHiddenArgs(bool thisPtr, bool sret) { return (thisPtr ? 1 : 0) + (sret ? 1 : 0); diff --git a/lib/Frontend/nvqpp/ConvertStmt.cpp b/lib/Frontend/nvqpp/ConvertStmt.cpp index 964a2169c16..7be52808877 100644 --- a/lib/Frontend/nvqpp/ConvertStmt.cpp +++ b/lib/Frontend/nvqpp/ConvertStmt.cpp @@ -359,6 +359,9 @@ bool QuakeBridgeVisitor::VisitReturnStmt(clang::ReturnStmt *x) { if (!cudaq::cc::isDynamicType(eleTy)) tySize = irb.getByteSizeOfType(loc, eleTy); if (!tySize) { + // TODO: we need to recursively create copies of all + // dynamic memory used within the type. See the + // implementation of `visit_Return` in the Python bridge. TODO_x(toLocation(x), x, mangler, "unhandled vector element type"); return false; } diff --git a/lib/Optimizer/CodeGen/CCToLLVM.cpp b/lib/Optimizer/CodeGen/CCToLLVM.cpp index 3f8bbdbaa01..f710b4e94d5 100644 --- a/lib/Optimizer/CodeGen/CCToLLVM.cpp +++ b/lib/Optimizer/CodeGen/CCToLLVM.cpp @@ -531,8 +531,8 @@ class SizeOfOpPattern : public ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto inputTy = sizeOfOp.getInputType(); auto resultTy = sizeOfOp.getType(); - if (quake::isQuakeType(inputTy) || cudaq::cc::isDynamicType(inputTy)) { - // Types that cannot be reified produce the poison op. + if (quake::isQuakeType(inputTy) || + cudaq::cc::isDynamicallySizedType(inputTy)) { rewriter.replaceOpWithNewOp(sizeOfOp, resultTy); return success(); } diff --git a/lib/Optimizer/Dialect/CC/CCTypes.cpp b/lib/Optimizer/Dialect/CC/CCTypes.cpp index 3c9c50f1329..d600247d68d 100644 --- a/lib/Optimizer/Dialect/CC/CCTypes.cpp +++ b/lib/Optimizer/Dialect/CC/CCTypes.cpp @@ -213,6 +213,22 @@ bool isDynamicType(Type ty) { return false; } +bool isDynamicallySizedType(Type ty) { + if (isa(ty)) + return false; + if (auto strTy = dyn_cast(ty)) { + for (auto memTy : strTy.getMembers()) + if (isDynamicallySizedType(memTy)) + return true; + return false; + } + if (auto arrTy = dyn_cast(ty)) + return arrTy.isUnknownSize() || + isDynamicallySizedType(arrTy.getElementType()); + // Note: this isn't considering quake, builtin, etc. types. + return false; +} + void CCDialect::registerTypes() { addTypes(); diff --git a/python/cudaq/kernel/ast_bridge.py b/python/cudaq/kernel/ast_bridge.py index c808e8a99c4..189d080780d 100644 --- a/python/cudaq/kernel/ast_bridge.py +++ b/python/cudaq/kernel/ast_bridge.py @@ -7,15 +7,14 @@ # ============================================================================ # import ast -import inspect import importlib -import graphlib +import inspect import textwrap import numpy as np import os import sys +import types from collections import deque -from types import FunctionType from cudaq.mlir._mlir_libs._quakeDialects import (cudaq_runtime, load_intrinsic, gen_vector_of_complex_constant @@ -61,44 +60,45 @@ class PyScopedSymbolTable(object): def __init__(self): - self.symbolTable = {} + self.symbolTable = deque() def pushScope(self): - pass + self.symbolTable.append({}) def popScope(self): - pass + self.symbolTable.pop() def numLevels(self): - return 1 + return len(self.symbolTable) - def add(self, symbol, value, unused=None): + def add(self, symbol, value, level=-1): """ Add a symbol to the scoped symbol table at any scope level. """ - self.symbolTable[symbol] = value + self.symbolTable[level][symbol] = value def __contains__(self, symbol): - return symbol in self.symbolTable + for st in reversed(self.symbolTable): + if symbol in st: + return True + + return False def __setitem__(self, symbol, value): # default to nearest surrounding scope self.add(symbol, value) def __getitem__(self, symbol): - if symbol in self.symbolTable: - return self.symbolTable[symbol] + for st in reversed(self.symbolTable): + if symbol in st: + return st[symbol] + raise RuntimeError( f"{symbol} is not a valid variable name in this scope.") def clear(self): - self.symbolTable.clear() - - def __str__(self): - s = "" - for sym in self.symbolTable: - s += str(sym) + ": " + str(self.symbolTable[sym]) + "\n" - return s + while len(self.symbolTable): + self.symbolTable.pop() class CompilerError(RuntimeError): @@ -110,6 +110,95 @@ def __init__(self, *args, **kwargs): RuntimeError.__init__(self, *args, **kwargs) +class PyStack(object): + ''' + Takes care of managing values produced while vising Python + AST nodes. Each visit to a node is expected to match one + stack frame. Values produced (meaning pushed) by child frames + are accessible (meaning can be popped) by the parent. A frame + cannot access the value it produced (it is owned by the parent). + ''' + + class Frame(object): + + def __init__(self, parent=None): + self.entries = None + self.parent = parent + + def __init__(self, error_handler=None): + + def default_error_handler(msg): + raise RuntimeError(msg) + + self._frame = None + self.emitError = error_handler or default_error_handler + + def pushFrame(self): + ''' + A new frame should be pushed to process a new node in the AST. + ''' + if self._frame and not self._frame.entries: + self._frame.entries = deque() + self._frame = PyStack.Frame(parent=self._frame) + + def popFrame(self): + ''' + A frame should be popped once a node in the AST has been processed. + ''' + if not self._frame: + self.emitError("stack has no frames to pop") + elif self._frame.entries: + self.emitError( + "all values must be processed before popping a frame") + else: + self._frame = self._frame.parent + + def pushValue(self, value): + ''' + Pushes a value to the make it available to the parent frame. + ''' + if not self._frame: + self.emitError("cannot push value to empty stack") + elif not self._frame.parent: + self.emitError("no parent frame is defined to push values to") + else: + self._frame.parent.entries.append(value) + + def popValue(self): + ''' + Pops the most recently produced (pushed) value by a child frame. + ''' + if not self._frame: + self.emitError("value stack is empty") + elif not self._frame.entries: + # This is the only error that may be directly user-facing even when + # the bridge is doing its processing correctly. + # We hence give a somewhat general error. + # For internal purposes, the error might be better stated as something like: + # either this frame has not had a child or the child did not produce any values + self.emitError("no valid value was created") + else: + return self._frame.entries.pop() + + @property + def isEmpty(self): + ''' + Returns true if and only if there are no remaining stack frames. + ''' + return not self._frame + + @property + def currentNumValues(self): + ''' + Returns the number of values that are accessible for processing by the current frame. + ''' + if not self._frame: + self.emitError("no frame defined for empty stack") + elif self._frame.entries: + return len(self._frame.entries) + return 0 + + def recover_kernel_decorator(name): from .kernel_decorator import isa_kernel_decorator for frameinfo in inspect.stack(): @@ -144,7 +233,8 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs): track of a symbol table, which maps variable names to constructed `mlir.Values`. """ - self.valueStack = deque() + self.valueStack = PyStack(lambda msg: self.emitFatalError( + f'processing error - {msg}', self.currentNode)) self.knownResultType = kwargs[ 'knownResultType'] if 'knownResultType' in kwargs else None self.uniqueId = kwargs['uniqueId'] if 'uniqueId' in kwargs else None @@ -192,6 +282,7 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs): self.disableNvqppPrefix = kwargs[ 'disableNvqppPrefix'] if 'disableNvqppPrefix' in kwargs else False self.symbolTable = PyScopedSymbolTable() + # FIXME: NEEDS TO BE RESET ON FUNCTION DEF (LOCAL FUNC) self.controlHeight = 0 self.indent_level = 0 self.indent = 4 * " " @@ -201,8 +292,8 @@ def __init__(self, capturedDataStorage: CapturedDataStorage, **kwargs): self.currentAssignVariableName = None self.walkingReturnNode = False self.controlNegations = [] - self.subscriptPushPointerValue = False - self.attributePushPointerValue = False + self.pushPointerValue = False + self.isSubscriptRoot = False self.verbose = 'verbose' in kwargs and kwargs['verbose'] self.currentNode = None self.firstLiftedPos = None @@ -211,9 +302,13 @@ def debug_msg(self, msg, node=None): if self.verbose: print(f'{self.indent * self.indent_level}{msg()}') if node is not None: - print( - textwrap.indent(ast.unparse(node), - (self.indent * (self.indent_level + 1)))) + try: + print( + textwrap.indent(ast.unparse(node), + (self.indent * + (self.indent_level + 1)))) + except: + pass def emitWarning(self, msg, astNode=None): """ @@ -242,15 +337,19 @@ def emitFatalError(self, msg, astNode=None): codeFile = os.path.basename(self.locationOffset[0]) if astNode == None: astNode = self.currentNode - lineNumber = '' if astNode == None else astNode.lineno + self.locationOffset[ - 1] - 1 + lineNumber = '' if astNode == None or not hasattr( + astNode, 'lineno') else astNode.lineno + self.locationOffset[1] - 1 + + try: + offending_source = "\n\t (offending source -> " + ast.unparse( + astNode) + ")" + except: + offending_source = '' print(Color.BOLD, end='') msg = codeFile + ":" + str( lineNumber - ) + ": " + Color.RED + "error: " + Color.END + Color.BOLD + msg + ( - "\n\t (offending source -> " + ast.unparse(astNode) + ")" if - hasattr(ast, 'unparse') and astNode is not None else '') + Color.END + ) + ": " + Color.RED + "error: " + Color.END + Color.BOLD + msg + offending_source + Color.END raise CompilerError(msg) def getVeqType(self, size=None): @@ -288,6 +387,23 @@ def isMeasureResultType(self, ty, value): return False return IntegerType.isinstance(ty) and ty == IntegerType.get_signless(1) + def isFunctionArgument(self, value): + return (BlockArgument.isinstance(value) and + isinstance(value.owner.owner, func.FuncOp)) + + def containsList(self, ty, innerListsOnly=False): + """ + Returns true if the give type is a vector or contains + items that are vectors. + """ + if cc.StdvecType.isinstance(ty): + return (not innerListsOnly or + self.containsList(cc.StdvecType.getElementType(ty))) + if not cc.StructType.isinstance(ty): + return False + eleTys = cc.StructType.getTypes(ty) + return any((self.containsList(t) for t in eleTys)) + def getIntegerType(self, width=64): """ Return an MLIR `IntegerType` of the given bit width (defaults to 64 @@ -386,6 +502,29 @@ def getConstantInt(self, value, width=64): ty = self.getIntegerType(width) return arith.ConstantOp(ty, self.getIntegerAttr(ty, value)).result + def __arithmetic_to_bool(self, value): + """ + Converts an integer or floating point value to a bool by + comparing it to zero. + """ + if self.getIntegerType(1) == value.type: + return value + if IntegerType.isinstance(value.type): + zero = self.getConstantInt(0, width=IntegerType(value.type).width) + condPred = IntegerAttr.get(self.getIntegerType(), 1) + return arith.CmpIOp(condPred, value, zero).result + elif F32Type.isinstance(value.type): + zero = self.getConstantFloat(0, width=32) + condPred = IntegerAttr.get(self.getIntegerType(), 13) + return arith.CmpFOp(condPred, value, zero).result + elif F64Type.isinstance(value.type): + zero = self.getConstantFloat(0, width=64) + condPred = IntegerAttr.get(self.getIntegerType(), 13) + return arith.CmpFOp(condPred, value, zero).result + else: + self.emitFatalError("value cannot be converted to bool", + self.currentNode) + def changeOperandToType(self, ty, operand, allowDemotion=False): """ Change the type of an operand to a specified type. This function primarily @@ -395,6 +534,13 @@ def changeOperandToType(self, ty, operand, allowDemotion=False): """ if ty == operand.type: return operand + if cc.CallableType.isinstance(ty): + fctTy = cc.CallableType.getFunctionType(ty) + if fctTy == operand.type: + return operand + self.emitFatalError( + f'cannot convert value of type {operand.type} to the requested type {fctTy}', + self.currentNode) if ComplexType.isinstance(ty): complexType = ComplexType(ty) @@ -405,21 +551,46 @@ def changeOperandToType(self, ty, operand, allowDemotion=False): if (floatType != otherFloatType): real = self.changeOperandToType( floatType, - complex.ReOp(operand).result) + complex.ReOp(operand).result, + allowDemotion=allowDemotion) imag = self.changeOperandToType( floatType, - complex.ImOp(operand).result) + complex.ImOp(operand).result, + allowDemotion=allowDemotion) return complex.CreateOp(complexType, real, imag).result else: - real = self.changeOperandToType(floatType, operand) + real = self.changeOperandToType(floatType, + operand, + allowDemotion=allowDemotion) imag = self.getConstantFloatWithType(0.0, floatType) return complex.CreateOp(complexType, real, imag).result if (cc.StdvecType.isinstance(ty)): - eleTy = cc.StdvecType.getElementType(ty) if cc.StdvecType.isinstance(operand.type): - return self.__copyVectorAndCastElements( - operand, eleTy, allowDemotion=allowDemotion) + eleTy = cc.StdvecType.getElementType(ty) + return self.__copyVectorAndConvertElements( + operand, + eleTy, + allowDemotion=allowDemotion, + alwaysCopy=False) + + if (cc.StructType.isinstance(ty)): + if cc.StructType.isinstance(operand.type): + expectedEleTys = cc.StructType.getTypes(ty) + currentEleTys = cc.StructType.getTypes(operand.type) + if len(expectedEleTys) == len(currentEleTys): + + def conversion(idx, value): + return self.changeOperandToType( + expectedEleTys[idx], + value, + allowDemotion=allowDemotion) + + return self.__copyStructAndConvertElements( + operand, + expectedTy=ty, + allowDemotion=allowDemotion, + conversion=conversion) if F64Type.isinstance(ty): if F32Type.isinstance(operand.type): @@ -447,6 +618,8 @@ def changeOperandToType(self, ty, operand, allowDemotion=False): if requested_width == operand_width: return operand elif requested_width < operand_width: + if requested_width == 1: + return self.__arithmetic_to_bool(operand) return cc.CastOp(ty, operand).result return cc.CastOp(ty, operand, @@ -480,16 +653,26 @@ def pushValue(self, value): visit method. """ self.debug_msg(lambda: f'push {value}') - self.valueStack.append(value) + self.valueStack.pushValue(value) def popValue(self): """ Pop an MLIR Value from the stack. """ - val = self.valueStack.pop() + val = self.valueStack.popValue() self.debug_msg(lambda: f'pop {val}') return val + def popAllValues(self, expectedNumVals): + values = [ + self.popValue() for _ in range(self.valueStack.currentNumValues) + ] + if len(values) != expectedNumVals: + self.emitFatalError( + "processing error - expression did not produce a valid value in this context", + self.currentNode) + return values + def pushForBodyStack(self, bodyBlockArgs): """ Indicate that we are entering a for loop body block. @@ -579,45 +762,16 @@ def __isUnitaryGate(self, id): id in ['swap', 'u3', 'exp_pauli'] or id in globalRegisteredOperations) - def __isNoiseChannelClass(self, value): - """ - Return True if the given value is a Kraus channel class. - """ - return isinstance(value, type) and issubclass( - value, cudaq_runtime.KrausChannel) - - def ifPointerThenLoad(self, value): - """ - If the given value is of pointer type, load the pointer and return that - new value. - """ - if cc.PointerType.isinstance(value.type): - return cc.LoadOp(value).result - return value - - def ifNotPointerThenStore(self, value): - """ - If the given value is not of a pointer type, allocate a slot on the - stack, store the the value in the slot, and return the slot address. - """ - if not cc.PointerType.isinstance(value.type): - slot = cc.AllocaOp(cc.PointerType.get(value.type), - TypeAttr.get(value.type)).result - cc.StoreOp(value, slot) - return slot - return value + def __createStdvecWithKnownValues(self, listElementValues): - def __createStdvecWithKnownValues(self, size, listElementValues): - # Turn this List into a StdVec - arrSize = self.getConstantInt(size) + assert (len(set((v.type for v in listElementValues))) == 1) + arrSize = self.getConstantInt(len(listElementValues)) elemTy = listElementValues[0].type # If this is an `i1`, turns it into an `i8` array. isBool = elemTy == self.getIntegerType(1) if isBool: elemTy = self.getIntegerType(8) - - arrTy = cc.ArrayType.get(elemTy) - alloca = cc.AllocaOp(cc.PointerType.get(arrTy), + alloca = cc.AllocaOp(cc.PointerType.get(cc.ArrayType.get(elemTy)), TypeAttr.get(elemTy), seqSize=arrSize).result @@ -631,15 +785,10 @@ def __createStdvecWithKnownValues(self, size, listElementValues): v = self.changeOperandToType(self.getIntegerType(8), v) cc.StoreOp(v, eleAddr) - # Create the `StdVec` from the `alloca` - # We still use `i1` as the vector element type if the - # original list was of `bool` type. - vecTy = elemTy if not isBool else self.getIntegerType(1) - if cc.PointerType.isinstance(vecTy): - vecTy = cc.PointerType.getElementType(vecTy) - - return cc.StdvecInitOp(cc.StdvecType.get(vecTy), alloca, - length=arrSize).result + # We still use `i1` as the vector element type for `cc.StdvecInitOp`. + vecTy = cc.StdvecType.get(elemTy) if not isBool else cc.StdvecType.get( + self.getIntegerType(1)) + return cc.StdvecInitOp(vecTy, alloca, length=arrSize).result def getStructMemberIdx(self, memberName, structTy): """ @@ -651,6 +800,8 @@ def getStructMemberIdx(self, memberName, structTy): else: structName = quake.StruqType.getName(structTy) structIdx = None + if structName == 'tuple': + self.emitFatalError('`tuple` does not support attribute access') if not globalRegisteredTypes.isRegisteredClass(structName): self.emitFatalError(f'Dataclass is not registered: {structName})') @@ -661,84 +812,189 @@ def getStructMemberIdx(self, memberName, structTy): break if structIdx == None: self.emitFatalError( - f'Invalid struct member: {structName}.{memberName} (members=' - f'{[k for k,_ in userType.items()]})') - return structIdx, mlirTypeFromPyType(userType[memberName], self.ctx) - - # Create a new vector with source elements converted to the target element - # type if needed. - def __copyVectorAndCastElements(self, - source, - targetEleType, - allowDemotion=False): - if not cc.PointerType.isinstance(source.type): - if cc.StdvecType.isinstance(source.type): - # Exit early if no copy is needed to avoid an unneeded store. - sourceEleType = cc.StdvecType.getElementType(source.type) - if (sourceEleType == targetEleType): - return source - - sourcePtr = source - if not cc.PointerType.isinstance(sourcePtr.type): - sourcePtr = self.ifNotPointerThenStore(sourcePtr) - - sourceType = cc.PointerType.getElementType(sourcePtr.type) - if not cc.StdvecType.isinstance(sourceType): - raise RuntimeError( - f"expected vector type to copy and cast elements but received {sourceType}" + f'Invalid struct member: {structName}.{memberName} (members={[k for k,_ in userType.items()]})' ) + return structIdx, mlirTypeFromPyType(userType[memberName], self.ctx) - sourceEleType = cc.StdvecType.getElementType(sourceType) - if (sourceEleType == targetEleType): - return sourcePtr - + def __copyStructAndConvertElements(self, + struct, + expectedTy=None, + allowDemotion=False, + conversion=None): + """ + Creates a new struct on the stack. If a conversion is provided, applies the conversion on each + element before changing its type to match the corresponding element type in `expectedTy`. + """ + assert cc.StructType.isinstance(struct.type) + if not expectedTy: + expectedTy = struct.type + assert cc.StructType.isinstance(expectedTy) + eleTys = cc.StructType.getTypes(struct.type) + expectedEleTys = cc.StructType.getTypes(expectedTy) + assert len(eleTys) == len(expectedEleTys) + + returnVal = cc.UndefOp(expectedTy) + for idx, eleTy in enumerate(eleTys): + element = cc.ExtractValueOp( + eleTy, struct, [], + DenseI32ArrayAttr.get([idx], context=self.ctx)).result + element = conversion(idx, element) if conversion else element + element = self.changeOperandToType(expectedEleTys[idx], + element, + allowDemotion=allowDemotion) + returnVal = cc.InsertValueOp( + expectedTy, returnVal, element, + DenseI64ArrayAttr.get([idx], context=self.ctx)).result + return returnVal + + # Create a new vector with source elements converted to the target element type if needed. + def __copyVectorAndConvertElements(self, + source, + targetEleType=None, + allowDemotion=False, + alwaysCopy=False, + conversion=None): + ''' + Creates a new vector with the requested element type. + Returns the original vector if the requested element type already matches + the current element type unless `alwaysCopy` is set to True. + If a conversion is provided, applies the conversion to each element before + changing its type to match the `targetEleType`. + If `alwaysCopy` is set to True, return a shallow copy of the vector by + default (conversion can be used to create a deep copy). + ''' + + assert cc.StdvecType.isinstance(source.type) + sourceEleType = cc.StdvecType.getElementType(source.type) + if not targetEleType: + targetEleType = sourceEleType + if not alwaysCopy and sourceEleType == targetEleType: + return source isSourceBool = sourceEleType == self.getIntegerType(1) if isSourceBool: sourceEleType = self.getIntegerType(8) - - sourceArrType = cc.ArrayType.get(sourceEleType) - sourceElePtrTy = cc.PointerType.get(sourceEleType) - sourceArrElePtrTy = cc.PointerType.get(sourceArrType) - sourceValue = self.ifPointerThenLoad(sourcePtr) - sourceDataPtr = cc.StdvecDataOp(sourceArrElePtrTy, sourceValue).result - sourceSize = cc.StdvecSizeOp(self.getIntegerType(), sourceValue).result - isTargetBool = targetEleType == self.getIntegerType(1) - # Vector type reflects the true type, including `i1` - targetVecTy = cc.StdvecType.get(targetEleType) - if isTargetBool: targetEleType = self.getIntegerType(8) - targetElePtrType = cc.PointerType.get(targetEleType) - targetTy = cc.ArrayType.get(targetEleType) - targetArrElePtrTy = cc.PointerType.get(targetTy) - targetPtr = cc.AllocaOp(targetArrElePtrTy, + sourceArrPtrTy = cc.PointerType.get(cc.ArrayType.get(sourceEleType)) + sourceDataPtr = cc.StdvecDataOp(sourceArrPtrTy, source).result + sourceSize = cc.StdvecSizeOp(self.getIntegerType(), source).result + targetPtr = cc.AllocaOp(cc.PointerType.get( + cc.ArrayType.get(targetEleType)), TypeAttr.get(targetEleType), seqSize=sourceSize).result rawIndex = DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx) def bodyBuilder(iterVar): - eleAddr = cc.ComputePtrOp(sourceElePtrTy, sourceDataPtr, [iterVar], - rawIndex).result + eleAddr = cc.ComputePtrOp(cc.PointerType.get(sourceEleType), + sourceDataPtr, [iterVar], rawIndex).result loadedEle = cc.LoadOp(eleAddr).result - castedEle = self.changeOperandToType(targetEleType, - loadedEle, - allowDemotion=allowDemotion) - targetEleAddr = cc.ComputePtrOp(targetElePtrType, targetPtr, - [iterVar], rawIndex).result - cc.StoreOp(castedEle, targetEleAddr) + convertedEle = conversion(iterVar, + loadedEle) if conversion else loadedEle + convertedEle = self.changeOperandToType(targetEleType, + convertedEle, + allowDemotion=allowDemotion) + targetEleAddr = cc.ComputePtrOp(cc.PointerType.get(targetEleType), + targetPtr, [iterVar], + rawIndex).result + cc.StoreOp(convertedEle, targetEleAddr) + + self.createInvariantForLoop(bodyBuilder, sourceSize) + + # We still use `i1` as the vector element type for `cc.StdvecInitOp`. + vecTy = cc.StdvecType.get( + targetEleType) if not isTargetBool else cc.StdvecType.get( + self.getIntegerType(1)) + return cc.StdvecInitOp(vecTy, targetPtr, length=sourceSize).result + + def __copyAndValidateContainer(self, value, pyVal, deepCopy, dataType=None): + """ + Helper function to implement deep and shallow copies for structs and vectors. + Arguments: + `value`: The MLIR value to copy + `pyVal`: The Python AST node to use for validation of the container entries. + `deepCopy`: Whether to perform a deep or shallow copy. + `dataType`: Must be None unless the value to copy is a vector. + If the value is a vector, then the element type of the new vector. + """ + # NOTE: Creating a copy means we are creating a new container. + # As such, all elements in the container need to pass the validation + # in `__validate_container_entry`. + if deepCopy: + + def conversion(idx, structItem): + if cc.StdvecType.isinstance(structItem.type): + structItem = self.__copyVectorAndConvertElements( + structItem, alwaysCopy=True, conversion=conversion) + elif (cc.StructType.isinstance(structItem.type) and + self.containsList(structItem.type)): + structItem = self.__copyStructAndConvertElements( + structItem, conversion=conversion) + self.__validate_container_entry(structItem, pyVal) + return structItem + else: + + def conversion(idx, structItem): + self.__validate_container_entry(structItem, pyVal) + return structItem + + if cc.StdvecType.isinstance(value.type): + listVal = self.__copyVectorAndConvertElements(value, + dataType, + alwaysCopy=True, + conversion=conversion) + return listVal + + if cc.StructType.isinstance(value.type): + if dataType: + self.emitFatalError("unsupported data type argument", + self.currentNode) + struct = self.__copyStructAndConvertElements(value, + conversion=conversion) + return struct + + self.emitFatalError( + f'copy is not supported on value of type {value.type}', + self.currentNode) + + def __migrateLists(self, value, migrate): + """ + Replaces all lists in the given value by the list returned + by the `migrate` function, including inner lists. Does an + in-place replacement for list elements. + """ + if cc.StdvecType.isinstance(value.type): + eleTy = cc.StdvecType.getElementType(value.type) + if self.containsList(eleTy): + size = cc.StdvecSizeOp(self.getIntegerType(), value).result + ptrTy = cc.PointerType.get(cc.ArrayType.get(eleTy)) + iterable = cc.StdvecDataOp(ptrTy, value).result + + def bodyBuilder(iterVar): + eleAddr = cc.ComputePtrOp( + cc.PointerType.get(eleTy), iterable, [iterVar], + DenseI32ArrayAttr.get([kDynamicPtrIndex], + context=self.ctx)) + loadedEle = cc.LoadOp(eleAddr).result + element = self.__migrateLists(loadedEle, migrate) + cc.StoreOp(element, eleAddr) - self.createInvariantForLoop(sourceSize, bodyBuilder) - return cc.StdvecInitOp(targetVecTy, targetPtr, length=sourceSize).result + self.createInvariantForLoop(bodyBuilder, size) + return migrate(value) + if (cc.StructType.isinstance(value.type) and + self.containsList(value.type)): + return self.__copyStructAndConvertElements( + value, conversion=lambda _, v: self.__migrateLists(v, migrate)) + assert not self.containsList(value.type) + return value def __insertDbgStmt(self, value, dbgStmt): """ Insert a debug print out statement if the programmer requested. Handles statements like `cudaq.dbg.ast.print_i64(i)`. """ - value = self.ifPointerThenLoad(value) printFunc = None printStr = '[cudaq-ast-dbg] ' argsTy = [cc.PointerType.get(self.getIntegerType(8))] @@ -793,25 +1049,6 @@ def __insertDbgStmt(self, value, dbgStmt): func.CallOp(printFunc, [strLit, value]) return - def __get_vector_size(self, vector): - """ - Get the size of a vector or array type. - - Args: - vector: MLIR Value of vector/array type - - Returns: - MLIR Value containing the size as an integer - """ - if cc.StdvecType.isinstance(vector.type): - return cc.StdvecSizeOp(self.getIntegerType(), vector).result - elif cc.ArrayType.isinstance(vector.type): - return self.getConstantInt( - cc.ArrayType.getSize(cc.PointerType.getElementType( - vector.type))) - self.emitFatalError("cannot get the size for a value of type {}".format( - vector.type)) - def __load_vector_element(self, vector, index): """ Load an element from a vector or array at the given index. @@ -914,83 +1151,84 @@ def mlirTypeFromAnnotation(self, annotation): if msg is not None: self.emitFatalError(msg, annotation) - def createInvariantForLoop(self, - endVal, - bodyBuilder, - startVal=None, - stepVal=None, - isDecrementing=False, - elseStmts=None): - """ - Create an invariant loop using the CC dialect. - """ - startVal = self.getConstantInt(0) if startVal == None else startVal - stepVal = self.getConstantInt(1) if stepVal == None else stepVal - - iTy = self.getIntegerType() - inputs = [startVal] - resultTys = [iTy] + def createForLoop(self, + argTypes, + bodyBuilder, + inputs, + evalCond, + evalStep, + orElseBuilder=None): - loop = cc.LoopOp(resultTys, inputs, BoolAttr.get(False)) + # post-conditional would be a do-while loop + isPostConditional = BoolAttr.get(False) + loop = cc.LoopOp(argTypes, inputs, isPostConditional) - whileBlock = Block.create_at_start(loop.whileRegion, [iTy]) + whileBlock = Block.create_at_start(loop.whileRegion, argTypes) with InsertionPoint(whileBlock): - condPred = IntegerAttr.get( - iTy, 2) if not isDecrementing else IntegerAttr.get(iTy, 4) - cc.ConditionOp( - arith.CmpIOp(condPred, whileBlock.arguments[0], endVal).result, - whileBlock.arguments) + condVal = evalCond(whileBlock.arguments) + cc.ConditionOp(condVal, whileBlock.arguments) - bodyBlock = Block.create_at_start(loop.bodyRegion, [iTy]) + bodyBlock = Block.create_at_start(loop.bodyRegion, argTypes) with InsertionPoint(bodyBlock): self.symbolTable.pushScope() self.pushForBodyStack(bodyBlock.arguments) - bodyBuilder(bodyBlock.arguments[0]) + bodyBuilder(bodyBlock.arguments) if not self.hasTerminator(bodyBlock): cc.ContinueOp(bodyBlock.arguments) self.popForBodyStack() self.symbolTable.popScope() - stepBlock = Block.create_at_start(loop.stepRegion, [iTy]) + stepBlock = Block.create_at_start(loop.stepRegion, argTypes) with InsertionPoint(stepBlock): - incr = arith.AddIOp(stepBlock.arguments[0], stepVal).result - cc.ContinueOp([incr]) + stepVals = evalStep(stepBlock.arguments) + cc.ContinueOp(stepVals) - if elseStmts: - elseBlock = Block.create_at_start(loop.elseRegion, [iTy]) + if orElseBuilder: + elseBlock = Block.create_at_start(loop.elseRegion, argTypes) with InsertionPoint(elseBlock): self.symbolTable.pushScope() - for stmt in elseStmts: - self.visit(stmt) + orElseBuilder(elseBlock.arguments) if not self.hasTerminator(elseBlock): cc.ContinueOp(elseBlock.arguments) self.symbolTable.popScope() - loop.attributes.__setitem__('invariant', UnitAttr.get()) - return + return loop - def __applyQuantumOperation(self, opName, parameters, targets): - opCtor = getattr(quake, '{}Op'.format(opName.title())) - for quantumValue in targets: - if quake.VeqType.isinstance(quantumValue.type): + def createMonotonicForLoop(self, + bodyBuilder, + startVal, + stepVal, + endVal, + isDecrementing=False, + orElseBuilder=None): - def bodyBuilder(iterVal): - q = quake.ExtractRefOp(self.getRefType(), - quantumValue, - -1, - index=iterVal).result - opCtor([], parameters, [], [q]) + iTy = self.getIntegerType() + assert startVal.type == iTy + assert stepVal.type == iTy + assert endVal.type == iTy - veqSize = quake.VeqSizeOp(self.getIntegerType(), - quantumValue).result - self.createInvariantForLoop(veqSize, bodyBuilder) - elif quake.RefType.isinstance(quantumValue.type): - opCtor([], parameters, [], [quantumValue]) - else: - self.emitFatalError( - f'quantum operation {opName} on incorrect quantum type {quantumValue.type}.' - ) - return + condPred = IntegerAttr.get( + iTy, 4) if isDecrementing else IntegerAttr.get(iTy, 2) + return self.createForLoop( + [iTy], lambda args: bodyBuilder(args[0]), [startVal], + lambda args: arith.CmpIOp(condPred, args[0], endVal).result, + lambda args: [arith.AddIOp(args[0], stepVal).result], + None if orElseBuilder is None else + (lambda args: orElseBuilder(args[0]))) + + def createInvariantForLoop(self, bodyBuilder, endVal): + """ + Create an invariant loop using the CC dialect. + """ + + startVal = self.getConstantInt(0) + stepVal = self.getConstantInt(1) + + loop = self.createMonotonicForLoop(bodyBuilder, + startVal=startVal, + stepVal=stepVal, + endVal=endVal) + loop.attributes.__setitem__('invariant', UnitAttr.get()) def __deconstructAssignment(self, target, value, process=None): if process is not None: @@ -1002,44 +1240,45 @@ def __deconstructAssignment(self, target, value, process=None): if (isinstance(value, ast.Tuple) or isinstance(value, ast.List)): nrArgs = len(value.elts) getItem = lambda idx: value.elts[idx] + elif (isinstance(value, tuple) or isinstance(value, list)): + nrArgs = len(value) + getItem = lambda idx: value[idx] + elif cc.StructType.isinstance(value.type): + argTypes = cc.StructType.getTypes(value.type) + nrArgs = len(argTypes) + getItem = lambda idx: cc.ExtractValueOp( + argTypes[idx], value, [], + DenseI32ArrayAttr.get([idx], context=self.ctx)).result + elif quake.StruqType.isinstance(value.type): + argTypes = quake.StruqType.getTypes(value.type) + nrArgs = len(argTypes) + getItem = lambda idx: quake.GetMemberOp( + argTypes[idx], value, + IntegerAttr.get(self.getIntegerType(32), idx)).result + elif cc.StdvecType.isinstance(value.type): + # We will get a runtime error for out of bounds access + eleTy = cc.StdvecType.getElementType(value.type) + elePtrTy = cc.PointerType.get(eleTy) + arrTy = cc.ArrayType.get(eleTy) + ptrArrTy = cc.PointerType.get(arrTy) + vecPtr = cc.StdvecDataOp(ptrArrTy, value).result + attr = DenseI32ArrayAttr.get([kDynamicPtrIndex], + context=self.ctx) + nrArgs = len(target.elts) + getItem = lambda idx: cc.LoadOp( + cc.ComputePtrOp(elePtrTy, vecPtr, [ + self.getConstantInt(idx) + ], attr).result).result + elif quake.VeqType.isinstance(value.type): + # We will get a runtime error for out of bounds access + nrArgs = len(target.elts) + getItem = lambda idx: quake.ExtractRefOp( + quake.RefType.get(), + value, + -1, + index=self.getConstantInt(idx)).result else: - value = self.ifPointerThenLoad(value) - if cc.StructType.isinstance(value.type): - argTypes = cc.StructType.getTypes(value.type) - nrArgs = len(argTypes) - getItem = lambda idx: cc.ExtractValueOp( - argTypes[idx], value, [], - DenseI32ArrayAttr.get([idx], context=self.ctx)).result - elif quake.StruqType.isinstance(value.type): - argTypes = quake.StruqType.getTypes(value.type) - nrArgs = len(argTypes) - getItem = lambda idx: quake.GetMemberOp( - argTypes[idx], value, - IntegerAttr.get(self.getIntegerType(32), idx)).result - elif cc.StdvecType.isinstance(value.type): - # We will get a runtime error for out of bounds access - eleTy = cc.StdvecType.getElementType(value.type) - elePtrTy = cc.PointerType.get(eleTy) - arrTy = cc.ArrayType.get(eleTy) - ptrArrTy = cc.PointerType.get(arrTy) - vecPtr = cc.StdvecDataOp(ptrArrTy, value).result - attr = DenseI32ArrayAttr.get([kDynamicPtrIndex], - context=self.ctx) - nrArgs = len(target.elts) - getItem = lambda idx: cc.LoadOp( - cc.ComputePtrOp(elePtrTy, vecPtr, [ - self.getConstantInt(idx) - ], attr).result).result - elif quake.VeqType.isinstance(value.type): - # We will get a runtime error for out of bounds access - nrArgs = len(target.elts) - getItem = lambda idx: quake.ExtractRefOp( - quake.RefType.get(), - value, - -1, - index=self.getConstantInt(idx)).result - else: - nrArgs = 0 + nrArgs = 0 if nrArgs != len(target.elts): self.emitFatalError("shape mismatch in tuple deconstruction", self.currentNode) @@ -1051,7 +1290,7 @@ def __deconstructAssignment(self, target, value, process=None): self.emitFatalError("unsupported target in tuple deconstruction", self.currentNode) - def __processRangeLoopIterationBounds(self, argumentNodes): + def __processRangeLoopIterationBounds(self, pyVals): """ Analyze `range(...)` bounds and return the start, end, and step values, as well as whether or not this a decrementing range. @@ -1059,135 +1298,225 @@ def __processRangeLoopIterationBounds(self, argumentNodes): iTy = self.getIntegerType(64) zero = arith.ConstantOp(iTy, IntegerAttr.get(iTy, 0)) one = arith.ConstantOp(iTy, IntegerAttr.get(iTy, 1)) + values = self.__groupValues(pyVals, [(1, 3)]) + isDecrementing = False - if len(argumentNodes) == 3: + if len(pyVals) == 3: # Find the step val and we need to know if its decrementing can be # incrementing or decrementing - stepVal = self.popValue() - if isinstance(argumentNodes[2], ast.UnaryOp): - self.debug_msg(lambda: f'[(Inline) Visit UnaryOp]', - argumentNodes[2]) - if isinstance(argumentNodes[2].op, ast.USub): - if isinstance(argumentNodes[2].operand, ast.Constant): - self.debug_msg(lambda: f'[(Inline) Visit Constant]', - argumentNodes[2].operand) - if argumentNodes[2].operand.value > 0: - isDecrementing = True - else: - self.emitFatalError( - 'CUDA-Q requires step value on range() to be a constant.' - ) + stepVal = values[2] + if isinstance(pyVals[2], ast.Constant): + pyStepVal = pyVals[2].value + elif (isinstance(pyVals[2], ast.UnaryOp) and + isinstance(pyVals[2].op, ast.USub) and + isinstance(pyVals[2].operand, ast.Constant)): + pyStepVal = -pyVals[2].operand.value + else: + self.emitFatalError('range step value must be a constant', + self.currentNode) + if pyStepVal == 0: + self.emitFatalError("range step value must be non-zero", + self.currentNode) + isDecrementing = pyStepVal < 0 # exclusive end - endVal = self.popValue() + endVal = values[1] # inclusive start - startVal = self.popValue() + startVal = values[0] - elif len(argumentNodes) == 2: + elif len(pyVals) == 2: stepVal = one - endVal = self.popValue() - startVal = self.popValue() + endVal = values[1] + startVal = values[0] else: stepVal = one - endVal = self.popValue() + endVal = values[0] startVal = zero - startVal = self.ifPointerThenLoad(startVal) - endVal = self.ifPointerThenLoad(endVal) - stepVal = self.ifPointerThenLoad(stepVal) - for idx, v in enumerate([startVal, endVal, stepVal]): if not IntegerType.isinstance(v.type): # matching Python behavior to error on non-integer values - self.emitFatalError( - "non-integer value in range expression", - argumentNodes[idx if len(argumentNodes) > 1 else 0]) + self.emitFatalError("non-integer value in range expression", + pyVals[idx if len(pyVals) > 1 else 0]) return startVal, endVal, stepVal, isDecrementing - def __visitStructAttribute(self, node, structValue): - """ - Handle struct member extraction from either a pointer to struct or - direct struct value. Uses the most efficient approach for each case. - """ - if cc.PointerType.isinstance(structValue.type): - # Handle pointer to struct - use ComputePtrOp - eleType = cc.PointerType.getElementType(structValue.type) - if cc.StructType.isinstance(eleType): - structIdx, memberTy = self.getStructMemberIdx( - node.attr, eleType) - eleAddr = cc.ComputePtrOp(cc.PointerType.get(memberTy), - structValue, [], - DenseI32ArrayAttr.get([structIdx - ])).result - - if self.attributePushPointerValue: - self.pushValue(eleAddr) - return - - # Load the value - eleAddr = cc.LoadOp(eleAddr).result - self.pushValue(eleAddr) - return - elif cc.StructType.isinstance(structValue.type): - # Handle direct struct value - use ExtractValueOp (more efficient) - structIdx, memberTy = self.getStructMemberIdx( - node.attr, structValue.type) - extractedValue = cc.ExtractValueOp( - memberTy, structValue, [], - DenseI32ArrayAttr.get([structIdx])).result - - if self.attributePushPointerValue: - # If we need a pointer, we have to create a temporary slot - tempSlot = cc.AllocaOp(cc.PointerType.get(memberTy), - TypeAttr.get(memberTy)).result - cc.StoreOp(extractedValue, tempSlot) - self.pushValue(tempSlot) - return - - self.pushValue(extractedValue) - return + def __groupValues(self, pyvals, groups: list[int | tuple[int, int]]): + ''' + Helper function that visits the given AST nodes (`pyvals`), + and groups them according to the specified list. + The list contains integers or tuples of two integers. + Integer values have to be positive or -1, where -1 + indicates that any number of values is acceptable. + Tuples of two integers (min, max) indicate that any number + of values in [min, max] is acceptable. + The list may only contain at most one negative integer or + tuple (enforced via assert only). + + Emits a fatal error if any of the given `pyvals` did not + generate a value. Emits a fatal error if there are too + many or too few values to satisfy the requested grouping. + + Returns a tuple of value groups. Each value group is + either a single value (if the corresponding entry in `groups` + equals 1), or a list of values. + ''' + + def group_values(numExpected, values, reverse): + groupedVals = [] + current_idx = 0 + for nArgs in numExpected: + if (isinstance(nArgs, int) and nArgs == 1 and + current_idx < len(values)): + groupedVals.append(values[current_idx]) + current_idx += 1 + continue + if isinstance(nArgs, tuple): + minNumArgs, maxNumArgs = nArgs + if minNumArgs == maxNumArgs: + nArgs = minNumArgs + if not isinstance(nArgs, int) or nArgs < 0: + break + if current_idx + nArgs > len(values): + self.emitFatalError("missing value", self.currentNode) + groupedVals.append(values[current_idx:current_idx + nArgs]) + if reverse: + groupedVals[-1].reverse() + current_idx += nArgs + remaining = values[current_idx:] + numExpected = numExpected[len(groupedVals):] + if reverse: + remaining.reverse() + groupedVals.reverse() + numExpected.reverse() + return groupedVals, numExpected, remaining + + [self.visit(arg) for arg in pyvals] + values = self.popAllValues(len(pyvals)) + groups.reverse() + backVals, groups, values = group_values(groups, values, reverse=True) + frontVals, groups, values = group_values(groups, values, reverse=False) + if not groups: + if values: + self.emitFatalError("too many values", self.currentNode) + groupedVals = *frontVals, *backVals else: + assert len(groups) == 1 # ambiguous otherwise + if isinstance(groups[0], tuple): + minNumArgs, maxNumArgs = groups[0] + assert 0 <= minNumArgs and (minNumArgs <= maxNumArgs or + maxNumArgs < 0) + if len(values) < minNumArgs: + self.emitFatalError("missing value", self.currentNode) + if len(values) > maxNumArgs and maxNumArgs > 0: + self.emitFatalError("too many values", self.currentNode) + groupedVals = *frontVals, values, *backVals + return groupedVals[0] if len(groupedVals) == 1 else groupedVals + + def __get_root_value(self, pyVal): + ''' + Strips any attribute and subscript expressions from the node + to get the root node that the expression accesses. + Returns the symbol table entry for the root node, if such an + entry exists, and return None otherwise. + ''' + pyValRoot = pyVal + while (isinstance(pyValRoot, ast.Subscript) or + isinstance(pyValRoot, ast.Attribute)): + pyValRoot = pyValRoot.value + if (isinstance(pyValRoot, ast.Name) and + pyValRoot.id in self.symbolTable): + return self.symbolTable[pyValRoot.id] + return None + + def __validate_container_entry(self, mlirVal, pyVal): + ''' + Helper function that should be invoked for any elements that are stored in + tuple, dataclass, or list. Note that the `pyVal` argument is only used to + determine the root of `mlirVal` and as such could be either the Python + AST node matching the container item (`mlirVal`) or the AST node for the + container itself. + ''' + + rootVal = self.__get_root_value(pyVal) + assert rootVal or not self.isFunctionArgument(mlirVal) + + if cc.PointerType.isinstance(mlirVal.type): + # We do not allow to create container that contain pointers. + valTy = cc.PointerType.getElementType(mlirVal.type) + assert cc.StateType.isinstance(valTy) + if cc.StateType.isinstance(valTy): + self.emitFatalError( + "cannot use `cudaq.State` as element in lists, tuples, or dataclasses", + self.currentNode) self.emitFatalError( - f"Cannot access attribute '{node.attr}' on type {structValue.type}" - ) + "lists, tuples, and dataclasses must not contain modifiable values", + self.currentNode) - def needsStackSlot(self, type): - """ - Return true if this is a type that has been "passed by value" and - needs a stack slot created (i.e. a `cc.alloca`) for use throughout the - function. - """ - # FIXME add more as we need them - return ComplexType.isinstance(type) or F64Type.isinstance( - type) or F32Type.isinstance(type) or IntegerType.isinstance( - type) or cc.StructType.isinstance(type) + if cc.StructType.isinstance(mlirVal.type): + structName = cc.StructType.getName(mlirVal.type) + # We need to give a proper error if we try to assign + # a mutable dataclass to an item in another container. + # Allowing this would lead to incorrect behavior (i.e. + # inconsistent with Python) unless we change the + # representation of structs to be like `StdvecType` + # where we have a container that is passed by value + # wrapping the actual pointer, thus ensuring that the + # reference behavior actually works across function + # boundaries. + if structName != 'tuple' and rootVal: + msg = "only dataclass literals may be used as items in other container values" + self.emitFatalError( + f"{msg} - use `.copy(deep)` to create a new {structName}", + self.currentNode) + + if (self.knownResultType and self.containsList(self.knownResultType) and + self.containsList(mlirVal.type)): + # For lists that were created inside a kernel, we have to + # copy the stack allocated array to the heap when we return such a list. + # In the case where the list was created by the caller, this copy leads + # to incorrect behavior (i.e. not matching Python behavior). We hence + # want to make sure that we can know when a host allocated list is returned. + # If we allow to assign lists passed as function arguments to inner items + # of other lists and dataclasses, we loose the information that this list + # was allocated by the parent. We hence forbid this. All of this applies + # regardless of how the list was passed (e.g. the list might be an inner + # item in a tuple or dataclass that was passed) or how it is assigned + # (e.g. the assigned value might be a tuple or dataclass that contains a list). + if rootVal and self.isFunctionArgument(rootVal): + msg = "lists passed as or contained in function arguments cannot be inner items in other container values when a list is returned" + self.emitFatalError( + f"{msg} - use `.copy(deep)` to create a new list", + self.currentNode) def visit(self, node): self.debug_msg(lambda: f'[Visit {type(node).__name__}]', node) self.indent_level += 1 parentNode = self.currentNode self.currentNode = node + numVals = 0 if isinstance( + node, ast.Module) else self.valueStack.currentNumValues + self.valueStack.pushFrame() super().visit(node) + self.valueStack.popFrame() + if isinstance(node, ast.Module): + if not self.valueStack.isEmpty: + self.emitFatalError( + "processing error - unprocessed frame(s) in value stack", + node) + elif self.valueStack.currentNumValues - numVals > 1: + # Do **NOT** change this to be more permissive and allow + # multiple values to be pushed without pushing proper + # frames for sub-nodes. If visiting a single node + # potentially produces more than one value, the bridge + # quickly will be a mess because we will easily end up + # with values in the wrong places. + self.emitFatalError( + "must not generate more one value at a time in each frame", + node) self.currentNode = parentNode self.indent_level -= 1 - # FIXME: using generic_visit the way we do seems incredibly dangerous; - # we use this and make assumptions about what values are on the value stack - # without any validation that we got the right values. - # The whole value stack needs to be revised; we need to properly push and pop - # not just individual values but groups of values to ensure that the right - # pieces get the right arguments (and give a proper error otherwise). - def generic_visit(self, node): - self.debug_msg(lambda: f'[Generic Visit]', node) - for field, value in reversed(list(ast.iter_fields(node))): - if isinstance(value, list): - for item in value: - if isinstance(item, ast.AST): - self.visit(item) - elif isinstance(value, ast.AST): - self.visit(value) - def visit_FunctionDef(self, node): """ Create an MLIR `func.FuncOp` for the given FunctionDef AST node. For the @@ -1201,6 +1530,7 @@ def visit_FunctionDef(self, node): We keep track of the top-level function name as well as its internal MLIR name, prefixed with the __nvqpp__mlirgen__ prefix. """ + if self.buildingEntryPoint: # This is an inner function def, we will # treat it as a cc.callable (cc.create_lambda) @@ -1241,23 +1571,18 @@ def visit_FunctionDef(self, node): (node.returns.value is None)): self.knownResultType = self.mlirTypeFromAnnotation(node.returns) - # Get the argument names - argNames = [arg.arg for arg in node.args.args] - # Add uniqueness. In MLIR, we require unique symbols (bijective # function between symbols and artifacts) even if Python allows # hiding symbols and replacing symbols (dynamic injective function # between scoped symbols and artifacts). - node_name = node.name + ".." + hex(self.uniqueId) - self.name = node_name + self.name = node.name + ".." + hex(self.uniqueId) self.capturedDataStorage.name = self.name - # the full function name in MLIR is `__nvqpp__mlirgen__` + the - # function name + # the full function name in MLIR is `__nvqpp__mlirgen__` + the function name if self.disableNvqppPrefix: - fullName = node_name + fullName = self.name else: - fullName = nvqppPrefix + node_name + fullName = nvqppPrefix + self.name # Create the FuncOp f = func.FuncOp(fullName, (self.argTypes, [] if self.knownResultType @@ -1276,45 +1601,44 @@ def visit_FunctionDef(self, node): # Set the insertion point to the start of the entry block with InsertionPoint(self.entry): - self.buildingEntryPoint = True self.symbolTable.pushScope() - # Add the block arguments to the symbol table, create a stack - # slot for value arguments - blockArgs = self.entry.arguments - for i, b in enumerate(blockArgs): - if self.needsStackSlot(b.type): - stackSlot = cc.AllocaOp(cc.PointerType.get(b.type), - TypeAttr.get(b.type)).result - cc.StoreOp(b, stackSlot) - self.symbolTable[argNames[i]] = stackSlot + # Process function arguments like any other assignments. + if node.args.args: + assignNode = ast.Assign() + if len(node.args.args) == 1: + assignNode.targets = [ast.Name(node.args.args[0].arg)] + assignNode.value = self.entry.arguments[0] else: - self.symbolTable[argNames[i]] = b - - # Visit the function - startIdx = 0 - # Search for the potential documentation string, and - # if found, start the body visitation after it. - if len(node.body) and isinstance(node.body[0], ast.Expr): - self.debug_msg(lambda: f'[(Inline) Visit Expr]', - node.body[0]) - expr = node.body[0] - if hasattr(expr, 'value') and isinstance( - expr.value, ast.Constant): - self.debug_msg(lambda: f'[(Inline) Visit Constant]', - expr.value) - constant = expr.value - if isinstance(constant.value, str): - startIdx = 1 - [self.visit(n) for n in node.body[startIdx:]] + assignNode.targets = [ + ast.Tuple( + [ast.Name(arg.arg) for arg in node.args.args]) + ] + assignNode.value = [ + self.entry.arguments[idx] + for idx in range(len(self.entry.arguments.types)) + ] + assignNode.lineno = node.lineno + self.visit_Assign(assignNode) + + # Intentionally set after we process the argument assignment, + # since we currently treat value vs reference semantics slightly + # differently when we have arguments vs when we have local values. + # To not make this distinction, we would need to add support + # for having proper reference arguments, which we don't want to. + # Barring that, we at least try to be nice and give errors on + # assignments that may lead to unexpected behavior (i.e. behavior + # not following expected Python behavior). + self.buildingEntryPoint = True + [self.visit(n) for n in node.body] # Add the return operation if not self.hasTerminator(self.entry): # If the function has a known (non-None) return type, emit # an `undef` of that type and return it; else return void if self.knownResultType is not None: undef = cc.UndefOp(self.knownResultType).result - ret = func.ReturnOp([undef]) + func.ReturnOp([undef]) else: - ret = func.ReturnOp([]) + func.ReturnOp([]) self.buildingEntryPoint = False self.symbolTable.popScope() @@ -1332,8 +1656,6 @@ def visit_FunctionDef(self, node): globalKernelRegistry[node.name] = f self.symbolTable.clear() - self.valueStack.clear() - self.knownResultType = parentResultType def visit_Expr(self, node): @@ -1351,6 +1673,13 @@ def visit_Expr(self, node): return self.visit(node.value) + if self.valueStack.currentNumValues > 0: + # An `ast.Expr` object is created when an expression + # is used as a statement. This expression may produce + # a value, which is ignored (not assigned) in the + # Python code. We hence need to pop that value to + # match that behavior and ignore it. + self.popValue() def visit_Lambda(self, node): """ @@ -1394,25 +1723,47 @@ def functor(qubits): def visit_Assign(self, node): """ Map an assign operation in the AST to an equivalent variable value - assignment in the MLIR. This method will first see if this is a tuple - assignment, enabling one to assign multiple values in a single - statement. + assignment in the MLIR. This method handles assignments, item updates, + as well as deconstruction. For all assignments, the variable name will be used as a key for the - symbol table, mapping to the corresponding MLIR Value. For values of - `ref` / `veq`, `i1`, or `cc.callable`, the values will be stored - directly in the table. For all other values, the variable will be - allocated with a `cc.alloca` op, and the loaded value will be stored in + symbol table, mapping to the corresponding MLIR Value. Quantum values, + measurements results, `cc.callable`, and `cc.stdvec` will be stored as + values in the symbol table. For all other values, the variable will be + allocated with a `cc.alloca` op, and the pointer will be stored in the symbol table. """ - def check_not_captured(name): - if name in self.liftedArgs: - self.emitFatalError( - f"CUDA-Q does not allow assignment to nonlocal variable " - f"{{name}}.", node) + # FIXME: Measurement results are stored as values + # to preserve their origin from discriminate. + # This should be revised when we introduce the proper + # type distinction. + def storedAsValue(val): + varTy = val.type + if cc.PointerType.isinstance(varTy): + varTy = cc.PointerType.getElementType(varTy) + # If `buildingEntryPoint` is not set we are processing function + # arguments. Function arguments are always passed by value, + # except states. We can treat non-container function arguments + # like any local variable and create a stack slot for them. + # For container types, on the the other hand, we need to preserve + # them as values in the symbol table to make sure we can detect + # any access to reference types that are function arguments, or + # function argument items. + containerFuncArg = (not self.buildingEntryPoint and + (cc.StructType.isinstance(varTy) or + cc.StdvecType.isinstance(varTy))) + storeAsVal = (containerFuncArg or self.isQuantumType(varTy) or + cc.CallableType.isinstance(varTy) or + cc.StdvecType.isinstance(varTy) or + self.isMeasureResultType(varTy, val)) + # Nothing should ever produce a pointer + # to a type we store as value in the symbol table. + assert (not storeAsVal or not cc.PointerType.isinstance(val.type)) + return storeAsVal def process_assignment(target, value): + if isinstance(target, ast.Tuple): if (isinstance(value, ast.Tuple) or @@ -1420,122 +1771,245 @@ def process_assignment(target, value): return target, value if isinstance(value, ast.AST): + # Measurements need to push their values to the stack, + # so we set a so we set a non-None variable name here. + self.currentAssignVariableName = '' + # NOTE: The way the assignment logic is processed, + # including that we load this value for the purpose + # of deconstruction, does not preserve any inner + # references. There are a bunch of issues that + # prevent us from properly dealing with any + # reference types stored as items in lists and + # dataclasses. We hence currently prevent the + # creation of such lists and dataclasses, and would + # need to change the representation for dataclasses + # to allow that. self.visit(value) - if len(self.valueStack) == 0: - self.emitFatalError("invalid assignment detected.", - node) - return target, self.popValue() + value = self.popValue() + self.currentAssignVariableName = None + return target, value return target, value - # Handle simple `var = expr` - elif isinstance(target, ast.Name): - check_not_captured(target.id) + # Make sure we process arbitrary combinations + # of subscript and attributes + target_root = target + while (isinstance(target_root, ast.Subscript) or + isinstance(target_root, ast.Attribute)): + target_root = target_root.value + if not isinstance(target_root, ast.Name): + self.emitFatalError("invalid target for assignment", node) + target_root_defined_in_parent_scope = ( + target_root.id in self.symbolTable and + target_root.id not in self.symbolTable.symbolTable[-1]) + value_root = self.__get_root_value(value) + + def update_in_parent_scope(destination, value): + assert not cc.PointerType.isinstance(value.type) + if cc.StructType.isinstance( + value.type) and cc.StructType.getName( + value.type) != 'tuple': + # We can't properly deal with this case if the value we are assigning + # is not an `rvalue`. Consider the case were we have `v1` defined in + # the parent scope, `v2` in a child scope, and we are assigning v2 to + # v1 in the child scope. To do this assignment properly, we would need to + # make sure that the pointers for both v1 and v2 points to the same memory + # location such that any changes to v1 after the assignment are reflected + # in v2 and vice versa (v2 could be changed in the child while v1 is still + # alive). Since we merely store the raw pointer in the symbol table for + # dataclasses, we have no way of updating that pointer conditionally on + # the child scope being executed. + # To determine whether the value we assign is an `rvalue`, it is + # sufficient to check whether its root is a value in the symbol table + # (values returned from calls are never `lvalues`). + if value_root: + # Note that this check also makes sure that function arguments are + # not assigned to local variables, since function arguments are in + # the symbol table. + self.emitFatalError( + "only literals can be assigned to variables defined in parent scope - use `.copy(deep)` to create a new value that can be assigned", + node) + if cc.StdvecType.isinstance(destination.type): + # In this case, we are assigning a list to a variable in a parent scope. + assert isinstance(target, ast.Name) + # If the value we are assigning is an `rvalue` then we can do an in-place + # update of the data in the parent; the restrictions for container items + # in `__validate_container_entry` ensure that the value we are assigning + # does not contain any references to dataclass values, and any lists + # contained in the value behave like proper references since they contain + # a data pointer (i.e. in-place update only does a shallow copy). + # TODO: The only reason we cannot currently support this is because we + # have no way of updating the size of an existing vector... + self.emitFatalError( + "variable defined in parent scope cannot be modified", + node) + # Allowing to assign vectors to container items in the parent scope + # should be fine regardless of whether the assigned value is an `rvalue` + # or not; replacing the item in the container with the value leads to the + # correct behavior much like it does for the case where the target is defined + # in the same scope. + # NOTE: The assignment is subject to the usual restrictions for container + # items - these should be validated before calling update_in_parent_scope. + if not cc.StdvecType.isinstance( + value.type) and storedAsValue(destination): + # We can't properly deal with this, since there is no way to ensure that + # the target in the symbol table is updated conditionally on the child + # scope executing. + self.emitFatalError( + "variable defined in parent scope cannot be modified", + node) + assert cc.PointerType.isinstance(destination.type) + expectedTy = cc.PointerType.getElementType(destination.type) + value = self.changeOperandToType(expectedTy, + value, + allowDemotion=False) + cc.StoreOp(value, destination) + + # Handle assignment `var = expr` + if isinstance(target, ast.Name): + # This is so that we properly preserve the references + # to local variables. These variables can be of a reference + # type and other values in the symbol table may be assigned + # to the same reference. It is hence important to keep the + # reference as is, since otherwise changes to it would not + # be reflected in other values. + # NOTE: we don't need to worry about any references in + # values that are not `ast.Name` objects, since we don't + # allow containers to contain references. + value_is_name = False + if (isinstance(value, ast.Name) and + value.id in self.symbolTable): + value_is_name = True + value = self.symbolTable[value.id] if isinstance(value, ast.AST): - # FIXME: this feature is going away. - # Retain the variable name for potential children (like - # `mz(q, registerName=...)`) + # Retain the variable name for potential children (like `mz(q, registerName=...)`) self.currentAssignVariableName = target.id self.visit(value) - self.currentAssignVariableName = None - if len(self.valueStack) == 0: - self.emitFatalError("invalid assignment detected.", - node) value = self.popValue() + self.currentAssignVariableName = None + storeAsVal = storedAsValue(value) + + if value_root and self.isFunctionArgument(value_root): + # If we assign a function argument or argument item to + # a local variable, we need to be careful to not loose + # the information about contained lists that have been + # allocated by the caller, if the return value contains + # any lists. This is problematic for reasons commented + # in `__validate_container_entry`. + if (cc.StdvecType.isinstance(value.type) and + self.knownResultType and + self.containsList(self.knownResultType)): + # We loose this information if we assign an item of + # a function argument. + if not value_is_name: + self.emitFatalError( + "lists passed as or contained in function arguments cannot be assigned to to a local variable when a list is returned - use `.copy(deep)` to create a new value that can be assigned", + node) + # We also loose this information if we assign to + # a value in the parent scope. + elif target_root_defined_in_parent_scope: + self.emitFatalError( + "lists passed as or contained in function arguments cannot be assigned to variables in the parent scope when a list is returned - use `.copy(deep)` to create a new value that can be assigned", + node) + if cc.StructType.isinstance(value.type): + structName = cc.StructType.getName(value.type) + # For dataclasses, we have to do an additional check + # to ensure that their behavior (for cases that don't + # give an error) is consistent with Python; + # since we pass them by value across functions, we + # either have to force that an explicit copy is made + # when using them as call arguments, or we have to + # force that an explicit copy is made when a dataclass + # argument is assigned to a local variable (as long as + # it is not assigned, it will not be possible to make + # any modification to it since the argument itself is + # represented as an immutable value). The latter seems + # more comprehensive and also ensures that there is no + # unexpected behavior with regards to kernels not being + # able to modify dataclass values in host code. + # NOTE: It is sufficient to check the value itself (not + # its root) is a function argument, (only!) since inner + # items are never references to dataclasses (enforced + # in `__validate_container_entry`). + if value_is_name and structName != 'tuple': + self.emitFatalError( + f"cannot assign dataclass passed as function argument to a local variable - use `.copy(deep)` to create a new value that can be assigned", + node) + elif (self.knownResultType and + self.containsList(self.knownResultType) and + self.containsList(value.type)): + self.emitFatalError( + f"cannot assign tuple or dataclass passed as function argument to a local variable if it contains a list when a list is returned - use `.copy(deep)` to create a new value that can be assigned", + node) - if self.isQuantumType(value.type) or cc.CallableType.isinstance( - value.type): - return target, value - elif self.isMeasureResultType(value.type, value): - value = self.ifPointerThenLoad(value) - if target.id in self.symbolTable: - addr = self.ifNotPointerThenStore( - self.symbolTable[target.id]) - cc.StoreOp(value, addr) - return target, value - elif target.id in self.symbolTable: - value = self.ifPointerThenLoad(value) - cc.StoreOp(value, self.symbolTable[target.id]) + if target_root_defined_in_parent_scope: + if cc.PointerType.isinstance(value.type): + # This is fine since/as long as update_in_parent_scope + # validates that `lvalues` of reference types cannot be + # assigned. Note that tuples and states are value types. + value = cc.LoadOp(value).result + destination = self.symbolTable[target.id] + update_in_parent_scope(destination, value) return target, None - elif cc.PointerType.isinstance(value.type): - return target, value - elif cc.StructType.isinstance(value.type) and isinstance( - value.owner.opview, cc.InsertValueOp): - # If we have a new struct from `cc.undef` and `cc.insert_value`, we don't - # want to allocate new memory. + + # The target variable has either not been defined + # or is defined within the current scope; + # we can simply modify the symbol table entry. + if storeAsVal or cc.PointerType.isinstance(value.type): return target, value - else: - # We should allocate and store - alloca = cc.AllocaOp(cc.PointerType.get(value.type), - TypeAttr.get(value.type)).result - cc.StoreOp(value, alloca) - return target, alloca - - # Handle assignments like `listVar[IDX] = expr` - elif (isinstance(target, ast.Subscript) and - isinstance(target.value, ast.Name) and - target.value.id in self.symbolTable): - check_not_captured(target.value.id) - - # Visit_Subscript will try to load any pointer and return it - # but here we want the pointer, so flip that flag - self.subscriptPushPointerValue = True - # Visit the subscript node, get the pointer value - self.visit(target) - # Reset the push pointer value flag - self.subscriptPushPointerValue = False - ptrVal = self.popValue() - if not cc.PointerType.isinstance(ptrVal.type): - self.emitFatalError( - "Invalid CUDA-Q subscript assignment, variable must be a pointer.", - node) - # See if this is a pointer to an array, if so cast it - # to a pointer on the array type - ptrEleType = cc.PointerType.getElementType(ptrVal.type) - if cc.ArrayType.isinstance(ptrEleType): - ptrVal = cc.CastOp( - cc.PointerType.get( - cc.ArrayType.getElementType(ptrEleType)), - ptrVal).result - - # Visit the value being assigned - self.visit(node.value) - valueToStore = self.popValue() - # Cast if necessary - valueToStore = self.changeOperandToType(ptrEleType, - valueToStore) - # Store the value - cc.StoreOp(valueToStore, ptrVal) - return target.value, None - - # Handle assignments like `classVar.attr = expr` - elif (isinstance(target, ast.Attribute) and - isinstance(target.value, ast.Name) and - target.value.id in self.symbolTable): - check_not_captured(target.value.id) - - self.attributePushPointerValue = True - # Visit the attribute node, get the pointer value - self.visit(target) - # Reset the push pointer value flag - self.attributePushPointerValue = False - ptrVal = self.popValue() - if not cc.PointerType.isinstance(ptrVal.type): - self.emitFatalError("invalid CUDA-Q attribute assignment", - node) - # Visit the value being assigned - self.visit(node.value) - valueToStore = self.popValue() - # Cast if necessary - valueToStore = self.changeOperandToType( - cc.PointerType.getElementType(ptrVal.type), valueToStore) - # Store the value - cc.StoreOp(valueToStore, ptrVal) - return target.value, None - else: - self.emitFatalError("Invalid target for assignment", node) + address = cc.AllocaOp(cc.PointerType.get(value.type), + TypeAttr.get(value.type)).result + cc.StoreOp(value, address) + return target, address + + # Handle updates of existing variables + # (target is a combination of attribute and subscript) + self.pushPointerValue = True + self.visit(target) + destination = self.popValue() + self.pushPointerValue = False + + # We should have a pointer since we requested a pointer. + assert cc.PointerType.isinstance(destination.type) + expectedTy = cc.PointerType.getElementType(destination.type) + # We prevent the creation of lists and structs that + # contain pointers, and prevent obtaining pointers to + # quantum types. + assert not cc.PointerType.isinstance(expectedTy) + assert not self.isQuantumType(expectedTy) + + if not isinstance(value, ast.AST): + # Can arise if have something like `l[0], l[1] = getTuple()` + self.emitFatalError( + "updating lists or dataclasses as part of deconstruction is not supported", + node) + + # Measurements need to push their values to the stack, + # so we set a so we set a non-None variable name here. + self.currentAssignVariableName = '' + self.visit(value) + mlirVal = self.popValue() + self.currentAssignVariableName = None + assert not cc.PointerType.isinstance(mlirVal.type) + + # Must validate the container entry regardless of what scope the + # target is defined in. + self.__validate_container_entry(mlirVal, value) + + if target_root_defined_in_parent_scope: + update_in_parent_scope(destination, mlirVal) + return target_root, None + + mlirVal = self.changeOperandToType(expectedTy, + mlirVal, + allowDemotion=False) + cc.StoreOp(mlirVal, destination) + # The returned target root has no effect here since no value + # is returns to push to he symbol table. We merely need to make + # sure that it is an `ast.Name` object to break the recursion. + return target_root, None if len(node.targets) > 1: # I am not entirely sure what kinds of Python language constructs @@ -1592,62 +2066,99 @@ def visit_Attribute(self, node): 'ZError', 'XError', 'YError', 'Pauli1', 'Pauli2', 'Depolarization1', 'Depolarization2' ]: - cudaq_module = importlib.import_module('cudaq') - channel_class = getattr(cudaq_module, node.attr) - self.pushValue( - self.getConstantInt(channel_class.num_parameters)) - self.pushValue(self.getConstantInt(hash(channel_class))) - return + self.emitFatalError( + "noise channels may only be used as part of call expressions", + node) - # Any other cudaq attributes should be handled by the parent + # must be handled by the parent return if node.attr == 'ctrl' or node.attr == 'adj': # to be processed by the caller return - def process_potential_ptr_types(value): - """ - Helper function to process anything that the parent may assign to, - depending on whether value is a pointer or not. - """ - valType = value.type - if cc.PointerType.isinstance(valType): - valType = cc.PointerType.getElementType(valType) - - if quake.StruqType.isinstance(valType): - # Need to extract value instead of load from compute pointer. - structIdx, memberTy = self.getStructMemberIdx( - node.attr, value.type) - attr = IntegerAttr.get(self.getIntegerType(32), structIdx) - self.pushValue(quake.GetMemberOp(memberTy, value, attr).result) - return True + if node.attr == 'copy': + if self.pushPointerValue: + self.emitFatalError( + "function call does not produce a modifiable value", node) + # needs to be handled by the caller + return - if cc.StructType.isinstance(valType): - # Handle the case where we have a struct member extraction, memory semantics - self.__visitStructAttribute(node, value) - return True + # Only variable names, subscripts and attributes can + # produce modifiable values. Anything else produces an + # immutable value. We make sure the visit gets processed + # such that the rest of the code can give a proper error. + value_root = node.value + while (isinstance(value_root, ast.Subscript) or + isinstance(value_root, ast.Attribute)): + value_root = value_root.value + if self.pushPointerValue and not isinstance(value_root, ast.Name): + self.pushPointerValue = False + self.visit(node.value) + value = self.popValue() + self.pushPointerValue = True + else: + self.visit(node.value) + value = self.popValue() - elif (quake.VeqType.isinstance(valType) or - cc.StdvecType.isinstance(valType) or - cc.ArrayType.isinstance(valType)): - return self.__isSupportedVectorFunction(node.attr) + valType = value.type + if cc.PointerType.isinstance(valType): + valType = cc.PointerType.getElementType(valType) - return False + if quake.StruqType.isinstance(valType): + if self.pushPointerValue: + self.emitFatalError( + "accessing attribute of quantum tuple or dataclass does not produce a modifiable value", + node) + # Need to extract value instead of load from compute pointer. + structIdx, memberTy = self.getStructMemberIdx(node.attr, value.type) + attr = IntegerAttr.get(self.getIntegerType(32), structIdx) + self.pushValue(quake.GetMemberOp(memberTy, value, attr).result) + return - # Make sure we preserve pointers for structs - if isinstance(node.value, - ast.Name) and node.value.id in self.symbolTable: - value = self.symbolTable[node.value.id] - processed = process_potential_ptr_types(value) - if processed: + if (cc.PointerType.isinstance(value.type) and + cc.StructType.isinstance(valType)): + assert self.pushPointerValue + structIdx, memberTy = self.getStructMemberIdx(node.attr, valType) + eleAddr = cc.ComputePtrOp(cc.PointerType.get(memberTy), value, [], + DenseI32ArrayAttr.get([structIdx])).result + + if self.pushPointerValue: + self.pushValue(eleAddr) return - self.visit(node.value) - if len(self.valueStack) == 0: - self.emitFatalError("failed to create value to access attribute", - node) - value = self.ifPointerThenLoad(self.popValue()) + eleAddr = cc.LoadOp(eleAddr).result + self.pushValue(eleAddr) + return + + if cc.StructType.isinstance(value.type): + if self.pushPointerValue: + self.emitFatalError( + "value cannot be modified - use `.copy(deep)` to create a new value that can be modified", + node) + + # Handle direct struct value - use ExtractValueOp (more efficient) + structIdx, memberTy = self.getStructMemberIdx(node.attr, value.type) + extractedValue = cc.ExtractValueOp( + memberTy, value, [], DenseI32ArrayAttr.get([structIdx])).result + + self.pushValue(extractedValue) + return + + if (quake.VeqType.isinstance(valType) or + cc.StdvecType.isinstance(valType)): + if self.__isSupportedVectorFunction(node.attr): + if self.pushPointerValue: + self.emitFatalError( + "function call does not produce a modifiable value", + node) + # needs to be handled by the caller + return + + # everything else does not produce a modifiable value + if self.pushPointerValue: + self.emitFatalError( + "attribute expression does not produce a modifiable value") if ComplexType.isinstance(value.type): if (node.attr == 'real'): @@ -1663,16 +2174,13 @@ def process_potential_ptr_types(value): if quake.VeqType.isinstance(value.type): self.pushValue( quake.VeqSizeOp(self.getIntegerType(), value).result) - return True - if cc.StdvecType.isinstance(value.type) or cc.ArrayType.isinstance( - value.type): - self.pushValue(self.__get_vector_size(value)) - return True + return + if cc.StdvecType.isinstance(value.type): + self.pushValue( + cc.StdvecSizeOp(self.getIntegerType(), value).result) + return - processed = process_potential_ptr_types(value) - if not processed: - self.emitFatalError("unrecognized attribute {}".format(node.attr), - node) + self.emitFatalError("unrecognized attribute {}".format(node.attr), node) def find_unique_decorator_name(self, name): mod = sys.modules[self.kernelModuleName] @@ -1686,39 +2194,13 @@ def find_unique_decorator_name(self, name): def visit_Call(self, node): """ - Map a Python Call operation to equivalent MLIR. This method will first - check for call operations that are `ast.Name` nodes in the tree (the - name of a function to call). It will handle the Python `range(start, - stop, step)` function by creating an array of integers to loop through - via an invariant CC loop operation. Subsequent users of the `range()` - result can iterate through the elements of the returned `cc.array`. It - will handle the Python `enumerate(iterable)` function by constructing - another invariant loop that builds up and array of `cc.struct`, - representing the counter and the element. - - It will next handle any quantum operation (optionally with a rotation - parameter). Single target operations can be represented that take a - single qubit reference, multiple single qubits, or a vector of qubits, - where the latter two will apply the operation to every qubit in the - vector: - - Valid single qubit operations are `h`, `x`, `y`, `z`, `s`, `t`, `rx`, - `ry`, `rz`, `r1`. - - Measurements `mx`, `my`, `mz` are mapped to corresponding quake - operations and the return i1 value is added to the value - stack. Measurements of single qubit reference and registers of qubits - are supported. - - General calls to previously seen CUDA-Q kernels are supported. By this - we mean that an kernel can not be invoked from a kernel unless it was - defined before the current kernel. Kernels can also be reversed or - controlled with `cudaq.adjoint(kernel, ...)` and `cudaq.control(kernel, - ...)`. - - Finally, general operation modifiers are supported, specifically - `OPERATION.adj` and `OPERATION.ctrl` for adjoint and control synthesis - of the operation. + Map a Python Call operation to equivalent MLIR. This method handles + functions that are `ast.Name` and `ast.Attribute` objects. + + This function handles all built-in unitary and measurement gates + as well as all the ways to adjoint and control them. + General calls to previously seen CUDA-Q kernels or registered + operations are supported. ```python q, r = cudaq.qubit(), cudaq.qubit() @@ -1732,22 +2214,34 @@ def visit_Call(self, node): """ global globalRegisteredOperations + def copy_list_to_stack(value): + symName = '__nvqpp_vectorCopyToStack' + load_intrinsic(self.module, symName) + elemTy = cc.StdvecType.getElementType(value.type) + if elemTy == self.getIntegerType(1): + elemTy = self.getIntegerType(8) + ptrTy = cc.PointerType.get(self.getIntegerType(8)) + resBuf = cc.StdvecDataOp(cc.PointerType.get(elemTy), value).result + eleSize = cc.SizeOfOp(self.getIntegerType(), + TypeAttr.get(elemTy)).result + dynSize = cc.StdvecSizeOp(self.getIntegerType(), value).result + stackCopy = cc.AllocaOp(cc.PointerType.get( + cc.ArrayType.get(elemTy)), + TypeAttr.get(elemTy), + seqSize=dynSize).result + func.CallOp([], symName, [ + cc.CastOp(ptrTy, stackCopy).result, + cc.CastOp(ptrTy, resBuf).result, + arith.MulIOp(dynSize, eleSize).result + ]) + return cc.StdvecInitOp(value.type, stackCopy, length=dynSize).result + def convertArguments(expectedArgTypes, values): - fName = 'function' - if hasattr(node.func, 'id'): - fName = node.func.id - elif hasattr(node.func, 'attr'): - fName = node.func.attr - if len(expectedArgTypes) != len(values): - self.emitFatalError( - f"invalid number of arguments passed in call to {fName} ({len(values)} vs required {len(expectedArgTypes)})", - node) + assert len(expectedArgTypes) == len(values) args = [] - for idx, value in enumerate(values): - arg = self.ifPointerThenLoad(value) - expectedTy = expectedArgTypes[idx] + for idx, expectedTy in enumerate(expectedArgTypes): arg = self.changeOperandToType(expectedTy, - arg, + values[idx], allowDemotion=True) args.append(arg) return args @@ -1762,129 +2256,20 @@ def getNegatedControlQubits(controls): self.controlNegations.clear() return negatedControlQubits - def processControlOrAdjoint(attrName): - # NOTE: CUDA-Q does not return a new function with these operations. - # Instead they are defined to immediately call an autogenerated - # variant of the callable (first argument). - if not node.args: - self.emitFatalError(attrName, "requires at least 1 argument", - node) - astName = node.args[0] - if not isinstance(astName, ast.Name): - self.emitFatalError( - f"unsupported argument in call to {attrName} - first " - f"argument must be a symbol name", node) - otherFuncName = astName.id - values = [self.popValue() for _ in range(len(self.valueStack))] - values.reverse() - kwargs = {"is_adj": attrName == 'adjoint'} - - if otherFuncName in self.symbolTable: - indirectCallee = [self.symbolTable[otherFuncName]] - values = values[1:] - else: - # First time seeing this symbol. Lambda lift it. It must be a - # callable. - decorator = recover_kernel_decorator(otherFuncName) - if not decorator: - self.emitFatalError( - "unprocessed kernel reference not yet supported", node) - self.appendToLiftedArgs(otherFuncName) - entryPoint = recover_func_op(decorator.qkeModule, - nvqppPrefix + decorator.uniqName) - funcTy = FunctionType( - TypeAttr(entryPoint.attributes['function_type']).value) - if decorator.firstLiftedPos: - moduloInTys = funcTy.inputs[:decorator.firstLiftedPos] - else: - moduloInTys = funcTy.inputs - callableTy = cc.CallableType.get(self.ctx, moduloInTys, - funcTy.results) - # indirectCallee[0] will be a new BlockArgument - indirectCallee = [ - cudaq_runtime.appendKernelArgument(self.kernelFuncOp, - callableTy) - ] - self.argTypes.append(callableTy) - self.symbolTable.add(otherFuncName, indirectCallee[0]) - - if not cc.CallableType.isinstance(indirectCallee[0].type): - self.emitFatalError(f"{otherFuncName} must be a callable", node) - functionTy = FunctionType( - cc.CallableType.getFunctionType(indirectCallee[0].type)) - inputTys, outputTys = functionTy.inputs, functionTy.results - numControlArgs = 1 if attrName == 'control' else 0 - - if len(values) < numControlArgs: - self.emitFatalError( - "missing control qubit(s) argument in cudaq.control", node) - controls = values[:numControlArgs] - invert_controls = lambda: None - if len(controls) != 0: - assert (len(controls) == 1) - if numControlArgs == 1: - if (not quake.RefType.isinstance(controls[0].type) and - not quake.VeqType.isinstance(controls[0].type)): - self.emitFatalError( - 'invalid argument type for control operand', node) - # TODO: it would be cleaner to add support for negated control - # qubits to `quake.ApplyOp` - if controls[0] in self.controlNegations: - invert_controls = lambda: self.__applyQuantumOperation( - 'x', [], controls) - self.controlNegations.clear() - args = convertArguments(inputTys, values[numControlArgs:]) - if len(outputTys) != 0: - self.emitFatalError( - f'cannot take {attrName} of kernel {otherFuncName} that ' - f'returns a value', node) - invert_controls() - quake.ApplyOp([], indirectCallee, controls, args, **kwargs) - invert_controls() - - def processFunctionCall(fType, nrValsToPop): - if len(fType.inputs) != nrValsToPop: - fName = 'function' - if hasattr(node.func, 'id'): - fName = node.func.id - elif hasattr(node.func, 'attr'): - fName = node.func.attr - self.emitFatalError( - f"invalid number of arguments passed in call to {fName} " - f"({nrValsToPop} vs required {len(fType.inputs)})", node) - values = [self.popValue() for _ in node.args] - values.reverse() - values = convertArguments([t for t in fType.inputs], values) - if len(fType.results) == 0: - func.CallOp(otherKernel, values) - else: - result = func.CallOp(otherKernel, values).result - # Copy to stack if necessary - if cc.StdvecType.isinstance(result.type): - elemTy = cc.StdvecType.getElementType(result.type) - if elemTy == self.getIntegerType(1): - elemTy = self.getIntegerType(8) - data = cc.StdvecDataOp(cc.PointerType.get(elemTy), - result).result - i64Ty = self.getIntegerType(64) - length = cc.StdvecSizeOp(i64Ty, result).result - elemSize = cc.SizeOfOp(i64Ty, TypeAttr.get(elemTy)).result - buffer = cc.AllocaOp(cc.PointerType.get( - cc.ArrayType.get(elemTy)), - TypeAttr.get(elemTy), - seqSize=length).result - i8PtrTy = cc.PointerType.get(self.getIntegerType(8)) - cbuffer = cc.CastOp(i8PtrTy, buffer).result - cdata = cc.CastOp(i8PtrTy, data).result - symName = '__nvqpp_vectorCopyToStack' - load_intrinsic(self.module, symName) - sizeInBytes = arith.MulIOp(length, elemSize).result - func.CallOp([], symName, [cbuffer, cdata, sizeInBytes]) - # Replace result with the stack buffer-backed vector - result = cc.StdvecInitOp(result.type, buffer, - length=length).result - - self.pushValue(result) + def processFunctionCall(kernel): + nrArgs = len(kernel.type.inputs) + values = self.__groupValues(node.args, [(nrArgs, nrArgs)]) + values = convertArguments([t for t in kernel.type.inputs], values) + if len(kernel.type.results) == 0: + func.CallOp(kernel, values) + return + # The logic for calls that return values must + # match the logic in `visit_Return`; anything + # copied to the heap during return must be copied + # back to the stack. Compiler optimizations should + # take care of eliminating unnecessary copies. + result = func.CallOp(kernel, values).result + return self.__migrateLists(result, copy_list_to_stack) def checkControlAndTargetTypes(controls, targets): """ @@ -1909,190 +2294,238 @@ def is_qvec_or_qubits(vals): self.emitFatalError(f'invalid argument type for target operand', node) - def processDecoratorCall(decorator, name): - if name in self.symbolTable: - callee = self.symbolTable[name] - assert (cc.CallableType.isinstance(callee.type)) - funcTy = FunctionType( - cc.CallableType.getFunctionType(callee.type)) + def processQuantumOperation(opName, + controls, + targets, + *args, + broadcast=lambda q: [q], + **kwargs): + opCtor = getattr(quake, f'{opName}Op') + checkControlAndTargetTypes(controls, targets) + if not broadcast: + return opCtor(*args, controls, targets, **kwargs) + elif quake.VeqType.isinstance(targets[0].type): + assert len(targets) == 1 + + def bodyBuilder(iterVal): + q = quake.ExtractRefOp(self.getRefType(), + targets[0], + -1, + index=iterVal).result + opCtor(*args, controls, broadcast(q), **kwargs) + + veqSize = quake.VeqSizeOp(self.getIntegerType(), + targets[0]).result + self.createInvariantForLoop(bodyBuilder, veqSize) + else: + for target in targets: + opCtor(*args, controls, broadcast(target), **kwargs) + + def processQuakeCtor(opName, + pyArgs, + isCtrl, + isAdj, + numParams=0, + numTargets=1): + kwargs = {} + if isCtrl: + argGroups = [(numParams, numParams), (1, -1), + (numTargets, numTargets)] + # FIXME: we could allow this as long as we have 1 target + kwargs['broadcast'] = False + elif numTargets == 1: + # when we have a single target and no controls, we generally + # support any version of `x(qubit)`, `x(qvector)`, `x(q, r)` + argGroups = [(numParams, numParams), 0, (1, -1)] + else: + argGroups = [(numParams, numParams), 0, + (numTargets, numTargets)] + kwargs['broadcast'] = False + + params, controls, targets = self.__groupValues(pyArgs, argGroups) + if isCtrl: + negatedControlQubits = getNegatedControlQubits(controls) + kwargs['negated_qubit_controls'] = negatedControlQubits + if isAdj: + kwargs['is_adj'] = True + params = [ + self.changeOperandToType(self.getFloatType(), param) + for param in params + ] + processQuantumOperation(opName, controls, targets, [], params, + **kwargs) + + def processDecorator(name, path=None): + if path: + name = f"{path}.{name}" + decorator = resolve_qualified_symbol(name) else: + decorator = recover_kernel_decorator(name) + + if decorator and not name in self.symbolTable: self.appendToLiftedArgs(name) entryPoint = recover_func_op(decorator.qkeModule, nvqppPrefix + decorator.uniqName) funcTy = FunctionType( TypeAttr(entryPoint.attributes['function_type']).value) - if decorator.firstLiftedPos: - moduloInTys = funcTy.inputs[:decorator.firstLiftedPos] - else: - moduloInTys = funcTy.inputs - callableTy = cc.CallableType.get(self.ctx, moduloInTys, - funcTy.results) + callableTy = cc.CallableType.get( + self.ctx, funcTy.inputs[:decorator.firstLiftedPos], + funcTy.results) + # callee will be a new BlockArgument callee = cudaq_runtime.appendKernelArgument( self.kernelFuncOp, callableTy) self.argTypes.append(callableTy) - self.symbolTable.add(name, callee) - return callee, funcTy + self.symbolTable.add(name, callee, 0) + + return name if decorator else None + + # FIXME: unify with processFunctionCall? + def processDecoratorCall(symName): + assert symName in self.symbolTable + self.visit(ast.Name(symName)) + kernel = self.popValue() + if not cc.CallableType.isinstance(kernel.type): + self.emitFatalError( + f"`{symName}` object is not callable, found symbol of type {kernel.type}", + node) + functionTy = FunctionType( + cc.CallableType.getFunctionType(kernel.type)) + nrArgs = len(functionTy.inputs) + values = self.__groupValues(node.args, [(nrArgs, nrArgs)]) + values = convertArguments([t for t in functionTy.inputs], values) + call = cc.CallCallableOp(functionTy.results, kernel, values) + call.attributes.__setitem__('symbol', StringAttr.get(symName)) + + if len(functionTy.results) == 0: + return + if len(functionTy.results) == 1: + result = call.results[0] + else: + # FIXME: SPLIT OUT INTO HELPER FUNCTION + for res in call.results: + self.__validate_container_entry(res, node) + structTy = mlirTryCreateStructType(functionTy.results, + name='tuple', + context=self.ctx) + if structTy is None: + self.emitFatalError( + "Hybrid quantum-classical data types and nested quantum structs are not allowed.", + node) + if quake.StruqType.isinstance(structTy): + result = quake.MakeStruqOp(structTy, call.results).result + else: + result = cc.UndefOp(structTy) + for idx, element in enumerate(call.results): + result = cc.InsertValueOp( + structTy, result, element, + DenseI64ArrayAttr.get([idx], + context=self.ctx)).result + # The logic for calls that return values must + # match the logic in `visit_Return`; anything + # copied to the heap during return must be copied + # back to the stack. Compiler optimizations should + # take care of eliminating unnecessary copies. + return self.__migrateLists(result, copy_list_to_stack) # do not walk the FunctionDef decorator_list arguments if isinstance(node.func, ast.Attribute): self.debug_msg(lambda: f'[(Inline) Visit Attribute]', node.func) - if (hasattr(node.func.value, 'id') and - node.func.value.id == 'cudaq' and - node.func.attr == 'kernel'): + if hasattr( + node.func.value, 'id' + ) and node.func.value.id == 'cudaq' and node.func.attr == 'kernel': return # If we have a `func = ast.Attribute``, then it could be that we # have a previously defined kernel function call with manually # specified module names. - # e.g. `cudaq.lib.test.hello.fermionic_swap``. In this case, we - # assume FindDepKernels has found something like this, loaded it, - # and now we just want to get the function name and call it. - - # First let's check for registered C++ kernels - cppDevModNames = [] + moduleNames = [] value = node.func.value - if isinstance(value, ast.Name) and value.id != 'cudaq': - self.debug_msg(lambda: f'[(Inline) Visit Name]', value) - cppDevModNames = [node.func.attr, value.id] - else: + while isinstance(value, ast.Attribute): self.debug_msg(lambda: f'[(Inline) Visit Attribute]', value) - while isinstance(value, ast.Attribute): - cppDevModNames.append(value.attr) - value = value.value - if isinstance(value, ast.Name): - self.debug_msg(lambda: f'[(Inline) Visit Name]', value) - cppDevModNames.append(value.id) - break - - devKey = '.'.join(cppDevModNames[::-1]) - - def get_full_module_path(partial_path): - parts = partial_path.split('.') + moduleNames.append(value.attr) + value = value.value + if isinstance(value, ast.Name): + self.debug_msg(lambda: f'[(Inline) Visit Name]', value) + moduleNames.append(value.id) + moduleNames.reverse() + + devKey = '.'.join(moduleNames) for module_name, module in sys.modules.items(): - if module_name.endswith(parts[0]): + if module_name.split('.')[-1] == moduleNames[0]: try: obj = module - for part in parts[1:]: + for part in moduleNames[1:]: obj = getattr(obj, part) - return f"{module_name}.{'.'.join(parts[1:])}" + devKey = f"{module_name}.{'.'.join(moduleNames[1:])}" except AttributeError: continue - return partial_path - devKey = get_full_module_path(devKey) - if cudaq_runtime.isRegisteredDeviceModule(devKey): - maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( - self.module, devKey + '.' + node.func.attr) - if maybeKernelName == None: + # Handle registered C++ kernels + if cudaq_runtime.isRegisteredDeviceModule(devKey): maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( - self.module, devKey) - if maybeKernelName != None: - otherKernel = SymbolTable( - self.module.operation)[maybeKernelName] + self.module, devKey + '.' + node.func.attr) + if maybeKernelName == None: + maybeKernelName = cudaq_runtime.checkRegisteredCppDeviceKernel( + self.module, devKey) + if maybeKernelName != None: + otherKernel = SymbolTable( + self.module.operation)[maybeKernelName] + res = processFunctionCall(otherKernel) + if res is not None: + self.pushValue(res) + return - [self.visit(arg) for arg in node.args] - processFunctionCall(otherKernel.type, len(node.args)) + # Handle debug functions + if devKey == 'cudaq.dbg.ast': + # Handle a debug print statement + arg = self.__groupValues(node.args, [1]) + self.__insertDbgStmt(arg, node.func.attr) return - # Start by seeing if we have mod1.mod2.mod3... - moduleNames = [] - value = node.func.value - while isinstance(value, ast.Attribute): - self.debug_msg(lambda: f'[(Inline) Visit Attribute]', value) - moduleNames.append(value.attr) - value = value.value - if isinstance(value, ast.Name): - self.debug_msg(lambda: f'[(Inline) Visit Name]', value) - moduleNames.append(value.id) - break - - if all(x in moduleNames for x in ['cudaq', 'dbg', 'ast']): - # FIXME: the above allows random permutations of these words - # Handle a debug print statement - [self.visit(arg) for arg in node.args] - if len(self.valueStack) != 1: - self.emitFatalError( - f"cudaq.dbg.ast.{node.func.attr} call invalid - " - f"too many arguments passed.", node) + # Handle kernels defined in other modules + symName = processDecorator(node.func.attr, path=devKey) + if symName: + node.func = ast.Name(symName) - self.__insertDbgStmt(self.popValue(), node.func.attr) + if isinstance(node.func, ast.Name): + symName = (node.func.id if node.func.id in self.symbolTable else + processDecorator(node.func.id)) + if symName: + result = processDecoratorCall(symName) + if result: + self.pushValue(result) return - # If we did have module names, then this is what we are looking for - if len(moduleNames): - name = node.func.attr - moduleNames.reverse() - if not moduleNames[0] in self.symbolTable: - decorator = recover_kernel_decorator(name) - if decorator: - callee, fType = processDecoratorCall(decorator, name) - if len(fType.inputs) != len(node.args): - funcName = node.func.id if hasattr( - node.func, 'id') else node.func.attr - self.emitFatalError( - f"invalid number of arguments passed to " - f"callable {funcName} ({len(node.args)} vs " - f"required {len(fType.inputs)})", node) - [self.visit(arg) for arg in node.args] - values = [self.popValue() for _ in node.args] - values.reverse() - values = [self.ifPointerThenLoad(v) for v in values] - call = cc.CallCallableOp(fType.results, callee, - values) - sa = StringAttr.get(name) - call.attributes.__setitem__('symbol', sa) - for r in call.results: - self.pushValue(r) - return + if node.func.id == 'complex': - # FIXME: This whole thing is widely inconsistent; - # For example; we pop all values on the value stack for a simple gate - # and allow x(q1, q2, q3, ...) here, but for a simple adjoint gate we - # only ever pop a single value. Then there are the control qubits, - # where we also allow to pass individual qubits instead of a vector. - # I'll tackle this as part of revising the value stack. - # FIXME: Expand the tests in test_control_negations as needed after - # revising this. - if isinstance(node.func, ast.Name): - # Just visit the arguments, we know the name - [self.visit(arg) for arg in node.args] - - namedArgs = {} - for keyword in node.keywords: - self.visit(keyword.value) - namedArgs[keyword.arg] = self.popValue() - - self.debug_msg(lambda: f'[(Inline) Visit Name]', node.func) - - decorator = recover_kernel_decorator(node.func.id) - if decorator: - # This is a call to a device kernel. - callee, funcTy = processDecoratorCall(decorator, node.func.id) - values = [self.popValue() for _ in node.args] - values.reverse() - values = [self.ifPointerThenLoad(v) for v in values] - call = cc.CallCallableOp(funcTy.results, callee, values) - sa = StringAttr.get(node.func.id) - call.attributes.__setitem__('symbol', sa) - for r in call.results: - self.pushValue(r) - return + keywords = [kw.arg for kw in node.keywords] + kwreal = 'real' in keywords + kwimag = 'imag' in keywords + real, imag = self.__groupValues(node.args, + [not kwreal, not kwimag]) + for keyword in node.keywords: + self.visit(keyword.value) + kwval = self.popValue() + if keyword.arg == 'real': + real = kwval + elif keyword.arg == 'imag': + imag = kwval + else: + self.emitFatalError(f"unknown keyword `{keyword.arg}`", + node) + if not real or not imag: + self.emitFatalError("missing value", node) - if node.func.id in self.symbolTable: - callee = self.symbolTable[node.func.id] - values = [self.popValue() for _ in node.args] - values.reverse() - values = [self.ifPointerThenLoad(v) for v in values] - cfTy = cc.CallableType.getFunctionType(callee.type) - resTys = FunctionType(cfTy).results - call = cc.CallCallableOp(resTys, callee, values) - for r in call.results: - self.pushValue(r) + imag = self.changeOperandToType(self.getFloatType(), imag) + real = self.changeOperandToType(self.getFloatType(), real) + self.pushValue( + complex.CreateOp(self.getComplexType(), real, imag).result) return if node.func.id == 'len': - listVal = self.ifPointerThenLoad(self.popValue()) + listVal = self.__groupValues(node.args, [1]) + if cc.StdvecType.isinstance(listVal.type): self.pushValue( cc.StdvecSizeOp(self.getIntegerType(), listVal).result) @@ -2113,17 +2546,15 @@ def get_full_module_path(partial_path): zero = arith.ConstantOp(iTy, IntegerAttr.get(iTy, 0)) one = arith.ConstantOp(iTy, IntegerAttr.get(iTy, 1)) - # The total number of elements in the iterable - # we are generating should be `N == endVal - startVal` - actualSize = arith.SubIOp(endVal, startVal).result - totalSize = math.AbsIOp(actualSize).result - - # If the step is not == 1, then we also have - # to update the total size for the range iterable - actualSize = arith.DivSIOp(actualSize, - math.AbsIOp(stepVal).result).result - totalSize = arith.DivSIOp(totalSize, - math.AbsIOp(stepVal).result).result + totalSize = arith.SubIOp(endVal, startVal).result + if isDecrementing: + roundingOffset = arith.AddIOp(stepVal, one) + else: + roundingOffset = arith.SubIOp(stepVal, one) + totalSize = arith.AddIOp(totalSize, roundingOffset) + totalSize = arith.MaxSIOp( + zero, + arith.DivSIOp(totalSize, stepVal).result).result # Create an array of i64 of the total size arrTy = cc.ArrayType.get(iTy) @@ -2142,91 +2573,67 @@ def get_full_module_path(partial_path): def bodyBuilder(iterVar): loadedCounter = cc.LoadOp(counter).result - tmp = arith.MulIOp(loadedCounter, stepVal).result - arrElementVal = arith.AddIOp(startVal, tmp).result eleAddr = cc.ComputePtrOp( cc.PointerType.get(iTy), iterable, [loadedCounter], DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx)) - cc.StoreOp(arrElementVal, eleAddr) + cc.StoreOp(iterVar, eleAddr) incrementedCounter = arith.AddIOp(loadedCounter, one).result cc.StoreOp(incrementedCounter, counter) - self.createInvariantForLoop(endVal, - bodyBuilder, + self.createMonotonicForLoop(bodyBuilder, startVal=startVal, stepVal=stepVal, + endVal=endVal, isDecrementing=isDecrementing) - self.pushValue(iterable) - self.pushValue(actualSize) + vect = cc.StdvecInitOp(cc.StdvecType.get(iTy), + iterable, + length=totalSize).result + self.pushValue(vect) return if node.func.id == 'enumerate': # We have to have something "iterable" on the stack, # could be coming from `range()` or an iterable like `qvector` - totalSize = None - iterable = None - iterEleTy = None - extractFunctor = None - if len(self.valueStack) == 1: - # `qreg`-like or `stdvec`-like thing thing - iterable = self.ifPointerThenLoad(self.popValue()) - # Create a new iterable, `alloca cc.struct` - totalSize = None - if quake.VeqType.isinstance(iterable.type): - iterEleTy = self.getRefType() - totalSize = quake.VeqSizeOp(self.getIntegerType(), - iterable).result - - def extractFunctor(idxVal): - return quake.ExtractRefOp(iterEleTy, - iterable, - -1, - index=idxVal).result - elif cc.StdvecType.isinstance(iterable.type): - iterEleTy = cc.StdvecType.getElementType(iterable.type) - totalSize = cc.StdvecSizeOp(self.getIntegerType(), - iterable).result - - def extractFunctor(idxVal): - arrEleTy = cc.ArrayType.get(iterEleTy) - elePtrTy = cc.PointerType.get(iterEleTy) - arrPtrTy = cc.PointerType.get(arrEleTy) - vecPtr = cc.StdvecDataOp(arrPtrTy, iterable).result - eleAddr = cc.ComputePtrOp( - elePtrTy, vecPtr, [idxVal], - DenseI32ArrayAttr.get([kDynamicPtrIndex], - context=self.ctx)).result - return cc.LoadOp(eleAddr).result - else: - self.emitFatalError( - "could not infer enumerate tuple type ({})".format( - iterable.type), node) - else: - if len(self.valueStack) != 2: - msg = 'Error in AST processing, should have 2 values on the stack for enumerate' - self.emitFatalError(msg, node) + iterable = self.__groupValues(node.args, [1]) - totalSize = self.popValue() - iterable = self.popValue() - arrTy = cc.PointerType.getElementType(iterable.type) - iterEleTy = cc.ArrayType.getElementType(arrTy) + # Create a new iterable, `alloca cc.struct` + if quake.VeqType.isinstance(iterable.type): + iterEleTy = self.getRefType() + totalSize = quake.VeqSizeOp(self.getIntegerType(), + iterable).result + + def extractFunctor(idxVal): + return quake.ExtractRefOp(iterEleTy, + iterable, + -1, + index=idxVal).result + elif cc.StdvecType.isinstance(iterable.type): + iterEleTy = cc.StdvecType.getElementType(iterable.type) + totalSize = cc.StdvecSizeOp(self.getIntegerType(), + iterable).result - def localFunc(idxVal): + def extractFunctor(idxVal): + arrEleTy = cc.ArrayType.get(iterEleTy) + elePtrTy = cc.PointerType.get(iterEleTy) + arrPtrTy = cc.PointerType.get(arrEleTy) + vecPtr = cc.StdvecDataOp(arrPtrTy, iterable).result eleAddr = cc.ComputePtrOp( - cc.PointerType.get(iterEleTy), iterable, [idxVal], + elePtrTy, vecPtr, [idxVal], DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx)).result return cc.LoadOp(eleAddr).result - - extractFunctor = localFunc + else: + self.emitFatalError( + "could not infer enumerate tuple type ({})".format( + iterable.type), node) # Enumerate returns a iterable of tuple(i64, T) for type T # Allocate an array of struct == tuple (for us) structTy = cc.StructType.get([self.getIntegerType(), iterEleTy]) - arrTy = cc.ArrayType.get(structTy) - enumIterable = cc.AllocaOp(cc.PointerType.get(arrTy), + enumIterable = cc.AllocaOp(cc.PointerType.get( + cc.ArrayType.get(structTy)), TypeAttr.get(structTy), seqSize=totalSize).result @@ -2251,123 +2658,48 @@ def bodyBuilder(iterVar): DenseI64ArrayAttr.get([1], context=self.ctx)).result cc.StoreOp(element, eleAddr) - self.createInvariantForLoop(totalSize, bodyBuilder) - self.pushValue(enumIterable) - self.pushValue(totalSize) + self.createInvariantForLoop(bodyBuilder, totalSize) + vect = cc.StdvecInitOp(cc.StdvecType.get(structTy), + enumIterable, + length=totalSize).result + self.pushValue(vect) return - if node.func.id == 'complex': - if len(namedArgs) == 0: - imag = self.popValue() - real = self.popValue() - else: - imag = namedArgs['imag'] - real = namedArgs['real'] - imag = self.changeOperandToType(self.getFloatType(), imag) - real = self.changeOperandToType(self.getFloatType(), real) - self.pushValue( - complex.CreateOp(self.getComplexType(), real, imag).result) + if self.__isSimpleGate(node.func.id): + processQuakeCtor(node.func.id.title(), + node.args, + isCtrl=False, + isAdj=False) return - if self.__isSimpleGate(node.func.id): - # Here we enable application of the op on all the provided - # arguments, e.g. `x(qubit)`, `x(qvector)`, `x(q, r)`, etc. - numValues = len(self.valueStack) - qubitTargets = [self.popValue() for _ in range(numValues)] - qubitTargets.reverse() - checkControlAndTargetTypes([], qubitTargets) - self.__applyQuantumOperation(node.func.id, [], qubitTargets) + if self.__isAdjointSimpleGate(node.func.id): + processQuakeCtor(node.func.id[0].title(), + node.args, + isCtrl=False, + isAdj=True) return if self.__isControlledSimpleGate(node.func.id): - # These are single target controlled quantum operations - MAX_ARGS = 2 - numValues = len(self.valueStack) - if numValues != MAX_ARGS: - raise RuntimeError( - "invalid number of arguments passed to callable {} " - "({} vs required {})".format(node.func.id, - len(node.args), MAX_ARGS)) - target = self.popValue() - control = self.popValue() - negatedControlQubits = getNegatedControlQubits([control]) - checkControlAndTargetTypes([control], [target]) - # Map `cx` to `XOp`... - opCtor = getattr( - quake, '{}Op'.format(node.func.id.title()[1:].upper())) - opCtor([], [], [control], [target], - negated_qubit_controls=negatedControlQubits) + processQuakeCtor(node.func.id[1:].title(), + node.args, + isCtrl=True, + isAdj=False) return if self.__isRotationGate(node.func.id): - numValues = len(self.valueStack) - if numValues < 2: - self.emitFatalError( - f'invalid number of arguments ({numValues}) passed to {node.func.id} (requires at least 2 arguments)', - node) - qubitTargets = [self.popValue() for _ in range(numValues - 1)] - qubitTargets.reverse() - param = self.popValue() - if IntegerType.isinstance(param.type): - param = arith.SIToFPOp(self.getFloatType(), param).result - elif not F64Type.isinstance(param.type): - self.emitFatalError( - 'rotational parameter must be a float, or int.', node) - checkControlAndTargetTypes([], qubitTargets) - self.__applyQuantumOperation(node.func.id, [param], - qubitTargets) + processQuakeCtor(node.func.id.title(), + node.args, + isCtrl=False, + isAdj=False, + numParams=1) return if self.__isControlledRotationGate(node.func.id): - ## These are single target, one parameter, controlled quantum operations - MAX_ARGS = 3 - numValues = len(self.valueStack) - if numValues != MAX_ARGS: - raise RuntimeError( - "invalid number of arguments passed to callable {} ({} vs required {})" - .format(node.func.id, len(node.args), MAX_ARGS)) - target = self.popValue() - control = self.popValue() - negatedControlQubits = getNegatedControlQubits([control]) - checkControlAndTargetTypes([control], [target]) - param = self.popValue() - if IntegerType.isinstance(param.type): - param = arith.SIToFPOp(self.getFloatType(), param).result - elif not F64Type.isinstance(param.type): - self.emitFatalError( - 'rotational parameter must be a float, or int.', node) - # Map `crx` to `RxOp`... - opCtor = getattr( - quake, '{}Op'.format(node.func.id.title()[1:].capitalize())) - opCtor([], [param], [control], [target], - negated_qubit_controls=negatedControlQubits) - return - - if self.__isAdjointSimpleGate(node.func.id): - target = self.popValue() - checkControlAndTargetTypes([], [target]) - # Map `sdg` to `SOp`... - opCtor = getattr(quake, '{}Op'.format(node.func.id.title()[0])) - if quake.VeqType.isinstance(target.type): - - def bodyBuilder(iterVal): - q = quake.ExtractRefOp(self.getRefType(), - target, - -1, - index=iterVal).result - opCtor([], [], [], [q], is_adj=True) - - veqSize = quake.VeqSizeOp(self.getIntegerType(), - target).result - self.createInvariantForLoop(veqSize, bodyBuilder) - return - elif quake.RefType.isinstance(target.type): - opCtor([], [], [], [target], is_adj=True) - return - else: - self.emitFatalError( - 'adj quantum operation on incorrect type {}.'.format( - target.type), node) + processQuakeCtor(node.func.id[1:].title(), + node.args, + isCtrl=True, + isAdj=False, + numParams=1) return if self.__isMeasurementGate(node.func.id): @@ -2392,105 +2724,68 @@ def bodyBuilder(iterVal): self.debug_msg(lambda: f'[(Inline) Visit Constant]', userProvidedRegName.value) registerName = userProvidedRegName.value.value - qubits = [self.popValue() for _ in range(len(self.valueStack))] - checkControlAndTargetTypes([], qubits) - opCtor = getattr(quake, '{}Op'.format(node.func.id.title())) - i1Ty = self.getIntegerType(1) - resTy = i1Ty if len(qubits) == 1 and quake.RefType.isinstance( - qubits[0].type) else cc.StdvecType.get(i1Ty) - measTy = quake.MeasureType.get( - ) if len(qubits) == 1 and quake.RefType.isinstance( - qubits[0].type) else cc.StdvecType.get( - quake.MeasureType.get()) - label = registerName - if not label: - label = None - measureResult = opCtor(measTy, [], qubits, - registerName=label).result + + qubits = self.__groupValues(node.args, [(1, -1)]) + label = registerName or None + if len(qubits) == 1 and quake.RefType.isinstance( + qubits[0].type): + measTy = quake.MeasureType.get() + resTy = self.getIntegerType(1) + else: + measTy = cc.StdvecType.get(quake.MeasureType.get()) + resTy = cc.StdvecType.get(self.getIntegerType(1)) + measureResult = processQuantumOperation( + node.func.id.title(), [], + qubits, + measTy, + broadcast=False, + registerName=label).result + + # FIXME: needs to be revised when we properly distinguish measurement types if pushResultToStack: self.pushValue( quake.DiscriminateOp(resTy, measureResult).result) return if node.func.id == 'swap': - qubitB = self.popValue() - qubitA = self.popValue() - checkControlAndTargetTypes([], [qubitA, qubitB]) - opCtor = getattr(quake, '{}Op'.format(node.func.id.title())) - opCtor([], [], [], [qubitA, qubitB]) + processQuakeCtor(node.func.id.title(), + node.args, + isCtrl=False, + isAdj=False, + numTargets=2) return if node.func.id == 'reset': - target = self.popValue() - checkControlAndTargetTypes([], [target]) - if quake.RefType.isinstance(target.type): - quake.ResetOp([], target) - return - if quake.VeqType.isinstance(target.type): - - def bodyBuilder(iterVal): - q = quake.ExtractRefOp( - self.getRefType(), - target, - -1, # `kDynamicIndex` - index=iterVal).result - quake.ResetOp([], q) - - veqSize = quake.VeqSizeOp(self.getIntegerType(), - target).result - self.createInvariantForLoop(veqSize, bodyBuilder) - return - self.emitFatalError( - 'reset quantum operation on incorrect type {}.'.format( - target.type), node) + targets = self.__groupValues(node.args, [(1, -1)]) + processQuantumOperation(node.func.id.title(), [], + targets, + broadcast=lambda q: q) + return if node.func.id == 'u3': - # Single target, three parameters `u3(θ,φ,λ)` - all_args = [ - self.popValue() for _ in range(len(self.valueStack)) - ] - if len(all_args) < 4: - self.emitFatalError( - f'invalid number of arguments ({len(all_args)}) passed to {node.func.id} (requires at least 4 arguments)', - node) - qubitTargets = all_args[:-3] - qubitTargets.reverse() - checkControlAndTargetTypes([], qubitTargets) - params = all_args[-3:] - params.reverse() - for idx, val in enumerate(params): - if IntegerType.isinstance(val.type): - params[idx] = arith.SIToFPOp(self.getFloatType(), - val).result - elif not F64Type.isinstance(val.type): - self.emitFatalError( - 'rotational parameter must be a float, or int.', - node) - self.__applyQuantumOperation(node.func.id, params, qubitTargets) + processQuakeCtor(node.func.id.title(), + node.args, + isCtrl=False, + isAdj=False, + numParams=3) return if node.func.id == 'exp_pauli': - pauliWord = self.popValue() - qubits = self.popValue() - checkControlAndTargetTypes([], [qubits]) - theta = self.popValue() - if IntegerType.isinstance(theta.type): - theta = arith.SIToFPOp(self.getFloatType(), theta).result - quake.ExpPauliOp([], [theta], [], [qubits], pauli=pauliWord) + # Note: C++ also has a constructor that takes an `f64`, `string`, + # any any number of qubits. We don't support this here. + theta, target, pauliWord = self.__groupValues( + node.args, [1, 1, 1]) + theta = self.changeOperandToType(self.getFloatType(), theta) + processQuantumOperation("ExpPauli", [], [target], [], [theta], + broadcast=False, + pauli=pauliWord) return if node.func.id in globalRegisteredOperations: unitary = globalRegisteredOperations[node.func.id] numTargets = int(np.log2(np.sqrt(unitary.size))) - - numValues = len(self.valueStack) - if numValues != numTargets: - self.emitFatalError( - f'invalid number of arguments ({numValues}) passed to {node.func.id} (requires {numTargets} arguments)', - node) - - targets = [self.popValue() for _ in range(numTargets)] - targets.reverse() + targets = self.__groupValues(node.args, + [(numTargets, numTargets)]) for i, t in enumerate(targets): if not quake.RefType.isinstance(t.type): @@ -2515,94 +2810,73 @@ def bodyBuilder(iterVal): is_adj=False) return - if node.func.id == 'exp_pauli': - pauliWord = self.popValue() - qubits = self.popValue() - self.checkControlAndTargetTypes([], [qubits]) - theta = self.popValue() - if IntegerType.isinstance(theta.type): - theta = arith.SIToFPOp(self.getFloatType(), theta).result - quake.ExpPauliOp([], [theta], [], [qubits], pauli=pauliWord) - return - - if node.func.id == 'int': + elif node.func.id == 'int': # cast operation - value = self.popValue() + value = self.__groupValues(node.args, [1]) casted = self.changeOperandToType(IntegerType.get_signless(64), value, allowDemotion=True) self.pushValue(casted) return - if node.func.id == 'list': - if len(self.valueStack) == 2: - maybeIterableSize = self.popValue() - maybeIterable = self.popValue() - - # Make sure that we have a list + size - if IntegerType.isinstance(maybeIterableSize.type): - if cc.PointerType.isinstance(maybeIterable.type): - ptrEleTy = cc.PointerType.getElementType( - maybeIterable.type) - if cc.ArrayType.isinstance(ptrEleTy): - # We're good, just pass this back through. - self.pushValue(maybeIterable) - self.pushValue(maybeIterableSize) - return - if len(self.valueStack) == 1: - arrayTy = self.valueStack[0].type - if cc.PointerType.isinstance(arrayTy): - arrayTy = cc.PointerType.getElementType(arrayTy) - if cc.StdvecType.isinstance(arrayTy): - return - if cc.ArrayType.isinstance(arrayTy): - return - - self.emitFatalError('Invalid list() cast requested.', node) + elif node.func.id == 'list': + # The expected Python behavior is that a constructor call + # to list creates a new list (a shallow copy). + value = self.__groupValues(node.args, [1]) + copy = self.__copyAndValidateContainer(value, node.args[0], + False) + self.pushValue(copy) + return - if node.func.id in ['print_i64', 'print_f64']: - self.__insertDbgStmt(self.popValue(), node.func.id) + elif node.func.id in ['print_i64', 'print_f64']: + value = self.__groupValues(node.args, [1]) + self.__insertDbgStmt(value, node.func.id) return - if node.func.id in globalRegisteredTypes.classes: + elif node.func.id in globalRegisteredTypes.classes: # Handle User-Custom Struct Constructor cls, annotations = globalRegisteredTypes.getClassAttributes( node.func.id) if '__slots__' not in cls.__dict__: self.emitWarning( - "Adding new fields in data classes is not yet " - "supported. The dataclass must be declared with " - "@dataclass(slots=True) or " - "@dataclasses.dataclass(slots=True).", node) + f"Adding new fields in data classes is not yet supported. The dataclass must be declared with @dataclass(slots=True) or @dataclasses.dataclass(slots=True).", + node) + + if node.keywords: + self.emitFatalError( + "keyword arguments for data classes are not yet supported", + node) - # Alloca the struct structTys = [ mlirTypeFromPyType(v, self.ctx) for _, v in annotations.items() ] + numArgs = len(structTys) + ctorArgs = self.__groupValues(node.args, [(numArgs, numArgs)]) + ctorArgs = convertArguments(structTys, ctorArgs) + for idx, arg in enumerate(ctorArgs): + self.__validate_container_entry(arg, node.args[idx]) + structTy = mlirTryCreateStructType(structTys, name=node.func.id, context=self.ctx) if structTy is None: self.emitFatalError( - "Hybrid quantum-classical data types and nested quantum" - " structs are not allowed.", node) + "Hybrid quantum-classical data types and nested quantum structs are not allowed.", + node) # Disallow user specified methods on structs - for k, v in cls.__dict__.items(): - if callable(v): - if not (k.startswith('__') and k.endswith('__')): - self.emitFatalError( - 'struct types with user specified methods are ' - 'not allowed.', node) - - ctorArgs = [ - self.popValue() for _ in range(len(self.valueStack)) - ] - ctorArgs.reverse() - ctorArgs = convertArguments(structTys, ctorArgs) + if len({ + k: v + for k, v in cls.__dict__.items() + if not (k.startswith('__') and k.endswith('__')) and + isinstance(v, types.FunctionType) + }) != 0: + self.emitFatalError( + 'struct types with user specified methods are not allowed.', + node) if quake.StruqType.isinstance(structTy): # If we have a quantum struct. We cannot allocate classical @@ -2611,20 +2885,18 @@ def bodyBuilder(iterVal): self.pushValue(quake.MakeStruqOp(structTy, ctorArgs).result) return - stackSlot = cc.AllocaOp(cc.PointerType.get(structTy), - TypeAttr.get(structTy)).result - - # loop over each type and `compute_ptr` / store - for i, ty in enumerate(structTys): - eleAddr = cc.ComputePtrOp( - cc.PointerType.get(ty), stackSlot, [], - DenseI32ArrayAttr.get([i], context=self.ctx)).result - cc.StoreOp(ctorArgs[i], eleAddr) - self.pushValue(stackSlot) + struct = cc.UndefOp(structTy) + for idx, element in enumerate(ctorArgs): + struct = cc.InsertValueOp( + structTy, struct, element, + DenseI64ArrayAttr.get([idx], context=self.ctx)).result + self.pushValue(struct) return - self.emitFatalError(f"unhandled function call - {node.func.id}", - node) + else: + self.emitFatalError( + "unhandled function call - {}, known kernels are {}".format( + node.func.id, globalKernelRegistry.keys()), node) elif isinstance(node.func, ast.Attribute): self.debug_msg(lambda: f'[(Inline) Visit Attribute]', node.func) @@ -2637,6 +2909,40 @@ def bodyBuilder(iterVal): # Handled in the Attribute visit, # since `numpy` arrays have a size attribute self.visit(node.func) + self.pushValue(self.popValue()) + return + + if node.func.attr == 'copy': + self.visit(node.func.value) + funcVal = self.popValue() + deepCopy, dTy = None, None + + for keyword in node.keywords: + if keyword.arg == 'deep': + deepCopy = keyword.value + elif keyword.arg == 'dtype': + self.visit(keyword.value) + dTy = self.popValue() + else: + self.emitFatalError(f"unknown keyword `{keyword.arg}`", + node) + + if len(node.args) == 1 and deepCopy is None: + deepCopy = node.args[0] + else: + self.__groupValues(node.args, [0]) + if deepCopy: + if not isinstance(deepCopy, ast.Constant): + self.emitFatalError( + "argument to `copy` must be a constant", node) + deepCopy = deepCopy.value + + # If we created a deep copy, we can set the parent node + # of the value to copy to be this node for validation purposes. + pyVal = node if deepCopy else node.func.value + copy = self.__copyAndValidateContainer(funcVal, pyVal, deepCopy, + dTy) + self.pushValue(copy) return if self.__isSupportedVectorFunction(node.func.attr): @@ -2646,7 +2952,7 @@ def bodyBuilder(iterVal): # we make the functions we support on values explicit # somewhere, there is no way around that. self.visit(node.func.value) - funcVal = self.ifPointerThenLoad(self.popValue()) + funcVal = self.popValue() # Just to be nice and give a dedicated error. if (node.func.attr == 'append' and @@ -2665,12 +2971,9 @@ def bodyBuilder(iterVal): node) funcArg = None - if len(node.args) > 1: - self.emitFatalError( - f'call to {node.func.attr} supports at most one value') - elif len(node.args) == 1: - self.visit(node.args[0]) - funcArg = self.ifPointerThenLoad(self.popValue()) + args = self.__groupValues(node.args, [(0, 1)]) + if args: + funcArg = args[0] if not IntegerType.isinstance(funcArg.type): self.emitFatalError( f'expecting an integer argument for call to {node.func.attr}', @@ -2734,39 +3037,19 @@ def bodyBuilder(iterVal): if isinstance(node.func.value, ast.Name): if node.func.value.id in ['numpy', 'np']: - [self.visit(arg) for arg in node.args] - - namedArgs = {} - for keyword in node.keywords: - self.visit(keyword.value) - namedArgs[keyword.arg] = self.popValue() - value = self.popValue() + value = self.__groupValues(node.args, [1]) if node.func.attr == 'array': - # `np.array(vec, )` - arrayType = value.type - if cc.PointerType.isinstance(value.type): - arrayType = cc.PointerType.getElementType( - value.type) - - if cc.StdvecType.isinstance(arrayType): - eleTy = cc.StdvecType.getElementType(arrayType) - dTy = eleTy - if len(namedArgs) > 0: - dTy = namedArgs['dtype'] - - # Convert the vector to the provided data type if needed. - self.pushValue( - self.__copyVectorAndCastElements( - value, dTy, allowDemotion=True)) - return - - raise self.emitFatalError( - f"unexpected numpy array initializer type: {value.type}", - node) - - value = self.ifPointerThenLoad(value) + # The expected Python behavior is that a constructor call + # to array creates a new array (a shallow copy). Additionally, since + # a new value is created, we need to make sure container entries + # are properly validated. To not duplicate the logic, we simply + # call `copy` here. + self.visit_Call( + ast.Call(ast.Attribute(node.args[0], 'copy'), [], + node.keywords)) + return if node.func.attr in ['complex128', 'complex64']: if node.func.attr == 'complex128': @@ -2798,8 +3081,7 @@ def bodyBuilder(iterVal): self.pushValue(value) return - # Promote argument's types for `numpy.func` calls to match - # python's semantics + # Promote argument's types for `numpy.func` calls to match python's semantics if self.__isSupportedNumpyFunction(node.func.attr): if ComplexType.isinstance(value.type): value = self.changeOperandToType( @@ -2857,8 +3139,8 @@ def bodyBuilder(iterVal): if node.func.attr == 'ceil': if ComplexType.isinstance(value.type): self.emitFatalError( - f"numpy call ({node.func.attr}) is not supported " - f"for complex numbers", node) + f"numpy call ({node.func.attr}) is not supported for complex numbers", + node) return self.pushValue(math.CeilOp(value).result) return @@ -2866,20 +3148,19 @@ def bodyBuilder(iterVal): self.emitFatalError( f"unsupported NumPy call ({node.func.attr})", node) - [self.visit(arg) for arg in node.args] - if node.func.value.id == 'cudaq': if node.func.attr == 'complex': + self.__groupValues(node.args, [0]) self.pushValue(self.simulationDType()) return if node.func.attr == 'amplitudes': - value = self.popValue() - arrayType = value.type + value = self.__groupValues(node.args, [1]) + + valueTy = value.type if cc.PointerType.isinstance(value.type): - arrayType = cc.PointerType.getElementType( - value.type) - if cc.StdvecType.isinstance(arrayType): + valueTy = cc.PointerType.getElementType(value.type) + if cc.StdvecType.isinstance(valueTy): self.pushValue(value) return @@ -2888,27 +3169,23 @@ def bodyBuilder(iterVal): node) if node.func.attr == 'qvector': - if len(self.valueStack) == 0: + if len(node.args) == 0: self.emitFatalError( 'qvector does not have default constructor. Init from size or existing state.', node) - valueOrPtr = self.popValue() - initializerTy = valueOrPtr.type - - if cc.PointerType.isinstance(initializerTy): - initializerTy = cc.PointerType.getElementType( - initializerTy) + value = self.__groupValues(node.args, [1]) - if (IntegerType.isinstance(initializerTy)): + if (IntegerType.isinstance(value.type)): # handle `cudaq.qvector(n)` - value = self.ifPointerThenLoad(valueOrPtr) ty = self.getVeqType() qubits = quake.AllocaOp(ty, size=value).result self.pushValue(qubits) return - if cc.StdvecType.isinstance(initializerTy): + if cc.StdvecType.isinstance(value.type): + + # handle `cudaq.qvector(initState)` def check_vector_init(): """ Run semantics checks. @@ -2940,19 +3217,15 @@ def check_vector_init(): "qvector init (not a power of 2)", node) - # handle `cudaq.qvector(initState)` - value = self.ifPointerThenLoad(valueOrPtr) check_vector_init() eleTy = cc.StdvecType.getElementType(value.type) - ptrTy = cc.PointerType.get(eleTy) arrTy = cc.ArrayType.get(eleTy) ptrArrTy = cc.PointerType.get(arrTy) data = cc.StdvecDataOp(ptrArrTy, value).result size = cc.StdvecSizeOp(self.getIntegerType(), value).result - # Dynamic checking of the number of elements being - # a power of 2 and that the state is normalized is + # Dynamic checking that the state is normalized is # done at the library layer. veqTy = quake.VeqType.get() stateTy = cc.PointerType.get(cc.StateType.get()) @@ -2961,30 +3234,27 @@ def check_vector_init(): size.type, statePtr).result qubits = quake.AllocaOp(veqTy, size=numQubits).result - ini = quake.InitializeStateOp( + init = quake.InitializeStateOp( veqTy, qubits, statePtr).result quake.DeleteStateOp(statePtr) - self.pushValue(ini) + self.pushValue(init) return - if cc.PointerType.isinstance(initializerTy): - # The pointer to the state object may be stored in a - # local variable. Deref the variable address to get - # the pointer value. - initializerTy = cc.PointerType.getElementType( - initializerTy) - valueOrPtr = self.ifPointerThenLoad(valueOrPtr) - if cc.StateType.isinstance(initializerTy): + if (cc.PointerType.isinstance(value.type) and + cc.StateType.isinstance( + cc.PointerType.getElementType(value.type))): # handle `cudaq.qvector(state)` - statePtr = self.ifNotPointerThenStore(valueOrPtr) + i64Ty = self.getIntegerType() - numQubits = quake.GetNumberOfQubitsOp( - i64Ty, statePtr).result + numQubits = quake.GetNumberOfQubitsOp(i64Ty, + value).result + veqTy = quake.VeqType.get() qubits = quake.AllocaOp(veqTy, size=numQubits).result init = quake.InitializeStateOp( - veqTy, qubits, statePtr).result + veqTy, qubits, value).result + self.pushValue(init) return @@ -2993,8 +3263,7 @@ def check_vector_init(): node) if node.func.attr == "qubit": - if len(self.valueStack) == 1 and IntegerType.isinstance( - self.valueStack[0].type): + if len(node.args) != 0: self.emitFatalError( 'cudaq.qubit() constructor does not take any arguments. To construct a vector of qubits, use `cudaq.qvector(N)`.' ) @@ -3002,34 +3271,130 @@ def check_vector_init(): return if node.func.attr == 'adjoint' or node.func.attr == 'control': - processControlOrAdjoint(node.func.attr) + + # NOTE: We currently generally don't have the means in the + # compiler to handle composition of control and adjoint, since + # control and adjoint are not proper functors (i.e. there is + # no way to obtain a new callable object that is the adjoint + # or controlled version of another callable). + # Since we don't really treat callables as first-class values, + # the first argument to control and adjoint indeed has to be + # a Name object. + + # FIXME: WE SHOULD NOW BE ABLE TO DEAL WITH ADJOINT OF QUALIFIED NAME + if not node.args or not isinstance( + node.args[0], ast.Name): + self.emitFatalError( + f'unsupported argument in call to {node.func.attr} - first argument must be a symbol name', + node) + otherFuncName = node.args[0].id + kwargs = {"is_adj": node.func.attr == 'adjoint'} + processDecorator(otherFuncName) + + if otherFuncName in self.symbolTable: + self.visit(node.args[0]) + fctArg = self.popValue() + if not cc.CallableType.isinstance(fctArg.type): + self.emitFatalError( + f"{otherFuncName} is not a quantum kernel", + node) + functionTy = FunctionType( + cc.CallableType.getFunctionType(fctArg.type)) + inputTys, outputTys = functionTy.inputs, functionTy.results + indirectCallee = [fctArg] + elif otherFuncName in globalRegisteredOperations: + self.emitFatalError( + f"calling cudaq.control or cudaq.adjoint on a globally registered operation is not supported", + node) + elif self.__isUnitaryGate( + otherFuncName) or self.__isMeasurementGate( + otherFuncName): + self.emitFatalError( + f"calling cudaq.control or cudaq.adjoint on a built-in gate is not supported", + node) + else: + self.emitFatalError( + f"{otherFuncName} is not a known quantum kernel - maybe a cudaq.kernel attribute is missing?.", + node) + + numArgs = len(inputTys) + invert_controls = lambda: None + if node.func.attr == 'control': + controls, args = self.__groupValues( + node.args[1:], [(1, -1), (numArgs, numArgs)]) + qvec_or_qubits = ( + all((quake.RefType.isinstance(v.type) + for v in controls)) or + (len(controls) == 1 and + quake.VeqType.isinstance(controls[0].type))) + if not qvec_or_qubits: + self.emitFatalError( + f'invalid argument type for control operand', + node) + # TODO: it would be cleaner to add support for negated control + # qubits to `quake.ApplyOp` + negatedControlQubits = self.controlNegations.copy() + self.controlNegations.clear() + if negatedControlQubits: + invert_controls = lambda: processQuantumOperation( + 'X', [], negatedControlQubits, [], []) + else: + controls, args = self.__groupValues( + node.args[1:], [(0, 0), (numArgs, numArgs)]) + + args = convertArguments(inputTys, args) + if len(outputTys) != 0: + self.emitFatalError( + f'cannot take {node.func.attr} of kernel {otherFuncName} that returns a value', + node) + invert_controls() + quake.ApplyOp([], indirectCallee, controls, args, + **kwargs) + invert_controls() return if node.func.attr == 'apply_noise': - # Pop off all the arguments we need - values = [ - self.popValue() for _ in range(len(self.valueStack)) + + supportedChannels = [ + 'DepolarizationChannel', 'AmplitudeDampingChannel', + 'PhaseFlipChannel', 'BitFlipChannel', + 'PhaseDamping', 'ZError', 'XError', 'YError', + 'Pauli1', 'Pauli2', 'Depolarization1', + 'Depolarization2' ] - # They are in reverse order - values.reverse() - # First one should be the number of Kraus channel parameters - numParamsVal = values[0] - # Shrink the arguments down - values = values[1:] - - # Need to get the number of parameters as an integer - concreteIntAttr = IntegerAttr( - numParamsVal.owner.attributes['value']) - numParams = concreteIntAttr.value - - # Next Value is our generated key for the channel - # Get it and shrink the list - key = values[0] - values = values[1:] - - # Now we know the next `numParams` arguments are - # our Kraus channel parameters - params = values[:numParams] + + # The first argument must be the Kraus channel + numParams, key = 0, None + if (isinstance(node.args[0], ast.Attribute) and + node.args[0].value.id == 'cudaq' and + node.args[0].attr in supportedChannels): + + cudaq_module = importlib.import_module('cudaq') + channel_class = getattr(cudaq_module, + node.args[0].attr) + numParams = channel_class.num_parameters + key = self.getConstantInt(hash(channel_class)) + elif isinstance(node.args[0], ast.Name): + arg = recover_value_of_or_none( + node.args[0].id, None) + if (arg and isinstance(arg, type) and issubclass( + arg, cudaq_runtime.KrausChannel)): + if not hasattr(arg, 'num_parameters'): + self.emitFatalError( + 'apply_noise kraus channels must have `num_parameters` constant class attribute specified.' + ) + numParams = arg.num_parameters + key = self.getConstantInt(hash(arg)) + if key is None: + self.emitFatalError( + "unsupported argument for Kraus channel in apply_noise", + node) + + # This currently requires at least one qubit argument + params, values = self.__groupValues( + node.args[1:], [(numParams, numParams), (1, -1)]) + checkControlAndTargetTypes([], values) + for i, p in enumerate(params): # If we have a F64 value, we want to # store it to a pointer @@ -3040,63 +3405,24 @@ def check_vector_init(): cc.StoreOp(p, alloca) params[i] = alloca - # The remaining arguments are the qubits - asVeq = quake.ConcatOp(self.getVeqType(), - values[numParams:]).result + asVeq = quake.ConcatOp(self.getVeqType(), values).result quake.ApplyNoiseOp(params, [asVeq], key=key) return if node.func.attr == 'compute_action': - # There can only be 2 arguments here. - action = None - compute = None - actionArg = node.args[1] - if isinstance(actionArg, ast.Name): - actionName = actionArg.id - if actionName in self.symbolTable: - action = self.symbolTable[actionName] - else: - self.emitFatalError( - "could not find action lambda / function in the symbol table.", - node) - else: - action = self.popValue() - - computeArg = node.args[0] - if isinstance(computeArg, ast.Name): - computeName = computeArg.id - if computeName in self.symbolTable: - compute = self.symbolTable[computeName] - else: - self.emitFatalError( - "could not find compute lambda / function in the symbol table.", - node) - else: - compute = self.popValue() - + compute, action = self.__groupValues(node.args, [2]) quake.ComputeActionOp(compute, action) return if node.func.attr == 'to_integer': - boolVec = self.popValue() - boolVec = self.ifPointerThenLoad(boolVec) - if not cc.StdvecType.isinstance(boolVec.type): - self.emitFatalError( - "to_integer expects a vector of booleans. Got type {}" - .format(boolVec.type), node) - elemTy = cc.StdvecType.getElementType(boolVec.type) - if elemTy != self.getIntegerType(1): - self.emitFatalError( - "to_integer expects a vector of booleans. Got type {}" - .format(boolVec.type), node) + boolVec = self.__groupValues(node.args, [1]) + args = convertArguments( + [cc.StdvecType.get(self.getIntegerType(1))], + [boolVec]) cudaqConvertToInteger = "__nvqpp_cudaqConvertToInteger" - # Load the intrinsic load_intrinsic(self.module, cudaqConvertToInteger) - # Signature: - # `func.func private @__nvqpp_cudaqConvertToInteger(%arg : !cc.stdvec) -> i64` - resultTy = self.getIntegerType(64) - result = func.CallOp([resultTy], cudaqConvertToInteger, - [boolVec]).result + result = func.CallOp([self.getIntegerType(64)], + cudaqConvertToInteger, args).result self.pushValue(result) return @@ -3125,209 +3451,63 @@ def maybeProposeOpAttrFix(opName, attrName): # We have a `func_name.ctrl` if self.__isSimpleGate(node.func.value.id): if node.func.attr == 'ctrl': - target = self.popValue() - # Should be number of arguments minus one for the controls - controls = [ - self.popValue() for i in range(len(node.args) - 1) - ] - if not controls: - self.emitFatalError( - 'controlled operation requested without any control argument(s).', - node) - negatedControlQubits = getNegatedControlQubits(controls) - - opCtor = getattr( - quake, '{}Op'.format(node.func.value.id.title())) - checkControlAndTargetTypes(controls, [target]) - opCtor([], [], - controls, [target], - negated_qubit_controls=negatedControlQubits) + processQuakeCtor(node.func.value.id.title(), + node.args, + isCtrl=True, + isAdj=False) return if node.func.attr == 'adj': - target = self.popValue() - checkControlAndTargetTypes([], [target]) - opCtor = getattr( - quake, '{}Op'.format(node.func.value.id.title())) - if quake.VeqType.isinstance(target.type): - - def bodyBuilder(iterVal): - q = quake.ExtractRefOp(self.getRefType(), - target, - -1, - index=iterVal).result - opCtor([], [], [], [q], is_adj=True) - - veqSize = quake.VeqSizeOp(self.getIntegerType(), - target).result - self.createInvariantForLoop(veqSize, bodyBuilder) - return - elif quake.RefType.isinstance(target.type): - opCtor([], [], [], [target], is_adj=True) - return - else: - self.emitFatalError( - 'adj quantum operation on incorrect type {}.'. - format(target.type), node) - - self.emitFatalError( - f'Unknown attribute on quantum operation {node.func.value.id} ({node.func.attr}). {maybeProposeOpAttrFix(node.func.value.id, node.func.attr)}' - ) - - # We have a `func_name.ctrl` - if node.func.value.id == 'swap' and node.func.attr == 'ctrl': - targetB = self.popValue() - targetA = self.popValue() - controls = [ - self.popValue() for i in range(len(self.valueStack)) - ] - if not controls: - self.emitFatalError( - 'controlled operation requested without any control argument(s).', - node) - negatedControlQubits = getNegatedControlQubits(controls) - opCtor = getattr(quake, - '{}Op'.format(node.func.value.id.title())) - checkControlAndTargetTypes(controls, [targetA, targetB]) - opCtor([], [], - controls, [targetA, targetB], - negated_qubit_controls=negatedControlQubits) - return + processQuakeCtor(node.func.value.id.title(), + node.args, + isCtrl=False, + isAdj=True) + return + self.emitFatalError( + f'Unknown attribute on quantum operation {node.func.value.id} ({node.func.attr}). {maybeProposeOpAttrFix(node.func.value.id, node.func.attr)}' + ) if self.__isRotationGate(node.func.value.id): if node.func.attr == 'ctrl': - target = self.popValue() - controls = [ - self.popValue() for i in range(len(self.valueStack)) - ] - param = controls[-1] - controls = controls[:-1] - if not controls: - self.emitFatalError( - 'controlled operation requested without any control argument(s).', - node) - negatedControlQubits = getNegatedControlQubits(controls) - if IntegerType.isinstance(param.type): - param = arith.SIToFPOp(self.getFloatType(), - param).result - elif not F64Type.isinstance(param.type): - self.emitFatalError( - 'rotational parameter must be a float, or int.', - node) - opCtor = getattr( - quake, '{}Op'.format(node.func.value.id.title())) - checkControlAndTargetTypes(controls, [target]) - opCtor([], [param], - controls, [target], - negated_qubit_controls=negatedControlQubits) + processQuakeCtor(node.func.value.id.title(), + node.args, + isCtrl=True, + isAdj=False, + numParams=1) return - if node.func.attr == 'adj': - target = self.popValue() - param = self.popValue() - if IntegerType.isinstance(param.type): - param = arith.SIToFPOp(self.getFloatType(), - param).result - elif not F64Type.isinstance(param.type): - self.emitFatalError( - 'rotational parameter must be a float, or int.', - node) - opCtor = getattr( - quake, '{}Op'.format(node.func.value.id.title())) - checkControlAndTargetTypes([], [target]) - if quake.VeqType.isinstance(target.type): - - def bodyBuilder(iterVal): - q = quake.ExtractRefOp(self.getRefType(), - target, - -1, - index=iterVal).result - opCtor([], [param], [], [q], is_adj=True) - - veqSize = quake.VeqSizeOp(self.getIntegerType(), - target).result - self.createInvariantForLoop(veqSize, bodyBuilder) - return - elif quake.RefType.isinstance(target.type): - opCtor([], [param], [], [target], is_adj=True) - return - else: - self.emitFatalError( - 'adj quantum operation on incorrect type {}.'. - format(target.type), node) - + processQuakeCtor(node.func.value.id.title(), + node.args, + isCtrl=False, + isAdj=True, + numParams=1) + return self.emitFatalError( - f'Unknown attribute on quantum operation ' - f'{node.func.value.id} ({node.func.attr}). ' - f'{maybeProposeOpAttrFix(node.func.value.id, node.func.attr)}' + f'Unknown attribute on quantum operation {node.func.value.id} ({node.func.attr}). {maybeProposeOpAttrFix(node.func.value.id, node.func.attr)}' ) - if node.func.value.id == 'u3': - numValues = len(self.valueStack) - target = self.popValue() - other_args = [self.popValue() for _ in range(numValues - 1)] - - opCtor = getattr(quake, - '{}Op'.format(node.func.value.id.title())) + if node.func.value.id == 'swap' and node.func.attr == 'ctrl': + processQuakeCtor(node.func.value.id.title(), + node.args, + isCtrl=True, + isAdj=False, + numTargets=2) + return + if node.func.value.id == 'u3': if node.func.attr == 'ctrl': - controls = other_args[:-3] - if not controls: - self.emitFatalError( - 'controlled operation requested without any ' - 'control argument(s).', node) - negatedControlQubits = getNegatedControlQubits(controls) - params = other_args[-3:] - params.reverse() - for idx, val in enumerate(params): - if IntegerType.isinstance(val.type): - params[idx] = arith.SIToFPOp( - self.getFloatType(), val).result - elif not F64Type.isinstance(val.type): - self.emitFatalError( - 'rotational parameter must be a float, or ' - 'int.', node) - - checkControlAndTargetTypes(controls, [target]) - opCtor([], - params, - controls, [target], - negated_qubit_controls=negatedControlQubits) + processQuakeCtor(node.func.value.id.title(), + node.args, + isCtrl=True, + isAdj=False, + numParams=3) return - if node.func.attr == 'adj': - params = other_args - params.reverse() - for idx, val in enumerate(params): - if IntegerType.isinstance(val.type): - params[idx] = arith.SIToFPOp( - self.getFloatType(), val).result - elif not F64Type.isinstance(val.type): - self.emitFatalError( - 'rotational parameter must be a float, or ' - 'int.', node) - - checkControlAndTargetTypes([], [target]) - if quake.VeqType.isinstance(target.type): - - def bodyBuilder(iterVal): - q = quake.ExtractRefOp(self.getRefType(), - target, - -1, - index=iterVal).result - opCtor([], params, [], [q], is_adj=True) - - veqSize = quake.VeqSizeOp(self.getIntegerType(), - target).result - self.createInvariantForLoop(veqSize, bodyBuilder) - return - elif quake.RefType.isinstance(target.type): - opCtor([], params, [], [target], is_adj=True) - return - else: - self.emitFatalError( - 'adj quantum operation on incorrect type {}.'. - format(target.type), node) - + processQuakeCtor(node.func.value.id.title(), + node.args, + isCtrl=False, + isAdj=True, + numParams=3) + return self.emitFatalError( f'unknown attribute {node.func.attr} on u3', node) @@ -3335,23 +3515,12 @@ def bodyBuilder(iterVal): if node.func.value.id in globalRegisteredOperations: if not node.func.attr == 'ctrl' and not node.func.attr == 'adj': self.emitFatalError( - f'Unknown attribute on custom operation ' - f'{node.func.value.id} ({node.func.attr}).') + f'Unknown attribute on custom operation {node.func.value.id} ({node.func.attr}).' + ) unitary = globalRegisteredOperations[node.func.value.id] numTargets = int(np.log2(np.sqrt(unitary.size))) - numValues = len(self.valueStack) - targets = [self.popValue() for _ in range(numTargets)] - targets.reverse() - - for i, t in enumerate(targets): - if not quake.RefType.isinstance(t.type): - self.emitFatalError( - f'invalid target operand {i}, broadcasting is not ' - f'supported on custom operations.') - globalName = f'{nvqppPrefix}{node.func.value.id}_generator_{numTargets}.rodata' - currentST = SymbolTable(self.module.operation) if not globalName in currentST: with InsertionPoint(self.module.body): @@ -3359,24 +3528,26 @@ def bodyBuilder(iterVal): self.loc, self.module, globalName, unitary.tolist()) - negatedControlQubits = None - controls = [] - is_adj = False - if node.func.attr == 'ctrl': - controls = [ - self.popValue() - for _ in range(numValues - numTargets) - ] - if not controls: - self.emitFatalError( - 'controlled operation requested without any ' - 'control argument(s).', node) + controls, targets = self.__groupValues( + node.args, [(1, -1), (numTargets, numTargets)]) negatedControlQubits = getNegatedControlQubits(controls) + is_adj = False if node.func.attr == 'adj': + controls, targets = self.__groupValues( + node.args, [0, (numTargets, numTargets)]) + negatedControlQubits = None is_adj = True checkControlAndTargetTypes(controls, targets) + # The check above makes sure targets are either a list + # of individual qubits, or a single `qvector`. Since + # a `qvector` is not allowed, we check this here: + if not quake.RefType.isinstance(targets[0].type): + self.emitFatalError( + f'invalid target operand - target must not be a qvector' + ) + quake.CustomUnitarySymbolOp( [], generator=FlatSymbolRefAttr.get(globalName), @@ -3387,44 +3558,7 @@ def bodyBuilder(iterVal): negated_qubit_controls=negatedControlQubits) return - # See if this is a path-qualified reference to a kernel. - def getCallFullName(n): - parts = [] - while isinstance(n, ast.Attribute): - parts.append(n.attr) - n = n.value - if isinstance(n, ast.Name): - parts.append(n.id) - else: - return None - return ".".join(reversed(parts)) - - fullName = getCallFullName(node.func) - if fullName is not None: - dec = resolve_qualified_symbol(fullName) - if dec is not None: - callee, fType = processDecoratorCall(dec, fullName) - declArgs = (dec.firstLiftedPos if dec.firstLiftedPos - is not None else len(fType.inputs)) - if declArgs != len(node.args): - funcName = (node.func.id if hasattr(node.func, 'id') - else node.func.attr) - self.emitFatalError( - f"invalid number of arguments passed to callable " - f"{funcName} ({len(node.args)} vs " - f"required {declArgs})", node) - [self.visit(arg) for arg in node.args] - values = [self.popValue() for _ in node.args] - values.reverse() - values = [self.ifPointerThenLoad(v) for v in values] - call = cc.CallCallableOp(fType.results, callee, values) - sa = StringAttr.get(name) - call.attributes.__setitem__('symbol', sa) - for r in call.results: - self.pushValue(r) - return - - self.emitFatalError("unknown function call", node) + self.emitFatalError(f"unknown function call", node) def visit_ListComp(self, node): """ @@ -3438,42 +3572,18 @@ def visit_ListComp(self, node): "CUDA-Q only supports single generators for list comprehension.", node) - # Let's handle the following `listVar` types - # ` %9 = cc.alloca !cc.array x 2> -> ptr x N>` - # or - # ` %3 = cc.alloca T[%2 : i64] -> ptr>` self.visit(node.generators[0].iter) - - if len(self.valueStack) == 1: - iterable = self.ifPointerThenLoad(self.popValue()) - iterableSize = None - if cc.StdvecType.isinstance(iterable.type): - iterableSize = cc.StdvecSizeOp(self.getIntegerType(), - iterable).result - iterTy = cc.StdvecType.getElementType(iterable.type) - iterArrPtrTy = cc.PointerType.get(cc.ArrayType.get(iterTy)) - iterable = cc.StdvecDataOp(iterArrPtrTy, iterable).result - elif quake.VeqType.isinstance(iterable.type): - iterableSize = quake.VeqSizeOp(self.getIntegerType(), - iterable).result - iterTy = quake.RefType.get() - if iterableSize is None: - self.emitFatalError( - "CUDA-Q only supports list comprehension on ranges and arrays", - node) - elif len(self.valueStack) == 2: - iterableSize = self.popValue() - iterable = self.popValue() - if not cc.PointerType.isinstance(iterable.type): - self.emitFatalError( - "CUDA-Q only supports list comprehension on ranges and arrays", - node) - iterArrTy = cc.PointerType.getElementType(iterable.type) - if not cc.ArrayType.isinstance(iterArrTy): - self.emitFatalError( - "CUDA-Q only supports list comprehension on ranges and arrays", - node) - iterTy = cc.ArrayType.getElementType(iterArrTy) + iterable = self.popValue() + if cc.StdvecType.isinstance(iterable.type): + iterableSize = cc.StdvecSizeOp(self.getIntegerType(), + iterable).result + iterTy = cc.StdvecType.getElementType(iterable.type) + iterArrPtrTy = cc.PointerType.get(cc.ArrayType.get(iterTy)) + iterable = cc.StdvecDataOp(iterArrPtrTy, iterable).result + elif quake.VeqType.isinstance(iterable.type): + iterableSize = quake.VeqSizeOp(self.getIntegerType(), + iterable).result + iterTy = quake.RefType.get() else: self.emitFatalError( "CUDA-Q only supports list comprehension on ranges and arrays", @@ -3493,11 +3603,13 @@ def process_void_list(): forNode.target = node.generators[0].target forNode.body = [node.elt] forNode.orelse = [] + forNode.lineno = node.lineno + # this loop could be marked as invariant if we didn't use `visit_For` self.visit_For(forNode) target_types = {} - def get_item_type(target, targetType): + def get_target_type(target, targetType): if isinstance(target, ast.Name): if target.id in target_types: self.emitFatalError( @@ -3514,12 +3626,12 @@ def get_item_type(target, targetType): self.emitFatalError( "shape mismatch in tuple deconstruction", node) for i, ty in enumerate(types): - get_item_type(target.elts[i], ty) + get_target_type(target.elts[i], ty) else: self.emitFatalError( "unsupported target in tuple deconstruction", node) - get_item_type(node.generators[0].target, iterTy) + get_target_type(node.generators[0].target, iterTy) # We need to know the element type of the list we are creating. # Unfortunately, dynamic typing makes this a bit painful. @@ -3539,7 +3651,12 @@ def get_item_type(pyval): elts = [get_item_type(v) for v in pyval.elts] if None in elts: return None - return cc.PointerType.get(cc.StructType.getNamed("tuple", elts)) + structTy = mlirTryCreateStructType(elts, context=self.ctx) + if not structTy: + # we return anything here since, or rather to make sure that, + # a comprehensive error is generated when `elt` is walked below. + return cc.StructType.getNamed("tuple", elts) + return structTy elif (isinstance(pyval, ast.Subscript) and IntegerType.isinstance(get_item_type(pyval.slice))): parentType = get_item_type(pyval.value) @@ -3571,13 +3688,8 @@ def get_item_type(pyval): return cc.StdvecType.get(base_elTy) elif isinstance(pyval, ast.Call): if isinstance(pyval.func, ast.Name): - # supported for calls but not here: - # 'range', 'enumerate', 'list' - if pyval.func.id == 'len' or pyval.func.id == 'int': - return IntegerType.get_signless(64) - elif pyval.func.id == 'complex': - return self.getComplexType() - elif self.__isUnitaryGate( + # supported for calls but not here: 'range', 'enumerate' + if self.__isUnitaryGate( pyval.func.id) or pyval.func.id == 'reset': process_void_list() return None @@ -3630,25 +3742,31 @@ def get_item_type(pyval): elif pyval.func.id in globalRegisteredTypes.classes: _, annotations = globalRegisteredTypes.getClassAttributes( pyval.func.id) - structTys = [ + elts = [ mlirTypeFromPyType(v, self.ctx) for _, v in annotations.items() ] - # no need to do much verification on the validity of the type here - - # this will be handled when we build the body - isStruq = any( - (self.isQuantumType(t) for t in structTys)) - if isStruq: - return quake.StruqType.getNamed( - pyval.func.id, structTys) - else: - return cc.PointerType.get( - cc.StructType.getNamed(pyval.func.id, - structTys)) - elif (isinstance(pyval.func, ast.Attribute) and - (pyval.func.attr == 'ctrl' or pyval.func.attr == 'adj')): - process_void_list() - return None + structTy = mlirTryCreateStructType(elts, + pyval.func.id, + context=self.ctx) + if not structTy: + # we return anything here since, or rather to make sure that, + # a comprehensive error is generated when `elt` is walked below. + return cc.StructType.getNamed(pyval.func.id, elts) + return structTy + elif pyval.func.id == 'len' or pyval.func.id == 'int': + return IntegerType.get_signless(64) + elif pyval.func.id == 'complex': + return self.getComplexType() + elif pyval.func.id == 'list' and len(pyval.args) == 1: + return get_item_type(pyval.args[0]) + elif isinstance(pyval.func, ast.Attribute): + if (pyval.func.attr == 'copy' and + 'dtype' not in pyval.keywords): + return get_item_type(pyval.func.value) + if pyval.func.attr == 'ctrl' or pyval.func.attr == 'adj': + process_void_list() + return None self.emitFatalError("unsupported call in list comprehension", node) elif isinstance(pyval, ast.Compare): @@ -3682,8 +3800,7 @@ def get_item_type(pyval): return resultVecTy = cc.StdvecType.get(listElemTy) - isBool = listElemTy == self.getIntegerType(1) - if isBool: + if listElemTy == self.getIntegerType(1): listElemTy = self.getIntegerType(8) listTy = cc.ArrayType.get(listElemTy) listValue = cc.AllocaOp(cc.PointerType.get(listTy), @@ -3699,31 +3816,38 @@ def get_item_type(pyval): def bodyBuilder(iterVar): self.symbolTable.pushScope() if quake.VeqType.isinstance(iterable.type): - loadedEle = quake.ExtractRefOp(iterTy, - iterable, - -1, - index=iterVar).result + iterVal = quake.ExtractRefOp(iterTy, + iterable, + -1, + index=iterVar).result else: eleAddr = cc.ComputePtrOp( cc.PointerType.get(iterTy), iterable, [iterVar], DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx)) - loadedEle = cc.LoadOp(eleAddr).result - self.__deconstructAssignment(node.generators[0].target, loadedEle) + iterVal = cc.LoadOp(eleAddr).result + + # We don't do support anything within list comprehensions that would + # require being careful about assigning references, so simply + # adding them to the symbol table is enough for list comprehension. + self.__deconstructAssignment(node.generators[0].target, iterVal) self.visit(node.elt) - result = self.popValue() + element = self.popValue() + # We do need to be careful, however, about validating the list elements. + self.__validate_container_entry(element, node.elt) + listValueAddr = cc.ComputePtrOp( cc.PointerType.get(listElemTy), listValue, [iterVar], DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx)) - - if isBool: - result = self.changeOperandToType(self.getIntegerType(8), - result) - cc.StoreOp(result, listValueAddr) + element = self.changeOperandToType(listElemTy, + element, + allowDemotion=False) + cc.StoreOp(element, listValueAddr) self.symbolTable.popScope() - self.createInvariantForLoop(iterableSize, bodyBuilder) - self.pushValue( - cc.StdvecInitOp(resultVecTy, listValue, length=iterableSize).result) + self.createInvariantForLoop(bodyBuilder, iterableSize) + res = cc.StdvecInitOp(resultVecTy, listValue, + length=iterableSize).result + self.pushValue(res) return def visit_List(self, node): @@ -3752,14 +3876,14 @@ def visit_List(self, node): listElementValues.append(evalElem) else: self.visit(element) - # We do not store lists of pointers - evalElem = self.ifPointerThenLoad(self.popValue()) + evalElem = self.popValue() if self.isQuantumType( evalElem.type) and not quake.RefType.isinstance( evalElem.type): self.emitFatalError( "list must not contain a qvector or quantum struct - use `*` operator to unpack qvectors", node) + self.__validate_container_entry(evalElem, element) listElementValues.append(evalElem) numQuantumTs = sum( @@ -3800,9 +3924,7 @@ def visit_List(self, node): ] # Turn this List into a StdVec - self.pushValue( - self.__createStdvecWithKnownValues(len(node.elts), - listElementValues)) + self.pushValue(self.__createStdvecWithKnownValues(listElementValues)) def visit_Constant(self, node): """ @@ -3874,11 +3996,16 @@ def fix_negative_idx(idx, get_size): # handle complex slice, VAR[lower:upper] if isinstance(node.slice, ast.Slice): self.debug_msg(lambda: f'[(Inline) Visit Slice]', node.slice) + if self.pushPointerValue: + self.emitFatalError( + "slicing a list or qvector does not produce a modifiable value", + node) + self.visit(node.value) - var = self.ifPointerThenLoad(self.popValue()) + var = self.popValue() vectorSize = get_size(var) - lowerVal, upperVal, stepVal = (None, None, None) + lowerVal, upperVal = None, None if node.slice.lower is not None: self.visit(node.slice.lower) lowerVal = fix_negative_idx(self.popValue(), lambda: vectorSize) @@ -3936,16 +4063,41 @@ def fix_negative_idx(idx, get_size): return - self.generic_visit(node) - - assert len(self.valueStack) > 1 + # Only variable names, subscripts and attributes can + # produce modifiable values. Anything else produces an + # immutable value. We make sure the visit gets processed + # such that the rest of the code can give a proper error. + value_root = node.value + while (isinstance(value_root, ast.Subscript) or + isinstance(value_root, ast.Attribute)): + value_root = value_root.value + if self.pushPointerValue and not isinstance(value_root, ast.Name): + self.pushPointerValue = False + self.visit(node.value) + var = self.popValue() + self.pushPointerValue = True + else: + # `isSubscriptRoot` is only used/needed to enable + # modification of items in lists and dataclasses + # contained in a tuple + subscriptRoot = self.isSubscriptRoot + self.isSubscriptRoot = True + self.visit(node.value) + var = self.popValue() + self.isSubscriptRoot = subscriptRoot - # get the last name, should be name of var being subscripted - var = self.ifPointerThenLoad(self.popValue()) + pushPtr = self.pushPointerValue + self.pushPointerValue = False + self.visit(node.slice) idx = self.popValue() + self.pushPointerValue = pushPtr - # Support `VAR[-1]` as the last element of `VAR` if quake.VeqType.isinstance(var.type): + if self.pushPointerValue: + self.emitFatalError( + "indexing into a qvector does not produce a modifyable value", + node) + if not IntegerType.isinstance(idx.type): self.emitFatalError( f'invalid index variable type used for qvector extraction ({idx.type})', @@ -3956,6 +4108,44 @@ def fix_negative_idx(idx, get_size): index=idx).result) return + if cc.PointerType.isinstance(var.type): + # We should only ever get a pointer if we + # explicitly asked for it. + assert self.pushPointerValue + varType = cc.PointerType.getElementType(var.type) + if cc.StdvecType.isinstance(varType): + # We can get a pointer to a vector (only) if we + # are updating a struct item that is a pointer. + if self.pushPointerValue: + # In this case, it should be save to load + # the vector, since the underlying data is + # not loaded. + var = cc.LoadOp(var).result + + if cc.StructType.isinstance(varType): + structName = cc.StructType.getName(varType) + if not self.isSubscriptRoot and structName == 'tuple': + self.emitFatalError("tuple value cannot be modified", node) + if not isinstance(node.slice, ast.Constant): + if self.pushPointerValue: + if structName == 'tuple': + self.emitFatalError( + "tuple value cannot be modified via non-constant subscript", + node) + self.emitFatalError( + f"{structName} value cannot be modified via non-constant subscript - use attribute access instead", + node) + + idxVal = node.slice.value + structTys = cc.StructType.getTypes(varType) + eleAddr = cc.ComputePtrOp(cc.PointerType.get(structTys[idxVal]), + var, [], + DenseI32ArrayAttr.get([idxVal + ])).result + if self.pushPointerValue: + self.pushValue(eleAddr) + return + if cc.StdvecType.isinstance(var.type): idx = fix_negative_idx(idx, lambda: get_size(var)) eleTy = cc.StdvecType.getElementType(var.type) @@ -3970,7 +4160,7 @@ def fix_negative_idx(idx, get_size): elePtrTy, vecPtr, [idx], DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx)).result - if self.subscriptPushPointerValue: + if self.pushPointerValue: self.pushValue(eleAddr) return val = cc.LoadOp(eleAddr).result @@ -3979,24 +4169,6 @@ def fix_negative_idx(idx, get_size): self.pushValue(val) return - if cc.PointerType.isinstance(var.type): - ptrEleTy = cc.PointerType.getElementType(var.type) - # Return the pointer if someone asked for it - if self.subscriptPushPointerValue: - self.pushValue(var) - return - if cc.ArrayType.isinstance(ptrEleTy): - # Here we want subscript on `ptr>` - arrayEleTy = cc.ArrayType.getElementType(ptrEleTy) - ptrEleTy = cc.PointerType.get(arrayEleTy) - casted = cc.CastOp(ptrEleTy, var).result - eleAddr = cc.ComputePtrOp( - ptrEleTy, casted, [idx], - DenseI32ArrayAttr.get([kDynamicPtrIndex], - context=self.ctx)).result - self.pushValue(cc.LoadOp(eleAddr).result) - return - def get_idx_value(upper_bound): idxValue = None if hasattr(idx.owner, 'opview') and isinstance( @@ -4018,33 +4190,29 @@ def get_idx_value(upper_bound): # We allow subscripts into `Structs`, but only if we don't need a pointer # (i.e. no updating of Tuples). if cc.StructType.isinstance(var.type): - if self.subscriptPushPointerValue: + if self.pushPointerValue: + structName = cc.StructType.getName(var.type) + if structName == 'tuple': + self.emitFatalError("tuple value cannot be modified", node) self.emitFatalError( - "indexing into tuple or dataclass must not modify value", + f"{structName} value cannot be modified - use `.copy(deep)` to create a new value that can be modified", node) - # Handle the case where we have a tuple member extraction, memory semantics memberTys = cc.StructType.getTypes(var.type) idxValue = get_idx_value(len(memberTys)) - structPtr = self.ifNotPointerThenStore(var) - eleAddr = cc.ComputePtrOp( - cc.PointerType.get(memberTys[idxValue]), structPtr, [], - DenseI32ArrayAttr.get([idxValue], context=self.ctx)).result + member = cc.ExtractValueOp(memberTys[idxValue], var, [], + DenseI32ArrayAttr.get([idxValue])).result - # Return the pointer if someone asked for it - if self.subscriptPushPointerValue: - self.pushValue(eleAddr) - return - self.pushValue(cc.LoadOp(eleAddr).result) + self.pushValue(member) return - # Let's allow subscripts into `Struqs`, but only if we don't need a pointer + # We allow subscripts into `Struqs`, but only if we don't need a pointer # (i.e. no updating of `Struqs`). if quake.StruqType.isinstance(var.type): - if self.subscriptPushPointerValue: + if self.pushPointerValue: self.emitFatalError( - "indexing into quantum tuple or dataclass must not modify value", + "indexing into quantum tuple or dataclass does not produce a modifiable value", node) memberTys = quake.StruqType.getTypes(var.type) @@ -4070,6 +4238,8 @@ def actually_visit_For(self, node): `veq` type, the `stdvec` type, and the result of range() and enumerate(). """ + + getValues = None if isinstance(node.iter, ast.Call): self.debug_msg(lambda: f'[(Inline) Visit Call]', node.iter) @@ -4077,195 +4247,101 @@ def actually_visit_For(self, node): # by just building a for loop with N as the upper value, # no need to generate an array from the `range` call. if node.iter.func.id == 'range': - # This is a range(N) for loop, we just need - # the upper bound N for this loop - [self.visit(arg) for arg in node.iter.args] + iterable = None startVal, endVal, stepVal, isDecrementing = self.__processRangeLoopIterationBounds( node.iter.args) - - if not isinstance(node.target, ast.Name): - self.emitFatalError( - "iteration variable must be a single name", node) - - def bodyBuilder(iterVar): - self.symbolTable.pushScope() - self.symbolTable.add(node.target.id, iterVar) - [self.visit(b) for b in node.body] - self.symbolTable.popScope() - - self.createInvariantForLoop(endVal, - bodyBuilder, - startVal=startVal, - stepVal=stepVal, - isDecrementing=isDecrementing, - elseStmts=node.orelse) - - return + getValues = lambda iterVar: iterVar # We can simplify `for i,j in enumerate(L)` MLIR code immensely # by just building a for loop over the iterable object L and using # the index into that iterable and the element. - if node.iter.func.id == 'enumerate': - [self.visit(arg) for arg in node.iter.args] - if len(self.valueStack) == 2: - iterable = self.popValue() - self.popValue() - else: - assert len(self.valueStack) == 1 - iterable = self.popValue() - iterable = self.ifPointerThenLoad(iterable) - totalSize = None - extractFunctor = None - - beEfficient = False - if quake.VeqType.isinstance(iterable.type): - totalSize = quake.VeqSizeOp(self.getIntegerType(), - iterable).result - - def functor(seq, idx): - q = quake.ExtractRefOp(self.getRefType(), - seq, - -1, - index=idx).result - return [idx, q] - - extractFunctor = functor - beEfficient = True - elif cc.StdvecType.isinstance(iterable.type): - totalSize = cc.StdvecSizeOp(self.getIntegerType(), - iterable).result - - def functor(seq, idx): - vecTy = cc.StdvecType.getElementType(seq.type) - dataTy = cc.PointerType.get(vecTy) - arrTy = vecTy - if not cc.ArrayType.isinstance(arrTy): - arrTy = cc.ArrayType.get(vecTy) - dataArrTy = cc.PointerType.get(arrTy) - data = cc.StdvecDataOp(dataArrTy, seq).result - v = cc.ComputePtrOp( - dataTy, data, [idx], - DenseI32ArrayAttr.get([kDynamicPtrIndex], - context=self.ctx)).result - return [idx, v] - - extractFunctor = functor - beEfficient = True - - if beEfficient: + elif node.iter.func.id == 'enumerate': + if len(node.iter.args) != 1: + self.emitFatalError( + "invalid number of arguments to enumerate - expecting 1 argument", + node) - if (not isinstance(node.target, ast.Tuple) or - len(node.target.elts) != 2): - self.emitFatalError( - "iteration variable must be a tuple of two items", - node) + self.visit(node.iter.args[0]) + iterable = self.popValue() + getValues = lambda iterVar, v: (iterVar, v) - def bodyBuilder(iterVar): - self.symbolTable.pushScope() - values = extractFunctor(iterable, iterVar) - assert (len(values) == 2) - for i, v in enumerate(values): - self.__deconstructAssignment(node.target.elts[i], v) - [self.visit(b) for b in node.body] - self.symbolTable.popScope() - - self.createInvariantForLoop(totalSize, - bodyBuilder, - elseStmts=node.orelse) - return + if not getValues: + self.visit(node.iter) + iterable = self.popValue() - self.visit(node.iter) - assert len(self.valueStack) > 0 and len(self.valueStack) < 3 + if iterable: - totalSize = None - iterable = None - extractFunctor = None + isDecrementing = False + startVal = self.getConstantInt(0) + stepVal = self.getConstantInt(1) + relevantVals = getValues or (lambda iterVar, v: v) - # It could be that its the only value we have, - # in which case we know we have for var in iterable, - # but we could also have another value on the stack, - # the total size of the iterable, produced by range() / enumerate() - if len(self.valueStack) == 1: - # Get the iterable from the stack - iterable = self.ifPointerThenLoad(self.popValue()) # we currently handle `veq` and `stdvec` types if quake.VeqType.isinstance(iterable.type): size = quake.VeqType.getSize(iterable.type) if quake.VeqType.hasSpecifiedSize(iterable.type): - totalSize = self.getConstantInt(size) + endVal = self.getConstantInt(size) else: - totalSize = quake.VeqSizeOp(self.getIntegerType(64), - iterable).result + endVal = quake.VeqSizeOp(self.getIntegerType(), + iterable).result - def functor(iter, idx): - return quake.ExtractRefOp(self.getRefType(), - iter, - -1, - index=idx).result + def loadElement(iterVar): + val = quake.ExtractRefOp(self.getRefType(), + iterable, + -1, + index=iterVar).result + return relevantVals(iterVar, val) + + getValues = loadElement - extractFunctor = functor elif cc.StdvecType.isinstance(iterable.type): iterEleTy = cc.StdvecType.getElementType(iterable.type) - totalSize = cc.StdvecSizeOp(self.getIntegerType(), - iterable).result isBool = iterEleTy == self.getIntegerType(1) if isBool: iterEleTy = self.getIntegerType(8) + endVal = cc.StdvecSizeOp(self.getIntegerType(), iterable).result - def functor(iter, idxVal): + def loadElement(iterVar): elePtrTy = cc.PointerType.get(iterEleTy) arrTy = cc.ArrayType.get(iterEleTy) ptrArrTy = cc.PointerType.get(arrTy) - vecPtr = cc.StdvecDataOp(ptrArrTy, iter).result + vecPtr = cc.StdvecDataOp(ptrArrTy, iterable).result eleAddr = cc.ComputePtrOp( - elePtrTy, vecPtr, [idxVal], + elePtrTy, vecPtr, [iterVar], DenseI32ArrayAttr.get([kDynamicPtrIndex], context=self.ctx)).result - result = cc.LoadOp(eleAddr).result + val = cc.LoadOp(eleAddr).result if isBool: - result = self.changeOperandToType( - self.getIntegerType(1), result) - return result + val = self.changeOperandToType(self.getIntegerType(1), + val) + return relevantVals(iterVar, val) - extractFunctor = functor + getValues = loadElement else: self.emitFatalError('{} iterable type not supported.', node) - else: - # In this case, we are coming from range() or enumerate(), - # and the iterable is a cc.array and the total size of the - # array is on the stack, pop it here - totalSize = self.popValue() - # Get the iterable from the stack - iterable = self.popValue() - - # Double check our types are right - assert cc.PointerType.isinstance(iterable.type) - arrayType = cc.PointerType.getElementType(iterable.type) - assert cc.ArrayType.isinstance(arrayType) - elementType = cc.ArrayType.getElementType(arrayType) - - def functor(iter, idx): - eleAddr = cc.ComputePtrOp( - cc.PointerType.get(elementType), iter, [idx], - DenseI32ArrayAttr.get([kDynamicPtrIndex], - context=self.ctx)).result - return cc.LoadOp(eleAddr).result - - extractFunctor = functor - - def bodyBuilder(iterVar): + def blockBuilder(iterVar, stmts): self.symbolTable.pushScope() - # we set the extract functor above, use it here - value = extractFunctor(iterable, iterVar) - self.__deconstructAssignment(node.target, value) - [self.visit(b) for b in node.body] + values = getValues(iterVar) + # We need to create proper assignments to the loop + # iteration variable(s) to have consistent behavior. + assignNode = ast.Assign() + assignNode.targets = [node.target] + assignNode.value = values + assignNode.lineno = node.lineno + self.visit(assignNode) + [self.visit(b) for b in stmts] self.symbolTable.popScope() - self.createInvariantForLoop(totalSize, - bodyBuilder, - elseStmts=node.orelse) + self.createMonotonicForLoop( + lambda iterVar: blockBuilder(iterVar, node.body), + startVal=startVal, + stepVal=stepVal, + endVal=endVal, + isDecrementing=isDecrementing, + orElseBuilder=None if not node.orelse else + lambda iterVar: blockBuilder(iterVar, node.orelse)) def visit_While(self, node): self.controlHeight = self.controlHeight + 1 @@ -4276,94 +4352,75 @@ def actually_visit_While(self, node): """ Convert Python while statements into the equivalent CC `LoopOp`. """ - loop = cc.LoopOp([], [], BoolAttr.get(False)) - whileBlock = Block.create_at_start(loop.whileRegion, []) - with InsertionPoint(whileBlock): + + def evalCond(args): # BUG you cannot print MLIR values while building the cc `LoopOp` while region. # verify will get called, no terminator yet, CCOps.cpp:520 v = self.verbose self.verbose = False self.visit(node.test) - condition = self.popValue() - if self.getIntegerType(1) != condition.type: - # not equal to 0, then compare with 1 - condPred = IntegerAttr.get(self.getIntegerType(), 1) - condition = arith.CmpIOp(condPred, condition, - self.getConstantInt(0)).result - cc.ConditionOp(condition, []) + condition = self.__arithmetic_to_bool(self.popValue()) self.verbose = v + return condition - bodyBlock = Block.create_at_start(loop.bodyRegion, []) - with InsertionPoint(bodyBlock): - self.symbolTable.pushScope() - self.pushForBodyStack([]) - [self.visit(b) for b in node.body] - if not self.hasTerminator(bodyBlock): - cc.ContinueOp([]) - self.popForBodyStack() - self.symbolTable.popScope() - - stepBlock = Block.create_at_start(loop.stepRegion, []) - with InsertionPoint(stepBlock): - cc.ContinueOp([]) - - if node.orelse: - elseBlock = Block.create_at_start(loop.elseRegion, []) - with InsertionPoint(elseBlock): - self.symbolTable.pushScope() - for stmt in node.orelse: - self.visit(stmt) - if not self.hasTerminator(elseBlock): - cc.ContinueOp(elseBlock.arguments) - self.symbolTable.popScope() + self.createForLoop([], lambda _: [self.visit(b) for b in node.body], [], + evalCond, lambda _: [], None if not node.orelse else + lambda _: [self.visit(stmt) for stmt in node.orelse]) def visit_BoolOp(self, node): """ Convert boolean operations into equivalent MLIR operations using the Arith Dialect. """ - shortCircuitWhenTrue = isinstance(node.op, ast.Or) if isinstance(node.op, ast.And) or isinstance(node.op, ast.Or): + # Visit the LHS and pop the value # Note we want any `mz(q)` calls to push their # result value to the stack, so we set a non-None # variable name here. self.currentAssignVariableName = '' self.visit(node.values[0]) - lhs = self.popValue() - zero = self.getConstantInt(0, IntegerType(lhs.type).width) + cond = self.__arithmetic_to_bool(self.popValue()) + + def process_boolean_op(prior, values): - cond = arith.CmpIOp( - self.getIntegerAttr(self.getIntegerType(), - 1 if shortCircuitWhenTrue else 0), lhs, - zero).result + if len(values) == 0: + return prior - ifOp = cc.IfOp([cond.type], cond, []) - thenBlock = Block.create_at_start(ifOp.thenRegion, []) - with InsertionPoint(thenBlock): if isinstance(node.op, ast.And): - constantFalse = arith.ConstantOp(cond.type, - BoolAttr.get(False)) - cc.ContinueOp([constantFalse]) - else: - cc.ContinueOp([cond]) + prior = arith.XOrIOp(prior, self.getConstantInt(1, + 1)).result + + ifOp = cc.IfOp([prior.type], prior, []) + thenBlock = Block.create_at_start(ifOp.thenRegion, []) + with InsertionPoint(thenBlock): + if isinstance(node.op, ast.And): + constantFalse = arith.ConstantOp( + prior.type, BoolAttr.get(False)) + cc.ContinueOp([constantFalse]) + else: + cc.ContinueOp([prior]) - elseBlock = Block.create_at_start(ifOp.elseRegion, []) - with InsertionPoint(elseBlock): - self.symbolTable.pushScope() - self.pushIfStmtBlockStack() - self.visit(node.values[1]) - rhs = self.popValue() - cc.ContinueOp([rhs]) - self.popIfStmtBlockStack() - self.symbolTable.popScope() + elseBlock = Block.create_at_start(ifOp.elseRegion, []) + with InsertionPoint(elseBlock): + self.symbolTable.pushScope() + self.pushIfStmtBlockStack() + self.visit(values[0]) + rhs = process_boolean_op( + self.__arithmetic_to_bool(self.popValue()), values[1:]) + cc.ContinueOp([rhs]) + self.popIfStmtBlockStack() + self.symbolTable.popScope() + return ifOp.result + + self.pushValue(process_boolean_op(cond, node.values[1:])) # Reset the assign variable name self.currentAssignVariableName = None - - self.pushValue(ifOp.result) return + self.emitFatalError(f'unsupported boolean expression {node.op}', node) + def visit_Compare(self, node): """ Visit while loop compare operations and translate to equivalent MLIR. @@ -4374,14 +4431,6 @@ def visit_Compare(self, node): self.emitFatalError("only single comparators are supported", node) iTy = self.getIntegerType() - - if isinstance(node.left, ast.Name): - self.debug_msg(lambda: f'[(Inline) Visit Name]', node.left) - if node.left.id not in self.symbolTable: - self.emitFatalError( - f"{node.left.id} was not initialized before use in compare expression", - node) - self.visit(node.left) left = self.popValue() self.visit(node.comparators[0]) @@ -4400,6 +4449,9 @@ def convert_arithmetic_types(item1, item2): allowDemotion=False) return item1, item2 + # To understand the integer attributes used here (the predicates) + # see `arith::CmpIPredicate` and `arith::CmpFPredicate`. + def compare_equality(item1, item2): # TODO: the In/NotIn case should be recursive such # that we can search for a list in a list of lists. @@ -4493,31 +4545,41 @@ def compare_equality(item1, item2): if isinstance(op, (ast.In, ast.NotIn)): # Type validation and vector initialization - if not (cc.StdvecType.isinstance(right.type) or - cc.ArrayType.isinstance(right.type)): + if not cc.StdvecType.isinstance(right.type): self.emitFatalError( "Right operand must be a list/vector for 'in' comparison") + vectSize = cc.StdvecSizeOp(self.getIntegerType(), right).result # Loop setup i1_type = self.getIntegerType(1) + trueVal = self.getConstantInt(1, 1) accumulator = cc.AllocaOp(cc.PointerType.get(i1_type), TypeAttr.get(i1_type)).result - cc.StoreOp(self.getConstantInt(0, 1), accumulator) + cc.StoreOp(trueVal, accumulator) # Element comparison loop - def check_element(idx): - element = self.__load_vector_element(right, idx) + def check_element(args): + element = self.__load_vector_element(right, args[0]) compRes = compare_equality(left, element) + neqRes = arith.XOrIOp(compRes, trueVal).result current = cc.LoadOp(accumulator).result - cc.StoreOp(arith.OrIOp(current, compRes), accumulator) + cc.StoreOp(arith.AndIOp(current, neqRes), accumulator) - self.createInvariantForLoop(self.__get_vector_size(right), - check_element) + def check_condition(args): + notListEnd = arith.CmpIOp(IntegerAttr.get(iTy, 2), args[0], + vectSize).result + notFound = cc.LoadOp(accumulator).result + return arith.AndIOp(notListEnd, notFound).result + + # Break early if we found the item + self.createForLoop( + [self.getIntegerType()], check_element, + [self.getConstantInt(0)], check_condition, lambda args: + [arith.AddIOp(args[0], self.getConstantInt(1)).result]) final_result = cc.LoadOp(accumulator).result - if isinstance(op, ast.NotIn): - final_result = arith.XOrIOp(final_result, - self.getConstantInt(1, 1)).result + if isinstance(op, ast.In): + final_result = arith.XOrIOp(final_result, trueVal).result self.pushValue(final_result) return @@ -4538,26 +4600,7 @@ def actually_visit_If(self, node): self.currentAssignVariableName = '' self.visit(node.test) self.currentAssignVariableName = None - - condition = self.popValue() - condition = self.ifPointerThenLoad(condition) - - # To understand the integer attributes used here (the predicates) - # see `arith::CmpIPredicate` and `arith::CmpFPredicate`. - - if self.getIntegerType(1) != condition.type: - if IntegerType.isinstance(condition.type): - condPred = IntegerAttr.get(self.getIntegerType(), 1) - condition = arith.CmpIOp(condPred, condition, - self.getConstantInt(0)).result - - elif F64Type.isinstance(condition.type): - condPred = IntegerAttr.get(self.getIntegerType(), 13) - condition = arith.CmpFOp(condPred, condition, - self.getConstantFloat(0)).result - else: - self.emitFatalError("condition cannot be converted to bool", - node) + condition = self.__arithmetic_to_bool(self.popValue()) ifOp = cc.IfOp([], condition, []) thenBlock = Block.create_at_start(ifOp.thenRegion, []) @@ -4590,38 +4633,79 @@ def visit_Return(self, node): self.visit(node.value) self.walkingReturnNode = False - if len(self.valueStack) == 0: + if self.valueStack.currentNumValues == 0: return - result = self.ifPointerThenLoad(self.popValue()) - result = self.ifPointerThenLoad(result) result = self.changeOperandToType(self.knownResultType, - result, + self.popValue(), allowDemotion=True) - if cc.StdvecType.isinstance(result.type): + # Generally, anything that was allocated locally on the stack + # needs to be copied to the heap to ensure it lives past the + # the function. This holds recursively; if we have a struct + # that contains a list, then the list data may need to be + # copied if it was allocated inside the function. + def copy_list_to_heap(value): symName = '__nvqpp_vectorCopyCtor' load_intrinsic(self.module, symName) - eleTy = cc.StdvecType.getElementType(result.type) + elemTy = cc.StdvecType.getElementType(value.type) + if elemTy == self.getIntegerType(1): + elemTy = self.getIntegerType(8) ptrTy = cc.PointerType.get(self.getIntegerType(8)) arrTy = cc.ArrayType.get(self.getIntegerType(8)) ptrArrTy = cc.PointerType.get(arrTy) - resBuf = cc.StdvecDataOp(ptrArrTy, result).result - # TODO Revisit this calculation - byteWidth = 16 if ComplexType.isinstance(eleTy) else 8 - eleSize = self.getConstantInt(byteWidth) - dynSize = cc.StdvecSizeOp(self.getIntegerType(), result).result + resBuf = cc.StdvecDataOp(ptrArrTy, value).result + eleSize = cc.SizeOfOp(self.getIntegerType(), + TypeAttr.get(elemTy)).result + dynSize = cc.StdvecSizeOp(self.getIntegerType(), value).result resBuf = cc.CastOp(ptrTy, resBuf) heapCopy = func.CallOp([ptrTy], symName, [resBuf, dynSize, eleSize]).result - res = cc.StdvecInitOp(result.type, heapCopy, length=dynSize).result - if self.controlHeight > 0: - cc.UnwindReturnOp([res]) - return - func.ReturnOp([res]) - return + return cc.StdvecInitOp(value.type, heapCopy, length=dynSize).result + + rootVal = self.__get_root_value(node.value) + if rootVal and self.isFunctionArgument(rootVal): + # If we allow assigning a value that contains a list to an item + # of a function argument (which we do with the exceptions + # commented below), then we necessarily need to make a copy when + # we return function arguments, or function argument elements, + # that contain lists, since we have to assume that their data may + # be allocated on the stack. However, this leads to incorrect + # behavior if a returned list was indeed caller-side allocated + # (and should correspondingly have been returned by reference). + # Rather than preventing that lists in function arguments can be + # updated, we instead ensure that lists contained in function + # arguments stay recognizable as such, and prevent that function + # arguments that contain list are returned. + # NOTE: Why is seems straightforward in principle to fail only + # for when we return *inner* lists of function arguments, this + # is still not a good option for two reasons: + # 1) Even if we return the reference to the outer list correctly, + # any caller-side assignment of the return value would no longer + # be recognizable as being the same reference given as argument, + # which is a problem if the list was an argument to the caller. + # I.e. while this works for one function indirection, it does + # not work for two (see assignment tests). + # 2) To ensure that we don't have any memory leaks, we copy any + # lists returned from function calls to the stack. This copy (as + # of the time of writing this) results in a segfault when the + # list is not on the heap. As it is, we hence indeed have to copy + # every returned list to the heap, followed by a copy to the stack + # in the caller. Subsequent optimization passes should largely + # eliminate unnecessary copies. + if (self.containsList(result.type)): + self.emitFatalError( + "return value must not contain a list that is a function argument or an item in a function argument" + + + " - for device kernels, lists passed as arguments will be modified in place; " + + + "remove the return value or use .copy(deep) to create a copy", + node) + else: + result = self.__migrateLists(result, copy_list_to_heap) if self.controlHeight > 0: + # We are in an inner scope, release all scopes before returning cc.UnwindReturnOp([result]) return func.ReturnOp([result]) @@ -4630,24 +4714,12 @@ def visit_Tuple(self, node): """ Map tuples in the Python AST to equivalents in MLIR. """ - # FIXME: The handling of tuples in Python likely needs to be examined carefully; - # The corresponding issue to clarify the expected behavior is - # https://github.com/NVIDIA/cuda-quantum/issues/3031 - # I re-enabled the tuple support in kernel signatures, given that we were already - # allowing the use of data classes everywhere, and supporting tuple use within a - # kernel. It hence seems that any issues with tuples also apply to named structs. - self.generic_visit(node) - - elementValues = [self.popValue() for _ in range(len(node.elts))] + [self.visit(el) for el in node.elts] + elementValues = self.popAllValues(len(node.elts)) elementValues.reverse() - - # We do not store structs of pointers - elementValues = [ - cc.LoadOp(ele).result - if cc.PointerType.isinstance(ele.type) else ele - for ele in elementValues - ] + for idx, value in enumerate(elementValues): + self.__validate_container_entry(value, node.elts[idx]) structTys = [v.type for v in elementValues] structTy = mlirTryCreateStructType(structTys, context=self.ctx) @@ -4658,27 +4730,23 @@ def visit_Tuple(self, node): if quake.StruqType.isinstance(structTy): self.pushValue(quake.MakeStruqOp(structTy, elementValues).result) - else: - stackSlot = cc.AllocaOp(cc.PointerType.get(structTy), - TypeAttr.get(structTy)).result - - # loop over each type and `compute_ptr` / store - for i, ty in enumerate(structTys): - eleAddr = cc.ComputePtrOp( - cc.PointerType.get(ty), stackSlot, [], - DenseI32ArrayAttr.get([i], context=self.ctx)).result - cc.StoreOp(elementValues[i], eleAddr) - - self.pushValue(stackSlot) return + struct = cc.UndefOp(structTy) + for idx, element in enumerate(elementValues): + struct = cc.InsertValueOp( + structTy, struct, element, + DenseI64ArrayAttr.get([idx], context=self.ctx)).result + self.pushValue(struct) + def visit_UnaryOp(self, node): """ Map unary operations in the Python AST to equivalents in MLIR. """ - self.generic_visit(node) + self.visit(node.operand) operand = self.popValue() + # Handle qubit negations if isinstance(node.op, ast.Invert): if quake.RefType.isinstance(operand.type): @@ -4713,13 +4781,9 @@ def visit_UnaryOp(self, node): return if isinstance(node.op, ast.Not): - if not IntegerType.isinstance(operand.type): - self.emitFatalError("UnaryOp Not() on non-integer value.", node) - - zero = self.getConstantInt(0, IntegerType(operand.type).width) self.pushValue( - arith.CmpIOp(IntegerAttr.get(self.getIntegerType(), 0), operand, - zero).result) + arith.XOrIOp(self.__arithmetic_to_bool(operand), + self.getConstantInt(1, 1)).result) return self.emitFatalError("unhandled UnaryOp.", node) @@ -4758,9 +4822,6 @@ def __process_binary_op(self, left, right, nodeType): MLIR. This method handles arithmetic operations between values. """ - left = self.ifPointerThenLoad(left) - right = self.ifPointerThenLoad(right) - # type promotion for anything except pow to match Python behavior if not issubclass(nodeType, ast.Pow): superiorTy = self.__get_superior_type(left.type, right.type) @@ -4934,36 +4995,45 @@ def visit_BinOp(self, node): MLIR. This method handles arithmetic operations between values. """ - # Get the left and right parts of this expression self.visit(node.left) left = self.popValue() self.visit(node.right) right = self.popValue() + # pushes to the value stack self.__process_binary_op(left, right, type(node.op)) def visit_AugAssign(self, node): """ Visit augment-assign operations (e.g. +=). """ - target = None - if isinstance(node.target, - ast.Name) and node.target.id in self.symbolTable: - self.debug_msg(lambda: f'[(Inline) Visit Name]', node.target) - target = self.symbolTable[node.target.id] - else: + self.pushPointerValue = True + self.visit(node.target) + self.pushPointerValue = False + target = self.popValue() + + if not cc.PointerType.isinstance(target.type): self.emitFatalError( "augment-assign target variable is not defined or " "cannot be assigned to.", node) self.visit(node.value) value = self.popValue() - loaded = cc.LoadOp(target).result - self.__process_binary_op(loaded, value, type(node.op)) + # NOTE: `aug-assign` is usually defined as producing a value, + # which we are not doing here. However, if this produces + # a value, then we need to start worrying that arbitrary + # expressions might contain assignments, which would require + # updates to the bridge in a bunch of places and add some + # complexity. We hence effectively disallow using + # any kind of assignment as expression. + self.valueStack.pushFrame() + self.__process_binary_op(loaded, value, type(node.op)) + self.valueStack.popFrame() res = self.popValue() + if res.type != loaded.type: self.emitFatalError( "augment-assign must not change the variable type", node) @@ -4994,32 +5064,65 @@ def visit_Name(self, node): if node.id in self.symbolTable: value = self.symbolTable[node.id] - if cc.PointerType.isinstance(value.type): - eleTy = cc.PointerType.getElementType(value.type) - if cc.ArrayType.isinstance(eleTy): - self.pushValue(value) - return - # Retain `ptr` - if IntegerType.isinstance(eleTy) and IntegerType( - eleTy).width == 8: - self.pushValue(value) - return - if cc.StdvecType.isinstance(eleTy): - self.pushValue(value) - return - if cc.StateType.isinstance(eleTy): - self.pushValue(value) - return - loaded = cc.LoadOp(value).result - self.pushValue(loaded) - elif (cc.CallableType.isinstance(value.type) and - not BlockArgument.isinstance(value)): + if (self.pushPointerValue or + not cc.PointerType.isinstance(value.type)): + self.pushValue(value) return - else: - self.pushValue(self.symbolTable[node.id]) + + eleTy = cc.PointerType.getElementType(value.type) + + # Retain state types as pointers + # (function arguments of `StateType` are passed as pointers) + if cc.StateType.isinstance(eleTy): + self.pushValue(value) + return + + loaded = cc.LoadOp(value).result + self.pushValue(loaded) return + # Check if a nonlocal symbol, and process it. + value = recover_value_of_or_none(node.id, None) + if is_recovered_value_ok(value): + from .kernel_decorator import isa_kernel_decorator + from .kernel_builder import isa_dynamic_kernel + if isa_kernel_decorator(value) or isa_dynamic_kernel(value): + # Not a data variable. Symbol bound to kernel object. This case is + # handled elsewhere. + return + + # node.id is a nonlocal symbol. Lift it to a formal argument. + self.dependentCaptureVars[node.id] = value + # If `node.id` is in liftedArgs, it should already + # be in the symbol table and processed. + assert not node.id in self.liftedArgs + self.appendToLiftedArgs(node.id) + + # Append as a new argument + argTy = mlirTypeFromPyType(type(value), self.ctx, argInstance=value) + mlirVal = cudaq_runtime.appendKernelArgument( + self.kernelFuncOp, argTy) + self.argTypes.append(argTy) + + assignNode = ast.Assign() + assignNode.targets = [node] + assignNode.value = mlirVal + assignNode.lineno = node.lineno + self.visit_Assign(assignNode) + + self.visit(node) + self.pushValue( + self.popValue()) # propagating the pushed value through + return + ''' + if (node.id in globalKernelRegistry or + node.id in globalRegisteredOperations): + # FIXME: newly changed to node.id in globalRegisteredOperations only?? + return + ''' if node.id in globalRegisteredOperations: + # FIXME: WAS + # (node.id in globalKernelRegistry or node.id in globalRegisteredOperations): return if (self.__isUnitaryGate(node.id) or self.__isMeasurementGate(node.id)): @@ -5033,66 +5136,23 @@ def visit_Name(self, node): self.pushValue(self.getFloatType()) return - # Check if a nonlocal symbol, and process it. - value = recover_value_of_or_none(node.id, None) - if not is_recovered_value_ok(value): - # Throw an exception for the case that the name is not in the - # kernel's symbol table and not a valid nonlocal symbol. - self.emitFatalError( - f"Invalid variable name requested - '{node.id}' is not defined " - f"in the quantum kernel where it is used.", node) - - from .kernel_decorator import isa_kernel_decorator - from .kernel_builder import isa_dynamic_kernel - if isa_kernel_decorator(value) or isa_dynamic_kernel(value): - # Not a data variable. Symbol bound to kernel object. This case is - # handled elsewhere. - return - - if self.__isNoiseChannelClass(value): - # This is a custom noise model, not a builtin one (see `visit_Attribute`). - # Handle it in the same way. - if not hasattr(value, 'num_parameters'): - self.emitFatalError( - 'apply_noise kraus channels must have `num_parameters` constant class attribute specified.' - ) - self.pushValue(self.getConstantInt(value.num_parameters)) - self.pushValue(self.getConstantInt(hash(value))) - return - - # node.id is a nonlocal symbol. Lift it to a formal argument. - self.dependentCaptureVars[node.id] = value - if node.id in self.liftedArgs: - self.pushValue(self.symbolTable[node.id]) - return - - self.appendToLiftedArgs(node.id) - # Append as a new argument - argTy = mlirTypeFromPyType(type(value), self.ctx, argInstance=value) - mlirVal = cudaq_runtime.appendKernelArgument(self.kernelFuncOp, argTy) - self.argTypes.append(argTy) - - # Save the lifted argument as a local variable. - with InsertionPoint.at_block_begin(self.entry): - stackSlot = cc.AllocaOp(cc.PointerType.get(mlirVal.type), - TypeAttr.get(mlirVal.type)).result - cc.StoreOp(mlirVal, stackSlot) - self.symbolTable.add(node.id, stackSlot, 0) - self.pushValue(stackSlot) - return + self.emitFatalError( + f"Invalid variable name requested - '{node.id}' is not defined within the scope it is used in.", + node) def compile_to_mlir(uniqueId, astModule, capturedDataStorage: CapturedDataStorage, **kwargs): """ - Compile the given Python AST Module for the CUDA-Q kernel FunctionDef to an - MLIR `ModuleOp`. Return both the `ModuleOp` and the list of function - argument types as MLIR Types. - - This function will first check to see if there are any dependent kernels - that are required by this function. If so, those kernels will also be - compiled into the `ModuleOp`. The AST will be stored later for future - potential dependent kernel lookups. + Compile the given Python AST Module for the CUDA-Q + kernel FunctionDef to an MLIR `ModuleOp`. + Return both the `ModuleOp` and the list of function + argument types as MLIR Types. + + This function will first check to see if there are any dependent + kernels that are required by this function. If so, those kernels + will also be compiled into the `ModuleOp`. The AST will be stored + later for future potential dependent kernel lookups. """ global globalAstRegistry @@ -5102,7 +5162,6 @@ def compile_to_mlir(uniqueId, astModule, parentVariables = kwargs[ 'parentVariables'] if 'parentVariables' in kwargs else {} preCompile = kwargs['preCompile'] if 'preCompile' in kwargs else False - shortName = kwargs['kernelName'] if 'kernelName' in kwargs else None kernelModuleName = kwargs[ 'kernelModuleName'] if 'kernelModuleName' in kwargs else None diff --git a/python/cudaq/kernel/utils.py b/python/cudaq/kernel/utils.py index 55479cc5dfd..747fdbdbd22 100644 --- a/python/cudaq/kernel/utils.py +++ b/python/cudaq/kernel/utils.py @@ -14,7 +14,7 @@ import traceback import importlib import numpy as np -from typing import get_origin, Callable, List +from typing import get_origin, get_args, Callable, List import types from cudaq.mlir.execution_engine import ExecutionEngine from cudaq.mlir.dialects import func @@ -128,7 +128,7 @@ def resolve_qualified_symbol(y): decorator name. """ parts = y.split('.') - for i in range(len(parts), 0, -1): + for i in range(len(parts)): modName = ".".join(parts[:i]) try: mod = importlib.import_module(modName) @@ -281,6 +281,8 @@ def isQuantumType(ty): numQuantumMembers = sum((isQuantumType(t) for t in mlirEleTypes)) if numQuantumMembers == 0: + if any((cc.PointerType.isinstance(t) for t in mlirEleTypes)): + return None return cc.StructType.getNamed(name, mlirEleTypes, context=context) if numQuantumMembers != len(mlirEleTypes) or \ any((quake.StruqType.isinstance(t) for t in mlirEleTypes)): @@ -346,18 +348,24 @@ def emitFatalErrorOverride(msg): f"Callable type must have signature specified ({ast.unparse(annotation) if hasattr(ast, 'unparse') else annotation})." ) - if hasattr(annotation.slice, 'elts'): - firstElement = annotation.slice.elts[0] + if hasattr(annotation.slice, 'elts') and len( + annotation.slice.elts) == 2: + args = annotation.slice.elts[0] + ret = annotation.slice.elts[1] elif hasattr(annotation.slice, 'value') and hasattr( - annotation.slice.value, 'elts'): - firstElement = annotation.slice.value.elts[0] + annotation.slice.value, 'elts') and len( + annotation.slice.value.elts) == 2: + args = annotation.slice.value.elts[0] + ret = annotation.slice.value.elts[1] else: localEmitFatalError( f"Unable to get list elements when inferring type from annotation ({ast.unparse(annotation) if hasattr(ast, 'unparse') else annotation})." ) - argTypes = [ - mlirTypeFromAnnotation(a, ctx) for a in firstElement.elts - ] + argTypes = [mlirTypeFromAnnotation(a, ctx) for a in args.elts] + if not isinstance(ret, ast.Constant) or ret.value: + localEmitFatalError( + "passing kernels as arguments that return a value is not currently supported" + ) return cc.CallableType.get(ctx, argTypes, []) if isinstance(annotation, @@ -529,12 +537,11 @@ def mlirTypeFromPyType(argType, ctx, **kwargs): return cc.PointerType.get(cc.StateType.get(ctx), ctx) if get_origin(argType) == list: - result = re.search(r'ist\[(.*)\]', str(argType)) - eleTyName = result.group(1) + pyEleTy = get_args(argType) + if len(pyEleTy) == 1: + eleTy = mlirTypeFromPyType(pyEleTy[0], ctx) + return cc.StdvecType.get(eleTy, ctx) argType = list - inst = pyInstanceFromName(eleTyName) - if (inst != None): - kwargs['argInstance'] = [inst] if argType in [list, np.ndarray, List]: if 'argInstance' not in kwargs: @@ -544,8 +551,8 @@ def mlirTypeFromPyType(argType, ctx, **kwargs): return cc.StdvecType.get(mlirTypeFromPyType(float, ctx), ctx) argInstance = kwargs['argInstance'] - argTypeToCompareTo = (kwargs['argTypeToCompareTo'] - if 'argTypeToCompareTo' in kwargs else None) + argTypeToCompareTo = kwargs[ + 'argTypeToCompareTo'] if 'argTypeToCompareTo' in kwargs else None if len(argInstance) == 0: if argTypeToCompareTo == None: @@ -567,30 +574,22 @@ def mlirTypeFromPyType(argType, ctx, **kwargs): ctx) if get_origin(argType) == tuple: - result = re.search(r'uple\[(?P.*)\]', str(argType)) - eleTyNames = result.group('names') eleTypes = [] - while eleTyNames != None: - result = re.search(r'(?P.*),\s*(?P.*)', eleTyNames) - eleTyName = result.group('name') if result != None else eleTyNames - eleTyNames = result.group('names') if result != None else None - pyInstance = pyInstanceFromName(eleTyName) - if pyInstance == None: - emitFatalError(f'Invalid tuple element type ({eleTyName})') - eleTypes.append(mlirTypeFromPyType(type(pyInstance), ctx)) - eleTypes.reverse() + for pyEleTy in get_args(argType): + eleTypes.append(mlirTypeFromPyType(pyEleTy, ctx)) tupleTy = mlirTryCreateStructType(eleTypes, context=ctx) if tupleTy is None: - emitFatalError("Hybrid quantum-classical data types and nested " - "quantum structs are not allowed.") + emitFatalError( + "Hybrid quantum-classical data types and nested quantum structs are not allowed." + ) return tupleTy if (argType == tuple): argInstance = kwargs['argInstance'] - argTypeToCompareTo = (kwargs['argTypeToCompareTo'] - if 'argTypeToCompareTo' in kwargs else None) if argInstance == None or (len(argInstance) == 0): emitFatalError(f'Cannot infer runtime argument type for {argType}') + argTypeToCompareTo = (kwargs['argTypeToCompareTo'] + if 'argTypeToCompareTo' in kwargs else None) if argTypeToCompareTo is None: eleTypes = [ mlirTypeFromPyType(type(ele), ctx) for ele in argInstance @@ -599,8 +598,9 @@ def mlirTypeFromPyType(argType, ctx, **kwargs): else: tupleTy = argTypeToCompareTo if tupleTy is None: - emitFatalError("Hybrid quantum-classical data types and nested " - "quantum structs are not allowed.") + emitFatalError( + "Hybrid quantum-classical data types and nested quantum structs are not allowed." + ) return tupleTy if argType == qvector or argType == qreg or argType == qview: diff --git a/python/runtime/cudaq/platform/py_alt_launch_kernel.cpp b/python/runtime/cudaq/platform/py_alt_launch_kernel.cpp index ef7b2f1c178..b94a8f06218 100644 --- a/python/runtime/cudaq/platform/py_alt_launch_kernel.cpp +++ b/python/runtime/cudaq/platform/py_alt_launch_kernel.cpp @@ -982,6 +982,17 @@ static MlirModule synthesizeKernel(py::object kernel, py::args runtimeArgs) { pm.addNestedPass(cudaq::opt::createLoopUnroll()); pm.addNestedPass(createCanonicalizerPass()); pm.addPass(createSymbolDCEPass()); + + std::string error_msg; + mlir::DiagnosticEngine &engine = context->getDiagEngine(); + auto handlerId = engine.registerHandler( + [&error_msg](mlir::Diagnostic &diag) -> mlir::LogicalResult { + if (diag.getSeverity() == mlir::DiagnosticSeverity::Error) { + error_msg += diag.str(); + return mlir::failure(false); + } + return mlir::failure(); + }); DefaultTimingManager tm; tm.setEnabled(cudaq::isTimingTagEnabled(cudaq::TIMING_JIT_PASSES)); auto timingScope = tm.getRootScope(); // starts the timer @@ -990,30 +1001,47 @@ static MlirModule synthesizeKernel(py::object kernel, py::args runtimeArgs) { context->disableMultithreading(); if (enablePrintMLIREachPass) pm.enableIRPrinting(); - if (failed(pm.run(cloned))) + if (failed(pm.run(cloned))) { + engine.eraseHandler(handlerId); throw std::runtime_error( - "cudaq::builder failed to JIT compile the Quake representation."); + "failed to JIT compile the Quake representation\n" + error_msg); + } timingScope.stop(); + engine.eraseHandler(handlerId); return wrap(cloned); } static void executeMLIRPassManager(ModuleOp mod, PassManager &pm) { auto enablePrintMLIREachPass = cudaq::getEnvBool("CUDAQ_MLIR_PRINT_EACH_PASS", false); - + auto context = mod.getContext(); if (enablePrintMLIREachPass) { - mod.getContext()->disableMultithreading(); + context->disableMultithreading(); pm.enableIRPrinting(); } + std::string error_msg; + mlir::DiagnosticEngine &engine = context->getDiagEngine(); + auto handlerId = engine.registerHandler( + [&error_msg](mlir::Diagnostic &diag) -> mlir::LogicalResult { + if (diag.getSeverity() == mlir::DiagnosticSeverity::Error) { + error_msg += diag.str(); + return mlir::failure(false); + } + return mlir::failure(); + }); DefaultTimingManager tm; tm.setEnabled(cudaq::isTimingTagEnabled(cudaq::TIMING_JIT_PASSES)); auto timingScope = tm.getRootScope(); // starts the timer pm.enableTiming(timingScope); // do this right before pm.run - if (failed(pm.run(mod))) + + if (failed(pm.run(mod))) { + engine.eraseHandler(handlerId); throw std::runtime_error( - "cudaq::builder failed to JIT compile the Quake representation."); + "failed to JIT compile the Quake representation\n" + error_msg); + } timingScope.stop(); + engine.eraseHandler(handlerId); } static ModuleOp cleanLowerToCodegenKernel(ModuleOp mod, diff --git a/python/tests/custom/test_custom_operations.py b/python/tests/custom/test_custom_operations.py index a4ef6d64f1a..8439db84d1a 100644 --- a/python/tests/custom/test_custom_operations.py +++ b/python/tests/custom/test_custom_operations.py @@ -180,6 +180,7 @@ def test_bad_attribute(): cudaq.register_operation("custom_s", np.array([1, 0, 0, 1j])) with pytest.raises(Exception) as error: + @cudaq.kernel def kernel(): q = cudaq.qubit() @@ -220,20 +221,21 @@ def test_invalid_ctrl(): cudaq.register_operation("custom_x", np.array([0, 1, 1, 0])) with pytest.raises(RuntimeError) as error: + @cudaq.kernel def bell(): q = cudaq.qubit() custom_x.ctrl(q) bell.compile() - assert 'controlled operation requested without any control argument(s)' in repr( - error) + assert 'missing value' in repr(error) def test_bug_2452(): cudaq.register_operation("custom_i", np.array([1, 0, 0, 1])) with pytest.raises(RuntimeError) as error: + @cudaq.kernel def kernel1(): qubits = cudaq.qvector(2) @@ -259,14 +261,14 @@ def kernel2(): -1])) with pytest.raises(RuntimeError) as error: + @cudaq.kernel def kernel3(): qubits = cudaq.qvector(2) custom_cz(qubits) cudaq.sample(kernel3) - assert 'invalid number of arguments (1) passed to custom_cz (requires 2 arguments)' in repr( - error) + assert 'missing value' in repr(error) # leave for gdb debugging diff --git a/python/tests/kernel/test_control_negations.py b/python/tests/kernel/test_control_negations.py index 325b51653f4..4cdb75a3dfc 100644 --- a/python/tests/kernel/test_control_negations.py +++ b/python/tests/kernel/test_control_negations.py @@ -28,6 +28,15 @@ def control_simple_gate(): counts = cudaq.sample(control_simple_gate) assert counts["01"] == 1000 + @cudaq.kernel + def multi_control_simple_gate(): + c, q = cudaq.qvector(4), cudaq.qubit() + x(c[0], c[3]) + cx(c[0], ~c[1], ~c[2], c[3], q) + + counts = cudaq.sample(multi_control_simple_gate) + assert counts["10011"] == 1000 + @cudaq.kernel def control_rotation_gate(): c, q = cudaq.qubit(), cudaq.qubit() @@ -37,9 +46,19 @@ def control_rotation_gate(): counts = cudaq.sample(control_rotation_gate) assert counts["01"] == 1000 + @cudaq.kernel + def multi_control_rotation_gate(): + c, q = cudaq.qvector(4), cudaq.qubit() + x(c[0], c[3]) + cry(np.pi, c[0], ~c[1], ~c[2], c[3], q) + + counts = cudaq.sample(multi_control_rotation_gate) + assert counts["10011"] == 1000 + + # Note: u3, swap, and exp_pauli do not have a built-in + # c version at the time of writing this. + -# Note: u3, swap, and exp_pauli do not have a built-in c version at -# the time of writing this. def test_ctrl_attribute(): @cudaq.kernel @@ -60,16 +79,6 @@ def multi_control_simple_gate(): counts = cudaq.sample(multi_control_simple_gate) assert counts["10011"] == 1000 - @cudaq.kernel - def multi_control_simple_gate2(): - c, q = cudaq.qvector(4), cudaq.qubit() - x(c[0], c[3]) - c1, c2, c3, c4 = c - x.ctrl(c1, ~c2, ~c3, c4, q) - - counts = cudaq.sample(multi_control_simple_gate2) - assert counts["10011"] == 1000 - @cudaq.kernel def control_rotation_gate(): c, q = cudaq.qubit(), cudaq.qubit() @@ -88,16 +97,6 @@ def multi_control_rotation_gate(): counts = cudaq.sample(multi_control_rotation_gate) assert counts["10011"] == 1000 - @cudaq.kernel - def multi_control_rotation_gate2(): - c, q = cudaq.qvector(4), cudaq.qubit() - x(c[0], c[3]) - c1, c2, c3, c4 = c - ry.ctrl(np.pi, c1, ~c2, ~c3, c4, q) - - counts = cudaq.sample(multi_control_rotation_gate2) - assert counts["10011"] == 1000 - @cudaq.kernel def control_swap_gate(): c, q1, q2 = cudaq.qubit(), cudaq.qubit(), cudaq.qubit() @@ -118,17 +117,6 @@ def multi_control_swap_gate(): counts = cudaq.sample(multi_control_swap_gate) assert counts["100101"] == 1000 - @cudaq.kernel - def multi_control_swap_gate2(): - c, q1, q2 = cudaq.qvector(4), cudaq.qubit(), cudaq.qubit() - x(q1) - x(c[0], c[3]) - c1, c2, c3, c4 = c - swap.ctrl(c1, ~c2, ~c3, c4, q1, q2) - - counts = cudaq.sample(multi_control_swap_gate2) - assert counts["100101"] == 1000 - @cudaq.kernel def control_u3_gate(): c, q = cudaq.qubit(), cudaq.qubit() @@ -149,17 +137,6 @@ def multi_control_u3_gate(): counts = cudaq.sample(multi_control_u3_gate) assert counts["10101"] == 1000 - @cudaq.kernel - def multi_control_u3_gate2(): - c, q = cudaq.qvector(4), cudaq.qubit() - x(c[0], c[2]) - c1, c2, c3, c4 = c - t, p, l = np.pi, 0., 0. - u3.ctrl(t, p, l, c1, ~c2, c3, ~c4, q) - - counts = cudaq.sample(multi_control_u3_gate2) - assert counts["10101"] == 1000 - cudaq.register_operation("custom_x", np.array([0, 1, 1, 0])) @cudaq.kernel @@ -180,16 +157,6 @@ def multi_control_registered_operation(): counts = cudaq.sample(multi_control_registered_operation) assert counts["10101"] == 1000 - @cudaq.kernel - def multi_control_registered_operation2(): - c, q = cudaq.qvector(4), cudaq.qubit() - x(c[0], c[2]) - c1, c2, c3, c4 = c - custom_x.ctrl(c1, ~c2, c3, ~c4, q) - - counts = cudaq.sample(multi_control_registered_operation2) - assert counts["10101"] == 1000 - def test_cudaq_control(): @@ -204,11 +171,19 @@ def control_kernel(): cudaq.control(custom_x, c, q) counts = cudaq.sample(control_kernel) - print(counts) assert counts["01"] == 1000 - # Note: calling cudaq.control on a registered operation or on a built-in gate is - # not supported at the time of writing this + @cudaq.kernel + def multi_control_kernel(): + c, q = cudaq.qvector(4), cudaq.qubit() + x(c[0], c[3]) + cudaq.control(custom_x, c[0], ~c[1], ~c[2], c[3], q) + + counts = cudaq.sample(multi_control_kernel) + assert counts["10011"] == 1000 + + # Note: calling cudaq.control on a registered operation + # or on a built-in gate is not supported at the time of writing this def test_unsupported_calls(): @@ -251,7 +226,8 @@ def control_registered_operation(): cudaq.control(custom_x, c, q) cudaq.sample(control_registered_operation) - assert "unprocessed kernel reference not yet supported" in str(e.value) + assert "calling cudaq.control or cudaq.adjoint on a globally registered operation is not supported" in str( + e.value) with pytest.raises(RuntimeError) as e: @@ -262,7 +238,8 @@ def control_rotation_gate(): cudaq.control(ry, c, np.pi, q) cudaq.sample(control_rotation_gate) - assert "unprocessed kernel reference not yet supported" in str(e.value) + assert "calling cudaq.control or cudaq.adjoint on a built-in gate is not supported" in str( + e.value) with pytest.raises(RuntimeError) as e: @@ -274,8 +251,8 @@ def control_simple_gate(): cx(c, q) cudaq.sample(control_simple_gate) - assert ("unary operator ~ is only supported for values of type qubit" - in str(e.value)) + assert "unary operator ~ is only supported for values of type qubit" in str( + e.value) # leave for gdb debugging diff --git a/python/tests/kernel/test_direct_call_return_kernel.py b/python/tests/kernel/test_direct_call_return_kernel.py index 77cb6e9121a..3d20ef51699 100644 --- a/python/tests/kernel/test_direct_call_return_kernel.py +++ b/python/tests/kernel/test_direct_call_return_kernel.py @@ -203,7 +203,7 @@ def simple_list_bool_no_args() -> list[bool]: @cudaq.kernel def simple_list_bool(n: int, t: list[bool]) -> list[bool]: qubits = cudaq.qvector(n) - return t + return t.copy() result = simple_list_bool(2, [True, False, True]) assert result == [True, False, True] @@ -221,7 +221,7 @@ def simple_list_int_no_args() -> list[int]: @cudaq.kernel def simple_list_int(n: int, t: list[int]) -> list[int]: qubits = cudaq.qvector(n) - return t + return t.copy() result = simple_list_int(2, [-13, 5, 42]) assert result == [-13, 5, 42] @@ -239,7 +239,7 @@ def simple_list_int32_no_args() -> list[np.int32]: @cudaq.kernel def simple_list_int32(n: int, t: list[np.int32]) -> list[np.int32]: qubits = cudaq.qvector(n) - return t + return t.copy() result = simple_list_int32(2, [-13, 5, 42]) assert result == [-13, 5, 42] @@ -257,7 +257,7 @@ def simple_list_int16_no_args() -> list[np.int16]: @cudaq.kernel def simple_list_int16(n: int, t: list[np.int16]) -> list[np.int16]: qubits = cudaq.qvector(n) - return t + return t.copy() result = simple_list_int16(2, [-13, 5, 42]) assert result == [-13, 5, 42] @@ -275,7 +275,7 @@ def simple_list_int8_no_args() -> list[np.int8]: @cudaq.kernel def simple_list_int8(n: int, t: list[np.int8]) -> list[np.int8]: qubits = cudaq.qvector(n) - return t + return t.copy() result = simple_list_int8(2, [-13, 5, 42]) assert result == [-13, 5, 42] @@ -293,7 +293,7 @@ def simple_list_int64_no_args() -> list[np.int64]: @cudaq.kernel def simple_list_int64(n: int, t: list[np.int64]) -> list[np.int64]: qubits = cudaq.qvector(n) - return t + return t.copy() result = simple_list_int64(2, [-13, 5, 42]) assert result == [-13, 5, 42] @@ -311,7 +311,7 @@ def simple_list_float_no_args() -> list[float]: @cudaq.kernel def simple_list_float(n: int, t: list[float]) -> list[float]: qubits = cudaq.qvector(n) - return t + return t.copy() result = simple_list_float(2, [-13.2, 5.0, 42.99]) assert result == [-13.2, 5.0, 42.99] @@ -330,7 +330,7 @@ def simple_list_float32_no_args() -> list[np.float32]: @cudaq.kernel def simple_list_float32(n: int, t: list[np.float32]) -> list[np.float32]: qubits = cudaq.qvector(n) - return t + return t.copy() result = simple_list_float32(2, [-13.2, 5.0, 42.99]) assert is_close_array(result, [-13.2, 5.0, 42.99]) @@ -348,7 +348,7 @@ def simple_list_float64_no_args() -> list[np.float64]: @cudaq.kernel def simple_list_float64(n: int, t: list[np.float64]) -> list[np.float64]: qubits = cudaq.qvector(n) - return t + return t.copy() result = simple_list_float64(2, [-13.2, 5.0, 42.99]) assert result == [-13.2, 5.0, 42.99] @@ -383,8 +383,7 @@ def simple_tuple_int_float_assign( return t simple_tuple_int_float_assign(2, (-13, 42.3)) - assert 'indexing into tuple or dataclass must not modify value' in str( - e.value) + assert 'tuple value cannot be modified' in str(e.value) def test_return_tuple_float_int(): @@ -444,17 +443,13 @@ def simple_tuple_int_bool(n: int, t: tuple[int, bool]) -> tuple[int, bool]: def test_return_tuple_int32_bool(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_tuple_int32_bool_no_args() -> tuple[np.int32, bool]: - return (-13, True) + @cudaq.kernel + def simple_tuple_int32_bool_no_args() -> tuple[np.int32, bool]: + return (-13, True) - simple_tuple_int32_bool_no_args() - # Note: it may make sense to support that if/since we support - # the cast for the individual item types. - assert 'cannot convert value of type !cc.struct<"tuple" {i64, i1}> to the requested type !cc.struct<"tuple" {i32, i1}>' in str( - e.value) + result = simple_tuple_int32_bool_no_args() + # See https://github.com/NVIDIA/cuda-quantum/issues/3524 + assert result == (-13, True) @cudaq.kernel def simple_tuple_int32_bool_no_args1() -> tuple[np.int32, bool]: @@ -475,10 +470,6 @@ def simple_tuple_int32_bool( return t result = simple_tuple_int32_bool(2, (np.int32(-13), True)) - # Note: printing the kernel correctly shows the MLIR - # values return type as "tuple" {i32, i1}, but we don't - # actually create numpy values even when these are requested - # in the signature. # See https://github.com/NVIDIA/cuda-quantum/issues/3524 assert result == (-13, True) @@ -594,7 +585,7 @@ class MyClass: @cudaq.kernel def test_return_dataclass(n: int, t: MyClass) -> MyClass: qubits = cudaq.qvector(n) - return t + return t.copy(deep=True) # TODO: Support recursive aggregate types in kernels. # result = test_return_dataclass(2, MyClass([0,1], 18)) diff --git a/python/tests/kernel/test_explicit_measurements.py b/python/tests/kernel/test_explicit_measurements.py index d14b764570e..5cba790e668 100644 --- a/python/tests/kernel/test_explicit_measurements.py +++ b/python/tests/kernel/test_explicit_measurements.py @@ -110,7 +110,7 @@ def kernel(theta: float, phi: float): assert len(seq[0]) == 20 # num qubits * num_rounds -def test_named_measurment(): +def test_named_measurement(): """ Test for while using "explicit measurements" mode, the sample result will not be saved to a mid-circuit measurement register. """ diff --git a/python/tests/kernel/test_kernel_features.py b/python/tests/kernel/test_kernel_features.py index 7ffe71bcb55..8d2f78ea8f5 100644 --- a/python/tests/kernel/test_kernel_features.py +++ b/python/tests/kernel/test_kernel_features.py @@ -350,13 +350,12 @@ def kernel(theta: float): @pytest.mark.parametrize('target', ['default', 'stim']) def test_dynamic_circuit(target): - """ - Test that we correctly sample circuits with mid-circuit measurements and - conditionals. - """ - save_target = cudaq.get_target() - if target != 'default': - cudaq.set_target(target) + """Test that we correctly sample circuits with + mid-circuit measurements and conditionals.""" + + if target == 'stim': + save_target = cudaq.get_target() + cudaq.set_target('stim') @cudaq.kernel def simple(): @@ -388,7 +387,8 @@ def simple(): assert '0' in c0 and '1' in c0 assert '00' in counts and '11' in counts - cudaq.set_target(save_target) + if target == 'stim': + cudaq.set_target(save_target) def test_teleport(): @@ -417,7 +417,8 @@ def teleport(): counts = cudaq.sample(teleport, shots_count=100) counts.dump() - # Note this is testing that we can provide the register name automatically + # Note this is testing that we can provide + # the register name automatically b0 = counts.get_register_counts('b0') assert '0' in b0 and '1' in b0 @@ -446,8 +447,10 @@ def callMe(): counts = cudaq.sample(callMe) assert len(counts) == 1 and '1' in counts - # This test is for a bug where by vqe_kernel thought kernel was a dependency - # because cudaq.kernel is a Call node in the AST. + # This test is for a bug where by + # vqe_kernel thought kernel was a + # dependency because cudaq.kernel + # is a Call node in the AST. @cudaq.kernel def kernel(): qubit = cudaq.qvector(2) @@ -638,8 +641,7 @@ def kernel(i: int) -> float: return t[i] ret = kernel() - assert ('non-constant subscript value on a tuple is not supported' - in repr(e)) + assert 'non-constant subscript value on a tuple is not supported' in repr(e) def test_list_creation(): @@ -742,9 +744,8 @@ def kernel(l: list[list[int]]): #FIXME: update broadcast detection logic to allow this case. # https://github.com/NVIDIA/cuda-quantum/issues/2895 counts = cudaq.sample(kernel, [[0, 1]]) - assert ( - 'Invalid runtime argument type. Argument of type list[int] was provided' - in repr(e)) + assert 'Invalid runtime argument type. Argument of type list[int] was provided' in repr( + e) def test_list_creation_with_cast(): @@ -844,6 +845,126 @@ def kernel6(): assert len(counts) == 1 assert '0101' in counts + @cudaq.kernel + def kernel7(): + qubits = cudaq.qvector(5) + r = [i for i in range(2, 5)] + for i in r: + x(qubits[i]) + + counts = cudaq.sample(kernel7) + assert len(counts) == 1 + assert '00111' in counts + + @cudaq.kernel + def kernel8(): + qubits = cudaq.qvector(5) + r = [i for i in range(2, 6, 2)] + for i in r: + x(qubits[i]) + + counts = cudaq.sample(kernel8) + assert len(counts) == 1 + assert '00101' in counts + + @cudaq.kernel + def kernel9(): + qubits = cudaq.qvector(5) + r = [i for i in range(6, 2, 2)] + for i in r: + x(qubits[i]) + + counts = cudaq.sample(kernel9) + assert len(counts) == 1 + assert '00000' in counts + + @cudaq.kernel + def kernel10(): + qubits = cudaq.qvector(5) + r = [i for i in range(3, 0, -2)] + for i in r: + x(qubits[i]) + + counts = cudaq.sample(kernel10) + assert len(counts) == 1 + assert '01010' in counts + + @cudaq.kernel + def kernel11(): + qubits = cudaq.qvector(5) + r = [i for i in range(-5, -2, -2)] + for i in r: + x(qubits[i]) + + counts = cudaq.sample(kernel11) + assert len(counts) == 1 + assert '00000' in counts + + @cudaq.kernel + def kernel12(): + qubits = cudaq.qvector(5) + r = [i for i in range(-1, -5, -2)] + for i in r: + x(qubits[-i]) + + counts = cudaq.sample(kernel12) + assert len(counts) == 1 + assert '01010' in counts + + @cudaq.kernel + def kernel13(): + qubits = cudaq.qvector(5) + r = [i for i in range(1, -4, -1)] + for i in r: + if i < 0: + x(qubits[-i]) + else: + x(qubits[i]) + + counts = cudaq.sample(kernel13) + assert len(counts) == 1 + assert '10110' in counts + + @cudaq.kernel + def kernel14(): + qubits = cudaq.qvector(5) + r = [i for i in range(-2, 6, 2)] + for i in r: + if i < 0: + x(qubits[-i]) + else: + x(qubits[i]) + + counts = cudaq.sample(kernel14) + assert len(counts) == 1 + assert '10001' in counts + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def kernel15(): + qubits = cudaq.qvector(5) + r = [i for i in range(1, 4, 0)] + for i in r: + x(qubits[i]) + + cudaq.sample(kernel15) + assert "range step value must be non-zero" in str(e.value) + assert "offending source -> range(1, 4, 0)" in str(e.value) + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def kernel16(v: int): + qubits = cudaq.qvector(5) + r = [i for i in range(1, 4, v)] + for i in r: + x(qubits[i]) + + cudaq.sample(kernel16) + assert "range step value must be a constant" in str(e.value) + assert "offending source -> range(1, 4, v)" in str(e.value) + def test_array_value_assignment(): @@ -1005,23 +1126,6 @@ def canCaptureList(): atol=1e-3) -def test_capture_disallow_change_variable(): - - n = 3 - - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def kernel() -> int: - if True: - cudaq.dbg.ast.print_i64(n) - # Change n, emits an error - n = 4 - return n - - kernel() - - def test_inner_function_capture(): n = 3 @@ -1590,7 +1694,9 @@ def test_kernel(): x(state_reg) test_kernel.compile() - assert 'invalid assignment detected.' in repr(e) + + assert 'no valid value was created' in repr(e) + assert '(offending source -> state_reg = cudaq.qubit)' in repr(e) def test_cast_error_1451(): @@ -1626,6 +1732,7 @@ def test_state(N: int): kernel.h(q[0]) test_state.compile() + assert "unknown function call" in repr(e) assert "offending source -> kernel.h(q[0])" in repr(e) @@ -1985,27 +2092,69 @@ def test(input: CustomIntAndFloatType): instance = CustomIntAndFloatType(2, np.pi / 2.) counts = cudaq.sample(test, instance) - counts.dump() assert len(counts) == 2 and '00' in counts and '11' in counts @dataclass(slots=True) class CustomIntAndListFloat: integer: int - array: List[float] + array: list[float] @cudaq.kernel def test(input: CustomIntAndListFloat): qubits = cudaq.qvector(input.integer) ry(input.array[0], qubits[0]) - rx(input.array[1], qubits[0]) + ry(input.array[1], qubits[2]) x.ctrl(qubits[0], qubits[1]) - print(test) - instance = CustomIntAndListFloat(2, [np.pi / 2., np.pi]) + instance = CustomIntAndListFloat(3, [np.pi, np.pi]) counts = cudaq.sample(test, instance) - counts.dump() - assert len(counts) == 2 and '00' in counts and '11' in counts + assert len(counts) == 1 and '111' in counts + @cudaq.kernel + def test(input: CustomIntAndListFloat): + qubits = cudaq.qvector(input.integer) + ry(input.array[1], qubits[0]) + ry(input.array[3], qubits[2]) + x.ctrl(qubits[0], qubits[1]) + + instance = CustomIntAndListFloat(3, [0, np.pi, 0, np.pi]) + counts = cudaq.sample(test, instance) + assert len(counts) == 1 and '111' in counts + + @dataclass(slots=True) + class CustomIntAndListFloat: + array: list[float] + integer: int + + @cudaq.kernel + def test(input: CustomIntAndListFloat): + qubits = cudaq.qvector(input.integer) + ry(input.array[1], qubits[0]) + ry(input.array[3], qubits[2]) + x.ctrl(qubits[0], qubits[1]) + + instance = CustomIntAndListFloat([0, np.pi, 0, np.pi], 3) + counts = cudaq.sample(test, instance) + assert len(counts) == 1 and '111' in counts + + # FIXME: CURRENTLY CRASHES + ''' + @dataclass(slots=True) + class CustomIntAndListFloat: + integer: list[int] + array: list[int] + + @cudaq.kernel + def test(input: CustomIntAndListFloat): + qubits = cudaq.qvector(input.integer[-1]) + ry(input.array[1], qubits[0]) + ry(input.array[3], qubits[2]) + x.ctrl(qubits[0], qubits[1]) + + instance = CustomIntAndListFloat([3], [0, np.pi, 0, np.pi]) + counts = cudaq.sample(test, instance) + assert len(counts) == 1 and '111' in counts + ''' # Test that the class can be in a library # and the paths all work out from mock.hello import TestClass @@ -2160,7 +2309,26 @@ def k(): h = NoCanDo(q) k() - assert ('struct types with user specified methods are not allowed.' + assert ('struct types with user specified methods are not allowed' + in repr(e)) + + @dataclass(slots=True) + class NoCanDo2: + a: list[float] + + def bob(self): + pass + + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def k(arg: NoCanDo2): + q = cudaq.qvector(len(arg.a)) + for idx, a in enumerate(arg.a): + ry(a, q[idx]) + + k(NoCanDo2([1, 1])) + assert ('struct types with user specified methods are not allowed' in repr(e)) @@ -2186,8 +2354,8 @@ def less_arguments(): rx(3.14) print(less_arguments) - assert ('invalid number of arguments (1) passed to rx (requires at ' - 'least 2 arguments)' in repr(error)) + assert 'missing value' in repr(error) + assert '(offending source -> rx(3.14))' in repr(error) with pytest.raises(RuntimeError) as error: @@ -2197,7 +2365,8 @@ def wrong_arguments(): rx("random_argument", q) print(wrong_arguments) - assert 'rotational parameter must be a float, or int' in repr(error) + assert 'cannot convert value' in repr(error) + assert "(offending source -> rx('random_argument', q))" in repr(error) with pytest.raises(RuntimeError) as error: @@ -2217,8 +2386,8 @@ def invalid_ctrl(): rx.ctrl(np.pi, q) print(invalid_ctrl) - assert ('controlled operation requested without any control argument(s)' - in repr(error)) + assert 'missing value' in repr(error) + assert '(offending source -> rx.ctrl(np.pi, q))' in repr(error) def test_control_then_adjoint(): @@ -2463,7 +2632,7 @@ def kernel(op: cudaq.pauli_word): op(q[0]) result = cudaq.sample(kernel, cudaq.pauli_word("X")) - assert "must be a cc.callable" in str(e.value) + assert "object is not callable" in str(e.value) # leave for gdb debugging diff --git a/python/tests/kernel/test_run_async_kernel.py b/python/tests/kernel/test_run_async_kernel.py index d6d10b9e093..a0a88c3d0b8 100644 --- a/python/tests/kernel/test_run_async_kernel.py +++ b/python/tests/kernel/test_run_async_kernel.py @@ -9,13 +9,12 @@ import os import time from dataclasses import dataclass +from typing import Callable import cudaq import numpy as np import pytest -list_err_msg = 'does not yet support returning `list` from entry-point kernels' - def is_close(actual, expected): return np.isclose(actual, expected, atol=1e-6) @@ -314,7 +313,7 @@ def incrementer(i: int) -> int: @cudaq.kernel def kernel_with_list_arg(arg: list[int]) -> list[int]: - result = arg + result = arg.copy() for i in result: incrementer(i) return result @@ -327,8 +326,7 @@ def caller_kernel(arg: list[int]) -> int: result += v return result - res = cudaq.run_async(caller_kernel, [4, 5, 6], shots_count=1) - results = res.get() + results = cudaq.run_async(caller_kernel, [4, 5, 6], shots_count=1).get() assert len(results) == 1 assert results[0] == 15 # 4+1 + 5+1 + 6+1 = 15 @@ -339,37 +337,44 @@ def test_return_list_bool(): def simple_list_bool_no_args() -> list[bool]: return [True, False, True] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_bool_no_args, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_bool_no_args, shots_count=2).get() + assert len(results) == 2 + assert results[0] == [True, False, True] + assert results[1] == [True, False, True] @cudaq.kernel def simple_list_bool(n: int) -> list[bool]: qubits = cudaq.qvector(n) return [True, False, True] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_bool, 2, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_bool, 2, shots_count=2).get() + assert len(results) == 2 + assert results[0] == [True, False, True] + assert results[1] == [True, False, True] @cudaq.kernel def simple_list_bool_args(n: int, t: list[bool]) -> list[bool]: qubits = cudaq.qvector(n) - return t + return t.copy() - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_bool_args, 2, [True, False, True]).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_bool_args, + 2, [True, False, True], + shots_count=2).get() + assert len(results) == 2 + assert results[0] == [True, False, True] + assert results[1] == [True, False, True] @cudaq.kernel def simple_list_bool_args_no_broadcast(t: list[bool]) -> list[bool]: qubits = cudaq.qvector(2) - return t + return t.copy() - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_bool_args_no_broadcast, - [True, False, True]).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_bool_args_no_broadcast, + [True, False, True], + shots_count=2).get() + assert len(results) == 2 + assert results[0] == [True, False, True] + assert results[1] == [True, False, True] def test_return_list_int(): @@ -378,18 +383,21 @@ def test_return_list_int(): def simple_list_int_no_args() -> list[int]: return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int_no_args, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int_no_args, shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] @cudaq.kernel def simple_list_int(n: int, t: list[int]) -> list[int]: qubits = cudaq.qvector(n) - return t + return t.copy() - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int, 2, [-13, 5, 42], shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int, 2, [-13, 5, 42], + shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_int8(): @@ -398,18 +406,21 @@ def test_return_list_int8(): def simple_list_int8_no_args() -> list[np.int8]: return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int8_no_args, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int8_no_args, shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] @cudaq.kernel def simple_list_int8(n: int, t: list[np.int8]) -> list[np.int8]: qubits = cudaq.qvector(n) - return t + return t.copy() - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int8, 2, [-13, 5, 42], shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int8, 2, [-13, 5, 42], + shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_int16(): @@ -418,18 +429,21 @@ def test_return_list_int16(): def simple_list_int16_no_args() -> list[np.int16]: return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int16_no_args, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int16_no_args, shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] @cudaq.kernel def simple_list_int16(n: int, t: list[np.int16]) -> list[np.int16]: qubits = cudaq.qvector(n) - return t + return t.copy() - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int16, 2, [-13, 5, 42], shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int16, 2, [-13, 5, 42], + shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_int32(): @@ -438,18 +452,21 @@ def test_return_list_int32(): def simple_list_int32_no_args() -> list[np.int32]: return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int32_no_args, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int32_no_args, shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] @cudaq.kernel def simple_list_int32(n: int, t: list[np.int32]) -> list[np.int32]: qubits = cudaq.qvector(n) - return t + return t.copy() - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int32, 2, [-13, 5, 42], shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int32, 2, [-13, 5, 42], + shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_int64(): @@ -458,18 +475,21 @@ def test_return_list_int64(): def simple_list_int64_no_args() -> list[np.int64]: return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int64_no_args, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int64_no_args, shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] @cudaq.kernel def simple_list_int64(n: int, t: list[np.int64]) -> list[np.int64]: qubits = cudaq.qvector(n) - return t + return t.copy() - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_int64, 2, [-13, 5, 42], shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_int64, 2, [-13, 5, 42], + shots_count=2).get() + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_float(): @@ -478,20 +498,22 @@ def test_return_list_float(): def simple_list_float_no_args() -> list[float]: return [-13.2, 5., 42.99] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_float_no_args, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_float_no_args, shots_count=2).get() + assert len(results) == 2 + assert np.allclose(results[0], [-13.2, 5., 42.99]) + assert np.allclose(results[1], [-13.2, 5., 42.99]) @cudaq.kernel def simple_list_float(n: int, t: list[float]) -> list[float]: qubits = cudaq.qvector(n) - return t + return t.copy() - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_float, - 2, [-13.2, 5.0, 42.99], - shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_float, + 2, [-13.2, 5.0, 42.99], + shots_count=2).get() + assert len(results) == 2 + assert np.allclose(results[0], [-13.2, 5., 42.99]) + assert np.allclose(results[1], [-13.2, 5., 42.99]) def test_return_list_float32(): @@ -500,20 +522,22 @@ def test_return_list_float32(): def simple_list_float32_no_args() -> list[np.float32]: return [-13.2, 5., 42.99] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_float32_no_args, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_float32_no_args, shots_count=2).get() + assert len(results) == 2 + assert np.allclose(results[0], [-13.2, 5., 42.99]) + assert np.allclose(results[1], [-13.2, 5., 42.99]) @cudaq.kernel def simple_list_float32(n: int, t: list[np.float32]) -> list[np.float32]: qubits = cudaq.qvector(n) - return t + return t.copy() - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_float32, - 2, [-13.2, 5.0, 42.99], - shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_float32, + 2, [-13.2, 5.0, 42.99], + shots_count=2).get() + assert len(results) == 2 + assert np.allclose(results[0], [-13.2, 5., 42.99]) + assert np.allclose(results[1], [-13.2, 5., 42.99]) def test_return_list_float64(): @@ -522,25 +546,22 @@ def test_return_list_float64(): def simple_list_float64_no_args() -> list[np.float64]: return [-13.2, 5., 42.99] - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_float64_no_args, shots_count=2).get() - assert list_err_msg in str(e.value) + results = cudaq.run_async(simple_list_float64_no_args, shots_count=2).get() + assert len(results) == 2 + assert np.allclose(results[0], [-13.2, 5., 42.99]) + assert np.allclose(results[1], [-13.2, 5., 42.99]) @cudaq.kernel def simple_list_float64(n: int, t: list[np.float64]) -> list[np.float64]: qubits = cudaq.qvector(n) - return t - - with pytest.raises(RuntimeError) as e: - cudaq.run_async(simple_list_float64, - 2, [-13.2, 5.0, 42.99], - shots_count=2).get() - assert list_err_msg in str(e.value) - + return t.copy() -# Test tuples -# TODO: Define spec for using tuples in kernels -# https://github.com/NVIDIA/cuda-quantum/issues/3031 + results = cudaq.run_async(simple_list_float64, + 2, [-13.2, 5.0, 42.99], + shots_count=2).get() + assert len(results) == 2 + assert np.allclose(results[0], [-13.2, 5., 42.99]) + assert np.allclose(results[1], [-13.2, 5., 42.99]) def test_return_tuple_int_float(): @@ -575,20 +596,18 @@ def simple_tuple_int_float_assign( return t cudaq.run_async(simple_tuple_int_float_assign, 2, (-13, 11.5)) - assert ('indexing into tuple or dataclass must not modify value' - in str(e.value)) - - with pytest.raises(RuntimeError) as e: + assert 'tuple value cannot be modified' in str(e.value) - @cudaq.kernel - def simple_tuple_int_float_error( - n: int, t: tuple[int, float]) -> tuple[bool, float]: - qubits = cudaq.qvector(n) - return t + @cudaq.kernel + def simple_tuple_int_float_conversion( + n: int, t: tuple[int, float]) -> tuple[bool, float]: + qubits = cudaq.qvector(n) + return t - cudaq.run_async(simple_tuple_int_float_error, 2, (-13, 11.5)) - assert ('cannot convert value of type !cc.struct<"tuple" {i64, f64}> to ' - 'the requested type !cc.struct<"tuple" {i1, f64}>' in str(e.value)) + result = cudaq.run_async(simple_tuple_int_float_conversion, + 2, (-13, 42.3), + shots_count=1).get() + assert len(result) == 1 and result[0] == (True, 42.3) def test_return_tuple_float_int(): @@ -655,15 +674,13 @@ def simple_tuple_int_bool(n: int, t: tuple[int, bool]) -> tuple[int, bool]: def test_return_tuple_int32_bool(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_tuple_int32_bool_no_args() -> tuple[np.int32, bool]: - return (-13, True) + @cudaq.kernel + def simple_tuple_int32_bool_no_args() -> tuple[np.int32, bool]: + return (-13, True) - cudaq.run_async(simple_tuple_int32_bool_no_args).get() - assert ('cannot convert value of type !cc.struct<"tuple" {i64, i1}> to ' - 'the requested type !cc.struct<"tuple" {i32, i1}>' in str(e.value)) + result = cudaq.run_async(simple_tuple_int32_bool_no_args, + shots_count=1).get() + assert len(result) == 1 and result[0] == (-13, True) @cudaq.kernel def simple_tuple_int32_bool_no_args1() -> tuple[np.int32, bool]: @@ -748,18 +765,16 @@ def simple_dataclass_int_bool_error() -> MyClass: return MyClass(x=-16, y=True) cudaq.run_async(simple_dataclass_int_bool_error, shots_count=2).get() - assert ('invalid number of arguments passed in call to MyClass (0 vs ' - 'required 2)' in repr(e)) - - with pytest.raises(RuntimeError) as e: + assert 'keyword arguments for data classes are not yet supported' in repr(e) - @cudaq.kernel - def simple_dataclass_int_bool_error() -> MyClass: - return MyClass(x=0.13, y=True) + @cudaq.kernel + def simple_dataclass_int_bool() -> MyClass: + return MyClass(2.13, True) - cudaq.run_async(simple_dataclass_int_bool_error, shots_count=2).get() - assert ('invalid number of arguments passed in call to MyClass (0 vs ' - 'required 2)' in repr(e)) + results = cudaq.run_async(simple_dataclass_int_bool, shots_count=2).get() + assert len(results) == 2 + assert results[0] == MyClass(2, True) + assert results[1] == MyClass(2, True) def test_return_dataclass_bool_int(): @@ -895,6 +910,7 @@ def simple_return_dataclass(n: int, t: MyClass2) -> MyClass2: def test_run_errors(): + with pytest.raises(RuntimeError) as e: @cudaq.kernel @@ -938,12 +954,13 @@ class MyClass: y: bool @cudaq.kernel - def simple_strucA(t: MyClass) -> MyClass: + def simple_structA(arg: MyClass) -> MyClass: q = cudaq.qubit() + t = arg.copy() t.x = 42 return t - results = cudaq.run_async(simple_strucA, MyClass(-13, True), + results = cudaq.run_async(simple_structA, MyClass(-13, True), shots_count=2).get() print(results) assert len(results) == 2 @@ -957,14 +974,16 @@ class Foo: z: int @cudaq.kernel - def kerneB(t: Foo) -> Foo: + def kernelB(arg: Foo) -> Foo: q = cudaq.qubit() + t = arg.copy() t.z = 100 t.y = 3.14 t.x = True return t - results = cudaq.run_async(kerneB, Foo(False, 6.28, 17), shots_count=2).get() + results = cudaq.run_async(kernelB, Foo(False, 6.28, 17), + shots_count=2).get() print(results) assert len(results) == 2 assert results[0] == Foo(True, 3.14, 100) diff --git a/python/tests/kernel/test_run_kernel.py b/python/tests/kernel/test_run_kernel.py index 01b7c88cd66..6d6f2370ff5 100644 --- a/python/tests/kernel/test_run_kernel.py +++ b/python/tests/kernel/test_run_kernel.py @@ -8,14 +8,13 @@ import os from dataclasses import dataclass +from typing import Callable import cudaq import numpy as np import warnings import pytest -list_err_msg = 'does not yet support returning `list` from entry-point kernels' - skipIfBraketNotInstalled = pytest.mark.skipif( not (cudaq.has_target("braket")), reason='Could not find `braket` in installation') @@ -309,7 +308,7 @@ def incrementer(i: int) -> int: @cudaq.kernel def kernel_with_list_arg(arg: list[int]) -> list[int]: - result = arg + result = arg.copy() for i in result: incrementer(i) return result @@ -324,232 +323,312 @@ def caller_kernel(arg: list[int]) -> int: results = cudaq.run(caller_kernel, [4, 5, 6], shots_count=1) assert len(results) == 1 - assert results[0] == 15 # 4+1 + 5+1 + 6+1 = 15 + assert results[0] == 15 # 4 + 5 + 6 = 15 def test_return_list_bool(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_list_bool_no_args() -> list[bool]: - return [True, False, True] - - cudaq.run(simple_list_bool_no_args, shots_count=2) - assert list_err_msg in str(e.value) - - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_list_bool(n: int) -> list[bool]: - qubits = cudaq.qvector(n) - return [True, False, True] + @cudaq.kernel + def simple_list_bool_no_args() -> list[bool]: + return [True, False, True] - cudaq.run(simple_list_bool, 2, shots_count=2) - assert list_err_msg in str(e.value) + results = cudaq.run(simple_list_bool_no_args, shots_count=2) + assert len(results) == 2 + assert results[0] == [True, False, True] + assert results[1] == [True, False, True] - with pytest.raises(RuntimeError) as e: + @cudaq.kernel + def simple_list_bool(n: int) -> list[bool]: + qubits = cudaq.qvector(n) + return [True, False, True] - @cudaq.kernel - def simple_list_bool_args(n: int, t: list[bool]) -> list[bool]: - qubits = cudaq.qvector(n) - return t + results = cudaq.run(simple_list_bool, 2, shots_count=2) + assert len(results) == 2 + assert results[0] == [True, False, True] + assert results[1] == [True, False, True] - cudaq.run(simple_list_bool_args, 2, [True, False, True]) - assert list_err_msg in str(e.value) + @cudaq.kernel + def simple_list_bool_args(n: int, t: list[bool]) -> list[bool]: + qubits = cudaq.qvector(n) + return t.copy() - with pytest.raises(RuntimeError) as e: + results = cudaq.run(simple_list_bool_args, + 2, [True, False, True], + shots_count=2) + assert len(results) == 2 + assert results[0] == [True, False, True] + assert results[1] == [True, False, True] - @cudaq.kernel - def simple_list_bool_args_no_broadcast(t: list[bool]) -> list[bool]: - qubits = cudaq.qvector(2) - return t + @cudaq.kernel + def simple_list_bool_args_no_broadcast(t: list[bool]) -> list[bool]: + qubits = cudaq.qvector(2) + return t.copy() - cudaq.run(simple_list_bool_args_no_broadcast, [True, False, True]) - assert list_err_msg in str(e.value) + results = cudaq.run(simple_list_bool_args_no_broadcast, [True, False, True], + shots_count=2) + assert len(results) == 2 + assert results[0] == [True, False, True] + assert results[1] == [True, False, True] def test_return_list_int(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_list_int_no_args() -> list[int]: - return [-13, 5, 42] - - cudaq.run(simple_list_int_no_args, shots_count=2) - assert list_err_msg in str(e.value) + @cudaq.kernel + def simple_list_int_no_args() -> list[int]: + return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: + results = cudaq.run(simple_list_int_no_args, shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] - @cudaq.kernel - def simple_list_int(n: int, t: list[int]) -> list[int]: - qubits = cudaq.qvector(n) - return t + @cudaq.kernel + def simple_list_int(n: int, t: list[int]) -> list[int]: + qubits = cudaq.qvector(n) + return t.copy() - cudaq.run(simple_list_int, 2, [-13, 5, 42], shots_count=2) - assert list_err_msg in str(e.value) + results = cudaq.run(simple_list_int, 2, [-13, 5, 42], shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_int8(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_list_int8_no_args() -> list[np.int8]: - return [-13, 5, 42] - - cudaq.run(simple_list_int8_no_args, shots_count=2) - assert list_err_msg in str(e.value) + @cudaq.kernel + def simple_list_int8_no_args() -> list[np.int8]: + return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: + results = cudaq.run(simple_list_int8_no_args, shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] - @cudaq.kernel - def simple_list_int8(n: int, t: list[np.int8]) -> list[np.int8]: - qubits = cudaq.qvector(n) - return t + @cudaq.kernel + def simple_list_int8(n: int, t: list[np.int8]) -> list[np.int8]: + qubits = cudaq.qvector(n) + return t.copy() - cudaq.run(simple_list_int8, 2, [-13, 5, 42], shots_count=2) - assert list_err_msg in str(e.value) + results = cudaq.run(simple_list_int8, 2, [-13, 5, 42], shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_int16(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_list_int16_no_args() -> list[np.int16]: - return [-13, 5, 42] - - cudaq.run(simple_list_int16_no_args, shots_count=2) - assert list_err_msg in str(e.value) + @cudaq.kernel + def simple_list_int16_no_args() -> list[np.int16]: + return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: + results = cudaq.run(simple_list_int16_no_args, shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] - @cudaq.kernel - def simple_list_int16(n: int, t: list[np.int16]) -> list[np.int16]: - qubits = cudaq.qvector(n) - return t + @cudaq.kernel + def simple_list_int16(n: int, t: list[np.int16]) -> list[np.int16]: + qubits = cudaq.qvector(n) + return t.copy() - cudaq.run(simple_list_int16, 2, [-13, 5, 42], shots_count=2) - assert list_err_msg in str(e.value) + results = cudaq.run(simple_list_int16, 2, [-13, 5, 42], shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_int32(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_list_int32_no_args() -> list[np.int32]: - return [-13, 5, 42] - - cudaq.run(simple_list_int32_no_args, shots_count=2) - assert list_err_msg in str(e.value) + @cudaq.kernel + def simple_list_int32_no_args() -> list[np.int32]: + return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: + results = cudaq.run(simple_list_int32_no_args, shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] - @cudaq.kernel - def simple_list_int32(n: int, t: list[np.int32]) -> list[np.int32]: - qubits = cudaq.qvector(n) - return t + @cudaq.kernel + def simple_list_int32(n: int, t: list[np.int32]) -> list[np.int32]: + qubits = cudaq.qvector(n) + return t.copy() - cudaq.run(simple_list_int32, 2, [-13, 5, 42], shots_count=2) - assert list_err_msg in str(e.value) + results = cudaq.run(simple_list_int32, 2, [-13, 5, 42], shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_int64(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_list_int64_no_args() -> list[np.int64]: - return [-13, 5, 42] - - cudaq.run(simple_list_int64_no_args, shots_count=2) - assert list_err_msg in str(e.value) + @cudaq.kernel + def simple_list_int64_no_args() -> list[np.int64]: + return [-13, 5, 42] - with pytest.raises(RuntimeError) as e: + results = cudaq.run(simple_list_int64_no_args, shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] - @cudaq.kernel - def simple_list_int64(n: int, t: list[np.int64]) -> list[np.int64]: - qubits = cudaq.qvector(n) - return t + @cudaq.kernel + def simple_list_int64(n: int, t: list[np.int64]) -> list[np.int64]: + qubits = cudaq.qvector(n) + return t.copy() - cudaq.run(simple_list_int64, 2, [-13, 5, 42], shots_count=2) - assert list_err_msg in str(e.value) + results = cudaq.run(simple_list_int64, 2, [-13, 5, 42], shots_count=2) + assert len(results) == 2 + assert results[0] == [-13, 5, 42] + assert results[1] == [-13, 5, 42] def test_return_list_float(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_list_float_no_args() -> list[float]: - return [-13.2, 5., 42.99] - - cudaq.run(simple_list_float_no_args, shots_count=2) - assert list_err_msg in str(e.value) + @cudaq.kernel + def simple_list_float_no_args() -> list[float]: + return [-13.2, 5., 42.99] - with pytest.raises(RuntimeError) as e: + results = cudaq.run(simple_list_float_no_args, shots_count=2) + assert len(results) == 2 + assert is_close_array(results[0], [-13.2, 5., 42.99]) + assert is_close_array(results[1], [-13.2, 5., 42.99]) - @cudaq.kernel - def simple_list_float(n: int, t: list[float]) -> list[float]: - qubits = cudaq.qvector(n) - return t + @cudaq.kernel + def simple_list_float(n: int, t: list[float]) -> list[float]: + qubits = cudaq.qvector(n) + return t.copy() - cudaq.run(simple_list_float, 2, [-13.2, 5.0, 42.99], shots_count=2) - assert list_err_msg in str(e.value) + results = cudaq.run(simple_list_float, + 2, [-13.2, 5.0, 42.99], + shots_count=2) + assert len(results) == 2 + assert is_close_array(results[0], [-13.2, 5., 42.99]) + assert is_close_array(results[1], [-13.2, 5., 42.99]) def test_return_list_float32(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_list_float32_no_args() -> list[np.float32]: - return [-13.2, 5., 42.99] - - cudaq.run(simple_list_float32_no_args, shots_count=2) - assert list_err_msg in str(e.value) + @cudaq.kernel + def simple_list_float32_no_args() -> list[np.float32]: + return [-13.2, 5., 42.99] - with pytest.raises(RuntimeError) as e: + results = cudaq.run(simple_list_float32_no_args, shots_count=2) + assert len(results) == 2 + assert is_close_array(results[0], [-13.2, 5., 42.99]) + assert is_close_array(results[1], [-13.2, 5., 42.99]) - @cudaq.kernel - def simple_list_float32(n: int, - t: list[np.float32]) -> list[np.float32]: - qubits = cudaq.qvector(n) - return t + @cudaq.kernel + def simple_list_float32(n: int, t: list[np.float32]) -> list[np.float32]: + qubits = cudaq.qvector(n) + return t.copy() - cudaq.run(simple_list_float32, 2, [-13.2, 5.0, 42.99], shots_count=2) - assert list_err_msg in str(e.value) + results = cudaq.run(simple_list_float32, + 2, [-13.2, 5.0, 42.99], + shots_count=2) + assert len(results) == 2 + assert is_close_array(results[0], [-13.2, 5., 42.99]) + assert is_close_array(results[1], [-13.2, 5., 42.99]) def test_return_list_float64(): - with pytest.raises(RuntimeError) as e: + @cudaq.kernel + def simple_list_float64_no_args() -> list[np.float64]: + return [-13.2, 5., 42.99] - @cudaq.kernel - def simple_list_float64_no_args() -> list[np.float64]: - return [-13.2, 5., 42.99] + results = cudaq.run(simple_list_float64_no_args, shots_count=2) + assert len(results) == 2 + assert is_close_array(results[0], [-13.2, 5., 42.99]) + assert is_close_array(results[1], [-13.2, 5., 42.99]) - cudaq.run(simple_list_float64_no_args, shots_count=2) - assert list_err_msg in str(e.value) + @cudaq.kernel + def simple_list_float64(n: int, t: list[np.float64]) -> list[np.float64]: + qubits = cudaq.qvector(n) + return t.copy() - with pytest.raises(RuntimeError) as e: + results = cudaq.run(simple_list_float64, + 2, [-13.2, 5.0, 42.99], + shots_count=2) + assert len(results) == 2 + assert is_close_array(results[0], [-13.2, 5., 42.99]) + assert is_close_array(results[1], [-13.2, 5., 42.99]) - @cudaq.kernel - def simple_list_float64(n: int, - t: list[np.float64]) -> list[np.float64]: - qubits = cudaq.qvector(n) - return t - cudaq.run(simple_list_float64, 2, [-13.2, 5.0, 42.99], shots_count=2) - assert list_err_msg in str(e.value) +def test_return_list_large_size(): + # Returns a large list (dynamic size) to stress test the code generation + @cudaq.kernel + def kernel_with_dynamic_int_array_input(n: int, t: list[int]) -> list[int]: + qubits = cudaq.qvector(n) + return t.copy() -# Test tuples -# TODO: Define spec for using tuples in kernels -# https://github.com/NVIDIA/cuda-quantum/issues/3031 + @cudaq.kernel + def kernel_with_dynamic_float_array_input(n: int, + t: list[float]) -> list[float]: + qubits = cudaq.qvector(n) + return t.copy() + + @cudaq.kernel + def kernel_with_dynamic_bool_array_input(n: int, + t: list[bool]) -> list[bool]: + qubits = cudaq.qvector(n) + return t.copy() + + # Test with various sizes (validate dynamic output logging) + for array_size in [10, 15, 100, 167, 1000]: + input_array = list(np.random.randint(-1000, 1000, size=array_size)) + results = cudaq.run(kernel_with_dynamic_int_array_input, + 2, + input_array, + shots_count=2) + assert len(results) == 2 + assert results[0] == input_array + assert results[1] == input_array + + input_array_float = list( + np.random.uniform(-1000.0, 1000.0, size=array_size)) + results = cudaq.run(kernel_with_dynamic_float_array_input, + 2, + input_array_float, + shots_count=2) + assert len(results) == 2 + assert is_close_array(results[0], input_array_float) + assert is_close_array(results[1], input_array_float) + + input_array_bool = [] + for _ in range(array_size): + input_array_bool.append(True if np.random.rand() > 0.5 else False) + results = cudaq.run(kernel_with_dynamic_bool_array_input, + 2, + input_array_bool, + shots_count=2) + assert len(results) == 2 + assert results[0] == input_array_bool + assert results[1] == input_array_bool + + +def test_return_dynamics_measure_results(): + + @cudaq.kernel + def measure_all_qubits(numQubits: int) -> list[bool]: + # Number of qubits is dynamic + qubits = cudaq.qvector(numQubits) + for i in range(numQubits): + if i % 2 == 0: + x(qubits[i]) + + return mz(qubits) + + for numQubits in [1, 3, 5, 11, 20]: + shots = 2 + results = cudaq.run(measure_all_qubits, numQubits, shots_count=shots) + assert len(results) == shots + for res in results: + assert len(res) == numQubits + for i in range(numQubits): + if i % 2 == 0: + assert res[i] == True + else: + assert res[i] == False def test_return_tuple_int_float(): @@ -581,20 +660,18 @@ def simple_tuple_int_float_assign( return t cudaq.run(simple_tuple_int_float_assign, 2, (-13, 11.5)) - assert 'indexing into tuple or dataclass must not modify value' in str( - e.value) - - with pytest.raises(RuntimeError) as e: + assert 'tuple value cannot be modified' in str(e.value) - @cudaq.kernel - def simple_tuple_int_float_error( - n: int, t: tuple[int, float]) -> tuple[bool, float]: - qubits = cudaq.qvector(n) - return t + @cudaq.kernel + def simple_tuple_int_float_conversion( + n: int, t: tuple[int, float]) -> tuple[bool, float]: + qubits = cudaq.qvector(n) + return t - cudaq.run(simple_tuple_int_float_error, 2, (-13, 11.5)) - assert 'cannot convert value of type !cc.struct<"tuple" {i64, f64}> to the requested type !cc.struct<"tuple" {i1, f64}>' in str( - e.value) + result = cudaq.run(simple_tuple_int_float_conversion, + 2, (-13, 42.3), + shots_count=1) + assert len(result) == 1 and result[0] == (True, 42.3) def test_return_tuple_float_int(): @@ -654,15 +731,12 @@ def simple_tuple_int_bool(n: int, t: tuple[int, bool]) -> tuple[int, bool]: def test_return_tuple_int32_bool(): - with pytest.raises(RuntimeError) as e: - - @cudaq.kernel - def simple_tuple_int32_bool_no_args() -> tuple[np.int32, bool]: - return (-13, True) + @cudaq.kernel + def simple_tuple_int32_bool_no_args() -> tuple[np.int32, bool]: + return (-13, True) - cudaq.run(simple_tuple_int32_bool_no_args) - assert 'cannot convert value of type !cc.struct<"tuple" {i64, i1}> to the requested type !cc.struct<"tuple" {i32, i1}>' in str( - e.value) + result = cudaq.run(simple_tuple_int32_bool_no_args, shots_count=1) + assert len(result) == 1 and result[0] == (-13, True) @cudaq.kernel def simple_tuple_int32_bool_no_args1() -> tuple[np.int32, bool]: @@ -866,6 +940,7 @@ def simple_return_dataclass(n: int, t: MyClass2) -> MyClass2: def test_run_errors(): + with pytest.raises(RuntimeError) as e: @cudaq.kernel @@ -916,13 +991,33 @@ class MyClass: x: int y: bool + with pytest.raises(RuntimeError) as e: + + @cudaq.kernel + def simple_struc_err(t: MyClass) -> MyClass: + q = cudaq.qubit() + # If we allowed this, the expected behavior for Python + # would be that t is modified also in the caller without + # having to return it. We hence give an error to make it + # clear that changes to structs don't propagate past + # function boundaries. + t.x = 42 + return t + + cudaq.run(simple_struc_err, MyClass(-13, True), shots_count=2) + + assert 'value cannot be modified - use `.copy(deep)` to create a new value that can be modified' in repr( + e) + assert '(offending source -> t.x)' in repr(e) + @cudaq.kernel - def simple_strucA(t: MyClass) -> MyClass: + def simple_structA(arg: MyClass) -> MyClass: q = cudaq.qubit() + t = arg.copy() t.x = 42 return t - results = cudaq.run(simple_strucA, MyClass(-13, True), shots_count=2) + results = cudaq.run(simple_structA, MyClass(-13, True), shots_count=2) print(results) assert len(results) == 2 assert results[0] == MyClass(42, True) @@ -935,14 +1030,15 @@ class Foo: z: int @cudaq.kernel - def kerneB(t: Foo) -> Foo: + def kernelB(arg: Foo) -> Foo: q = cudaq.qubit() + t = arg.copy() t.z = 100 t.y = 3.14 t.x = True return t - results = cudaq.run(kerneB, Foo(False, 6.28, 17), shots_count=2) + results = cudaq.run(kernelB, Foo(False, 6.28, 17), shots_count=2) print(results) assert len(results) == 2 assert results[0] == Foo(True, 3.14, 100) diff --git a/python/tests/kernel/test_sample_kernel.py b/python/tests/kernel/test_sample_kernel.py index 03d531ddc8b..4a1fdb2937f 100644 --- a/python/tests/kernel/test_sample_kernel.py +++ b/python/tests/kernel/test_sample_kernel.py @@ -81,6 +81,10 @@ def qpe(nC: int, nQ: int, statePrep: Callable[[cudaq.qubit], None], assert len(counts) == 1 assert '100' in counts + counts_async = cudaq.sample_async(qpe, 3, 1, xGate, tGate).get() + assert len(counts_async) == 1 + assert '100' in counts_async + # Test that we can define kernels after the # definition of a composable kernel like qpe # and use them as input (they get added to the diff --git a/python/tests/visualization/test_draw.py b/python/tests/visualization/test_draw.py index 33db2632330..db21104ff78 100644 --- a/python/tests/visualization/test_draw.py +++ b/python/tests/visualization/test_draw.py @@ -138,11 +138,11 @@ def hw_kernel(): cudaq.set_target('ionq', emulate=True) # fmt: on expected_str = R""" - ╭───╮ ╭───╮╭─────╮╭───╮╭───╮ -q0 : ┤ h ├──●─────────────────────●──────────────┤ x ├┤ tdg ├┤ x ├┤ t ├ - ╰───╯ │ │ ╰─┬─╯╰─────╯╰─┬─╯├───┤ -q1 : ───────┼───────────●─────────┼───────────●────●───────────●──┤ t ├ - ╭───╮╭─┴─╮╭─────╮╭─┴─╮╭───╮╭─┴─╮╭─────╮╭─┴─╮╭───╮ ╭───╮ ╰───╯ + ╭───╮ ╭───╮ +q0 : ┤ h ├──────────────●─────────────────────●────●───────────●──┤ t ├ + ╰───╯ │ │ ╭─┴─╮╭─────╮╭─┴─╮├───┤ +q1 : ───────●───────────┼─────────●───────────┼──┤ x ├┤ tdg ├┤ x ├┤ t ├ + ╭───╮╭─┴─╮╭─────╮╭─┴─╮╭───╮╭─┴─╮╭─────╮╭─┴─╮├───┤╰┬───┬╯╰───╯╰───╯ q2 : ┤ h ├┤ x ├┤ tdg ├┤ x ├┤ t ├┤ x ├┤ tdg ├┤ x ├┤ t ├─┤ h ├─────────── ╰───╯╰───╯╰─────╯╰───╯╰───╯╰───╯╰─────╯╰───╯╰───╯ ╰───╯ """ diff --git a/test/AST-Quake/control_flow.cpp b/test/AST-Quake/control_flow.cpp index 95f1311bcd6..3984766e6b3 100644 --- a/test/AST-Quake/control_flow.cpp +++ b/test/AST-Quake/control_flow.cpp @@ -20,27 +20,27 @@ void g3(); void g4(); struct C { - void operator()() __qpu__ { - cudaq::qvector r(2); - g1(); - for (int i = 0; i < 10; ++i) { - if (f1(i)) { - cudaq::qubit q; - x(q,r[0]); - break; - } - x(r[0],r[1]); - g2(); - if (f2(i)) { - y(r[1]); - continue; - } - g3(); - z(r); - } - g4(); - mz(r); - } + void operator()() __qpu__ { + cudaq::qvector r(2); + g1(); + for (int i = 0; i < 10; ++i) { + if (f1(i)) { + cudaq::qubit q; + x(q,r[0]); + break; + } + x(r[0],r[1]); + g2(); + if (f2(i)) { + y(r[1]); + continue; + } + g3(); + z(r); + } + g4(); + mz(r); + } }; // CHECK-LABEL: func.func @__nvqpp__mlirgen__C() @@ -109,26 +109,26 @@ struct C { // CHECK: } struct D { - void operator()() __qpu__ { - cudaq::qvector r(2); - g1(); - for (int i = 0; i < 10; ++i) { - if (f1(i)) { - cudaq::qubit q; - x(q,r[0]); - continue; - } - x(r[0],r[1]); - g2(); - if (f2(i)) { - y(r[1]); - break; - } - g3(); - z(r); - } - g4(); - mz(r); + void operator()() __qpu__ { + cudaq::qvector r(2); + g1(); + for (int i = 0; i < 10; ++i) { + if (f1(i)) { + cudaq::qubit q; + x(q,r[0]); + continue; + } + x(r[0],r[1]); + g2(); + if (f2(i)) { + y(r[1]); + break; + } + g3(); + z(r); + } + g4(); + mz(r); } }; @@ -199,25 +199,25 @@ struct D { struct E { void operator()() __qpu__ { - cudaq::qvector r(2); - g1(); - for (int i = 0; i < 10; ++i) { - if (f1(i)) { - cudaq::qubit q; - x(q,r[0]); - return; - } - x(r[0],r[1]); - g2(); - if (f2(i)) { - y(r[1]); - break; - } - g3(); - z(r); - } - g4(); - mz(r); + cudaq::qvector r(2); + g1(); + for (int i = 0; i < 10; ++i) { + if (f1(i)) { + cudaq::qubit q; + x(q,r[0]); + return; + } + x(r[0],r[1]); + g2(); + if (f2(i)) { + y(r[1]); + break; + } + g3(); + z(r); + } + g4(); + mz(r); } }; @@ -286,26 +286,26 @@ struct E { // CHECK: return struct F { - void operator()() __qpu__ { - cudaq::qvector r(2); - g1(); - for (int i = 0; i < 10; ++i) { - if (f1(i)) { - cudaq::qubit q; - x(q,r[0]); - continue; - } - x(r[0],r[1]); - g2(); - if (f2(i)) { - y(r[1]); - return; - } - g3(); - z(r); - } - g4(); - mz(r); + void operator()() __qpu__ { + cudaq::qvector r(2); + g1(); + for (int i = 0; i < 10; ++i) { + if (f1(i)) { + cudaq::qubit q; + x(q,r[0]); + continue; + } + x(r[0],r[1]); + g2(); + if (f2(i)) { + y(r[1]); + return; + } + g3(); + z(r); + } + g4(); + mz(r); } };