-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This change adds two related things: a pretty-printer for KLR terms, and tensor names. Tensor names make the pretty printing nicer, but have a second purpose. By naming all of the tensors, we can scan a KLR kernel to collect up all of the input, output, and intermediate tensors that will be needed to run the kernel. For argument tensors, the generated tensor names are changed to the argument variable names; this is just for readability.
- Loading branch information
Showing
5 changed files
with
101 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
/- | ||
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
Released under Apache 2.0 license as described in the file LICENSE. | ||
Authors: Paul Govereau | ||
-/ | ||
import NKL.KLR.Basic | ||
|
||
namespace NKL.KLR | ||
open Std | ||
|
||
/- | ||
This is a simple pretty printer for KLR terms. At some point, we may want to | ||
make this output valid python syntax that would parse and elaborate to the same | ||
KLR kernel. At the moment, there are too many unknowns to spend time on this. | ||
The format here is just for ease of debugging, feel free to modify as you wish. | ||
-/ | ||
|
||
private def abracket (f : Format) : Format := | ||
Format.bracket "<" f ">" | ||
|
||
private def ppArgs [ToFormat a] (l : List a) : Format := | ||
Format.joinSep l "," | ||
|
||
def ppTensor (t : Tensor) : Format := | ||
"%" ++ t.name ++ abracket (t.dtype ++ ":" ++ ppArgs t.shape) | ||
|
||
def ppConst : Const -> Format | ||
| .none => "None" | ||
| .bool true => "True" | ||
| .bool false => "False" | ||
| .int i => format i | ||
| .float f => format f | ||
| .string s => "\"" ++ s.push '"' | ||
|
||
private def addParens : Nat -> Format -> Format | ||
| 0, f => f | ||
| _, f => f.paren | ||
|
||
def ppIndexExpr (n : Nat) : IndexExpr -> Format | ||
| .var x => x | ||
| .int i => format i | ||
| .neg e => "-" ++ ppIndexExpr (n+1) e | ||
| .add l r => addParens n $ ppIndexExpr 1 l ++ "+" ++ ppIndexExpr 1 r | ||
| .mul i e => addParens n $ format i ++ "*" ++ ppIndexExpr 1 e | ||
| .floor e i => addParens n $ ppIndexExpr 1 e ++ "/" ++ format i | ||
| .ceil e i => "ceil" ++ Format.paren (ppIndexExpr 0 e ++","++ format i) | ||
| .mod e i => addParens n $ ppIndexExpr 1 e ++ "%" ++ format i | ||
|
||
def ppIndexExpr? : Option IndexExpr -> Format | ||
| none => "None" | ||
| some e => ppIndexExpr 0 e | ||
|
||
def ppIndex : Index -> Format | ||
| .ellipsis => "..." | ||
| .coord e => ppIndexExpr? e | ||
| .slice l u s => .joinSep ([l,u,s].map ppIndexExpr?) ":" | ||
|
||
partial def ppExpr : Expr -> Format | ||
| .var x => x | ||
| .const c => ppConst c | ||
| .tensor t => ppTensor t | ||
| .access t ix => .fill (ppExpr t ++ .sbracket (.joinSep (ix.map ppIndex) ",")) | ||
| .call f args kwargs => | ||
let args := args.map ppExpr | ||
let kwargs := kwargs.map fun (x,e) => x ++ "=" ++ ppExpr e | ||
.fill (ppExpr f ++ .paren (ppArgs (args ++ kwargs))) | ||
|
||
def ppStmt : Stmt -> Format | ||
| .pass => "pass" | ||
| .expr e => ppExpr e | ||
| .ret e => "ret" ++ ppExpr e | ||
| .assign x e => x ++ " = " ++ ppExpr e | ||
| .loop _ _ _ _ _ => "<loop>" | ||
|
||
instance : ToFormat Tensor where format := ppTensor | ||
instance : ToFormat Const where format := ppConst | ||
instance : ToFormat IndexExpr where format := ppIndexExpr 0 | ||
instance : ToFormat Index where format := ppIndex | ||
instance : ToFormat Expr where format := ppExpr | ||
instance : ToFormat Stmt where format := ppStmt | ||
|
||
@[default_instance] | ||
instance [ToFormat a]: ToString a where | ||
toString s := Format.pretty (format s) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters