Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions Lib/test/test_generated_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,49 @@ def test_pystackref_frompyobject_new_next_to_cmacro(self):
"""
self.run_cases_test(input, output)

def test_escaping_call_next_to_comment(self):
input = """
inst(OP, (--)) {
// comments
/* before */
escaping_call();
// comment
another_escaping_call();
/* another comment */
yet_another_escaping_call();
// couple
/* more */
// comments
final_escaping_call();
// comments
/* after */
}
"""
output = """
TARGET(OP) {
frame->instr_ptr = next_instr;
next_instr += 1;
INSTRUCTION_STATS(OP);
// comments
/* before */
_PyFrame_SetStackPointer(frame, stack_pointer);
escaping_call();
// comment
another_escaping_call();
/* another comment */
yet_another_escaping_call();
// couple
/* more */
// comments
final_escaping_call();
stack_pointer = _PyFrame_GetStackPointer(frame);
// comments
/* after */
DISPATCH();
}
"""
self.run_cases_test(input, output)

def test_pop_input(self):
input = """
inst(OP, (a, b --)) {
Expand Down
34 changes: 25 additions & 9 deletions Tools/cases_generator/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def has_error_without_pop(op: parser.InstDef) -> bool:
"restart_backoff_counter",
)

def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
def find_stmt_start(node: parser.InstDef, idx: int) -> int:
assert idx < len(node.block.tokens)
while True:
tkn = node.block.tokens[idx-1]
Expand All @@ -648,18 +648,19 @@ def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
assert idx > 0
while node.block.tokens[idx].kind == "COMMENT":
idx += 1
return node.block.tokens[idx]
return idx


def find_stmt_end(node: parser.InstDef, idx: int) -> lexer.Token:
def find_stmt_end(node: parser.InstDef, idx: int) -> int:
assert idx < len(node.block.tokens)
while True:
idx += 1
tkn = node.block.tokens[idx]
if tkn.kind == "SEMI":
return node.block.tokens[idx+1]
return idx+1

def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, tuple[lexer.Token, lexer.Token]]) -> None:
def check_escaping_calls(instr: parser.InstDef, escapes: list[tuple[int, int, int]]) -> dict[lexer.Token, tuple[lexer.Token, lexer.Token]]:
escapes = merge_subsequent_escaping_calls(instr.block.tokens, escapes)
calls = {escapes[t][0] for t in escapes}
in_if = 0
tkn_iter = iter(instr.block.tokens)
Expand All @@ -677,9 +678,25 @@ def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, tuple
in_if -= 1
elif tkn in calls and in_if:
raise analysis_error(f"Escaping call '{tkn.text} in condition", tkn)
return escapes

def merge_subsequent_escaping_calls(tokens: list[lexer.Token], escapes: list[tuple[int, int, int]]) -> dict[lexer.Token, tuple[lexer.Token, lexer.Token]]:
if not escapes:
return {}
merged: list[list[int]] = [[*escapes[0]]]
for idx, start, end in escapes[1:]:
curr_start = start
prev_end = merged[-1][2]
while curr_start > prev_end and tokens[curr_start-1].kind == lexer.COMMENT:
curr_start -= 1
if curr_start <= prev_end:
merged[-1][2] = max(prev_end, end)
else:
merged.append([idx, start, end])
return {tokens[start]: (tokens[idx], tokens[end]) for idx, start, end in merged}

def find_escaping_api_calls(instr: parser.InstDef) -> dict[lexer.Token, tuple[lexer.Token, lexer.Token]]:
result: dict[lexer.Token, tuple[lexer.Token, lexer.Token]] = {}
result: list[tuple[int, int, int]] = []
tokens = instr.block.tokens
for idx, tkn in enumerate(tokens):
try:
Expand Down Expand Up @@ -716,9 +733,8 @@ def find_escaping_api_calls(instr: parser.InstDef) -> dict[lexer.Token, tuple[le
continue
start = find_stmt_start(instr, idx)
end = find_stmt_end(instr, idx)
result[start] = tkn, end
check_escaping_calls(instr, result)
return result
result.append((idx, start, end))
return check_escaping_calls(instr, result)


EXITS = {
Expand Down
Loading