-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a basic parser for NKI kernels and test cases from the public NKI documentation. The parser builds on top of the standard parser contained in the python ast library.
- Loading branch information
Showing
28 changed files
with
4,574 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
.lake/** | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,8 @@ | ||
def main : IO Unit := | ||
IO.println s!"Hello, NKL!" | ||
import NKL | ||
|
||
def main (args : List String) : IO Unit := | ||
match args with | ||
| .nil => IO.println s!"Hello, NKL!" | ||
| .cons x _ => do | ||
let s <- IO.FS.readFile x | ||
NKL.parse_json s |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/- | ||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Paul Govereau | ||
-/ | ||
|
||
import NKL.NKI | ||
|
||
namespace NKL | ||
|
||
instance : ToString Const where | ||
toString | ||
| .nil => "None" | ||
| .bool b => toString b | ||
| .int i => toString i | ||
| .float f => toString f | ||
| .string s => s | ||
|
||
|
||
mutual | ||
private partial def exps_ s l := String.intercalate s (List.map expr l) | ||
private partial def exps := exps_ "," | ||
private partial def ndxs l := String.intercalate "," (List.map ndx l) | ||
|
||
private partial def expr : Expr -> String | ||
| .value c => toString c | ||
| .bvar s | .var s _ => s | ||
| .subscript e ix => expr e ++ "[" ++ ndxs ix ++ "]" | ||
| .binop op l r => op ++ "(" ++ expr l ++ "," ++ expr r ++ ")" | ||
| .cond e thn els => expr thn ++ " if " ++ expr e ++ " else " ++ expr els | ||
| .tuple es => "(" ++ exps es ++ ")" | ||
| .list es => "[" ++ exps es ++ "]" | ||
| .call f es => expr f ++ "(" ++ exps es ++ ")" | ||
| .gridcall f ix es => expr f ++ "[" ++ ndxs ix ++ "](" ++ exps es ++ ")" | ||
|
||
private partial def ndx : Index -> String | ||
| .coord e => expr e | ||
| .slice l u s => exps_ ":" [l,u,s] | ||
| .dots => "..." | ||
end | ||
|
||
instance : ToString Expr where | ||
toString := expr | ||
|
||
instance : ToString Index where | ||
toString := ndx | ||
|
||
mutual | ||
private partial def stmts sp l := | ||
String.intercalate "\n" $ List.map (stmt sp) l | ||
|
||
private partial def stmt (sp : String) (stmt : Stmt) : String := | ||
let stmts := stmts (sp ++ " ") | ||
sp ++ match stmt with | ||
| .ret e => s!"ret {e}" | ||
| .assign x e => s!"{x} = {e}" | ||
| .ifstm e thn els => s!"if ({e}):\n{stmts thn}¬{sp}else:\n{stmts els}" | ||
| .forloop x e b => s!"for {x} in {expr e}:\n{stmts b}" | ||
| .check e => "assert(" ++ expr e ++ ")" | ||
end | ||
|
||
instance : ToString Stmt where | ||
toString := stmt "" | ||
|
||
def print_nki (f : Fun) : IO Unit := do | ||
IO.println $ f.name ++"("++ String.intercalate "," f.args ++")" | ||
IO.println $ stmts " " f.body | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
__all__ = [ | ||
'getting_started', | ||
'layout', | ||
'index', | ||
'mm', | ||
'prof', | ||
'average_pool', | ||
'fused_mamba', | ||
'layernorm', | ||
'matmul', | ||
'rmsnorm', | ||
'sd_attention', | ||
'tensor_addition', | ||
'transpose2d', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
""" | ||
Copyright (C) 2024, Amazon.com. All Rights Reserved | ||
NKI implementation for average pool 2D NKI tutorial. | ||
""" | ||
import numpy as np | ||
import nki | ||
import nki.language as nl | ||
|
||
def tensor_avgpool_kernel_(in_tensor, out_tensor, pool_size): | ||
"""NKI kernel to compute a 2D avg-pool operation | ||
Args: | ||
in_tensor: an input tensor, of shape C x H x W | ||
pool_size: an integer representing a (square) pool-window size | ||
out_tensor: the resulting output tensor, of shape C x (H/pool_size) x (W/pool_size) | ||
""" | ||
|
||
# Get input/output dimensions | ||
sz_cin, sz_hin, sz_win = in_tensor.shape | ||
sz_cout, sz_hout, sz_wout = out_tensor.shape | ||
assert sz_cin == sz_cout | ||
|
||
# Set relevant sizes | ||
sz_p = sz_cin | ||
sz_pool = pool_size | ||
|
||
# Generate tensor h/w index patterns | ||
# 3D indexing according to [C, H, W] | ||
i_p = nl.arange(sz_p)[:, None, None] # 3D for | ||
i_win = nl.arange(sz_win)[None, None, :] | ||
i_hin = nl.arange(sz_hin)[None, :, None] | ||
|
||
i_wout = nl.arange(sz_wout)[None, None, :] | ||
i_hout = nl.arange(sz_hout)[None, :, None] | ||
|
||
# Generate pool index patterns (requires two extra dimensions, for the pool window) | ||
i_0 = nl.arange(sz_p)[:, None, None, None, None] # | ||
i_1 = nl.arange(sz_hin//sz_pool)[None, :, None, None, None] # y_outer | ||
i_2 = nl.arange(sz_pool)[None, None, :, None, None] # y_inner | ||
i_3 = nl.arange(sz_win//sz_pool)[None, None, None, :, None] # x_outer | ||
i_4 = nl.arange(sz_pool)[None, None, None, None, :] # x_inner | ||
|
||
# Load input data from external memory to on-chip memory | ||
# Declare ndarray to force a 3D tensor (temporary requirement) | ||
in_tile = nl.ndarray([sz_p, sz_hin, sz_win], dtype=in_tensor.dtype) | ||
in_tile[:,:,:] = nl.load(in_tensor[i_p, i_hin, i_win]) | ||
|
||
# Perform the pooling operation: | ||
# We use numpy's advanced indexing, in order to extend in_tile to 5D, and then reduce-average two dimension. | ||
# axis[0] is the index for p_dim, and thus doesn't participate in the reduction operation. | ||
# axis[1] and axis[2] together index the rows, with axis[2] responsible for inner strides | ||
# (i.e. inside a pooling window), and axis[1] responsible for the outer strides. As such, we reduce over axis[2]. | ||
# Similarly, axis[3] and axis[4] together index the columns, and we thus reduce over axis[4]. | ||
out_tile = nl.sum(in_tile[i_0, sz_pool*i_1+i_2, sz_pool*i_3+i_4], axis=[2,4]) / (pool_size*pool_size) | ||
|
||
# Store the results back to external memory | ||
nl.store(out_tensor[i_p, i_hout, i_wout], value=out_tile) | ||
|
||
|
||
# Reference NumPy implementation | ||
def np_average_pool_2D(in_tensor, pool_size): | ||
c, h_in, w_in = in_tensor.shape | ||
reshaped = in_tensor.reshape(c, h_in // pool_size, pool_size, w_in // pool_size, pool_size) | ||
return np.nanmean(reshaped, axis=(2, 4)) |
Oops, something went wrong.