Skip to content

gh-128760: Merge subsequent escaping calls to avoid redundant stack spills when escaping calls are separated by comments #128761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
43 changes: 43 additions & 0 deletions Lib/test/test_generated_cases.py
Original file line number Diff line number Diff line change
@@ -1674,6 +1674,49 @@ def test_pystackref_frompyobject_new_next_to_cmacro(self):
"""
self.run_cases_test(input, output)

def test_escaping_call_next_to_comment(self):
input = """
inst(OP, (--)) {
// comments
/* before */
escaping_call();
// comment
another_escaping_call();
/* another comment */
yet_another_escaping_call();
// couple
/* more */
// comments
final_escaping_call();
// comments
/* after */
}
"""
output = """
TARGET(OP) {
frame->instr_ptr = next_instr;
next_instr += 1;
INSTRUCTION_STATS(OP);
// comments
/* before */
_PyFrame_SetStackPointer(frame, stack_pointer);
escaping_call();
// comment
another_escaping_call();
/* another comment */
yet_another_escaping_call();
// couple
/* more */
// comments
final_escaping_call();
stack_pointer = _PyFrame_GetStackPointer(frame);
// comments
/* after */
DISPATCH();
}
"""
self.run_cases_test(input, output)

def test_pop_input(self):
input = """
inst(OP, (a, b --)) {
34 changes: 25 additions & 9 deletions Tools/cases_generator/analyzer.py
Original file line number Diff line number Diff line change
@@ -638,7 +638,7 @@ def has_error_without_pop(op: parser.InstDef) -> bool:
"restart_backoff_counter",
)

def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
def find_stmt_start(node: parser.InstDef, idx: int) -> int:
assert idx < len(node.block.tokens)
while True:
tkn = node.block.tokens[idx-1]
@@ -648,18 +648,19 @@ def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
assert idx > 0
while node.block.tokens[idx].kind == "COMMENT":
idx += 1
return node.block.tokens[idx]
return idx


def find_stmt_end(node: parser.InstDef, idx: int) -> lexer.Token:
def find_stmt_end(node: parser.InstDef, idx: int) -> int:
assert idx < len(node.block.tokens)
while True:
idx += 1
tkn = node.block.tokens[idx]
if tkn.kind == "SEMI":
return node.block.tokens[idx+1]
return idx+1

def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, tuple[lexer.Token, lexer.Token]]) -> None:
def check_escaping_calls(instr: parser.InstDef, escapes: list[tuple[int, int, int]]) -> dict[lexer.Token, tuple[lexer.Token, lexer.Token]]:
escapes = merge_subsequent_escaping_calls(instr.block.tokens, escapes)
calls = {escapes[t][0] for t in escapes}
in_if = 0
tkn_iter = iter(instr.block.tokens)
@@ -677,9 +678,25 @@ def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, tuple
in_if -= 1
elif tkn in calls and in_if:
raise analysis_error(f"Escaping call '{tkn.text} in condition", tkn)
return escapes

def merge_subsequent_escaping_calls(tokens: list[lexer.Token], escapes: list[tuple[int, int, int]]) -> dict[lexer.Token, tuple[lexer.Token, lexer.Token]]:
if not escapes:
return {}
merged: list[list[int]] = [[*escapes[0]]]
for idx, start, end in escapes[1:]:
curr_start = start
prev_end = merged[-1][2]
while curr_start > prev_end and tokens[curr_start-1].kind == lexer.COMMENT:
curr_start -= 1
if curr_start <= prev_end:
merged[-1][2] = max(prev_end, end)
else:
merged.append([idx, start, end])
return {tokens[start]: (tokens[idx], tokens[end]) for idx, start, end in merged}

def find_escaping_api_calls(instr: parser.InstDef) -> dict[lexer.Token, tuple[lexer.Token, lexer.Token]]:
result: dict[lexer.Token, tuple[lexer.Token, lexer.Token]] = {}
result: list[tuple[int, int, int]] = []
tokens = instr.block.tokens
for idx, tkn in enumerate(tokens):
try:
@@ -716,9 +733,8 @@ def find_escaping_api_calls(instr: parser.InstDef) -> dict[lexer.Token, tuple[le
continue
start = find_stmt_start(instr, idx)
end = find_stmt_end(instr, idx)
result[start] = tkn, end
check_escaping_calls(instr, result)
return result
result.append((idx, start, end))
return check_escaping_calls(instr, result)


EXITS = {