Skip to content

Commit

Permalink
feat: simple pretty-printer for KLR
Browse files Browse the repository at this point in the history
This change adds two related things: a pretty-printer for KLR terms,
and tensor names. Tensor names make the pretty printing nicer, but
have a second purpose. By naming all of the tensors, we can scan a KLR
kernel to collect up all of the input, output, and intermediate
tensors that will be needed to run the kernel. For argument tensors,
the generated tensor names are changed to the argument variable names;
this is just for readability.
  • Loading branch information
govereau committed Jan 21, 2025
1 parent 45006ab commit bcaf95d
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 7 deletions.
4 changes: 3 additions & 1 deletion NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean
import NKL.KLR.Pretty
import NKL.Python
import NKL.Trace

namespace NKL
open NKL.KLR

local instance : MonadLift (Except String) IO where
monadLift
Expand All @@ -19,4 +21,4 @@ def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let stmts <- NKL.Trace.runNKIKernel kernel
for s in stmts do
IO.println s!"{repr s}"
IO.println (" " ++ toString s) --s!"{s}\n{repr s}"
1 change: 1 addition & 0 deletions NKL/KLR/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace NKL.KLR
abbrev Dtype := String
abbrev Shape := List Int
structure Tensor where
name : String
dtype : Dtype
shape : Shape
deriving Repr, BEq
Expand Down
6 changes: 3 additions & 3 deletions NKL/KLR/Encode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ private def chkIndex (i : Index) : Bool :=

partial def encExpr : Expr -> ByteArray
| .var s => tag 0x30 [encString s]
| .tensor t => tag 0x31 [encString t.dtype, encList encInt t.shape]
| .tensor t => tag 0x31 [encString t.name, encString t.dtype, encList encInt t.shape]
| .const c => tag 0x32 [encConst c]
| .access e ix => tag 0x33 [encExpr e, encList encIndex ix]
| .call f ax kw => tag 0x34 [encExpr f, encList encExpr ax, encList encKeyword kw]
Expand All @@ -276,7 +276,7 @@ where
partial def decExpr : DecodeM Expr := do
match (<- next) with
| 0x30 => return .var (<- decString)
| 0x31 => return .tensor $ .mk (<- decString) (<- decList decInt)
| 0x31 => return .tensor $ .mk (<- decString) (<- decString) (<- decList decInt)
| 0x32 => return .const (<- decConst)
| 0x33 => return .access (<- decExpr) (<- decList decIndex)
| 0x34 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword)
Expand All @@ -293,7 +293,7 @@ private def ixz := Index.coord (IndexExpr.int 0)

#guard chkExpr nil
#guard chkExpr (.var "var")
#guard chkExpr (.tensor $ .mk "float32" [1,2,3])
#guard chkExpr (.tensor $ .mk "t" "float32" [1,2,3])
#guard chkExpr (.const (.int 1))
#guard chkExpr (.access nil [ixz, ixz, ixz])
#guard chkExpr (.call nil [nil, nil, nil] [("a", nil), ("b", nil)])
Expand Down
84 changes: 84 additions & 0 deletions NKL/KLR/Pretty.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR.Basic

namespace NKL.KLR
open Std

/-
This is a simple pretty printer for KLR terms. At some point, we may want to
make this output valid python syntax that would parse and elaborate to the same
KLR kernel. At the moment, there are too many unknowns to spend time on this.
The format here is just for ease of debugging, feel free to modify as you wish.
-/

private def abracket (f : Format) : Format :=
Format.bracket "<" f ">"

private def ppArgs [ToFormat a] (l : List a) : Format :=
Format.joinSep l ","

def ppTensor (t : Tensor) : Format :=
"%" ++ t.name ++ abracket (t.dtype ++ ":" ++ ppArgs t.shape)

def ppConst : Const -> Format
| .none => "None"
| .bool true => "True"
| .bool false => "False"
| .int i => format i
| .float f => format f
| .string s => "\"" ++ s.push '"'

private def addParens : Nat -> Format -> Format
| 0, f => f
| _, f => f.paren

