diff --git a/benchmarks/nbody/pallene.pln b/benchmarks/nbody/pallene.pln index 4b0b92c0..9f374182 100644 --- a/benchmarks/nbody/pallene.pln +++ b/benchmarks/nbody/pallene.pln @@ -1,14 +1,6 @@ local m: module = {} -record Body - x: float - y: float - z: float - vx: float - vy: float - vz: float - mass: float -end +typealias Body = {x: float, y: float, z: float, vx: float, vy: float, vz: float, mass: float} function m.new_body( x: float, y: float, z: float, diff --git a/doc/manual.md b/doc/manual.md index 1e810e5c..ae05f1b0 100644 --- a/doc/manual.md +++ b/doc/manual.md @@ -382,6 +382,37 @@ end function even(n:integer): boolean if n == 0 then return true else return odd(n-1) end end +``` + +## Imported Modules + +It's possible to import other Pallene modules using the `require` function, just as is in Lua. + +``` +local mod = require"mod" +``` + +There is one caveat, though: +`require` can only be called from a toplevel local variable declaration. + +The syntax for calling module functions, or using module variables is the following + +``` +local mod = require"mod" +local ret = mod.foo() +local var = mod.var +``` + +As of now, you can pass all types of arguments to imported module function, except for `Records`. +Also, module functions can't return `Records` and module variables can't be `Records`. + +Since module variables are constant in pallene, you also can't change their values. +Any attempt to do so, like the one shown below, will yield an error. + +``` +local mod = require"mod" +mod.var = 1 +``` ## Expressions and Statements diff --git a/pallene/builtins.lua b/pallene/builtins.lua index b390c9e8..fdb87009 100644 --- a/pallene/builtins.lua +++ b/pallene/builtins.lua @@ -15,7 +15,8 @@ local ipairs_itertype = T.Function({T.Any(), T.Any()}, {T.Any(), T.Any()}) builtins.functions = { type = T.Function({ T.Any() }, { T.String() }), tostring = T.Function({ T.Any() }, { T.String() }), - ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()}) + ipairs = T.Function({T.Array(T.Any())}, {ipairs_itertype, T.Any(), T.Any()}), + require = T.Function({ T.String() }, { T.Any() }) } builtins.modules = { diff --git a/pallene/checker.lua b/pallene/checker.lua index 7130daa2..9958ae15 100644 --- a/pallene/checker.lua +++ b/pallene/checker.lua @@ -46,11 +46,13 @@ local util = require "pallene.util" local checker = {} local Checker = util.Class() +local driver = {} -- Type-check a Pallene module -- On success, returns the typechecked module for the program -- On failure, returns false and a list of compilation errors -function checker.check(prog_ast) +function checker.check(prog_ast, driver_passed) + driver = driver_passed local co = coroutine.create(function() return Checker.new():check_program(prog_ast) end) @@ -96,6 +98,8 @@ function Checker:init() self.module_symbol = false -- checker.Symbol.Module self.symbol_table = symtab.new() -- string => checker.Symbol self.ret_types_stack = {} -- stack of types.T + self.imported_modules = {} -- Imported modules + self.in_toplevel_decl = false -- Identifies if we are evaluating a toplevel decl return self end @@ -115,7 +119,8 @@ end declare_type("Symbol", { Type = { "typ" }, Value = { "typ", "def" }, - Module = { "typ", "symbols" }, -- Note: a module name can also be a type (e.g. "string") + Module = { "typ", "symbols", "imported_mod_name" }, + -- Note: a module name can also be a type (e.g. "string") }) -- @@ -154,10 +159,11 @@ function Checker:add_value_symbol(name, typ, def) return self.symbol_table:add_symbol(name, checker.Symbol.Value(typ, def)) end -function Checker:add_module_symbol(name, typ, symbols) +function Checker:add_module_symbol(name, typ, symbols, imported_mod_name) assert(type(name) == "string") assert((not typ) or typedecl.match_tag(typ._tag, "types.T")) - return self.symbol_table:add_symbol(name, checker.Symbol.Module(typ, symbols)) + imported_mod_name = imported_mod_name or false + return self.symbol_table:add_symbol(name, checker.Symbol.Module(typ, symbols, imported_mod_name)) end function Checker:export_value_symbol(name, typ, def) @@ -304,25 +310,65 @@ end -- If the last expression in @rhs is a function call that returns multiple values, add ExtraRet -- nodes to the end of the list. -function Checker:expand_function_returns(rhs) +function Checker:expand_function_returns(rhs, is_toplevel_decl) local N = #rhs local last = rhs[N] if last and (last._tag == "ast.Exp.CallFunc" or last._tag == "ast.Exp.CallMethod") then + self.in_toplevel_decl = is_toplevel_decl last = self:check_exp_synthesize(last) rhs[N] = last for i = 2, #last._types do rhs[N-1+i] = ast.Exp.ExtraRet(last.loc, last, i) end + self.in_toplevel_decl = false end end -function Checker:is_the_module_variable(exp) +function Checker:is_the_module_variable(var_name) + return self.module_symbol == self.symbol_table:find_symbol(var_name) +end + +function Checker:exp_is_the_module_variable(exp) -- Check if the expression is the module variable without calling check_exp. -- Doing that would have raised an exception because it is not a value. return ( exp._tag == "ast.Exp.Var" and exp.var._tag == "ast.Var.Name" and - (self.module_symbol == self.symbol_table:find_symbol(exp.var.name))) + self:is_the_module_variable(exp.var.name)) +end + +function Checker:init_imported_module(var_name, mod_name) + local mod_ast = self.imported_modules[mod_name] + local symbols = {} + for _, tls in ipairs(mod_ast.tls) do + if tls._tag == "ast.Toplevel.Stats" then + local stats = tls.stats + for _, stat in ipairs(stats) do + if stat._tag == "ast.Stat.Functions" then + for _, func_stat in ipairs(stat.funcs) do + if func_stat.module then + local typ = func_stat._type + local def = checker.Def.Function(func_stat) + symbols[func_stat.name] = checker.Symbol.Value(typ, def) + end + end + elseif stat._tag == "ast.Stat.Assign" then + for _, var in ipairs(stat.vars) do + if var._exported_as then + symbols[var.name] = checker.Symbol.Value(var._type, var._def) + end + end + end + end + end + end + self:add_module_symbol(var_name, types.T.String(), symbols, mod_name) +end + +local function exp_is_require(exp) + assert(exp._tag == "ast.Exp.Var") + local def = exp.var._def + return def and def._tag == "checker.Def.Builtin" and def.id == "require" end function Checker:check_stat(stat, is_toplevel) @@ -338,7 +384,7 @@ function Checker:check_stat(stat, is_toplevel) decl._type = self:from_ast_type(decl.type) end else - self:expand_function_returns(stat.exps) + self:expand_function_returns(stat.exps, is_toplevel) local m = #stat.decls local n = #stat.exps if m > n then @@ -348,6 +394,11 @@ function Checker:check_stat(stat, is_toplevel) stat.exps[i] = self:check_initializer_exp( stat.decls[i], stat.exps[i], "declaration of local variable '%s'", stat.decls[i].name) + if stat.exps[i]._tag == "ast.Exp.CallFunc" and exp_is_require(stat.exps[i].exp) then + self:init_imported_module(stat.decls[i].name, stat.exps[i].args[i].value) + table.remove(stat.decls, i) + table.remove(stat.exps, i) + end end for i = m + 1, n do stat.exps[i] = self:check_exp_synthesize(stat.exps[i]) @@ -486,7 +537,7 @@ function Checker:check_stat(stat, is_toplevel) elseif tag == "ast.Stat.Assign" then for i, var in ipairs(stat.vars) do - if var._tag == "ast.Var.Dot" and self:is_the_module_variable(var.exp) then + if var._tag == "ast.Var.Dot" and self:exp_is_the_module_variable(var.exp) then -- Declaring a module field if not is_toplevel then type_error(var.loc, "module fields can only be set at the toplevel") @@ -533,6 +584,9 @@ function Checker:check_stat(stat, is_toplevel) if var._exported_as then self:export_value_symbol(var._exported_as, var._type, var._def) end + if var._mod_name then + type_error(stat.loc, "Can't assign to imported module variables") + end end elseif tag == "ast.Stat.Call" then @@ -570,11 +624,19 @@ function Checker:check_stat(stat, is_toplevel) local arg_types = {} for i, decl in ipairs(func.value.arg_decls) do arg_types[i] = self:from_ast_type(decl.type) + if func.module and arg_types[i]._tag == "types.T.Record" then + type_error(decl.type.loc, + "Argument number %d of module function is a record", i) + end end local ret_types = {} for i, ast_typ in ipairs(func.ret_types) do ret_types[i] = self:from_ast_type(ast_typ) + if func.module and ret_types[i]._tag == "types.T.Record" then + type_error(ast_typ.loc, + "Return value number %d of module function is a record", i) + end end local typ = types.T.Function(arg_types, ret_types) @@ -670,6 +732,11 @@ function Checker:try_flatten_to_qualified_name(outer_var) local q = ast.Var.Name(var.loc, table.concat(components, ".")) q._type = sym.typ q._def = sym.def + + local is_builtin = sym.def._tag == "checker.Def.Builtin" + q._mod_name = not is_builtin and not self:is_the_module_variable(root) and + root_sym.imported_mod_name + return q end @@ -747,6 +814,26 @@ function Checker:coerce_numeric_exp_to_float(exp) end end +function Checker:check_require(exp) + if not self.in_toplevel_decl then + type_error(exp.loc, "Can only call require from a local variable declaration") + end + local args = exp.args + local arg = args[1] + local fileprefix = string.gsub(arg.value, "%.", "/") + local filename = string.format("%s.pln", fileprefix) + local input, err = driver.load_input(filename) + if err then + type_error(exp.loc, "Can't find module %s\n", arg.value) + end + + local module_ast, err = driver.compile_internal(filename, input, "checker", 0) + if not module_ast then + type_error(exp.loc, "Error loading module %s: %s", arg.value, err[1]) + end + self.imported_modules[arg.value] = module_ast +end + -- Check (synthesize) the type of a function call expression. -- If the function returns 0 arguments, it is only allowed in a statement context. -- Void functions in an expression context are a constant source of headaches. @@ -786,6 +873,10 @@ function Checker:check_fun_call(exp, is_stat) end exp._types = f_type.ret_types + if exp_is_require(exp.exp) then + self:check_require(exp) + end + return exp end diff --git a/pallene/coder.lua b/pallene/coder.lua index ab65921e..10fd4e12 100644 --- a/pallene/coder.lua +++ b/pallene/coder.lua @@ -58,6 +58,8 @@ function Coder:init(module, modname, filename) self.upvalue_of_string = {} -- str => integer self.upvalue_of_function = {} -- f_id => integer self.upvalue_of_global = {} -- g_id => integer + self.upvalue_of_imported_function = {} -- f_id => integer + self.upvalue_of_imported_var = {} -- v_id => integer self:init_upvalues() self.record_ids = {} -- types.T.Record => integer @@ -332,6 +334,14 @@ function Coder:c_value(value) local f_id = value.id local typ = self.module.functions[f_id].typ return lua_value(typ, self:function_upvalue_slot(f_id)) + elseif tag == "ir.Value.ImportedFunction" then + local f_id = value.id + local typ = self.module.imported_functions[f_id].typ + return lua_value(typ, self:imported_function_upvalue_slot(f_id)) + elseif tag == "ir.Value.ImportedVar" then + local v_id = value.id + local typ = self.module.imported_vars[v_id].typ + return lua_value(typ, self:imported_var_upvalue_slot(v_id)) elseif typedecl.match_tag(tag, "ir.Value") then typedecl.tag_error(tag, "unable to get C expression for this value type.") else @@ -627,7 +637,14 @@ typedecl.declare(coder, "coder", "Upvalue", { Metatable = {"typ"}, String = {"str"}, Function = {"f_id"}, + ImportedFunction = {"f_id"}, Global = {"g_id"}, + ImportedVar = {"v_id"}, +}) + +typedecl.declare(coder, "coder", "Imported", { + Function = {"f_id", "name"}, + Var = {"v_id", "name"}, }) function Coder:init_upvalues() @@ -683,6 +700,16 @@ function Coder:init_upvalues() table.insert(self.upvalues, coder.Upvalue.Global(g_id)) self.upvalue_of_global[g_id] = #self.upvalues end + + for f_id = 1, #self.module.imported_functions do + table.insert(self.upvalues, coder.Upvalue.ImportedFunction(f_id)) + self.upvalue_of_imported_function[f_id] = #self.upvalues + end + + for v_id = 1, #self.module.imported_vars do + table.insert(self.upvalues, coder.Upvalue.ImportedVar(v_id)) + self.upvalue_of_imported_var[v_id] = #self.upvalues + end end local function upvalue_slot(ix) @@ -704,6 +731,16 @@ function Coder:function_upvalue_slot(f_id) return upvalue_slot(ix) end +function Coder:imported_function_upvalue_slot(f_id) + local ix = assert(self.upvalue_of_imported_function[f_id]) + return upvalue_slot(ix) +end + +function Coder:imported_var_upvalue_slot(v_id) + local ix = assert(self.upvalue_of_imported_var[v_id]) + return upvalue_slot(ix) +end + function Coder:global_upvalue_slot(g_id) local ix = assert(self.upvalue_of_global[g_id]) return upvalue_slot(ix) @@ -1661,6 +1698,7 @@ end function Coder:generate_luaopen_function() local init_constants = {} + local modules = {} for ix, upv in ipairs(self.upvalues) do local tag = upv._tag if tag ~= "coder.Upvalue.Global" then @@ -1683,16 +1721,56 @@ function Coder:generate_luaopen_function() entry_point = self:lua_entry_point_name(upv.f_id), ix = C.integer(self.upvalue_of_function[upv.f_id]), })) + elseif tag == "coder.Upvalue.ImportedFunction" then + local mod_name = self.module.imported_functions[upv.f_id].mod + local field_name = self.module.imported_functions[upv.f_id].name + modules[mod_name] = modules[mod_name] or {} + table.insert(modules[mod_name], coder.Imported.Function(upv.f_id, field_name)) + elseif tag == "coder.Upvalue.ImportedVar" then + local mod_name = self.module.imported_vars[upv.v_id].mod + local field_name = self.module.imported_vars[upv.v_id].name + modules[mod_name] = modules[mod_name] or {} + table.insert(modules[mod_name], coder.Imported.Var(upv.v_id, field_name)) else typedecl.tag_error(tag) end - table.insert(init_constants, util.render([[ - lua_setiuservalue(L, globals, $ix); - /**/ - ]], { - ix = C.integer(ix), - })) + if tag ~= "coder.Upvalue.ImportedFunction" and tag ~= "coder.Upvalue.ImportedVar" then + table.insert(init_constants, util.render([[ + lua_setiuservalue(L, globals, $ix); + /**/ + ]], { + ix = C.integer(ix), + })) + end + end + end + + for mod_name, fields in pairs(modules) do + table.insert(init_constants, util.render([[ + lua_getglobal(L, "require"); + lua_pushstring(L, "$mod_name"); + lua_call(L, 1, 1); + if (PALLENE_UNLIKELY(lua_type(L, -1) != LUA_TTABLE)) + luaL_error(L, "Module not found"); + ]], { mod_name = mod_name })) + for _, field in ipairs(fields) do + local tag = field._tag + local ix + if tag == "coder.Imported.Function" then + ix = C.integer(self.upvalue_of_imported_function[field.f_id]) + elseif tag == "coder.Imported.Var" then + ix = C.integer(self.upvalue_of_imported_var[field.v_id]) + else + typedecl.tag_error(tag) + end + + table.insert(init_constants, util.render([[ + lua_getfield(L, -1, "$field_name"); + if (PALLENE_UNLIKELY(lua_type(L, -1) == LUA_TNIL)) + luaL_error(L, "field %s is nil", "$field_name"); + lua_setiuservalue(L, globals, $ix); + ]], { field_name = field.name, ix = ix })) end end diff --git a/pallene/constant_propagation.lua b/pallene/constant_propagation.lua index 3acd02b8..afca8165 100644 --- a/pallene/constant_propagation.lua +++ b/pallene/constant_propagation.lua @@ -10,13 +10,14 @@ local constant_propagation = {} local function is_constant_value(v) local tag = v._tag - if tag == "ir.Value.Nil" then return true - elseif tag == "ir.Value.Bool" then return true - elseif tag == "ir.Value.Integer" then return true - elseif tag == "ir.Value.Float" then return true - elseif tag == "ir.Value.String" then return true - elseif tag == "ir.Value.LocalVar" then return false - elseif tag == "ir.Value.Function" then return true + if tag == "ir.Value.Nil" then return true + elseif tag == "ir.Value.Bool" then return true + elseif tag == "ir.Value.Integer" then return true + elseif tag == "ir.Value.Float" then return true + elseif tag == "ir.Value.String" then return true + elseif tag == "ir.Value.LocalVar" then return false + elseif tag == "ir.Value.Function" then return true + elseif tag == "ir.Value.ImportedVar" then return true else typedecl.tag_error(tag) end diff --git a/pallene/driver.lua b/pallene/driver.lua index 32225242..4ae71df3 100644 --- a/pallene/driver.lua +++ b/pallene/driver.lua @@ -64,7 +64,7 @@ function driver.compile_internal(filename, input, stop_after, opt_level) return prog_ast, errs end - prog_ast, errs = checker.check(prog_ast) + prog_ast, errs = checker.check(prog_ast, driver) if stop_after == "checker" or not prog_ast then return prog_ast, errs end diff --git a/pallene/ir.lua b/pallene/ir.lua index c591b80d..6fff50fd 100644 --- a/pallene/ir.lua +++ b/pallene/ir.lua @@ -33,6 +33,8 @@ function ir.Module() globals = {}, -- list of ir.VarDecl exported_functions = {}, -- list of function ids exported_globals = {}, -- list of variable ids + imported_functions = {}, -- list of imported functions + imported_vars = {}, -- list of imported variables } end @@ -61,6 +63,14 @@ function ir.Function(loc, name, typ) } end +function ir.ImportedFunction(name, typ, mod) + return { + name = name, -- string + typ = typ, -- Type + mod = mod, -- Module name + } +end + --- --- Mutate modules -- @@ -88,6 +98,12 @@ function ir.add_exported_global(module, g_id) table.insert(module.exported_globals, g_id) end +function ir.add_imported_function(module, mod_name, name, typ) + table.insert(module.imported_functions, ir.ImportedFunction(name, typ, mod_name)) + + return #module.imported_functions +end + -- -- Function variables -- @@ -114,14 +130,16 @@ end -- declare_type("Value", { - Nil = {}, - Bool = {"value"}, - Integer = {"value"}, - Float = {"value"}, - String = {"value"}, - LocalVar = {"id"}, - Upvalue = {"id"}, - Function = {"id"}, + Nil = {}, + Bool = {"value"}, + Integer = {"value"}, + Float = {"value"}, + String = {"value"}, + LocalVar = {"id"}, + Upvalue = {"id"}, + Function = {"id"}, + ImportedFunction = {"id"}, + ImportedVar = {"id"}, }) -- declare_type("Cmd" diff --git a/pallene/print_ir.lua b/pallene/print_ir.lua index 4377dd3a..33f4e356 100644 --- a/pallene/print_ir.lua +++ b/pallene/print_ir.lua @@ -32,20 +32,26 @@ local function Fun(id) end end +local function ImportedFun(id) + return "impf"..id +end + local function Global(id) return "g"..id end local function Val(val) local tag = val._tag - if tag == "ir.Value.Nil" then return "nil" - elseif tag == "ir.Value.Bool" then return tostring(val.value) - elseif tag == "ir.Value.Integer" then return tostring(val.value) - elseif tag == "ir.Value.Float" then return C.float(val.value) - elseif tag == "ir.Value.String" then return C.string(val.value) - elseif tag == "ir.Value.LocalVar" then return Var(val.id) - elseif tag == "ir.Value.Upvalue" then return Upval(val.id) - elseif tag == "ir.Value.Function" then return Fun(val.id) + if tag == "ir.Value.Nil" then return "nil" + elseif tag == "ir.Value.Bool" then return tostring(val.value) + elseif tag == "ir.Value.Integer" then return tostring(val.value) + elseif tag == "ir.Value.Float" then return C.float(val.value) + elseif tag == "ir.Value.String" then return C.string(val.value) + elseif tag == "ir.Value.LocalVar" then return Var(val.id) + elseif tag == "ir.Value.Upvalue" then return Upval(val.id) + elseif tag == "ir.Value.Function" then return Fun(val.id) + elseif tag == "ir.Value.ImportedFunction" then return ImportedFun(val.id) + elseif tag == "ir.Value.ImportedVar" then return ImportedFun(val.id) else typedecl.tag_error(tag) end diff --git a/pallene/to_ir.lua b/pallene/to_ir.lua index 82d5f9c4..46c3b373 100644 --- a/pallene/to_ir.lua +++ b/pallene/to_ir.lua @@ -152,6 +152,17 @@ function ToIR:register_lambda(exp, name) return f_id end +function ToIR:register_imported_function(stat, mod_name) + assert(stat._tag == "ast.FuncStat.FuncStat") + assert(stat.module) + + return ir.add_imported_function(self.module, mod_name, stat.name, stat.value._type) +end + +function ToIR:is_cross_module_call(exp) + return exp.exp.var._mod_name +end + function ToIR:convert_toplevel(prog_ast) -- Create the $init function (it must have ID = 1) @@ -743,6 +754,17 @@ function ToIR:exp_to_value(cmds, exp, _recursive) local var = exp.var if var._tag == "ast.Var.Name" then local def = var._def + if var._mod_name then + if def._tag == "checker.Def.Variable" then + local var_name = def.decl._exported_as + local imported_var = { name = var_name, mod = var._mod_name, typ = var._type } + table.insert(self.module.imported_vars, imported_var) + return ir.Value.ImportedVar(#self.module.imported_vars) + else + local id = self:register_imported_function(def.func, var._mod_name) + return ir.Value.ImportedFunction(id) + end + end if def._tag == "checker.Def.Variable" then local var_info = self:resolve_variable(def.decl) if var_info._tag == "to_ir.Var.LocalVar" then @@ -892,7 +914,8 @@ function ToIR:exp_to_assignment(cmds, dst, exp) local f_val if def and ( def._tag == "checker.Def.Builtin" or - def._tag == "checker.Def.Function") then + def._tag == "checker.Def.Function") and + not self:is_cross_module_call(exp) then f_val = false else f_val = self:exp_to_value(cmds, exp.exp) @@ -932,8 +955,12 @@ function ToIR:exp_to_assignment(cmds, dst, exp) end elseif def and def._tag == "checker.Def.Function" then - local f_id = assert(self.fun_id_of_exp[def.func.value]) - table.insert(cmds, ir.Cmd.CallStatic(loc, f_typ, dsts, f_id, xs)) + if self:is_cross_module_call(exp) then + table.insert(cmds, ir.Cmd.CallDyn(loc, f_typ, dsts, f_val, xs)) + else + local f_id = assert(self.fun_id_of_exp[def.func.value]) + table.insert(cmds, ir.Cmd.CallStatic(loc, f_typ, dsts, f_id, xs)) + end else table.insert(cmds, ir.Cmd.CallDyn(loc, f_typ, dsts, f_val, xs)) diff --git a/spec/checker_spec.lua b/spec/checker_spec.lua index f306bbe9..afffa8be 100644 --- a/spec/checker_spec.lua +++ b/spec/checker_spec.lua @@ -115,6 +115,31 @@ describe("Module", function() ]], "module field 'x' does not exist") end) + it("forbids exported function from having Records as arguments", function() + assert_error([[ + record Point + x: float + y: float + end + function m.f(i: integer, p: Point): integer + return i + end + ]], "Argument number 2 of module function is a record") + end) + + it("forbids exported function from having Records as return values", function() + assert_error([[ + record Point + x: float + y: float + end + function m.f(i: integer): Point + local p: Point = {x = 1.4, y = -2.0} + return p + end + ]], "Return value number 1 of module function is a record") + end) + end) -- @@ -550,7 +575,8 @@ describe("Field acess (dot)", function() x: integer y: integer end - function m.f(p: Point) + local f + function f(p: Point) local _ = p.z end ]], "field 'z' not found in type 'Point'") @@ -562,7 +588,8 @@ describe("Field acess (dot)", function() x: integer y: integer end - function m.f(p: Point) + local f + function f(p: Point) p.x = "hello" end ]], "expected integer but found string in assignment") @@ -920,4 +947,126 @@ describe("Table constructor", function() ]], "type hint for table initializer is not an array, table, or record type") end) + describe("Test imported modules", function() + + local function assert_module_error(code, expected_error) + local program = util.render([[ + typealias point = {x: float, y: float} + local p1: point = {x = 1.0, y = 2.0} + local p2: point = {x = 2.0, y = 1.0} + local test = require"spec.modtest" + function m.f() + $code + end + ]], { code = code }) + + assert_error(program, expected_error) + end + + it("First argument to add must be an integer", function() + assert_module_error([[ + local ret = test.add(1.0, 2) + ]], "expected integer but found float in argument 1 of call to function") + end) + + it("Second argument to add must be an integer", function() + assert_module_error([[ + local ret = test.add(1, 2.0) + ]], "expected integer but found float in argument 2 of call to function") + end) + + it("First argument to sub must be an integer", function() + assert_module_error([[ + local ret = test.sub("wrong arg", 2) + ]], "expected integer but found string in argument 1 of call to function") + end) + + it("First argument to addpoints must be a point", function() + assert_module_error([[ + local ret = test.addpoints(1, p2) + ]], "expected { x: float, y: float } but found integer in argument 1 of call to function") + end) + + it("Second argument to addpoints must be a point", function() + assert_module_error([[ + local ret = test.addpoints(p1, 2) + ]], "expected { x: float, y: float } but found integer in argument 2 of call to function") + end) + + it("Return value of addpoints is a point", function() + assert_module_error([[ + local iret: integer = test.addpoints(p1, p2) + ]], "expected integer but found { x: float, y: float } in declaration of local variable 'iret'") + end) + + it("Wrong number of arguments to addpoints", function() + assert_module_error([[ + local pret: point = test.addpoints(p1, p2, p1) + ]], "function expects 2 argument(s) but received 3") + end) + + it("Call non existent module function", function() + assert_module_error([[ + local ret = test.unknown() + ]], "module field 'unknown' does not exist") + end) + + it("Access not existent module variable", function() + assert_module_error([[ + local i = test.unknown + ]], "module field 'unknown' does not exist") + end) + + it("Forbid assignment of imported module variable", function() + assert_module_error([[ + test.var = 1 + ]], "Can't assign to imported module variables") + end) + + it("Assign imported variable of type integer to local variable of type float", function() + assert_module_error([[ + local varloc: float = test.var + ]], "expected float but found integer in declaration of local variable 'varloc'") + end) + + end) + + describe("Forbid require from being stored in anything but a toplevel local variable", function() + + local function assert_module_error(code, expected_error) + local program = util.render([[ + local dummy + function dummy(s: any): any + return s + end + $code + ]], { code = code }) + + assert_error(program, expected_error) + end + + it("Forbid require from being called in a function argument", function() + assert_module_error([[ + local test = dummy(require"test") + ]], "Can only call require from a local variable declaration") + end) + + it("Forbid require from being called indirectly", function() + assert_module_error([[ + local dummy2 + function dummy2(s: string): any + return require(s) + end + ]], "Can only call require from a local variable declaration") + end) + + --TODO: Improve the error message for this case + it("Forbid multiple requires in one line", function() + assert_module_error([[ + local m1, m2 = require'mod1', require'spec.modtest' + ]], "Can only call require from a local variable declaration") + end) + + end) + end) diff --git a/spec/coder_spec.lua b/spec/coder_spec.lua index cdd57f7b..336570d4 100644 --- a/spec/coder_spec.lua +++ b/spec/coder_spec.lua @@ -7,13 +7,21 @@ local execution_tests = require "spec.execution_tests" -- This is because those test cases are used for both the C backend and the Lua backend. -- -local function compile(filename, pallene_code) - assert(util.set_file_contents(filename, pallene_code)) +local function compile_file(filename) local cmd = string.format("./pallenec %s", util.shell_quote(filename)) local ok, _, _, errmsg = util.outputs_of_execute(cmd) assert(ok, errmsg) end +local function compile_code(filename, pallene_code) + assert(util.set_file_contents(filename, pallene_code)) + compile_file(filename) +end + +do -- Compile modtest + compile_file("spec/modtest.pln") +end + describe("#c_backend /", function () - execution_tests.run(compile, 'c', _ENV, false) + execution_tests.run(compile_code, 'c', _ENV, false) end) diff --git a/spec/execution_tests.lua b/spec/execution_tests.lua index 2b620054..08f79593 100644 --- a/spec/execution_tests.lua +++ b/spec/execution_tests.lua @@ -648,7 +648,8 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) y: float end - function m.points(): {Point} + local points + function points(): {Point} return { { x = 1.0, y = 2.0 }, { x = 1.0, y = 2.0 }, @@ -656,36 +657,53 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) } end - function m.eq_point(p: Point, q: Point): boolean + local eq_point + function eq_point(p: Point, q: Point): boolean return p == q end - function m.ne_point(p: Point, q: Point): boolean + local ne_point + function ne_point(p: Point, q: Point): boolean return p ~= q end - ]]) - it("==", function() - run_test([[ - local p = test.points() + function m.eq_points(): boolean + local p = points() for i = 1, #p do for j = 1, #p do local ok = (i == j) - assert(ok == test.eq_point(p[i], p[j])) + if ok ~= eq_point(p[i], p[j]) then + return false + end end end - ]]) - end) + return true + end - it("~=", function() - run_test([[ - local p = test.points() + function m.ne_points(): boolean + local p = points() for i = 1, #p do for j = 1, #p do local ok = (i ~= j) - assert(ok == test.ne_point(p[i], p[j])) + if ok ~= ne_point(p[i], p[j]) then + return false + end end end + return true + end + + ]]) + + it("==", function() + run_test([[ + assert(test.eq_points()) + ]]) + end) + + it("~=", function() + run_test([[ + assert(test.ne_points()) ]]) end) end) @@ -697,17 +715,17 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) { "integer" , "integer", "17"}, { "float" , "float", "3.14"}, { "string" , "string", "'hello'"}, - { "function" , "integer->string", "tostring"}, { "array" , "{integer}", "{10,20}"}, { "table" , "{x: integer}", "{x = 1}"}, - { "record" , "Empty", "test.new_empty()"}, + { "record" , "Empty", "new_empty()"}, { "any" , "any", "17"}, } local record_decls = [[ record Empty end - function m.new_empty(): Empty + local new_empty + function new_empty(): Empty return {} end ]] @@ -720,35 +738,59 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) local name, typ, value = test[1], test[2], test[3] pallene_code[i] = util.render([[ - function m.from_${name}(x: ${typ}): any + local from_${name} + function from_${name}(x: ${typ}): any return (x as any) end - function m.to_${name}(x: any): ${typ} + local to_${name} + function to_${name}(x: any): ${typ} return (x as ${typ}) end + + function m.from_${name}(): boolean + local x: $typ = ${value} + return (x as any) == from_${name}(x) + end + + function m.to_${name}(): boolean + local x: $typ = ${value} + return x == to_${name}(x) + end + ]], { name = name, typ = typ, + value = value, }) test_to[name] = util.render([[ - local x = ${value} - assert(x == test.from_${name}(x)) + assert(test.to_${name}()) ]], { name = name, - value = value }) test_from[name] = util.render([[ - local x = ${value} - assert(x == test.from_${name}(x)) + assert(test.from_${name}()) ]], { name = name, - value = value }) end + pallene_code[#pallene_code + 1] = [[ + function m.to_integer_dummy(a: any): integer + return (a as integer) + end + + function m.to_function(x: any): (integer->string) + return (x as integer->string) + end + + function m.from_function(x: integer->string): any + return (x as any) + end + ]] + compile( record_decls .. "\n" .. table.concat(pallene_code, "\n") @@ -764,9 +806,21 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) it("any->" .. name, function() run_test(test_from[name]) end) end + it("any -> function", function() run_test([[ + local x = tostring + assert(x == test.to_function(x)) + ]]) + end) + + it("function -> any", function() run_test([[ + local x = tostring + assert(x == test.from_function(x)) + ]]) + end) + it("detects downcast error", function() run_test([[ - assert_pallene_error("wrong type for downcasted value", test.to_integer, "hello") + assert_pallene_error("wrong type for downcasted value", test.to_integer_dummy, "hello") ]]) end) end) @@ -1235,52 +1289,50 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) compile([[ typealias Float = float typealias FLOAT = float - function m.Float2float(x: Float): float return x end function m.float2Float(x: float): Float return x end function m.Float2FLOAT(x: Float): FLOAT return x end - record point x: Float end - typealias Point = point typealias Points = {Point} - function m.newPoint(x: Float): Point + local newPoint + function newPoint(x: Float): Point return {x = x} end - - function m.get(p: Point): FLOAT + local get + function get(p: Point): FLOAT return p.x end - - function m.addPoint(ps: Points, p: Point) + local addPoint + function addPoint(ps: Points, p: Point) ps[#ps + 1] = p end - ]]) - it("converts between typealiases of the same type", function() - run_test([[ - assert(1.1 == test.Float2float(1.1)) - assert(1.1 == test.float2Float(1.1)) - assert(1.1 == test.Float2FLOAT(1.1)) - ]]) - end) + function m.get(): FLOAT + local p = newPoint(1.1) + return get(p) + end + + function m.addPoint(): boolean + local p = newPoint(1.1) + local ps: Points = {} + addPoint(ps, p) + return p == ps[1] + end + ]]) it("creates a records with typealiases", function() run_test([[ - local p = test.newPoint(1.1) - assert(1.1 == test.get(p)) + assert(1.1 == test.get()) ]]) end) it("manipulates typealias of an array", function() run_test([[ - local p = test.newPoint(1.1) - local ps = {} - test.addPoint(ps, p) - assert(p == ps[1]) + assert(test.addPoint()) ]]) end) end) @@ -1292,117 +1344,40 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) y: {integer} end - function m.make_foo(x: integer, y: {integer}): Foo + local make_foo + function make_foo(x: integer, y: {integer}): Foo return { x = x, y = y } end - function m.get_x(foo: Foo): integer - return foo.x - end - - function m.set_x(foo: Foo, x: integer) - foo.x = x + function m.getset(): boolean + local foo = make_foo(123, {}) + local res1 = 123 == foo.x + foo.x = 456 + return res1 and foo.x == 456 end - function m.get_y(foo: Foo): {integer} - return foo.y - end - - function m.set_y(foo: Foo, y: {integer}) - foo.y = y - end - - record Prim - x: integer - end - - function m.make_prim(x: integer): Prim - return { x = x } - end - - record Gc - x: {integer} - end - - function m.make_gc(x: {integer}): Gc - return { x = x } - end - - record Empty - end - - function m.make_empty(): Empty - return {} + function m.getset_gc(): boolean + local a: {integer} = {} + local b: {integer} = {} + local foo = make_foo(123, a) + local res1 = a == foo.y + foo.y = b + return res1 and foo.y == b end ]]) - it("create records", function() - run_test([[ - local x = test.make_foo(123, {}) - assert_is_pallene_record(x) - ]]) - end) - it("get/set primitive fields in pallene", function() run_test([[ - local foo = test.make_foo(123, {}) - assert(123 == test.get_x(foo)) - test.set_x(foo, 456) - assert(456 == test.get_x(foo)) + assert(test.getset()) ]]) end) it("get/set gc fields in pallene", function() run_test([[ - local a = {} - local b = {} - local foo = test.make_foo(123, a) - assert(a == test.get_y(foo)) - test.set_y(foo, b) - assert(b == test.get_y(foo)) - ]]) - end) - - it("create records with only primitive fields", function() - run_test([[ - local x = test.make_prim(123) - assert_is_pallene_record(x) - ]]) - end) - - it("create records with only gc fields", function() - run_test([[ - local x = test.make_gc({}) - assert_is_pallene_record(x) + assert(test.getset_gc()) ]]) end) - it("create empty records", function() - run_test([[ - local x = test.make_empty() - assert_is_pallene_record(x) - ]]) - end) - - it("check record tags", function() - -- TODO: change this message to mention the relevant record types - -- instead of only saying "userdata" - run_test([[ - local prim = test.make_prim(123) - assert_pallene_error("expected userdata but found userdata", test.get_x, prim) - ]]) - end) - - -- The follow test case is special. Therefore, we manually check the backend we are testing - -- before executing it. - if backend == "c" then - it("protect record metatables", function() - run_test([[ - local x = test.make_prim(123) - assert(getmetatable(x) == false) - ]]) - end) - end end) describe("I/O", function() @@ -2066,11 +2041,13 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) y: integer end - function m.new_rpoint(x:integer, y:integer): RPoint + local new_rpoint + function new_rpoint(x:integer, y:integer): RPoint return {x = x, y = y} end - function m.get_rpoint_fields(p:RPoint): (integer, integer) + local get_rpoint_fields + function get_rpoint_fields(p:RPoint): (integer, integer) return p.x, p.y end @@ -2127,16 +2104,34 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) return a, b end - function m.assign_recs_1(a:RPoint, b:RPoint, c:integer, d:integer): (RPoint, RPoint) + local assign_recs_1 + function assign_recs_1(a:RPoint, b:RPoint, c:integer, d:integer): (RPoint, RPoint) a, a.x, a.y = b, c, d return a, b end - function m.assign_recs_2(a:RPoint, b:RPoint, c:integer, d:integer): (RPoint, RPoint) + function m.assign_recs_1(): boolean + local a, b = new_rpoint(10, 20), new_rpoint(30, 40) + local a2, b2 = assign_recs_1(a, b, 50, 60) + local ax, ay = get_rpoint_fields(a) + local bx, by = get_rpoint_fields(b) + return b == a2 and b == b2 and 50 == ax and 60 == ay and 30 == bx and 40 == by + end + + local assign_recs_2 + function assign_recs_2(a:RPoint, b:RPoint, c:integer, d:integer): (RPoint, RPoint) a.x, a.y, a = c, d, b return a, b end + function m.assign_recs_2(): boolean + local a, b = new_rpoint(10, 20), new_rpoint(30, 40) + local a2, b2 = assign_recs_2(a, b, 50, 60) + local ax, ay = get_rpoint_fields(a) + local bx, by = get_rpoint_fields(b) + return a2 == b and b2 == b and 50 == ax and 60 == ay and 30 == bx and 40 == by + end + function m.assign_same_var(): integer local a:integer a, a, a = 10, 20, 30 @@ -2266,33 +2261,13 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) it("use temporary variables correctly on records assignments 1", function() run_test([[ - local a, b = test.new_rpoint(10, 20), test.new_rpoint(30, 40) - local t = table.pack(test.assign_recs_1(a, b, 50, 60)) - local ax, ay = test.get_rpoint_fields(a) - local bx, by = test.get_rpoint_fields(b) - assert(2 == t.n) - assert(b == t[1]) - assert(b == t[2]) - assert(50 == ax) - assert(60 == ay) - assert(30 == bx) - assert(40 == by) + assert(test.assign_recs_1()) ]]) end) it("use temporary variables correctly on records assignments 2", function() run_test([[ - local a, b = test.new_rpoint(10, 20), test.new_rpoint(30, 40) - local t = table.pack(test.assign_recs_2(a, b, 50, 60)) - local ax, ay = test.get_rpoint_fields(a) - local bx, by = test.get_rpoint_fields(b) - assert(2 == t.n) - assert(b == t[1]) - assert(b == t[2]) - assert(50 == ax) - assert(60 == ay) - assert(30 == bx) - assert(40 == by) + assert(test.assign_recs_2()) ]]) end) @@ -2571,6 +2546,78 @@ function execution_tests.run(compile_file, backend, _ENV, only_compile) run_test([[ assert(1 == test.f()) ]]) end) end) + + describe("Test imported module", function() + compile([[ + typealias point = {x: float, y: float} + local modtest = require"spec.modtest" + + function m.add(x1: integer, x2: integer): integer + return modtest.add(x1, x2) + end + + function m.sub(x1: integer, x2: integer): integer + return modtest.sub(x1, x2) + end + + function m.get_var(): integer + return modtest.var + end + + function m.addpoints(x1: float, y1: float, x2: float, y2: float): (float, float) + local p1: point = {x = x1, y = y1} + local p2: point = {x = x2, y = y2} + + local pret: point = modtest.addpoints(p1, p2) + + return pret.x, pret.y + end + local require + function require(s: string): string + return s + end + + function m.require(s: string): string + return require(s) + end + ]]) + + it("works correctly with a add", function() + run_test([[ + local x = test.add(1, 29) + assert(x == 30) + ]]) + end) + + it("works correctly with a sub", function() + run_test([[ + local x = test.sub(1, 29) + assert(x == -28) + ]]) + end) + + it("works getting module var", function() + run_test([[ + local x = test.get_var() + assert(x == 1) + ]]) + end) + + it("works correctly with a table as argument", function() + run_test([[ + local x, y = test.addpoints(1.0, 2.7, 2.0, 2.0) + assert(x == 3.0) + assert(y == 4.7) + ]]) + end) + + it("Assert require can be shadowed", function() + run_test([[ + assert(assert(test.require("test") == "test")) + ]]) + end) + + end) end return execution_tests diff --git a/spec/modtest.pln b/spec/modtest.pln new file mode 100644 index 00000000..aa7bc9a5 --- /dev/null +++ b/spec/modtest.pln @@ -0,0 +1,20 @@ +local m: module = {} + +typealias point = {x: float, y: float} + +function m.add(x1: integer, x2: integer): integer + return x1 + x2 +end + +function m.sub(x1: integer, x2: integer): integer + return x1 - x2 +end + +function m.addpoints(p1: point, p2: point): point + local padd: point = {x = p1.x + p2.x, y = p1.y + p2.y} + return padd +end + +m.var = 1 + +return m