Skip to content

Commit

Permalink
feat: basic parser for NKI kernels
Browse files Browse the repository at this point in the history
Adds a basic parser for NKI kernels and test cases from
the public NKI documentation. The parser builds on top of
the standard parser contained in the python ast library.
  • Loading branch information
govereau committed Oct 29, 2024
1 parent b004bea commit f793a6f
Show file tree
Hide file tree
Showing 28 changed files with 4,574 additions and 29 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.lake/**
__pycache__/
2 changes: 0 additions & 2 deletions Export.lean
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,8 @@ run_meta
h.putStr header
flip List.forM (genPython h)
[ `NKL.Const
, `NKL.BinOp
, `NKL.Expr
, `NKL.Index
, `NKL.Stmt
, `NKL.Arg
, `NKL.Fun
]
10 changes: 8 additions & 2 deletions Main.lean
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
def main : IO Unit :=
IO.println s!"Hello, NKL!"
import NKL

def main (args : List String) : IO Unit :=
match args with
| .nil => IO.println s!"Hello, NKL!"
| .cons x _ => do
let s <- IO.FS.readFile x
NKL.parse_json s
3 changes: 2 additions & 1 deletion NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Paul Govereau
-/
import Lean
import NKL.NKI
import NKL.PrettyPrint

namespace NKL

Expand All @@ -17,4 +18,4 @@ def parse_json (json : String) : IO Unit := do
| .ok jsn => do
match Lean.fromJson? jsn with
| .error str => throw $ .userError str
| .ok (_:Fun) => return ()
| .ok (f:Fun) => print_nki f
42 changes: 21 additions & 21 deletions NKL/NKI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ Authors: Paul Govereau
import Lean

/-!
# Concrete Syntax of NKI kernels
# Syntax of NKI kernels
Representation of the "concrete" syntax of NKI kernels
Representation for the abstract syntax of NKI kernels
generated by the python frontend.
-/

Expand All @@ -22,43 +22,43 @@ inductive Const where
| string (value: String)
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

inductive BinOp where
| And | Or
| Eq | NotEq | Lt | LtE | Gt | GtE
| Add | Sub | Mul | Div
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

mutual
inductive Expr where
| value (c: Const)
| bvar (name: String)
| var (name: String)
| subscript (tensor: String) (ix: Array Index)
| binop (op: BinOp) (left right: Expr)
| call (f: String) (args: Array Expr)
| var (name value: String)
| subscript (tensor: Expr) (ix: List Index)
| binop (op: String) (left right: Expr)
| cond (e thn els: Expr)
| tuple (xs: List Expr)
| list (xs: List Expr)
| call (f: Expr) (args: List Expr)
| gridcall (f: Expr) (ix: List Index) (args: List Expr)
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

inductive Index where
| coord (i : Expr)
| slice (l u step: Expr)
| dots
deriving Repr, BEq, Lean.ToJson, Lean.FromJson
end

inductive Stmt where
| ret(e: Expr)
| assign (x: String) (e: Expr)
| ret (e: Expr)
| assign (x: Expr) (e: Expr)
| ifstm (e : Expr) (thn els: List Stmt)
| forloop (x: String) (iter: Expr) (body: List Stmt)
| gridcall (f: String) (ix: Array Index) (args: Array Expr)
| check (e : Expr)
deriving Repr, BEq, Lean.ToJson, Lean.FromJson

structure Arg where
name : String
type : Option String := .none
value : Option Const := .none
deriving Repr, BEq, Lean.ToJson, Lean.FromJson
--structure Arg where
-- name : String
-- type : Option String := .none
-- value : Option Const := .none
-- deriving Repr, BEq, Lean.ToJson, Lean.FromJson

structure Fun where
name : String
args : Array Arg
args : List String
body : List Stmt
deriving Repr, BEq, Lean.ToJson, Lean.FromJson
68 changes: 68 additions & 0 deletions NKL/PrettyPrint.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/

import NKL.NKI

namespace NKL

instance : ToString Const where
toString
| .nil => "None"
| .bool b => toString b
| .int i => toString i
| .float f => toString f
| .string s => s


mutual
private partial def exps_ s l := String.intercalate s (List.map expr l)
private partial def exps := exps_ ","
private partial def ndxs l := String.intercalate "," (List.map ndx l)

private partial def expr : Expr -> String
| .value c => toString c
| .bvar s | .var s _ => s
| .subscript e ix => expr e ++ "[" ++ ndxs ix ++ "]"
| .binop op l r => op ++ "(" ++ expr l ++ "," ++ expr r ++ ")"
| .cond e thn els => expr thn ++ " if " ++ expr e ++ " else " ++ expr els
| .tuple es => "(" ++ exps es ++ ")"
| .list es => "[" ++ exps es ++ "]"
| .call f es => expr f ++ "(" ++ exps es ++ ")"
| .gridcall f ix es => expr f ++ "[" ++ ndxs ix ++ "](" ++ exps es ++ ")"

private partial def ndx : Index -> String
| .coord e => expr e
| .slice l u s => exps_ ":" [l,u,s]
| .dots => "..."
end

instance : ToString Expr where
toString := expr

instance : ToString Index where
toString := ndx

mutual
private partial def stmts sp l :=
String.intercalate "\n" $ List.map (stmt sp) l

private partial def stmt (sp : String) (stmt : Stmt) : String :=
let stmts := stmts (sp ++ " ")
sp ++ match stmt with
| .ret e => s!"ret {e}"
| .assign x e => s!"{x} = {e}"
| .ifstm e thn els => s!"if ({e}):\n{stmts thn}¬{sp}else:\n{stmts els}"
| .forloop x e b => s!"for {x} in {expr e}:\n{stmts b}"
| .check e => "assert(" ++ expr e ++ ")"
end

instance : ToString Stmt where
toString := stmt ""

def print_nki (f : Fun) : IO Unit := do
IO.println $ f.name ++"("++ String.intercalate "," f.args ++")"
IO.println $ stmts " " f.body

15 changes: 15 additions & 0 deletions interop/examples/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
__all__ = [
'getting_started',
'layout',
'index',
'mm',
'prof',
'average_pool',
'fused_mamba',
'layernorm',
'matmul',
'rmsnorm',
'sd_attention',
'tensor_addition',
'transpose2d',
]
66 changes: 66 additions & 0 deletions interop/examples/average_pool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""
Copyright (C) 2024, Amazon.com. All Rights Reserved
NKI implementation for average pool 2D NKI tutorial.
"""
import numpy as np
import nki
import nki.language as nl

def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size):
"""NKI kernel to compute a 2D avg-pool operation
Args:
in_tensor: an input tensor, of shape C x H x W
pool_size: an integer representing a (square) pool-window size
out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size)
"""

# Get input/output dimensions
sz_cin, sz_hin, sz_win = in_tensor.shape
sz_cout, sz_hout, sz_wout = out_tensor.shape
assert sz_cin == sz_cout

# Set relevant sizes
sz_p = sz_cin
sz_pool = pool_size

# Generate tensor h/w index patterns
# 3D indexing according to [C, H, W]
i_p = nl.arange(sz_p)[:, None, None] # 3D for
i_win = nl.arange(sz_win)[None, None, :]
i_hin = nl.arange(sz_hin)[None, :, None]

i_wout = nl.arange(sz_wout)[None, None, :]
i_hout = nl.arange(sz_hout)[None, :, None]

# Generate pool index patterns (requires two extra dimensions, for the pool window)
i_0 = nl.arange(sz_p)[:, None, None, None, None] #
i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer
i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner
i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer
i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner

# Load input data from external memory to on-chip memory
# Declare ndarray to force a 3D tensor (temporary requirement)
in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype)
in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win])

# Perform the pooling operation:
# We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension.
# axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation.
# axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides
# (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2].
# Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4].
out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size)

# Store the results back to external memory
nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile)


# Reference NumPy implementation
def np_average_pool_2D(in_tensor, pool_size):
c, h_in, w_in = in_tensor.shape
reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size)
return np.nanmean(reshaped, axis=(2, 4))
Loading

0 comments on commit f793a6f

Please sign in to comment.