def ppIndexExpr (n : Nat) : IndexExpr -> Format
| .var x => x
| .int i => format i
| .neg e => "-" ++ ppIndexExpr (n+1) e
| .add l r => addParens n $ ppIndexExpr 1 l ++ "+" ++ ppIndexExpr 1 r
| .mul i e => addParens n $ format i ++ "*" ++ ppIndexExpr 1 e
| .floor e i => addParens n $ ppIndexExpr 1 e ++ "/" ++ format i
| .ceil e i => "ceil" ++ Format.paren (ppIndexExpr 0 e ++","++ format i)
| .mod e i => addParens n $ ppIndexExpr 1 e ++ "%" ++ format i

def ppIndexExpr? : Option IndexExpr -> Format
| none => "None"
| some e => ppIndexExpr 0 e

def ppIndex : Index -> Format
| .ellipsis => "..."
| .coord e => ppIndexExpr? e
| .slice l u s => .joinSep ([l,u,s].map ppIndexExpr?) ":"

partial def ppExpr : Expr -> Format
| .var x => x
| .const c => ppConst c
| .tensor t => ppTensor t
| .access t ix => .fill (ppExpr t ++ .sbracket (.joinSep (ix.map ppIndex) ","))
| .call f args kwargs =>
let args := args.map ppExpr
let kwargs := kwargs.map fun (x,e) => x ++ "=" ++ ppExpr e
.fill (ppExpr f ++ .paren (ppArgs (args ++ kwargs)))

def ppStmt : Stmt -> Format
| .pass => "pass"
| .expr e => ppExpr e
| .ret e => "ret" ++ ppExpr e
| .assign x e => x ++ " = " ++ ppExpr e
| .loop _ _ _ _ _ => "<loop>"

instance : ToFormat Tensor where format := ppTensor
instance : ToFormat Const where format := ppConst
instance : ToFormat IndexExpr where format := ppIndexExpr 0
instance : ToFormat Index where format := ppIndex
instance : ToFormat Expr where format := ppExpr
instance : ToFormat Stmt where format := ppStmt

@[default_instance]
instance [ToFormat a]: ToString a where
toString s := Format.pretty (format s)
13 changes: 10 additions & 3 deletions NKL/Trace/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ partial def expr' : Expr' -> Tracer Item
| .const c => return .term (<- const c)
| .tensor s dty => do
let shape <- s.mapM integer
return .term (.expr (.tensor ⟨ dty, shape ⟩) (.tensor dty shape))
let name <- genName "t".toName
return .term (.expr (.tensor ⟨ name.toString, dty, shape ⟩) (.tensor dty shape))
| .name id _ => lookup_item id.toName
| .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id)
| .tuple l _ => return .term (.tuple (<- l.mapM term))
Expand Down Expand Up @@ -190,6 +191,7 @@ partial def stmt' : Stmt' -> Tracer Unit
partial def bind_args (f : Fun)
(args : List Term)
(kwargs : List (String × Term))
(rename : Bool := false)
: Tracer (List (String × Term)) := do
if f.args.vararg != none || f.args.kwarg != none then
throw "var args not supported"
Expand All @@ -208,16 +210,21 @@ partial def bind_args (f : Fun)
return (x, <- term' e)
else
throw s!"argument {x} not supplied"
-- rename tensors if asked to
let argmap := if rename then argmap.map renameTensors else argmap
return argmap
where
renameTensors : String × Term -> String × Term
| (s, .expr (.tensor t) ty) => (s, .expr (.tensor {t with name := s}) ty)
| other => other

-- For a function call, first evaluate the argument in the current environment.
-- Then enter a new environment and evaluate the function statements.
partial def function_call (f : Fun)
(args : List Term)
(kwargs : List (String × Term))
: Tracer Unit := do
let args <- bind_args f args kwargs
--let args <- args.mapM fun (x,e) => return (x, e)
let args <- bind_args f args kwargs true
withSrc f.source $ enterFun $ do
args.forM fun (x,e) => do extend x.toName e
f.body.forM stmt
Expand Down

0 comments on commit bcaf95d

Please sign in to comment.