Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add assignments to tuples and lists #22

Merged
merged 1 commit into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 111 additions & 26 deletions NKL/Trace/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def const : Const -> ErrorM Term
| .ellipsis => throw "unsupported use of ellipsis"

/-
Evaluating index expressions.
# Evaluating index expressions

An index expression occurs only within a subscript expression. For example, in
the expression:
Expand Down Expand Up @@ -85,6 +85,113 @@ def access : Term -> List KLR.Index -> TraceM Term
| .list l, ix => list_access l ix
| .expr e _, ix => return .expr (.access e ix) (.any "?".toName)

/-
# Handling of assignment statements

Assignments can be things like:

x = y = 1
a, y = (1,2)

or even

(x,y), z = a, [b,c] = f()

The current implementation requires that these kinds of complex assignments are
expanded at tracing time. That is, in the example above the call to f must
generate a tuple or list of the right size at tracing time. This will then be
expanded out to assignments of the individual variables.

In general, we need to make sure the right-hand side is only evaluated once. For
example, consider the assignment below:

(x,y) = a = (f(), 1)

The following is and incorrect translation, because f is called twice.

a = (f(), 1)
x = f() # INCORRECT
y = 1

One correct translation is:

tmp = f()
a = (tmp, 1)
x = tmp
y = 1

The extra variable `tmp` is only needed if the right-hand side has side-effects
or is expensive to compute. Otherwise, it is safe to copy the right-hand side
everywhere it is needed.

In the above case, only the first assignment is emitted to the translated
function. The other three assignments are placed in the environment, but not
emitted. Therefore any uses of a, x, or y will be replaced with their
assignments. This is effectively a simple form of constant propagation and
dead-code elimination for simple assignments.
-/

-- Convert an expression in assignment context (an L-Value).
partial def LValue : Expr -> Tracer Term
| .exprPos e' p => withPos p (lval e')
where
lval : Expr' -> Tracer Term
| .name id .store => return .expr (.var id) (.any "?".toName)
| .tuple l .store => return .tuple (<- l.mapM LValue)
| .list l .store => return .list (<- l.mapM LValue)
| _ => throw "cannot assign to expression"

-- Convert an R-Value to a pure expression, emitting
-- additional assignments as needed.
partial def RValue : Term -> Tracer Term
| .object o => return .object o
| .tuple l => return .tuple (<- l.mapM RValue)
| .list l => return .list (<- l.mapM RValue)
| .expr e@(.call _ _ _) ty => do
let v := (<- genName).toString
add_stmt (.assign v e)
return .expr (.var v) ty
| .expr e ty => return .expr e ty

-- Create an assignment to a KLR Expr, this must be a variable
-- TODO: should we support tensors and sub-tensors?
def assignExpr (e : KLR.Expr) (t : Term) : Tracer Unit := do
match e with
| .var x => extend x.toName t
| _ => throw s!"cannot assign to {repr e}"

-- Unpack an RValue, must be a list or tuple
def unpack : Term -> Tracer (List Term)
| .object o => throw s!"cannot unpack non-iterable object {o.name}"
| .expr _ t => throw s!"cannot unpack non-iterable object {repr t}"
| .tuple l | .list l => return l

-- Assign to a term, handling unpacking for tuples and lists
def assignTerm (x : Term) (e : Term) : Tracer Unit := do
match x with
| .object o => throw s!"cannot assign to {o.name}"
| .tuple l
| .list l => assignList l (<- unpack e)
| .expr x _ => assignExpr x e
where
assignList : List Term -> List Term -> Tracer Unit
| [], [] => return ()
| [], _ => throw "not enough values to unpack"
| _, [] => throw "too many values to unpack"
| x::xs, t::ts => do
assignTerm x t;
assignList xs ts

-- Top-level assignment handling
-- e.g. x1 = x2 = e
def assign (xs : List Term) (e : Term) : Tracer Unit := do
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty.

let e <- RValue e
for x in xs do
assignTerm x e

/-
# Expressions and Statements
-/

mutual
partial def expr : Expr -> Tracer Item
Expand Down Expand Up @@ -146,39 +253,17 @@ partial def expr' : Expr' -> Tracer Item
partial def keyword (f : Expr -> Tracer a) : Keyword -> Tracer (String × a)
| .keyword id e p => withPos p do return (id, (<- f e))

-- When looking for a variable we rely on the store attribute
-- from the Python parser to check if it is a defining use.
partial def var : Expr -> Tracer String
| .exprPos (.name id .store) _ => return id
| _ => throw "expecting variable"

-- When we perform an assignment, we will either add to the environment
-- the term found on the RHS, or the variable itself. The latter case
-- allows us to lookup and find the variable without substituting
-- its definition.
partial def assign (xs : List Expr) (e : Expr) : Tracer Unit := do
let xs <- xs.mapM var
let e <- term e
match e with
| .expr (.const _) _ => xs.forM fun x => extend x.toName e
| .expr e ty => xs.forM fun x => do
extend x.toName (.expr (.var xs[0]!) ty)
add_stmt (KLR.Stmt.assign x e)
| t => xs.forM fun x => extend x.toName t

partial def stmt : Stmt -> Tracer Unit
| .stmtPos s' p => withPos p (stmt' s')

partial def stmt' : Stmt' -> Tracer Unit
| .expr (.exprPos (.const _) _) => return ()
| .expr e => do
match <- term e with
| .expr e _ => add_stmt (.expr e)
| _ => return () -- effects are done, can be removed from KLR
let t <- term e
let _ <- RValue t
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if this is clear. But my thinking is that if you have an expression as a statement, then we can use the RValue conversion to figure out if it "has side-effects". If it does, the RValue translation will insert an assignment to a fresh variable, otherwise we just discard the expression (e.g. doc strings, etc.)

| .assert e => do
let t <- term e
if (<- t.isFalse) then throw "assertion failed"
| .assign xs e => assign xs e
| .assign xs e => do assign (<- xs.mapM LValue) (<- term e)
| .augAssign x op e => do
stmt' (.assign [x] (.exprPos (.binOp op x e) (<- getPos)))
| .annAssign _ _ .none => return ()
Expand Down
15 changes: 15 additions & 0 deletions interop/test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,28 @@ def expr_bool_op(t):
1 or None # evals to 1
(False,) or 1 # evals to (False,)

def assign(t):
x = y = 1
assert x == y
x, y = [1,2]
assert x == 1
assert y == 2
(x,y), z = a, [b,c] = ((1,2),(3,4))
assert x == 1
assert y == 2
assert z == (3,4)
assert a == (1,2)
assert b == 3
assert c == 4

@pytest.mark.parametrize("f", [
const_stmt,
expr_name,
expr_tuple,
expr_list,
expr_subscript,
expr_bool_op,
assign
])
def test_succeed(f):
t = np.ndarray(10)
Expand Down
Loading