Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement modules and require #397

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions benchmarks/nbody/pallene.pln
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
local m: module = {}

record Body
x: float
y: float
z: float
vx: float
vy: float
vz: float
mass: float
end
typealias Body = {x: float, y: float, z: float, vx: float, vy: float, vz: float, mass: float}

function m.new_body(
x: float, y: float, z: float,
Expand Down
31 changes: 31 additions & 0 deletions doc/manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,37 @@ end
function even(n:integer): boolean
if n == 0 then return true else return odd(n-1) end
end
```

## Imported Modules

It's possible to import other Pallene modules using the `require` function, just as is in Lua.

```
local mod = require"mod"
```

There is one caveat, though:
`require` can only be called from a toplevel local variable declaration.

The syntax for calling module functions, or using module variables is the following

```
local mod = require"mod"
local ret = mod.foo()
local var = mod.var
```

As of now, you can pass all types of arguments to imported module function, except for `Records`.
Also, module functions can't return `Records` and module variables can't be `Records`.

Since module variables are constant in pallene, you also can't change their values.
Any attempt to do so, like the one shown below, will yield an error.

```
local mod = require"mod"
mod.var = 1
```

## Expressions and Statements

Expand Down
3 changes: 2 additions & 1 deletion pallene/builtins.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ local ipairs_itertype = T.Function({T.Any(), T.Any()}, {T.Any(), T.Any()})
builtins.functions = {
type = T.Function({ T.Any() }, { T.String() }),
tostring = T.Function({ T.Any() }, { T.String() }),
ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()})
ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()}),
require = T.Function({ T.String() }, { T.Any() })
}

builtins.modules = {
Expand Down
109 changes: 100 additions & 9 deletions pallene/checker.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,13 @@ local util = require "pallene.util"
local checker = {}

local Checker = util.Class()
local driver = {}

-- Type-check a Pallene module
-- On success, returns the typechecked module for the program
-- On failure, returns false and a list of compilation errors
function checker.check(prog_ast)
function checker.check(prog_ast, driver_passed)
driver = driver_passed
local co = coroutine.create(function()
return Checker.new():check_program(prog_ast)
end)
Expand Down Expand Up @@ -96,6 +98,8 @@ function Checker:init()
self.module_symbol = false -- checker.Symbol.Module
self.symbol_table = symtab.new() -- string => checker.Symbol
self.ret_types_stack = {} -- stack of types.T
self.imported_modules = {} -- Imported modules
self.in_toplevel_decl = false -- Identifies if we are evaluating a toplevel decl
return self
end

Expand All @@ -115,7 +119,8 @@ end
declare_type("Symbol", {
Type = { "typ" },
Value = { "typ", "def" },
Module = { "typ", "symbols" }, -- Note: a module name can also be a type (e.g. "string")
Module = { "typ", "symbols", "imported_mod_name" },
-- Note: a module name can also be a type (e.g. "string")
})

--
Expand Down Expand Up @@ -154,10 +159,11 @@ function Checker:add_value_symbol(name, typ, def)
return self.symbol_table:add_symbol(name, checker.Symbol.Value(typ, def))
end

function Checker:add_module_symbol(name, typ, symbols)
function Checker:add_module_symbol(name, typ, symbols, imported_mod_name)
assert(type(name) == "string")
assert((not typ) or typedecl.match_tag(typ._tag, "types.T"))
return self.symbol_table:add_symbol(name, checker.Symbol.Module(typ, symbols))
imported_mod_name = imported_mod_name or false
return self.symbol_table:add_symbol(name, checker.Symbol.Module(typ, symbols, imported_mod_name))
end

function Checker:export_value_symbol(name, typ, def)
Expand Down Expand Up @@ -304,25 +310,65 @@ end

-- If the last expression in @rhs is a function call that returns multiple values, add ExtraRet
-- nodes to the end of the list.
function Checker:expand_function_returns(rhs)
function Checker:expand_function_returns(rhs, is_toplevel_decl)
local N = #rhs
local last = rhs[N]
if last and (last._tag == "ast.Exp.CallFunc" or last._tag == "ast.Exp.CallMethod") then
self.in_toplevel_decl = is_toplevel_decl
last = self:check_exp_synthesize(last)
rhs[N] = last
for i = 2, #last._types do
rhs[N-1+i] = ast.Exp.ExtraRet(last.loc, last, i)
end
self.in_toplevel_decl = false
end
end

function Checker:is_the_module_variable(exp)
function Checker:is_the_module_variable(var_name)
return self.module_symbol == self.symbol_table:find_symbol(var_name)
end

function Checker:exp_is_the_module_variable(exp)
-- Check if the expression is the module variable without calling check_exp.
-- Doing that would have raised an exception because it is not a value.
return (
exp._tag == "ast.Exp.Var" and
exp.var._tag == "ast.Var.Name" and
(self.module_symbol == self.symbol_table:find_symbol(exp.var.name)))
self:is_the_module_variable(exp.var.name))
end

function Checker:init_imported_module(var_name, mod_name)
local mod_ast = self.imported_modules[mod_name]
local symbols = {}
for _, tls in ipairs(mod_ast.tls) do
if tls._tag == "ast.Toplevel.Stats" then
local stats = tls.stats
for _, stat in ipairs(stats) do
if stat._tag == "ast.Stat.Functions" then
for _, func_stat in ipairs(stat.funcs) do
if func_stat.module then
local typ = func_stat._type
local def = checker.Def.Function(func_stat)
symbols[func_stat.name] = checker.Symbol.Value(typ, def)
end
end
elseif stat._tag == "ast.Stat.Assign" then
for _, var in ipairs(stat.vars) do
if var._exported_as then
symbols[var.name] = checker.Symbol.Value(var._type, var._def)
end
end
end
end
end
end
self:add_module_symbol(var_name, types.T.String(), symbols, mod_name)
end

local function exp_is_require(exp)
assert(exp._tag == "ast.Exp.Var")
local def = exp.var._def
return def and def._tag == "checker.Def.Builtin" and def.id == "require"
end

function Checker:check_stat(stat, is_toplevel)
Expand All @@ -338,7 +384,7 @@ function Checker:check_stat(stat, is_toplevel)
decl._type = self:from_ast_type(decl.type)
end
else
self:expand_function_returns(stat.exps)
self:expand_function_returns(stat.exps, is_toplevel)
local m = #stat.decls
local n = #stat.exps
if m > n then
Expand All @@ -348,6 +394,11 @@ function Checker:check_stat(stat, is_toplevel)
stat.exps[i] = self:check_initializer_exp(
stat.decls[i], stat.exps[i],
"declaration of local variable '%s'", stat.decls[i].name)
if stat.exps[i]._tag == "ast.Exp.CallFunc" and exp_is_require(stat.exps[i].exp) then
self:init_imported_module(stat.decls[i].name, stat.exps[i].args[i].value)
table.remove(stat.decls, i)
table.remove(stat.exps, i)
end
end
for i = m + 1, n do
stat.exps[i] = self:check_exp_synthesize(stat.exps[i])
Expand Down Expand Up @@ -486,7 +537,7 @@ function Checker:check_stat(stat, is_toplevel)
elseif tag == "ast.Stat.Assign" then

for i, var in ipairs(stat.vars) do
if var._tag == "ast.Var.Dot" and self:is_the_module_variable(var.exp) then
if var._tag == "ast.Var.Dot" and self:exp_is_the_module_variable(var.exp) then
-- Declaring a module field
if not is_toplevel then
type_error(var.loc, "module fields can only be set at the toplevel")
Expand Down Expand Up @@ -533,6 +584,9 @@ function Checker:check_stat(stat, is_toplevel)
if var._exported_as then
self:export_value_symbol(var._exported_as, var._type, var._def)
end
if var._mod_name then
type_error(stat.loc, "Can't assign to imported module variables")
end
end

elseif tag == "ast.Stat.Call" then
Expand Down Expand Up @@ -570,11 +624,19 @@ function Checker:check_stat(stat, is_toplevel)
local arg_types = {}
for i, decl in ipairs(func.value.arg_decls) do
arg_types[i] = self:from_ast_type(decl.type)
if func.module and arg_types[i]._tag == "types.T.Record" then
type_error(decl.type.loc,
"Argument number %d of module function is a record", i)
end
end

local ret_types = {}
for i, ast_typ in ipairs(func.ret_types) do
ret_types[i] = self:from_ast_type(ast_typ)
if func.module and ret_types[i]._tag == "types.T.Record" then
type_error(ast_typ.loc,
"Return value number %d of module function is a record", i)
end
end

local typ = types.T.Function(arg_types, ret_types)
Expand Down Expand Up @@ -670,6 +732,11 @@ function Checker:try_flatten_to_qualified_name(outer_var)
local q = ast.Var.Name(var.loc, table.concat(components, "."))
q._type = sym.typ
q._def = sym.def

local is_builtin = sym.def._tag == "checker.Def.Builtin"
q._mod_name = not is_builtin and not self:is_the_module_variable(root) and
root_sym.imported_mod_name

return q
end

Expand Down Expand Up @@ -747,6 +814,26 @@ function Checker:coerce_numeric_exp_to_float(exp)
end
end

function Checker:check_require(exp)
if not self.in_toplevel_decl then
type_error(exp.loc, "Can only call require from a local variable declaration")
end
local args = exp.args
local arg = args[1]
local fileprefix = string.gsub(arg.value, "%.", "/")
local filename = string.format("%s.pln", fileprefix)
local input, err = driver.load_input(filename)
if err then
type_error(exp.loc, "Can't find module %s\n", arg.value)
end

local module_ast, err = driver.compile_internal(filename, input, "checker", 0)
if not module_ast then
type_error(exp.loc, "Error loading module %s: %s", arg.value, err[1])
end
self.imported_modules[arg.value] = module_ast
end

-- Check (synthesize) the type of a function call expression.
-- If the function returns 0 arguments, it is only allowed in a statement context.
-- Void functions in an expression context are a constant source of headaches.
Expand Down Expand Up @@ -786,6 +873,10 @@ function Checker:check_fun_call(exp, is_stat)
end
exp._types = f_type.ret_types

if exp_is_require(exp.exp) then
self:check_require(exp)
end

return exp
end

Expand Down
Loading