From 3dcd5590547f26922a0258d5b4475065617e57b3 Mon Sep 17 00:00:00 2001 From: Paul Govereau Date: Mon, 30 Dec 2024 14:11:56 -0500 Subject: [PATCH] chore: port encoder to KLR The KLR language is the official intermediate format. This patch ports the serialization and deserialization code to this new data type. --- NKL/Encode.lean | 206 ++++++++++++++++++++++++++++++++++-------------- 1 file changed, 145 insertions(+), 61 deletions(-) diff --git a/NKL/Encode.lean b/NKL/Encode.lean index 49c5a6e..7fe38b5 100644 --- a/NKL/Encode.lean +++ b/NKL/Encode.lean @@ -3,14 +3,14 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. Released under Apache 2.0 license as described in the file LICENSE. Authors: Paul Govereau -/ -import NKL.NKI +import NKL.KLR /-! # Serialization and Deserialization -/ -namespace NKL +namespace NKL.KLR -- All of the encode function are pure; decoding uses an instance of EStateM. @@ -25,10 +25,10 @@ def decode (f : DecodeM a) (ba : ByteArray) : Except String a := | .error s _ => .error s def decodeFile (f : DecodeM a) (path : System.FilePath) : IO a := do - let buf <- IO.FS.readBinFile path - match decode f buf with - | .ok x => return x - | .error s => throw $ IO.userError s + let buf <- IO.FS.readBinFile path + match decode f buf with + | .ok x => return x + | .error s => throw $ IO.userError s private def next : DecodeM UInt8 := do let it <- get @@ -132,7 +132,6 @@ private def decString : DecodeM String := do ------------------------------------------------------------------------------ -- Lists are encoded as a length followed by a sequence of encoded values --- TODO: not efficient private def encList (f : a -> ByteArray) (l : List a) : ByteArray := let rec mapa : List a -> ByteArray @@ -151,117 +150,202 @@ private def decList (f : DecodeM a) : DecodeM (List a) := do #guard decode' (decList decInt) (encList encInt [1,2,3]) = some [1,2,3] ------------------------------------------------------------------------------ --- Finally, constants are encoded with a tag followed by the values +-- Options are encoded using a tag followed by the encoded value private def tag (t : UInt8) : List ByteArray -> ByteArray := List.foldl ByteArray.append (.mk #[t]) +private def encOption (f : a -> ByteArray) : Option a -> ByteArray + | .none => tag 0 [] + | .some x => tag 1 [f x] + +private def decOption (f : DecodeM a) : DecodeM (Option a) := do + match (<- next) with + | 0 => return .none + | 1 => f + | t => throw s!"invalid option tag {t}" + +#guard decode' (decOption decInt) (encOption encInt none) = some none +#guard decode' (decOption decInt) (encOption encInt $ some 1) = some (some 1) + +------------------------------------------------------------------------------ +-- Constants are encoded with a tag followed by the values + def encConst : Const -> ByteArray - | .nil => tag 0x00 [] + | .none => tag 0x00 [] | .bool false => tag 0x01 [] | .bool true => tag 0x02 [] | .int i => tag 0x03 [encInt i] | .float f => tag 0x04 [encFloat f] | .string s => tag 0x05 [encString s] - | .dots => tag 0x06 [] def decConst : DecodeM Const := do let val <- next match val with - | 0x00 => return .nil + | 0x00 => return .none | 0x01 => return .bool false | 0x02 => return .bool true | 0x03 => return .int (<- decInt) | 0x04 => return .float (<- decFloat) | 0x05 => return .string (<- decString) - | 0x06 => return .dots | _ => throw s!"Unknown Const tag value {val}" private def chkConst (c: Const) : Bool := (decode' decConst $ encConst c) == some c -#guard chkConst .nil +#guard chkConst .none #guard chkConst (.bool true) #guard chkConst (.bool false) #guard chkConst (.int 1) #guard chkConst (.float 1.0) #guard chkConst (.string "str") -#guard chkConst .dots + +------------------------------------------------------------------------------ +-- Affine Expressions + +def encIndexExpr : IndexExpr -> ByteArray + | .var name => tag 0x10 [encString name] + | .int i => tag 0x11 [encInt i] + | .neg e => tag 0x12 [encIndexExpr e] + | .add l r => tag 0x13 [encIndexExpr l, encIndexExpr r] + | .mul i e => tag 0x14 [encInt i, encIndexExpr e] + | .floor e i => tag 0x15 [encIndexExpr e, encInt i] + | .ceil e i => tag 0x16 [encIndexExpr e, encInt i] + | .mod e i => tag 0x17 [encIndexExpr e, encInt i] + +partial def decIndexExpr : DecodeM IndexExpr := do + match (<- next) with + | 0x10 => return .var (<- decString) + | 0x11 => return .int (<- decInt) + | 0x12 => return .neg (<- decIndexExpr) + | 0x13 => return .add (<- decIndexExpr) (<- decIndexExpr) + | 0x14 => return .mul (<- decInt) (<- decIndexExpr) + | 0x15 => return .floor (<- decIndexExpr) (<- decInt) + | 0x16 => return .ceil (<- decIndexExpr) (<- decInt) + | 0x17 => return .mod (<- decIndexExpr) (<- decInt) + | t => throw s!"Unknown tag in IndexExpr {t}" + +private def chkIE (e: IndexExpr) : Bool := + (decode' decIndexExpr $ encIndexExpr e) == some e + +private def ie_var : IndexExpr := .var "s" + +#guard chkIE (.var "v") +#guard chkIE (.int 1) +#guard chkIE (.neg ie_var) +#guard chkIE (.add ie_var ie_var) +#guard chkIE (.mul 2 ie_var) +#guard chkIE (.floor ie_var 2) +#guard chkIE (.ceil ie_var 2) +#guard chkIE (.mod ie_var 2) + +def encIndex : Index -> ByteArray + | .ellipsis => tag 0x20 [] + | .coord e => tag 0x21 [enc e] + | .range l u s => tag 0x22 [enc l, enc u, enc s] +where + enc := encOption encIndexExpr + +def decIndex : DecodeM Index := do + match (<- next) with + | 0x20 => return .ellipsis + | 0x21 => return .coord (<- dec) + | 0x22 => return .range (<- dec) (<- dec) (<- dec) + | t => throw s!"Unknown tag in Index {t}" +where + dec:= decOption decIndexExpr + +private def chkIndex (i : Index) : Bool := + (decode' decIndex $ encIndex i) == some i + +#guard chkIndex .ellipsis +#guard chkIndex (.coord none) +#guard chkIndex (.coord $ some ie_var) +#guard chkIndex (.range (some ie_var) none none) ------------------------------------------------------------------------------ -- Expressions partial def encExpr : Expr -> ByteArray - | .value c => tag 0x10 [encConst c] - | .bvar s => tag 0x11 [encString s] - | .var s _ => tag 0x12 [encString s] - | .subscript e ix => tag 0x13 [encExpr e, encList encExpr ix] - | .slice l u step => tag 0x14 [encExpr l, encExpr u, encExpr step] - | .binop op l r => tag 0x15 [encString op, encExpr l, encExpr r] - | .cond c t e => tag 0x16 [encExpr c, encExpr t, encExpr e] - | .tuple es => tag 0x17 [encList encExpr es] - | .list es => tag 0x18 [encList encExpr es] - | .call f ax => tag 0x19 [encExpr f, encList encExpr ax] - | .gridcall f ix ax => tag 0x1a [encExpr f, encList encExpr ix, encList encExpr ax] + | .var s => tag 0x30 [encString s] + | .tensor t s => tag 0x31 [encString t, encList encInt s] + | .const c => tag 0x32 [encConst c] + | .tuple es => tag 0x33 [encList encExpr es] + | .list es => tag 0x34 [encList encExpr es] + | .access e ix => tag 0x35 [encExpr e, encList encIndex ix] + | .binop op l r => tag 0x36 [encString op, encExpr l, encExpr r] + | .unop op e => tag 0x37 [encString op, encExpr e] + | .call f ax kw => tag 0x38 [encExpr f, encList encExpr ax, encList encKeyword kw] +where + encKeyword : String × Expr -> ByteArray + | (key, expr) => (encString key).append (encExpr expr) partial def decExpr : DecodeM Expr := do match (<- next) with - | 0x10 => return .value (<- decConst) - | 0x11 => return .bvar (<- decString) - | 0x12 => return .var (<- decString) "" - | 0x13 => return .subscript (<- decExpr) (<- decList decExpr) - | 0x14 => return .slice (<- decExpr) (<- decExpr) (<- decExpr) - | 0x15 => return .binop (<- decString) (<- decExpr) (<- decExpr) - | 0x16 => return .cond (<- decExpr) (<- decExpr) (<- decExpr) - | 0x17 => return .tuple (<- decList decExpr) - | 0x18 => return .list (<- decList decExpr) - | 0x19 => return .call (<- decExpr) (<- decList decExpr) - | 0x1a => return .gridcall (<- decExpr) (<- decList decExpr) (<- decList decExpr) + | 0x30 => return .var (<- decString) + | 0x31 => return .tensor (<- decString) (<- decList decInt) + | 0x32 => return .const (<- decConst) + | 0x33 => return .tuple (<- decList decExpr) + | 0x34 => return .list (<- decList decExpr) + | 0x35 => return .access (<- decExpr) (<- decList decIndex) + | 0x36 => return .binop (<- decString) (<- decExpr) (<- decExpr) + | 0x37 => return .unop (<- decString) (<- decExpr) + | 0x38 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword) | t => throw s!"Unknown tag in Expr {t}" +where + decKeyword : DecodeM (String × Expr) := + return ((<- decString), (<- decExpr)) private def chkExpr (e : Expr) : Bool := (decode' decExpr $ encExpr e) == some e -private def nil := Expr.value .nil +private def nil := Expr.const .none +private def ixz := Index.coord (IndexExpr.int 0) #guard chkExpr nil -#guard chkExpr (.bvar "var") -#guard chkExpr (.var "var" "") -#guard chkExpr (.subscript nil [nil, nil, nil]) -#guard chkExpr (.slice nil nil nil) -#guard chkExpr (.binop "op" nil nil) -#guard chkExpr (.cond nil nil nil) +#guard chkExpr (.var "var") +#guard chkExpr (.tensor "float32" [1,2,3]) +#guard chkExpr (.const (.int 1)) #guard chkExpr (.tuple [nil, nil, nil]) #guard chkExpr (.list [nil, nil, nil]) -#guard chkExpr (.call nil [nil, nil, nil]) -#guard chkExpr (.gridcall nil [nil, nil, nil] [nil, nil, nil]) +#guard chkExpr (.access nil [ixz, ixz, ixz]) +#guard chkExpr (.binop "op" nil nil) +#guard chkExpr (.unop "op" nil) +#guard chkExpr (.call nil [nil, nil, nil] [("a", nil), ("b", nil)]) ------------------------------------------------------------------------------ -- Statements partial def encStmt : Stmt -> ByteArray - | .ret e => tag 0x30 [encExpr e] - | .assign x e => tag 0x31 [encExpr x, encExpr e] - | .ifstm c t e => tag 0x32 [encExpr c, encList encStmt t, encList encStmt e] - | .forloop x e b => tag 0x33 [encString x, encExpr e, encList encStmt b] - | .check e => tag 0x34 [encExpr e] + | .pass => tag 0x40 [] + | .expr e => tag 0x41 [encExpr e] + | .ret e => tag 0x42 [encExpr e] + | .assign x e => tag 0x43 [encString x, encExpr e] + | .loop x l u step body => + tag 0x44 [ encString x, + encIndexExpr l, encIndexExpr u, encIndexExpr step, + encList encStmt body ] partial def decStmt : DecodeM Stmt := do match (<- next) with - | 0x30 => return .ret (<- decExpr) - | 0x31 => return .assign (<- decExpr) (<- decExpr) - | 0x32 => return .ifstm (<- decExpr) (<- decList decStmt) (<- decList decStmt) - | 0x33 => return .forloop (<- decString) (<- decExpr) (<- decList decStmt) - | 0x34 => return .check (<- decExpr) + | 0x40 => return .pass + | 0x41 => return .expr (<- decExpr) + | 0x42 => return .ret (<- decExpr) + | 0x43 => return .assign (<- decString) (<- decExpr) + | 0x44 => do + let x <- decString + let l <- decIndexExpr + let u <- decIndexExpr + let step <- decIndexExpr + let body <- decList decStmt + return .loop x l u step body | t => throw s!"Unknown tag in Stmt {t}" private def chkStmt (s : Stmt) : Bool := (decode' decStmt $ encStmt s) == some s -private def stm := Stmt.check nil - +#guard chkStmt .pass +#guard chkStmt (.expr nil) #guard chkStmt (.ret nil) -#guard chkStmt (.assign nil nil) -#guard chkStmt (.ifstm nil [stm, stm, stm] [stm, stm, stm]) -#guard chkStmt (.forloop "x" nil [stm, stm, stm]) -#guard chkStmt (.check nil) +#guard chkStmt (.assign "x" nil) +#guard chkStmt (.loop "x" ie_var ie_var ie_var [.pass, .pass])