diff --git a/py/dml/codegen.py b/py/dml/codegen.py index 5674999f0..b309b209e 100644 --- a/py/dml/codegen.py +++ b/py/dml/codegen.py @@ -3769,6 +3769,21 @@ def codegen_method_func(func): if defined(e) else e) inline_scope.add(ExpressionSymbol(name, inlined_arg, method.site)) inp = [(n, t) for (n, t) in func.inp if isinstance(t, DMLType)] + if indices: + out_of_bounds = ' || '.join( + f'_idx{i} >= {dimsize}' + for (i, dimsize) in enumerate(method.dimsizes)) + reset = ' = '.join(f'_idx{i}' for i in range(method.dimensions)) + # squeeze in index validation code on one line, to avoid polluting + # code coverage analysis + validation = ( + f'if (unlikely({out_of_bounds})) {{' + f' {reset} = 0; _DML_fault(' + f'"{ctree.quote_filename(method.site.filename())}",' + f' {method.site.lineno},' + ' "indices out of bounds in method call"); }') + else: + validation = None with ErrorContext(method): location = Location(method, indices) @@ -3781,7 +3796,8 @@ def codegen_method_func(func): method.site, inp, func.outp, func.throws, func.independent, memoization, method.astcode, method.default_method.default_sym(indices), - location, inline_scope, method.rbrace_site) + location, inline_scope, method.rbrace_site, + validation) return code def codegen_return(site, outp, throws, retvals): @@ -3814,7 +3830,7 @@ def codegen_return(site, outp, throws, retvals): return mkCompound(site, stmts) def codegen_method(site, inp, outp, throws, independent, memoization, ast, - default, location, fnscope, rbrace_site): + default, location, fnscope, rbrace_site, validation=None): with (crep.DeviceInstanceContext() if not independent else contextlib.nullcontext()): for (arg, etype) in inp: @@ -3867,6 +3883,8 @@ def prelude(): code.append(mkAssignStatement(site, param, init)) else: code = [] + if validation: + code.append(mkInline(site, validation)) with fail_handler, exit_handler: code.append(codegen_statement(ast, location, fnscope)) @@ -3882,6 +3900,8 @@ def prelude(): [subs] = ast.args with fail_handler, exit_handler: body = prelude() + if validation: + body.append(mkInline(site, validation)) body.extend(codegen_statements(subs, location, fnscope)) code = mkCompound(site, body) if code.control_flow().fallthrough: