From 4317d8554c7cb6b9b5440d4af1521ee1c0815bfe Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Fri, 23 May 2025 11:49:45 +0200 Subject: [PATCH 1/2] wip: First interpreter draft --- brat/Brat/Compile/Hugr.hs | 2 +- brat/Brat/Compile/Interpreter.hs | 318 +++++++++++++++++++++++++++++++ brat/Brat/Compiler.hs | 28 +++ brat/Brat/Graph.hs | 3 + brat/app/Main.hs | 12 +- brat/brat.cabal | 5 +- 6 files changed, 363 insertions(+), 5 deletions(-) create mode 100644 brat/Brat/Compile/Interpreter.hs diff --git a/brat/Brat/Compile/Hugr.hs b/brat/Brat/Compile/Hugr.hs index 67f6413b..ca12db3a 100644 --- a/brat/Brat/Compile/Hugr.hs +++ b/brat/Brat/Compile/Hugr.hs @@ -7,7 +7,7 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeSynonymInstances #-} -module Brat.Compile.Hugr (compile) where +module Brat.Compile.Hugr (compile, Compile, NodeId, CompilationState (..), emptyCS, TypedPort, addNode, addEdge, freshNode, addOp, compileRo, compileCTy, renameAndSortHugr) where import Brat.Constructors.Patterns (pattern CFalse, pattern CTrue) import Brat.Checker.Monad (track, trackM, CheckingSig(..)) diff --git a/brat/Brat/Compile/Interpreter.hs b/brat/Brat/Compile/Interpreter.hs new file mode 100644 index 00000000..123d4732 --- /dev/null +++ b/brat/Brat/Compile/Interpreter.hs @@ -0,0 +1,318 @@ +module Brat.Compile.Interpreter (run, Value(..)) where + +import Brat.Naming (Name, Namespace) +import Brat.Graph (Graph, NodeType (..), Node (BratNode, KernelNode), wiresTo, MatchSequence (..), PrimTest (..), TestMatchData (..)) +import qualified Data.Map as M +import Brat.Syntax.Common +import Brat.Checker.Types (Store, VEnv) +import Brat.Syntax.Value +import Brat.Compile.Hugr +import Control.Monad.State +import Data.Tuple.HT (fst3) +import Control.Monad (forM, foldM, forM_) +import Brat.Syntax.Simple (SimpleTerm (..)) +import Control.Arrow (first) +import Data.List.NonEmpty (NonEmpty(..), toList) +import Brat.QualName (QualName (PrefixName)) +import Data.Hugr +import Debug.Trace (trace) +import Hasochism +import Brat.Constructors.Patterns + + +type HugrPort = TypedPort + +data Value = + IntV Int + | FloatV Double + | BoolV Bool + | VecV [Value] + | ThunkV BratThunk + | KernelV HugrKernel + +data BratThunk = + BratClosure (EvalEnv Brat) Name Name -- Captured environment, src node, tgt node + | BratPrim String String (CTy Brat Z) + +data HugrKernel = + HugrFunc NodeId FunctionType -- Either a user-defined function + | HugrOp String String FunctionType [Value] -- or an operation + deriving Show + +instance Show Value where + show (IntV x) = show x + show (FloatV x) = show x + show (BoolV x) = show x + show (VecV xs) = show xs + show (ThunkV _) = "" + show (KernelV k) = "Kernel (" ++ show k ++ ")" + +-- The data we're tracking for each port in the Brat graph +type family PortData (m :: Mode) where + -- In Brat mode, we track a value for each port + PortData Brat = Value + -- In Kernel mode, we track a Hugr port for each Brat port + -- in the Hugr that is currently under construction + PortData Kernel = HugrPort + + +type EvalEnv m = M.Map OutPort (PortData m) + +data EvalState = EvalState + { evaledBratPorts :: EvalEnv Brat + , evaledKernelPorts :: EvalEnv Kernel + , moduleNode :: Maybe NodeId + , currentParent :: Maybe NodeId + } + +type Eval a = StateT EvalState Compile a + + +emptyEvalEnv = EvalState + { evaledBratPorts = M.empty + , evaledKernelPorts = M.empty + , moduleNode = Nothing + , currentParent = Nothing +} + +getEvaled :: Modey m -> Eval (M.Map OutPort (PortData m)) +getEvaled Braty = gets evaledBratPorts +getEvaled Kerny = gets evaledKernelPorts + +putEvaled :: Modey m -> M.Map OutPort (PortData m) -> Eval () +putEvaled Braty e = get >>= \st -> put (st { evaledBratPorts = e }) +putEvaled Kerny e = get >>= \st -> put (st { evaledKernelPorts = e }) + +getModuleNode :: Eval NodeId +getModuleNode = get >>= \st -> case moduleNode st of + Just node -> pure node + Nothing -> do + id <- lift $ freshNode "module" + lift $ addOp (OpMod $ ModuleOp id) id + put (st { moduleNode = Just id }) + pure id + +evalPort :: Modey m -> OutPort -> Eval (PortData m) +evalPort my port@(Ex node offset) = getEvaled my >>= \evaled -> case M.lookup port evaled of + Just v -> return v + Nothing -> do + graph@(nodes, _) <- lift $ gets bratGraph + inputs <- forM (fst3 <$> wiresTo node graph) (evalPort my) -- TODO: Very inefficient + outputs <- case (my, nodes M.! node) of + (Braty, BratNode thing _ _) -> evalNode Braty thing node inputs + (Kerny, KernelNode thing _ _) -> evalNode Kerny thing node inputs + _ -> error "Internal error: Brat vs kernel node mismatch" + putEvaled my $ evaled `M.union` M.fromList [(Ex node i, v) | (i, v) <- zip [0..] outputs] + pure $ outputs !! offset + +evalNode :: Modey m -> NodeType m -> Name -> [PortData m] -> Eval [PortData m] +evalNode _ Source node _ = error $ "Internal error: Source should be in evaluated state: " ++ show node +evalNode _ Target _ inputs = pure inputs +evalNode _ Id _ inputs = pure inputs +evalNode Braty (Const term) _ [] = pure [evalSimpleTerm term] +evalNode Braty (Constructor con) _ inputs = pure [evalConstructor con inputs] +evalNode Braty (ArithNode op) _ inputs = pure [evalArith op inputs] +evalNode Braty (PatternMatch clauses) _ inputs = evalBratMatch (toList clauses) inputs +evalNode Braty (Eval thunk) _ inputs = evalPort Braty thunk >>= \case + ThunkV th -> evalBratCall th inputs + v -> error $ "Internal error: Not a thunk: " ++ show v +evalNode Kerny (Splice kernel) _ inputs = get >>= \st -> + -- Spliced kernel is Brat value + lift (evalStateT (evalPort Braty kernel) st) >>= \case + KernelV k -> gets currentParent >>= \(Just parent) -> lift (evalKernelSplice k parent inputs) + _ -> error "Internal error: Not a kernel value" +evalNode Braty (Box venv src tgt) node [] = do + graph <- lift $ gets bratGraph + case fst graph M.! node of + (BratNode _ _ [(_, VFun Braty _)]) -> evalBratBox venv src tgt + (BratNode _ _ [(_, VFun Kerny cty)]) -> evalKernelBox node src tgt cty + _ -> error "Internal error: Unexpected box signature" +evalNode Braty (Prim (extension, op)) node [] = do + graph <- lift $ gets bratGraph + case fst graph M.! node of + (BratNode _ _ [(_, VFun Braty cty)]) -> pure [ThunkV (BratPrim extension op cty)] + (BratNode _ _ [(_, VFun Kerny cty)]) -> pure [KernelV (HugrOp extension op (body $ compileCTy cty) [])] + _ -> error "Internal error: Unexpected prim signature" +evalNode _ thing _ _ = error $ "Internal error: Unexpected node in Brat box: " ++ show thing + + +evalBratBox :: VEnv -> Name -> Name -> Eval [Value] +evalBratBox venv src tgt = do + -- Make a closure that captures the entire venv. Haskells laziness ensures + -- that we won't run into problems with recursive definitions + let envPorts = map (fst . first end) (concat $ M.elems venv) + envVals <- forM envPorts (evalPort Braty) + let env = M.fromList (zip envPorts envVals) + pure [ThunkV (BratClosure env src tgt)] + +evalBratCall :: BratThunk -> [Value] -> Eval [Value] +evalBratCall (BratClosure env src tgt) inputs = do + st <- get + graph <- lift $ gets bratGraph + lift $ evalStateT (forM (wiresTo tgt graph) (\(port, _, _) -> evalPort Braty port)) + (st { evaledBratPorts = env `M.union` M.fromList (zip (Ex src <$> [0..]) inputs) + , evaledKernelPorts = M.empty }) +evalBratCall (BratPrim extension op (inRo :->> RPr (_, VFun Kerny cty) R0)) inputs = do + let bratInTys = compileRo inRo + let PolyFuncType _ (FunctionType inTys outTys _) = compileCTy cty + pure [KernelV (HugrOp extension op (FunctionType (bratInTys ++ inTys) outTys []) inputs)] +evalBratCall _ _ = error "todo" + +evalKernelBox :: Name -> Name -> Name -> CTy Kernel Z -> Eval [Value] +evalKernelBox node src tgt cty = do + graph <- lift $ gets bratGraph + -- Build a new Hugr function definition + let name = "" -- TODO + let polyFunTy@(PolyFuncType _ funTy@(FunctionType inTys outTys _)) = compileCTy cty + moduleNode <- getModuleNode + defNode <- lift $ addNode name (OpDefn $ FuncDefn moduleNode name polyFunTy) + let kernelValue = KernelV (HugrFunc defNode funTy) + -- Compile the kernel + st <- get + inpNode <- lift $ addNode "Input" (OpIn $ InputNode defNode inTys) + outputs <- lift $ evalStateT (forM (wiresTo tgt graph) (\(port, _, _) -> evalPort Kerny port)) + -- Mark the kernel port as defined to enable recursive calls + (st { evaledBratPorts = evaledBratPorts st `M.union` M.fromList [(Ex node 0, kernelValue)] + , evaledKernelPorts = M.fromList (zip (Ex src <$> [0..]) (zip (Port inpNode <$> [0..]) inTys)) + , currentParent = Just defNode }) + outNode <- lift $ addNode "Output" (OpOut $ OutputNode defNode outTys) + lift $ forM_ (zip outputs [0..]) (\((p, _), i) -> addEdge (p, Port outNode i)) + pure [kernelValue] + +evalKernelSplice :: HugrKernel -> NodeId -> [HugrPort] -> Compile [HugrPort] +evalKernelSplice (HugrFunc funcNode funcTy@(FunctionType _ outTys _)) parent inputs = do + callNode <- addNode "Call" (OpCall (CallOp parent funcTy)) + forM_ inputs (\(p, _) -> addEdge (p, Port callNode 0)) + addEdge (Port funcNode 0, Port callNode (length inputs)) + pure (zip (Port callNode <$> [0..]) outTys) +evalKernelSplice (HugrOp extension op funcTy@(FunctionType _ outTys _) bratInputs) parent inputs = do + bratInputs <- forM bratInputs (loadBratValue parent) + node <- addNode (extension ++ "." ++ op) (OpCustom $ CustomOp parent extension op funcTy []) + forM_ (zip (bratInputs ++ inputs) [0..]) (\((p, _), i) -> addEdge (p, Port node i)) + pure (zip (Port node <$> [0..]) outTys) + +kernelToHugrFunc :: HugrKernel -> Eval HugrKernel +kernelToHugrFunc k@(HugrFunc _ _) = pure k +kernelToHugrFunc k@(HugrOp extension op funcTy@(FunctionType inTys outTys _) _) = do + moduleNode <- getModuleNode + let name = extension ++ "." ++ op + defNode <- lift $ addNode name (OpDefn $ FuncDefn moduleNode name (PolyFuncType [] funcTy)) + inpNode <- lift $ addNode "Input" (OpIn $ InputNode defNode inTys) + let inputs = zip (Port inpNode <$> [0..]) inTys + outputs <- lift $ evalKernelSplice k defNode inputs + outNode <- lift $ addNode "Output" (OpOut $ OutputNode defNode outTys) + lift $ forM_ (zip outputs [0..]) (\((p, _), i) -> addEdge (p, Port outNode i)) + pure $ HugrFunc defNode funcTy + +evalBratMatch :: [(TestMatchData Brat, Name)] -> [Value] -> Eval [Value] +evalBratMatch ((TestMatchData _ (MatchSequence matchInputs tests matchOutputs), rhs) : rest) inputs = do + -- Add the inputs to the port map + evaled <- getEvaled Braty + putEvaled Braty $ evaled `M.union` M.fromList (zip (end . fst <$> matchInputs) inputs) + -- Run the tests. TODO: Use something like andM instead + result <- and <$> forM tests evalTest + case result of + True -> do + outputs <- forM matchOutputs (evalPort Braty . end . fst) + evalPort Braty (Ex rhs 0) >>= \case + ThunkV th -> evalBratCall th outputs + _ -> error "Internal error: Not a thunk" + False -> evalBratMatch rest inputs +evalBratMatch [] _ = error "No matching clause" + +evalTest :: (Src, PrimTest (BinderType Brat)) -> Eval Bool +evalTest (inputSrc, test) = do + input <- evalPort Braty (end inputSrc) + case test of + PrimLitTest term -> pure $ testLiteral term input + PrimCtorTest ctor ty _ outSrcs -> do + case testCtor ty ctor input of + Nothing -> pure False + Just outputs -> do + evaled <- getEvaled Braty + putEvaled Braty $ evaled `M.union` M.fromList (zip (end . fst <$> outSrcs) outputs) + pure True + +testLiteral :: SimpleTerm -> Value -> Bool +testLiteral (Num x) (IntV y) = x == y +testLiteral (Float x) (FloatV y) = x == y +testLiteral _ _ = error "Internal error: Unexpected literal test" + +testCtor :: QualName -> QualName -> Value -> Maybe [Value] +testCtor CBool CTrue (BoolV True) = Just [] +testCtor CBool CFalse (BoolV False) = Just [] +testCtor CNat CZero (IntV 0) = Just [] +testCtor CNat CSucc (IntV x) | x > 0 = Just [IntV (x - 1)] +testCtor CVec CNil (VecV []) = Just [] +testCtor CVec CCons (VecV (v:vs)) = Just [v, VecV vs] +testCtor _ _ _ = Nothing + +evalConstructor :: QualName -> [Value] -> Value +evalConstructor CTrue [] = BoolV True +evalConstructor CFalse [] = BoolV False +evalConstructor CZero [] = IntV 0 +evalConstructor CNil [] = VecV [] +evalConstructor _ _ = error "Internal error: Unhandled constructor" + +evalSimpleTerm :: SimpleTerm -> Value +evalSimpleTerm (Num x) = IntV x +evalSimpleTerm (Float x) = FloatV x +evalSimpleTerm _ = error "todo" + +evalArith :: ArithOp -> [Value] -> Value +evalArith op [IntV x, IntV y] = IntV $ case op of + Add -> x + y + Sub -> y - x -- What?? + Mul -> x * y + Div -> div x y + Pow -> x ^ y +evalArith op [FloatV x, FloatV y] = FloatV $ case op of + Add -> x + y + Sub -> x - y + Mul -> x * y + Div -> x / y + Pow -> x ** y +evalArith _ _ = error "Bad arith inputs" + +bratValueToHugr :: Value -> (HugrType, HugrValue) +bratValueToHugr (IntV x) = (hugrInt, hvInt x) +bratValueToHugr (FloatV x) = (hugrFloat, hvFloat x) +bratValueToHugr _ = error "todo" + +loadBratValue :: NodeId -> Value -> Compile TypedPort +loadBratValue parent v = do + let (ty, hugrValue) = bratValueToHugr v + const <- addNode "Const" (OpConst $ ConstOp parent hugrValue) + load <- addNode "LoadConst" (OpLoadConstant $ LoadConstantOp parent ty ) + addEdge (Port const 0, Port load 0) + pure (Port load 0, ty) + +-- buildKernelMatch :: [(TestMatchData Kernel, Name)] -> [Value] -> Eval [Value] +-- buildKernelMatch ((TestMatchData _ (MatchSequence matchInputs tests matchOutputs), rhs) : rest) inputs = do +-- _ + + + +evalMain :: Name -> [Value] -> Eval [Value] +evalMain main inputs = evalPort Braty (Ex main 0) >>= \case + ThunkV th -> case inputs of + [] -> error "Missing arguments to entry point" + inputs -> evalBratCall th inputs >>= \case + [KernelV k] -> pure . KernelV <$> kernelToHugrFunc k + vs -> pure vs + KernelV k -> case inputs of + [] -> pure . KernelV <$> kernelToHugrFunc k + _ -> error "Entry point is a kernel. Cannot supply arguments" + v -> pure [v] + +valuesOrHugr :: [Value] -> Compile (Either [Value] (Hugr Int)) +valuesOrHugr [KernelV _] = do + ns <- gets nodes + es <- gets edges + pure . Right $ renameAndSortHugr ns es +valuesOrHugr vs = pure (Left vs) + +run :: Store -> Namespace -> Graph -> Name -> [Value] -> Either [Value] (Hugr Int) +run store ns graph main inputs = + evalState (evalStateT (evalMain main inputs) emptyEvalEnv >>= valuesOrHugr) (emptyCS graph ns store) + diff --git a/brat/Brat/Compiler.hs b/brat/Brat/Compiler.hs index 3414bc85..03852789 100644 --- a/brat/Brat/Compiler.hs +++ b/brat/Brat/Compiler.hs @@ -3,6 +3,7 @@ module Brat.Compiler (printAST ,writeDot ,compileFile ,compileAndPrintFile + ,runFileAndPrintResults ,CompilingHoles(..) ) where @@ -19,6 +20,13 @@ import Control.Monad (when) import Control.Monad.Except import qualified Data.ByteString.Lazy as BS import System.Exit (die) +import Brat.Compile.Interpreter (run, Value) +import Data.Maybe (fromMaybe) +import Brat.QualName (QualName(..)) +import qualified Data.Map as M +import Brat.Syntax.Port (NamedPort(..), OutPort (..)) +import Data.Hugr +import Data.Aeson (encode) printDeclsHoles :: [FilePath] -> String -> IO () printDeclsHoles libDirs file = do @@ -84,3 +92,23 @@ compileAndPrintFile :: [FilePath] -> String -> IO () compileAndPrintFile libDirs file = compileFile libDirs file >>= \case Right bs -> BS.putStr bs Left err -> die (show err) + +runFile :: [FilePath] -> String -> Maybe String -> [Value] -> IO (Either CompilingHoles (Either [Value] (Hugr Int))) +runFile libDirs file function inputs = do + let (checkRoot, newRoot) = split "checking" root + env <- runExceptT $ loadFilename checkRoot libDirs file + (venv, _, holes, defs, outerGraph) <- eitherIO env + -- Lookup the node corresponding to entry point + let entry = case venv M.!? PrefixName [] (fromMaybe "main" function) of + Just [(NamedPort (Ex node _) _, _)] -> node + _ -> error "Couldn't find entry point" + case holes of + [] -> Right <$> evaluate -- turns 'error' into IO 'die' + (run defs newRoot outerGraph entry inputs) + hs -> pure $ Left (CompilingHoles hs) + +runFileAndPrintResults :: [FilePath] -> String -> Maybe String -> [Value] -> IO () +runFileAndPrintResults libDirs file function inputs = runFile libDirs file function inputs >>= \case + Right (Left vs) -> print vs + Right (Right hugr) -> BS.putStr (encode hugr) + Left err -> die (show err) diff --git a/brat/Brat/Graph.hs b/brat/Brat/Graph.hs index 50bad752..d0047192 100644 --- a/brat/Brat/Graph.hs +++ b/brat/Brat/Graph.hs @@ -115,6 +115,9 @@ toGraph (ns, ws) = G.graphFromEdges adj wiresFrom :: Name -> Graph -> [Wire] wiresFrom src (_, ws) = [ w | w@(Ex a _, _, _) <- ws, a == src ] +wiresTo :: Name -> Graph -> [Wire] +wiresTo tgt (_, ws) = [ w | w@(_, _, In a _) <- ws, a == tgt ] + lookupNode :: Name -> Graph -> Maybe Node lookupNode name (ns, _) = M.lookup name ns diff --git a/brat/app/Main.hs b/brat/app/Main.hs index bac393b7..55f3a0ba 100644 --- a/brat/app/Main.hs +++ b/brat/app/Main.hs @@ -2,11 +2,14 @@ import Brat.Compiler import Control.Monad (when) import Options.Applicative +import Brat.Compile.Interpreter (Value(IntV)) data Options = Opt { ast :: Bool, dot :: String, compile :: Bool, + run :: String, + runArgs :: [Int], file :: String, libs :: String, raw :: Bool @@ -15,6 +18,10 @@ data Options = Opt { compileFlag :: Parser Bool compileFlag = switch (long "compile" <> short 'c' <> help "Compile to TIERKREIS") +runOption = strOption (long "run" <> short 'r' <> value "" <> help "Run in interpreter") + +runArgsOptions = many $ option auto (long "args" <> help "Run in interpreter") + astFlag = switch (long "ast" <> help "Print desugared BRAT syntax tree") rawFlag = switch (long "raw" <> help "Print raw BRAT syntax tree") @@ -24,7 +31,7 @@ dotOption = strOption (long "dot" <> value "" <> help "Write graph in Dot format libOption = strOption (long "lib" <> value "" <> help "Look in extra directories for libraries (delimited with ;)") opts :: Parser Options -opts = Opt <$> astFlag <*> dotOption <*> compileFlag <*> strArgument (metavar "FILE") <*> libOption <*> rawFlag +opts = Opt <$> astFlag <*> dotOption <*> compileFlag <*> runOption <*> runArgsOptions <*> strArgument (metavar "FILE") <*> libOption <*> rawFlag -- Parse a list of library directories delimited by a semicolon parseLibs :: String -> [String] @@ -39,4 +46,5 @@ main = do when (ast || raw) $ printAST raw ast file let libDirs = parseLibs libs when (dot /= "") $ writeDot libDirs file dot - if compile then compileAndPrintFile libDirs file else printDeclsHoles libDirs file + if compile then compileAndPrintFile libDirs file else + if run /= "" then runFileAndPrintResults libDirs file (Just run) (IntV <$> runArgs) else printDeclsHoles libDirs file diff --git a/brat/brat.cabal b/brat/brat.cabal index 3873bcf1..fc9648ae 100644 --- a/brat/brat.cabal +++ b/brat/brat.cabal @@ -44,10 +44,10 @@ common warning-flags -Wno-unused-do-bind -Wno-missing-signatures -Wno-noncanonical-monoid-instances - -Werror=unused-imports + -- -Werror=unused-imports -Werror=unused-matches -Werror=missing-methods - -Werror=unused-top-binds + -- -Werror=unused-top-binds -Werror=unused-local-binds -Werror=redundant-constraints -Werror=orphans @@ -72,6 +72,7 @@ library Brat.Checker.SolvePatterns, Brat.Checker.Types, Brat.Compile.Hugr, + Brat.Compile.Interpreter, Brat.Constructors, Brat.Constructors.Patterns, Brat.Error, From 7194313d3ba4ee48a3875d1feff800d06821b16f Mon Sep 17 00:00:00 2001 From: Mark Koch Date: Thu, 4 Dec 2025 17:13:18 +0000 Subject: [PATCH 2/2] Add qubit-money-bit hasochism tracking --- brat/Brat/Constructors.hs | 21 +++++++++++++++- brat/Brat/Syntax/Value.hs | 50 +++++++++++++++++++++++++++++++++++++++ brat/Hasochism.hs | 31 ++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 1 deletion(-) diff --git a/brat/Brat/Constructors.hs b/brat/Brat/Constructors.hs index 16b196bc..d19301b4 100644 --- a/brat/Brat/Constructors.hs +++ b/brat/Brat/Constructors.hs @@ -1,13 +1,14 @@ module Brat.Constructors where import qualified Data.Map as M +import Data.Type.Equality ((:~:)(..)) import Brat.Constructors.Patterns import Brat.QualName (QualName, plain) import Brat.Syntax.Common import Brat.Syntax.Value import Bwd -import Hasochism (N(..), Ny(..)) +import Hasochism (N(..), Ny(..), Some(..), integer2Ny, mulL) -- TODO: Enforce the invariant that the number of pattern variables is n data CtorArgs m where @@ -169,3 +170,21 @@ natConstructors = M.fromList ,(plain "full", (Nothing, nFull)) ,(plain "zero", (Just NP0, id)) ] + + +kernTy :: Val Z -> Some KernTy +kernTy (VCon CQubit []) = Some KTQubit +kernTy (VCon CMoney []) = Some KTMoney +kernTy (VCon CBool []) = Some KTBit +kernTy (VCon CBit []) = Some KTBit +-- Vectors with constant length +kernTy (VCon CVec [elems, VNum (NumValue size Constant0)]) = case (kernTy elems, integer2Ny size) of + (Some elemsTy, Some sizey) -> case mul3 (ktN3y elemsTy) sizey of + Some qmb -> Some (KTVec elemsTy sizey qmb) +kernTy _ = error "kernTy: Not of kind $" + +kernRo :: Ro Kernel Z top -> (top :~: Z, Some KernTy) +kernRo R0 = (Refl, Some KTUnit) +kernRo (RPr (_, ty) ro) = case (kernTy ty, kernRo ro) of + (Some kty, (Refl, Some kro)) -> case add3 (ktN3y kty) (ktN3y kro) of + Some qmb -> (Refl, Some (KTPair kty kro qmb)) diff --git a/brat/Brat/Syntax/Value.hs b/brat/Brat/Syntax/Value.hs index 4bc71fcd..09d91e98 100644 --- a/brat/Brat/Syntax/Value.hs +++ b/brat/Brat/Syntax/Value.hs @@ -618,3 +618,53 @@ stkLen (zx :<< _) = Sy (stkLen zx) numValIsConstant :: NumVal (VVar Z) -> Maybe Integer numValIsConstant (NumValue up Constant0) = pure up numValIsConstant _ = Nothing + + +---------------------- Kernel Types ---------------------- + +data N3 = QMB { numQubits :: N, numMoney :: N, numBits :: N } + +data Add3 :: N3 -> N3 -> N3 -> Type where + Add3 :: AddL ql qr q -> AddL ml mr m -> AddL bl br b -> Add3 (QMB ql ml bl) (QMB qr mr br) (QMB q m b) + +data Mul3 :: N3 -> N -> N3 -> Type where + Mul3 :: MulL q n qn -> MulL m n mn -> MulL b n bn -> Mul3 (QMB q m b) n (QMB qn mn bn) + +data N3y :: N3 -> Type where + QMBy :: Ny q -> Ny m -> Ny b -> N3y (QMB q m b) + +add3 :: N3y l -> N3y r -> Some (Add3 l r) +add3 (QMBy ql ml bl) (QMBy qr mr br) = case (addL ql qr, addL ml mr, addL bl br) of + (Some q, Some m, Some b) -> Some (Add3 q m b) + +add3Tot :: Add3 l r t -> N3y t +add3Tot (Add3 aq am ab) = QMBy (addTot aq) (addTot am) (addTot ab) + +mul3 :: N3y x -> Ny n -> Some (Mul3 x n) +mul3 (QMBy q m b) n = case (mulL q n, mulL m n, mulL b n) of + (Some qn, Some mn, Some bn) -> Some (Mul3 qn mn bn) + +mul3Tot :: Mul3 x n t -> N3y t +mul3Tot (Mul3 q m b) = QMBy (mulTot q) (mulTot m) (mulTot b) + +type ZERO3 = QMB Z Z Z +type QUBIT3 = QMB (S Z) Z Z +type MONEY3 = QMB Z (S Z) Z +type BIT3 = QMB Z Z (S Z) + +data KernTy :: N3 -> Type where + KTQubit :: KernTy QUBIT3 + KTMoney :: KernTy MONEY3 + KTBit :: KernTy BIT3 + KTUnit :: KernTy ZERO3 + KTPair :: KernTy l -> KernTy r -> Add3 l r t -> KernTy t + KTVec :: KernTy x -> Ny n -> Mul3 x n t -> KernTy t + +ktN3y :: KernTy x -> N3y x +ktN3y KTQubit = QMBy (Sy Zy) Zy Zy +ktN3y KTMoney = QMBy Zy (Sy Zy) Zy +ktN3y KTBit = QMBy Zy Zy (Sy Zy) +ktN3y KTUnit = QMBy Zy Zy Zy +ktN3y (KTPair _ _ a) = add3Tot a +ktN3y (KTVec _ _ m) = mul3Tot m + diff --git a/brat/Hasochism.hs b/brat/Hasochism.hs index c0bb7d71..5221934c 100644 --- a/brat/Hasochism.hs +++ b/brat/Hasochism.hs @@ -13,6 +13,11 @@ ny2int :: Ny n -> Int ny2int Zy = 0 ny2int (Sy n) = 1 + ny2int n +integer2Ny :: Integer -> Some Ny +integer2Ny 0 = Some Zy +integer2Ny n | n > 0 = case integer2Ny (n - 1) of Some x -> Some (Sy x) +integer2Ny _ = error "integer2Ny: negative" + instance TestEquality Ny where testEquality Zy Zy = Just Refl testEquality (Sy n) (Sy m) | Just Refl <- testEquality n m = Just Refl @@ -31,3 +36,29 @@ newtype Flip (t :: a -> b -> Type) (y :: b) (x :: a) instance Show (t a b) => Show (Flip t b a) where show = show . getFlip + +-- Not to be confused with AddR in Value.hs where the arguments are flipped +data AddL :: N -> N -> N -> Type where + AddLZ :: Ny out -> AddL Z out out + AddLS :: AddL inn out tot -> AddL (S inn) out (S tot) + +data MulL :: N -> N -> N -> Type where + MulLZ :: Ny inn -> MulL inn Z Z + MulLS :: MulL inn mul prd -> AddL inn prd tot -> MulL inn (S mul) tot + +addL :: Ny l -> Ny r -> Some (AddL l r) +addL Zy out = Some (AddLZ out) +addL (Sy inn) out = case addL inn out of Some tot -> Some (AddLS tot) + +addTot :: AddL l r t -> Ny t +addTot (AddLZ out) = out +addTot (AddLS a) = Sy (addTot a) + +mulTot :: MulL l r t -> Ny t +mulTot (MulLZ _) = Zy +mulTot (MulLS _ a) = addTot a + +mulL :: Ny l -> Ny r -> Some (MulL l r) +mulL inn Zy = Some (MulLZ inn) +mulL inn (Sy mul) = case mulL inn mul of Some prd -> case addL inn (mulTot prd) of Some tot -> Some (MulLS prd tot) +