From 754ed88cbbbcdfd388577a57c4581e9d01c5a00f Mon Sep 17 00:00:00 2001 From: Paul Govereau Date: Mon, 20 Jan 2025 15:13:48 -0500 Subject: [PATCH] feat: add assignments to tuples and lists This change adds support for assignment like: `x, y = z`. Python calls this "unpacking", and you can convert between lists and tuples as long as they are the right length. --- NKL/Trace/Python.lean | 137 ++++++++++++++++++++++++++++++------- interop/test/test_basic.py | 15 ++++ 2 files changed, 126 insertions(+), 26 deletions(-) diff --git a/NKL/Trace/Python.lean b/NKL/Trace/Python.lean index 47667e9..0394453 100644 --- a/NKL/Trace/Python.lean +++ b/NKL/Trace/Python.lean @@ -21,7 +21,7 @@ def const : Const -> ErrorM Term | .ellipsis => throw "unsupported use of ellipsis" /- -Evaluating index expressions. +# Evaluating index expressions An index expression occurs only within a subscript expression. For example, in the expression: @@ -85,6 +85,113 @@ def access : Term -> List KLR.Index -> TraceM Term | .list l, ix => list_access l ix | .expr e _, ix => return .expr (.access e ix) (.any "?".toName) +/- +# Handling of assignment statements + +Assignments can be things like: + + x = y = 1 + a, y = (1,2) + +or even + + (x,y), z = a, [b,c] = f() + +The current implementation requires that these kinds of complex assignments are +expanded at tracing time. That is, in the example above the call to f must +generate a tuple or list of the right size at tracing time. This will then be +expanded out to assignments of the individual variables. + +In general, we need to make sure the right-hand side is only evaluated once. For +example, consider the assignment below: + + (x,y) = a = (f(), 1) + +The following is and incorrect translation, because f is called twice. + + a = (f(), 1) + x = f() # INCORRECT + y = 1 + +One correct translation is: + + tmp = f() + a = (tmp, 1) + x = tmp + y = 1 + +The extra variable `tmp` is only needed if the right-hand side has side-effects +or is expensive to compute. Otherwise, it is safe to copy the right-hand side +everywhere it is needed. + +In the above case, only the first assignment is emitted to the translated +function. The other three assignments are placed in the environment, but not +emitted. Therefore any uses of a, x, or y will be replaced with their +assignments. This is effectively a simple form of constant propagation and +dead-code elimination for simple assignments. +-/ + +-- Convert an expression in assignment context (an L-Value). +partial def LValue : Expr -> Tracer Term + | .exprPos e' p => withPos p (lval e') +where + lval : Expr' -> Tracer Term + | .name id .store => return .expr (.var id) (.any "?".toName) + | .tuple l .store => return .tuple (<- l.mapM LValue) + | .list l .store => return .list (<- l.mapM LValue) + | _ => throw "cannot assign to expression" + +-- Convert an R-Value to a pure expression, emitting +-- additional assignments as needed. +partial def RValue : Term -> Tracer Term + | .object o => return .object o + | .tuple l => return .tuple (<- l.mapM RValue) + | .list l => return .list (<- l.mapM RValue) + | .expr e@(.call _ _ _) ty => do + let v := (<- genName).toString + add_stmt (.assign v e) + return .expr (.var v) ty + | .expr e ty => return .expr e ty + +-- Create an assignment to a KLR Expr, this must be a variable +-- TODO: should we support tensors and sub-tensors? +def assignExpr (e : KLR.Expr) (t : Term) : Tracer Unit := do + match e with + | .var x => extend x.toName t + | _ => throw s!"cannot assign to {repr e}" + +-- Unpack an RValue, must be a list or tuple +def unpack : Term -> Tracer (List Term) + | .object o => throw s!"cannot unpack non-iterable object {o.name}" + | .expr _ t => throw s!"cannot unpack non-iterable object {repr t}" + | .tuple l | .list l => return l + +-- Assign to a term, handling unpacking for tuples and lists +def assignTerm (x : Term) (e : Term) : Tracer Unit := do + match x with + | .object o => throw s!"cannot assign to {o.name}" + | .tuple l + | .list l => assignList l (<- unpack e) + | .expr x _ => assignExpr x e +where + assignList : List Term -> List Term -> Tracer Unit + | [], [] => return () + | [], _ => throw "not enough values to unpack" + | _, [] => throw "too many values to unpack" + | x::xs, t::ts => do + assignTerm x t; + assignList xs ts + +-- Top-level assignment handling +-- e.g. x1 = x2 = e +def assign (xs : List Term) (e : Term) : Tracer Unit := do + let e <- RValue e + for x in xs do + assignTerm x e + +/- +# Expressions and Statements +-/ mutual partial def expr : Expr -> Tracer Item @@ -146,39 +253,17 @@ partial def expr' : Expr' -> Tracer Item partial def keyword (f : Expr -> Tracer a) : Keyword -> Tracer (String × a) | .keyword id e p => withPos p do return (id, (<- f e)) --- When looking for a variable we rely on the store attribute --- from the Python parser to check if it is a defining use. -partial def var : Expr -> Tracer String - | .exprPos (.name id .store) _ => return id - | _ => throw "expecting variable" - --- When we perform an assignment, we will either add to the environment --- the term found on the RHS, or the variable itself. The latter case --- allows us to lookup and find the variable without substituting --- its definition. -partial def assign (xs : List Expr) (e : Expr) : Tracer Unit := do - let xs <- xs.mapM var - let e <- term e - match e with - | .expr (.const _) _ => xs.forM fun x => extend x.toName e - | .expr e ty => xs.forM fun x => do - extend x.toName (.expr (.var xs[0]!) ty) - add_stmt (KLR.Stmt.assign x e) - | t => xs.forM fun x => extend x.toName t - partial def stmt : Stmt -> Tracer Unit | .stmtPos s' p => withPos p (stmt' s') partial def stmt' : Stmt' -> Tracer Unit - | .expr (.exprPos (.const _) _) => return () | .expr e => do - match <- term e with - | .expr e _ => add_stmt (.expr e) - | _ => return () -- effects are done, can be removed from KLR + let t <- term e + let _ <- RValue t | .assert e => do let t <- term e if (<- t.isFalse) then throw "assertion failed" - | .assign xs e => assign xs e + | .assign xs e => do assign (<- xs.mapM LValue) (<- term e) | .augAssign x op e => do stmt' (.assign [x] (.exprPos (.binOp op x e) (<- getPos))) | .annAssign _ _ .none => return () diff --git a/interop/test/test_basic.py b/interop/test/test_basic.py index 71c64be..dfbfb0e 100644 --- a/interop/test/test_basic.py +++ b/interop/test/test_basic.py @@ -65,6 +65,20 @@ def expr_bool_op(t): 1 or None # evals to 1 (False,) or 1 # evals to (False,) +def assign(t): + x = y = 1 + assert x == y + x, y = [1,2] + assert x == 1 + assert y == 2 + (x,y), z = a, [b,c] = ((1,2),(3,4)) + assert x == 1 + assert y == 2 + assert z == (3,4) + assert a == (1,2) + assert b == 3 + assert c == 4 + @pytest.mark.parametrize("f", [ const_stmt, expr_name, @@ -72,6 +86,7 @@ def expr_bool_op(t): expr_list, expr_subscript, expr_bool_op, + assign ]) def test_succeed(f): t = np.ndarray(10)