Skip to content

Commit

Permalink
feat: support for python syntax
Browse files Browse the repository at this point in the history
Introduces a direct representation of the python abstract
syntax, and a simple parser for python functions. This is
necessary to support the formalization of the tracing
process.
  • Loading branch information
govereau committed Dec 9, 2024
1 parent ee68c08 commit 6da8f2a
Show file tree
Hide file tree
Showing 5 changed files with 442 additions and 14 deletions.
9 changes: 3 additions & 6 deletions Main.lean
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import NKL

def main (args : List String) : IO Unit :=
match args with
| .nil => IO.println s!"Hello, NKL!"
| .cons x _ => do
let s <- IO.FS.readFile x
NKL.parse_json s
def main : List String -> IO Unit
| [ file ] => IO.FS.readFile file >>= NKL.parse_json
| _ => throw $ .userError "invalid arguments"
1 change: 1 addition & 0 deletions NKL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ Authors: Paul Govereau
import NKL.Encode
import NKL.FFI
import NKL.NKI
import NKL.Python
24 changes: 16 additions & 8 deletions NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,24 @@ Authors: Paul Govereau
import Lean
import NKL.NKI
import NKL.PrettyPrint
import NKL.Python

namespace NKL

-- temporary for testing
local instance : MonadLift (Except String) IO where
monadLift
| .ok x => return x
| .error s => throw $ .userError s

@[export parse_json_old]
def parse_json_old (json : String) : IO Unit := do
let jsn <- Lean.Json.parse json
let f:Fun <- Lean.fromJson? jsn
print_nki f

@[export parse_json]
def parse_json (json : String) : IO Unit := do
match Lean.Json.parse json with
| .error str => throw $ .userError str
| .ok jsn => do
match Lean.fromJson? jsn with
| .error str => throw $ .userError str
| .ok (f:Fun) => print_nki f
def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
for (n,f) in kernel.funcs do
IO.println s!"found {n}"
IO.println s!"{repr f}"
288 changes: 288 additions & 0 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean

/-!
# Abstract syntax of Python functions
Mostly 1-to-1 translation of the Python AST to lean.
see: https://docs.python.org/3/library/ast.html
-/

namespace NKL
namespace Python

deriving instance Repr for Lean.JsonNumber

structure Pos where
lineno : Nat
end_lineno : Nat := 0
col_offset : Nat := 0
end_col_offset : Nat := 0
deriving Repr

inductive Const where
| none
| bool (value: Bool)
| num (value: Lean.JsonNumber)
| string (value: String)
| ellipsis
deriving Repr

inductive Ctx where
| load | store | del
deriving Repr

