Skip to content

Commit

Permalink
feat: tracing for python source functions
Browse files Browse the repository at this point in the history
This patch adds basic tracing for user python functions. The main code
is in Python.lean, and depends on definitions in Basic.lean and
NKI.lean, which are incomplete. As more primitives are implemented,
more user kernels will be supported.
  • Loading branch information
govereau committed Jan 21, 2025
1 parent 578dc59 commit 9ffeee1
Show file tree
Hide file tree
Showing 7 changed files with 432 additions and 27 deletions.
13 changes: 4 additions & 9 deletions NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Paul Govereau
-/
import Lean
import NKL.Python
import NKL.Trace

namespace NKL

Expand All @@ -16,12 +17,6 @@ local instance : MonadLift (Except String) IO where
@[export parse_json]
def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let names := kernel.funcs.map fun x => x.fst
let names := String.intercalate "," names
IO.println s!"Found functions: {names}"
for x in kernel.args do
IO.println s!"arg: {repr x}"
for x in kernel.kwargs do
IO.println s!"arg: {repr x}"
for x in kernel.globals do
IO.println s!"global: {repr x}"
let stmts <- NKL.Trace.runNKIKernel kernel
for s in stmts do
IO.println s!"{repr s}"
24 changes: 10 additions & 14 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ then the structure will be populated with:
defaults = [1, 2]
vararg = "args"
kwonlyargs = [d, e]
kw_defaults = [None, 3]
kw_defaults = [("e", 3)]
kwarg = "kwargs"
Note, defaults and kw_defaults are inconsistent in how they treat
missing arguments, but this is just how it works in the python AST.
Note, this is slightly different from the official Python AST, which
encodes the kw_defaults as a list with None for missing defaults.
-/
structure Args where
posonlyargs : List String
Expand All @@ -122,17 +122,13 @@ structure Args where
kwarg : Option String
deriving Repr

def Args.names (ax : Args) : List String :=
let xs := ax.posonlyargs.append ax.args
let xs := match ax.vararg with | none => xs | some x => xs.append [x]
let xs := xs.append ax.kwonlyargs
let xs := match ax.kwarg with | none => xs | some x => xs.append [x]
xs

def Args.all_defaults (ax : Args) : List (String × Expr') :=
let args := ax.posonlyargs ++ ax.args
let dflt := args.reverse.zip ax.defaults.reverse
dflt ++ ax.kw_defaults
def Args.names (args : Args) : List String :=
args.posonlyargs ++ args.args ++ args.kwonlyargs

def Args.all_defaults (args : Args) : List (String × Expr') :=
let pargs := args.posonlyargs ++ args.args
let dflt := pargs.reverse.zip args.defaults.reverse
dflt ++ args.kw_defaults

structure Fun where
source : String
Expand Down
13 changes: 12 additions & 1 deletion NKL/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.Python
import NKL.Trace.Types
import NKL.Trace.Basic
import NKL.Trace.Builtin
--import NKL.Trace.Python
import NKL.Trace.Python
import NKL.Trace.NKI

namespace NKL.Trace

def runNKIKernel (k : NKL.Python.Kernel) : Except String (List NKL.KLR.Stmt) :=
tracer ⟨ .ofList NKIEnv, #[] ⟩ do
traceKernel k
let g <- get
return g.body.toList
3 changes: 0 additions & 3 deletions NKL/Trace/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -118,11 +118,8 @@ def unOp : String -> Term -> TraceM Term
| op, _ => throw s!"unimp {op}"

-- Comparison operators
-- TODO: need to think about comparison of tensors, in NKI this is object equality,
-- but I suspect this doesn't make sense and may be a source of bugs.

def cmpOp : String -> Term -> Term -> TraceM Bool
| "Eq", .expr l _, .expr r _ => return l == r
| s, l, r => throw s!"unsupported comparison operator {s} {repr l} {repr r}"

def compare : Term -> List String -> List Term -> TraceM Term
Expand Down
49 changes: 49 additions & 0 deletions NKL/Trace/NKI.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.Trace.Types
import NKL.Trace.Builtin

/-
# NKI built-ins
This module defines the builtin constants used by tracing for NKI kernels.
-/
namespace NKL.Trace
open NKL.KLR

private def module (s : String) : Name × Item :=
let name := s.toName
(name, .module name)

private def const_var (s : String) : Name × Item :=
let name := s.toName
(name, .term (.expr (.var s) (.any name)))

/-
Note: this object contains a bunch of architecture parameters that
need to be set according to which HW we are compiling for.
TODO: figure out the mechanism for this.
-/
def tile_size : Global :=
let name := "nki.langauge.tile_size".toName
{ name := name
, attr := attrs
, call := uncallable name
}
where
attrs : GlobalAttr
| "pmax" => return .expr (.const $ .int 128) .int
| a => throw s!"unsupported attribute {a}"

def NKIEnv : List (Name × Item) :=
[ module "nki"
, module "nki.language"
, const_var "nki.language.add"
, const_var "nki.language.load"
, const_var "nki.language.store"
, ("nki.language.tile_size".toName, .global tile_size)
]
Loading

0 comments on commit 9ffeee1

Please sign in to comment.