Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: cleanup KLR definitions #17

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion NKL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.Encode
import NKL.FFI
import NKL.KLR
import NKL.Python
104 changes: 2 additions & 102 deletions NKL/KLR.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,105 +3,5 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/


/-!
# Abstract syntax of Core NKL language

This language is the result of "tracing", and is used as the
portable format, a.k.a. Kernel Language Representation (KLR).
-/

namespace NKL.KLR

-- TODO
inductive Ty where

inductive Const where
| none
| bool (value : Bool)
| int (value : Int)
| float (value : Float)
| string (value : String)
deriving Repr, BEq

namespace Const

-- Python-like rules for conversion to boolean
def isTrue : Const -> Bool
| .none => false
| .bool b => b
| .int i => i != 0
| .float f => f != 0.0
| .string s => s != ""

-- Python-like rules for conversion to integer
def toInt : Const -> Except String Int
| .none => throw "none cannot be converted to an integer"
| .bool true => return 1
| .bool false => return 0
| .int i => return i
| .float f =>
-- Python is a bit strange here, it truncates both
-- positive and negative numbers toward zero
if f < 0.0 then
return (Int.ofNat (Float.floor (-f)).toUInt64.toBitVec.toNat).neg
else
return Int.ofNat (Float.floor f).toUInt64.toBitVec.toNat
| .string s =>
match s.toInt? with
| .none => throw s!"string {s} cannot be converted to an integer"
| .some i => return i

end Const

inductive IndexExpr where
| var (name : String)
| int (i : Int)
| neg (expr : IndexExpr)
| add (left right : IndexExpr)
| mul (scalar : Int) (expr : IndexExpr)
| floor (expr : IndexExpr) (scalar : Int)
| ceil (expr : IndexExpr) (scalar : Int)
| mod (expr : IndexExpr) (scalar : Int)
deriving Repr, BEq

inductive Index where
| ellipsis
| coord (e : Option IndexExpr)
| range (l u step : Option IndexExpr)
deriving Repr, BEq

inductive Expr where
| var (x : String)
| const (c : Const)
| tensor (name : String) (shape : List Int)
| tuple (xs : List Expr)
| list (xs : List Expr)
| access (t : Expr) (ix : List Index)
| binop (op : String) (left right : Expr)
| unop (op : String) (e : Expr)
| call (f : Expr) (args : List Expr) (keywords : List (String × Expr))
deriving Repr, BEq

namespace Expr

-- TODO: Just a place-holder for now
def toAffine : Expr -> Except String IndexExpr
| .var v => return .var v
| .const (.int i) => return .int i
| e => throw s!"toAffine unimp {repr e}"

-- TODO: Just a place-holder for now
def simplify : Expr -> Expr :=
fun x => x

end Expr

inductive Stmt where
| pass
| expr (v : Expr)
| ret (v : Expr)
| assign (x : String) (e : Expr)
| loop (x : String) (l u step : IndexExpr) (body : List Stmt)
deriving Repr, BEq
import NKL.KLR.Basic
import NKL.KLR.Encode
108 changes: 108 additions & 0 deletions NKL/KLR/Basic.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/-
govereau marked this conversation as resolved.
Show resolved Hide resolved
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
--import TensorLib.Tensor

/-!
# Abstract syntax of Core NKL language

This language is the result of "tracing", and is used as the
portable format, a.k.a. Kernel Language Representation (KLR).
-/

namespace NKL.KLR

-- TODO switch to tensor lib
--export TensorLib (Tensor Dtype Shape)
-- Mostly, NKL deals with empty tensors, so just check dtype and shape
-- TODO: talk to Sean about a more general BEq for Tensor
--instance : BEq Tensor where
-- beq t₁ t₂ := t₁.dtype == t₂.dtype && t₁.shape == t₂.shape

abbrev Dtype := String
abbrev Shape := List Int
structure Tensor where
dtype : Dtype
shape : Shape
deriving Repr, BEq

-- TODO
inductive Typ where
govereau marked this conversation as resolved.
Show resolved Hide resolved

inductive Const where
| none
| bool (value : Bool)
| int (value : Int)
| float (value : Float)
| string (value : String)
deriving Repr, BEq

namespace Const

-- Python-like rules for conversion to boolean
def isTrue : Const -> Bool
| .none => false
| .bool b => b
| .int i => i != 0
| .float f => f != 0.0
| .string s => s != ""

-- Python-like rules for conversion to integer
def toInt : Const -> Except String Int
| .none => throw "none cannot be converted to an integer"
| .bool true => return 1
| .bool false => return 0
| .int i => return i
| .float f =>
-- Python is a bit strange here, it truncates both
-- positive and negative numbers toward zero
if f < 0.0 then
return (Int.ofNat (Float.floor (-f)).toUInt64.toNat).neg
else
return Int.ofNat (Float.floor f).toUInt64.toNat
| .string s =>
-- Fortunately, Lean's String.toInt appears to be compatible
-- with Python's int(string) conversion.
match s.toInt? with
govereau marked this conversation as resolved.
Show resolved Hide resolved
| .none => throw s!"string {s} cannot be converted to an integer"
| .some i => return i

end Const

-- This correspondes to the "Quasi-Affine Expressions" in Neuron.
-- Note, `floor` is the usual integer division.
inductive IndexExpr where
| var (name : String)
| int (i : Int)
| neg (expr : IndexExpr)
| add (left right : IndexExpr)
| mul (scalar : Int) (expr : IndexExpr)
| floor (expr : IndexExpr) (scalar : Int)
govereau marked this conversation as resolved.
Show resolved Hide resolved
| ceil (expr : IndexExpr) (scalar : Int)
| mod (expr : IndexExpr) (scalar : Int)
deriving Repr, BEq

