@@ -638,7 +638,7 @@ def has_error_without_pop(op: parser.InstDef) -> bool:
638638 "restart_backoff_counter" ,
639639)
640640
641- def find_stmt_start (node : parser .InstDef , idx : int ) -> lexer . Token :
641+ def find_stmt_start (node : parser .InstDef , idx : int ) -> int :
642642 assert idx < len (node .block .tokens )
643643 while True :
644644 tkn = node .block .tokens [idx - 1 ]
@@ -648,23 +648,19 @@ def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
648648 assert idx > 0
649649 while node .block .tokens [idx ].kind == "COMMENT" :
650650 idx += 1
651- return node . block . tokens [ idx ]
651+ return idx
652652
653653
654- def find_stmt_end (node : parser .InstDef , idx : int ) -> lexer .Token :
655- tokens = node .block .tokens
656- assert idx < len (tokens )
654+ def find_stmt_end (node : parser .InstDef , idx : int ) -> int :
655+ assert idx < len (node .block .tokens )
657656 while True :
658657 idx += 1
659- tkn = tokens [idx ]
658+ tkn = node . block . tokens [idx ]
660659 if tkn .kind == "SEMI" :
661- end = idx + 1
662- while end < len (tokens ) and tokens [end ].kind == lexer .COMMENT :
663- end += 1
664- assert end < len (tokens )
665- return tokens [end ]
660+ return idx + 1
666661
667- def check_escaping_calls (instr : parser .InstDef , escapes : dict [lexer .Token , tuple [lexer .Token , lexer .Token ]]) -> None :
662+ def check_escaping_calls (instr : parser .InstDef , escapes : list [list [int ]]) -> dict [lexer .Token , tuple [lexer .Token , lexer .Token ]]:
663+ escapes = merge_subsequent_escaping_calls (instr .block .tokens , escapes )
668664 calls = {escapes [t ][0 ] for t in escapes }
669665 in_if = 0
670666 tkn_iter = iter (instr .block .tokens )
@@ -682,9 +678,25 @@ def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, tuple
682678 in_if -= 1
683679 elif tkn in calls and in_if :
684680 raise analysis_error (f"Escaping call '{ tkn .text } in condition" , tkn )
681+ return escapes
682+
683+ def merge_subsequent_escaping_calls (tokens : list [lexer .Token ], escapes : list [list [int ]]) -> dict [lexer .Token , tuple [lexer .Token , lexer .Token ]]:
684+ if not escapes :
685+ return escapes
686+ merged = [[* escapes [0 ]]]
687+ for idx , start , end in escapes [1 :]:
688+ curr_start = start
689+ prev_end = merged [- 1 ][2 ]
690+ while curr_start > prev_end and tokens [curr_start - 1 ].kind == lexer .COMMENT :
691+ curr_start -= 1
692+ if curr_start <= prev_end :
693+ merged [- 1 ][2 ] = max (prev_end , end )
694+ else :
695+ merged .append ([idx , start , end ])
696+ return {tokens [start ]: (tokens [idx ], tokens [end ]) for idx , start , end in merged }
685697
686698def find_escaping_api_calls (instr : parser .InstDef ) -> dict [lexer .Token , tuple [lexer .Token , lexer .Token ]]:
687- result : dict [ lexer . Token , tuple [ lexer . Token , lexer . Token ]] = {}
699+ result : list [ list [ int ]] = []
688700 tokens = instr .block .tokens
689701 for idx , tkn in enumerate (tokens ):
690702 try :
@@ -721,9 +733,8 @@ def find_escaping_api_calls(instr: parser.InstDef) -> dict[lexer.Token, tuple[le
721733 continue
722734 start = find_stmt_start (instr , idx )
723735 end = find_stmt_end (instr , idx )
724- result [start ] = tkn , end
725- check_escaping_calls (instr , result )
726- return result
736+ result .append ((idx , start , end ))
737+ return check_escaping_calls (instr , result )
727738
728739
729740EXITS = {
0 commit comments