mutual
inductive Expr where
| exprPos (expr : Expr') (pos : Pos)
deriving Repr

inductive Expr' where
| const (value: Const)
| name (id: String) (ctx : Ctx)
| attr (value : Expr) (id : String) (ctx : Ctx)
| tuple (xs: List Expr) (ctx : Ctx)
| list (xs: List Expr) (ctx : Ctx)
| subscript (tensor: Expr) (ix: List Expr) (ctx : Ctx)
| slice (l u step: Option Expr)
| boolOp (op : String) (values : List Expr)
| binOp (op : String) (left right : Expr)
| unaryOp (op : String) (operand : Expr)
| compare (left : Expr) (ops : List String) (comparators : List Expr)
| ifExp (test body orelse : Expr)
| call (f: Expr) (args: List Expr) (keywords : List Keyword)
deriving Repr

inductive Keyword where
| keyword (id : String) (value : Expr) (pos : Pos)
deriving Repr
end

mutual
inductive Stmt where
| stmtPos (stmt : Stmt') (pos : Pos)
deriving Repr

inductive Stmt' where
| pass
| expr (e : Expr)
| assert (e : Expr)
| ret (e: Expr)
| assign (xs: List Expr) (e: Expr)
| augAssign (x : Expr) (op : String) (e : Expr)
| annAssign (x : Expr) (annotation : Expr) (value : Option Expr)
| forLoop (x : Expr) (iter: Expr) (body: List Stmt) (orelse : List Stmt)
| ifStm (e : Expr) (thn els: List Stmt)
deriving Repr
end

structure Args where
posonlyargs : List String
args : List String
defaults: List Expr
vararg : Option String
kwonlyargs : List String
kw_defaults: List Expr
kwarg : Option String
deriving Repr

structure Fun where
source : String
args : Args
defaults: List Const
body: List Stmt
deriving Repr

structure Kernel where
entry : String
funcs : List (String × Fun)
globals : List (String × Option String)

-------------------------------------------------------------------------------
-- Converting Python AST from Json

namespace Parsing
open Lean

-- I am using a state monad only to provide better error messages: the source
-- span (Pos) is saved while traversing the tree to identify the location
-- of any errors in the original program

abbrev Parser := EStateM String Pos

local instance : MonadLift (Except String) Parser where
monadLift
| .ok x => return x
| .error s => throw s

private def str : Json -> Parser String :=
monadLift ∘ Json.getStr?

private def field (f: Json -> Parser a) (j : Json) (name : String) : Parser a :=
j.getObjVal? name >>= f

private def field? (f: Json -> Parser a) (j : Json) (name : String) : Parser (Option a) :=
try let x <- field f j name; return (some x)
catch _ => return none

private def list (f: Json -> Parser a) : Json -> Parser (List a)
| .arr arr => arr.toList.mapM f
| json => return [(<- f json )]

private def dict (f : Json -> Parser a) : Json -> Parser (List (String × a))
| .obj kvs => kvs.toArray.toList.mapM fun p => return (p.1, (<- f p.2))
| _ => throw s!"expecting dictionary"

private def opt (p : Json -> Parser a) : Json -> Parser (Option a)
| .null => return none
| j => return (some (<- p j))

-- Note: this will not fail, but can produce an invalid Pos
private def pos (j: Json) : Parser Pos :=
return {
lineno := (<- nat "lineno")
end_lineno := (<- nat "end_lineno")
col_offset := (<- nat "col_offset")
end_col_offset := (<- nat "end_col_offset")
}
where
nat (name : String) : Parser Nat :=
tryCatch (nat' name) fun _ => return 0
nat' (name : String) : Parser Nat := do
let obj <- j.getObjVal? name
Json.getNat? obj

private def withPos (p : String -> Json -> Parser b) (f : b -> Pos -> a) : Json -> Parser a
| .obj (.node _ _ key val _) => do
let pos <- pos val
set pos
let exp <- p key val
return (f exp pos)
| _ => throw "expecting object"

private def withSrc (source : String) (p : Parser a) : Parser a :=
try set { lineno := 0 : Pos } ; p
catch e => get >>= throw ∘ genError e
where
genError (err : String) (pos : Pos) : String :=
let lines := source.splitOn "\n"
let lineno := pos.lineno - 1
let colno := pos.col_offset
let line := if lines.length < lineno
then "<source not available>"
else lines[lineno]!
let indent := (Nat.repeat (List.cons ' ') colno List.nil).asString
s!"line {lineno}:\n{line}\n{indent}^-- {err}"

-------------------------------------------------------------------------------
-- Python AST Json objects

def const : Json -> Parser Const
| .null => return .none
| .bool b => return (.bool b)
| .num jn => return (.num jn)
| .str "..." => return .ellipsis
| .str s => return (.string s)
| _ => throw "expecting constant"

def exprCtx : Json -> Parser Ctx
| .str "Load" => return .load
| .str "Store" => return .store
| .str "Del" => return .del
| _ => throw "expecting ctx"

partial def expr (j : Json) : Parser Expr :=
withPos expr' Expr.exprPos j
where
expr' (key : String) (j : Json) : Parser Expr' := do
let strs := field (list str) j
let str := field str j
let ctx := field exprCtx j
let const := field const j
let exprs := field (list expr) j
let expr? := field (opt expr) j
let expr := field expr j
let keywords := field (list keyword) j
match key with
| "Constant" => return (.const (<- const "value"))
| "Name" => return (.name (<- str "id") (<- ctx "ctx"))
| "Attribute" => return (.attr (<- expr "value") (<- str "attr") (<- ctx "ctx"))
| "Tuple" => return (.tuple (<- exprs "elts") (<- ctx "ctx"))
| "List" => return (.list (<- exprs "elts") (<- ctx "ctx"))
| "Subscript" => return (.subscript (<- expr "value") (<- exprs "slice") (<- ctx "ctx"))
| "Slice" => return (.slice (<- expr? "lower") (<- expr? "upper") (<- expr? "step"))
| "BoolOp" => return (.boolOp (<- str "op") (<- exprs "values"))
| "BinOp" => return (.binOp (<- str "op") (<- expr "left") (<- expr "right"))
| "UnaryOp" => return (.unaryOp (<- str "op") (<- expr "operand"))
| "Compare" => return (.compare (<- expr "left") (<- strs "ops") (<- exprs "comparators"))
| "IfExp" => return (.ifExp (<- expr "test") (<- expr "body") (<- expr "orelse"))
| "Call" => return (.call (<- expr "func") (<- exprs "args") (<- keywords "keywords"))
| _ => throw s!"unsupported python construct {key}"

keyword (j: Json) : Parser Keyword := do
let j <- j.getObjVal? "keyword"
return ⟨ <- field str j "arg", <- field expr j "value", <- pos j ⟩

partial def stmt (j : Json) : Parser Stmt :=
withPos stmt' Stmt.stmtPos j
where
stmt' (key : String) (j : Json) : Parser Stmt' := do
let str := field str j
let exprs := field (list expr) j
let expr? := field (opt expr) j
let expr := field expr j
let stmts := field (list stmt) j
match key with
| "Pass" => return .pass
| "Expr" => return (.expr (<- expr "value"))
| "Assert" => return (.assert (<- expr "test"))
| "Return" => return (.ret (<- expr "value"))
| "Assign" => return (.assign (<- exprs "targets") (<- expr "value"))
| "AugAssign" => return (.augAssign (<- expr "target") (<- str "op") (<- expr "value"))
| "AnnAssign" => return (.annAssign (<- expr "target") (<- expr "annotation") (<- expr? "value"))
| "For" => return (.forLoop (<- expr "target") (<- expr "iter") (<- stmts "body") (<- stmts "orelse"))
| "If" => return (.ifStm (<- expr "test") (<- stmts "body") (<- stmts "orelse"))
| _ => throw s!"unsupported python construct {key}"

def arguments (j : Json) : Parser Args := do
let obj <- j.getObjVal? "arguments"
let arg? := field (opt arg) obj
let args := field (list arg) obj
let exprs := field (list expr) obj
return {
posonlyargs := (<- args "posonlyargs")
args := (<- args "args")
defaults := (<- exprs "defaults")
vararg := (<- arg? "vararg")
kwonlyargs := (<- args "kwonlyargs")
kw_defaults := (<- exprs "kw_defaults")
kwarg := (<- arg? "kwarg")
}
where
arg (j : Json) : Parser String := do
let obj <- j.getObjVal? "arg"
return (<- field str obj "arg")

def function (j : Json) : Parser Fun := do
let source <- field str j "source"
withSrc source do
let args <- field arguments j "args"
let defaults <- field (list const) j "defaults"
let body <- field (list stmt) j "body"
return Fun.mk source args defaults body

def kernel (j : Json) : Parser Kernel := do
let name <- field str j "entry"
let funcs <- field (dict function) j "funcs"
let globals <- field (dict (opt str)) j "globals"
return Kernel.mk name funcs globals

def parse (s : String) : Except String Kernel := do
let jsn <- Json.parse s
match kernel jsn { lineno := 0 } with
| .ok x _ => .ok x
| .error s _ => .error s
Loading

0 comments on commit 6da8f2a

Please sign in to comment.