-- Note: `np.newindex` is represented as `(.coord none)`
inductive Index where
| ellipsis
| coord (e : Option IndexExpr)
govereau marked this conversation as resolved.
Show resolved Hide resolved
| slice (l u step : Option IndexExpr)
deriving Repr, BEq

inductive Expr where
| var (x : String)
| const (c : Const)
| tensor (t : Tensor)
| access (t : Expr) (ix : List Index)
| call (f : Expr) (args : List Expr) (kwargs : List (String × Expr))
govereau marked this conversation as resolved.
Show resolved Hide resolved
deriving Repr, BEq

inductive Stmt where
| pass
| expr (v : Expr)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is expr a statement in a language without side effects?

Copy link
Collaborator Author

@govereau govereau Jan 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good question. The (admittedly bad) answer is that I am not sure KLR is the final step. Maybe it isn't pure, but there is a transform we can do to make it pure. For example, we see stuff like:

# in nki.isa
 def builtin_custom_op(fn_name, file_name, ucode_version, isa_version, srcs, dsts, **kwargs):

If you call this function, it would look like:

def kernel(...):
  ...
  builtin_custom_op("name", "file", 1, 1, a, b)
  return b

Right, now the call to builtin_custom_op would be one of these expression statements. Not sure what to do here? I think we need to look one-by-one at the ISA operations.

| ret (v : Expr)
| assign (x : String) (e : Expr)
| loop (x : String) (l u step : IndexExpr) (body : List Stmt)
deriving Repr, BEq
38 changes: 13 additions & 25 deletions NKL/Encode.lean → NKL/KLR/Encode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.KLR.Basic

/-!
# Serialization and Deserialization
Expand Down Expand Up @@ -239,15 +239,15 @@ private def ie_var : IndexExpr := .var "s"
def encIndex : Index -> ByteArray
| .ellipsis => tag 0x20 []
| .coord e => tag 0x21 [enc e]
| .range l u s => tag 0x22 [enc l, enc u, enc s]
| .slice l u s => tag 0x22 [enc l, enc u, enc s]
where
enc := encOption encIndexExpr

def decIndex : DecodeM Index := do
match (<- next) with
| 0x20 => return .ellipsis
| 0x21 => return .coord (<- dec)
| 0x22 => return .range (<- dec) (<- dec) (<- dec)
| 0x22 => return .slice (<- dec) (<- dec) (<- dec)
| t => throw s!"Unknown tag in Index {t}"
where
dec:= decOption decIndexExpr
Expand All @@ -258,36 +258,28 @@ private def chkIndex (i : Index) : Bool :=
#guard chkIndex .ellipsis
#guard chkIndex (.coord none)
#guard chkIndex (.coord $ some ie_var)
#guard chkIndex (.range (some ie_var) none none)
#guard chkIndex (.slice (some ie_var) none none)

------------------------------------------------------------------------------
-- Expressions

partial def encExpr : Expr -> ByteArray
| .var s => tag 0x30 [encString s]
| .tensor t s => tag 0x31 [encString t, encList encInt s]
| .const c => tag 0x32 [encConst c]
| .tuple es => tag 0x33 [encList encExpr es]
| .list es => tag 0x34 [encList encExpr es]
| .access e ix => tag 0x35 [encExpr e, encList encIndex ix]
| .binop op l r => tag 0x36 [encString op, encExpr l, encExpr r]
| .unop op e => tag 0x37 [encString op, encExpr e]
| .call f ax kw => tag 0x38 [encExpr f, encList encExpr ax, encList encKeyword kw]
| .var s => tag 0x30 [encString s]
| .tensor t => tag 0x31 [encString t.dtype, encList encInt t.shape]
| .const c => tag 0x32 [encConst c]
| .access e ix => tag 0x33 [encExpr e, encList encIndex ix]
| .call f ax kw => tag 0x34 [encExpr f, encList encExpr ax, encList encKeyword kw]
where
encKeyword : String × Expr -> ByteArray
| (key, expr) => (encString key).append (encExpr expr)

partial def decExpr : DecodeM Expr := do
match (<- next) with
| 0x30 => return .var (<- decString)
| 0x31 => return .tensor (<- decString) (<- decList decInt)
| 0x31 => return .tensor $ .mk (<- decString) (<- decList decInt)
| 0x32 => return .const (<- decConst)
| 0x33 => return .tuple (<- decList decExpr)
| 0x34 => return .list (<- decList decExpr)
| 0x35 => return .access (<- decExpr) (<- decList decIndex)
| 0x36 => return .binop (<- decString) (<- decExpr) (<- decExpr)
| 0x37 => return .unop (<- decString) (<- decExpr)
| 0x38 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword)
| 0x33 => return .access (<- decExpr) (<- decList decIndex)
| 0x34 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword)
| t => throw s!"Unknown tag in Expr {t}"
where
decKeyword : DecodeM (String × Expr) :=
Expand All @@ -301,13 +293,9 @@ private def ixz := Index.coord (IndexExpr.int 0)

#guard chkExpr nil
#guard chkExpr (.var "var")
#guard chkExpr (.tensor "float32" [1,2,3])
#guard chkExpr (.tensor $ .mk "float32" [1,2,3])
#guard chkExpr (.const (.int 1))
#guard chkExpr (.tuple [nil, nil, nil])
#guard chkExpr (.list [nil, nil, nil])
#guard chkExpr (.access nil [ixz, ixz, ixz])
#guard chkExpr (.binop "op" nil nil)
#guard chkExpr (.unop "op" nil)
#guard chkExpr (.call nil [nil, nil, nil] [("a", nil), ("b", nil)])

------------------------------------------------------------------------------
Expand Down
Loading