@@ -638,7 +638,7 @@ def has_error_without_pop(op: parser.InstDef) -> bool:
638
638
"restart_backoff_counter" ,
639
639
)
640
640
641
- def find_stmt_start (node : parser .InstDef , idx : int ) -> lexer . Token :
641
+ def find_stmt_start (node : parser .InstDef , idx : int ) -> int :
642
642
assert idx < len (node .block .tokens )
643
643
while True :
644
644
tkn = node .block .tokens [idx - 1 ]
@@ -648,23 +648,19 @@ def find_stmt_start(node: parser.InstDef, idx: int) -> lexer.Token:
648
648
assert idx > 0
649
649
while node .block .tokens [idx ].kind == "COMMENT" :
650
650
idx += 1
651
- return node . block . tokens [ idx ]
651
+ return idx
652
652
653
653
654
- def find_stmt_end (node : parser .InstDef , idx : int ) -> lexer .Token :
655
- tokens = node .block .tokens
656
- assert idx < len (tokens )
654
+ def find_stmt_end (node : parser .InstDef , idx : int ) -> int :
655
+ assert idx < len (node .block .tokens )
657
656
while True :
658
657
idx += 1
659
- tkn = tokens [idx ]
658
+ tkn = node . block . tokens [idx ]
660
659
if tkn .kind == "SEMI" :
661
- end = idx + 1
662
- while end < len (tokens ) and tokens [end ].kind == lexer .COMMENT :
663
- end += 1
664
- assert end < len (tokens )
665
- return tokens [end ]
660
+ return idx + 1
666
661
667
- def check_escaping_calls (instr : parser .InstDef , escapes : dict [lexer .Token , tuple [lexer .Token , lexer .Token ]]) -> None :
662
+ def check_escaping_calls (instr : parser .InstDef , escapes : list [list [int ]]) -> dict [lexer .Token , tuple [lexer .Token , lexer .Token ]]:
663
+ escapes = merge_subsequent_escaping_calls (instr .block .tokens , escapes )
668
664
calls = {escapes [t ][0 ] for t in escapes }
669
665
in_if = 0
670
666
tkn_iter = iter (instr .block .tokens )
@@ -682,9 +678,25 @@ def check_escaping_calls(instr: parser.InstDef, escapes: dict[lexer.Token, tuple
682
678
in_if -= 1
683
679
elif tkn in calls and in_if :
684
680
raise analysis_error (f"Escaping call '{ tkn .text } in condition" , tkn )
681
+ return escapes
682
+
683
+ def merge_subsequent_escaping_calls (tokens : list [lexer .Token ], escapes : list [list [int ]]) -> dict [lexer .Token , tuple [lexer .Token , lexer .Token ]]:
684
+ if not escapes :
685
+ return escapes
686
+ merged = [[* escapes [0 ]]]
687
+ for idx , start , end in escapes [1 :]:
688
+ curr_start = start
689
+ prev_end = merged [- 1 ][2 ]
690
+ while curr_start > prev_end and tokens [curr_start - 1 ].kind == lexer .COMMENT :
691
+ curr_start -= 1
692
+ if curr_start <= prev_end :
693
+ merged [- 1 ][2 ] = max (prev_end , end )
694
+ else :
695
+ merged .append ([idx , start , end ])
696
+ return {tokens [start ]: (tokens [idx ], tokens [end ]) for idx , start , end in merged }
685
697
686
698
def find_escaping_api_calls (instr : parser .InstDef ) -> dict [lexer .Token , tuple [lexer .Token , lexer .Token ]]:
687
- result : dict [ lexer . Token , tuple [ lexer . Token , lexer . Token ]] = {}
699
+ result : list [ list [ int ]] = []
688
700
tokens = instr .block .tokens
689
701
for idx , tkn in enumerate (tokens ):
690
702
try :
@@ -721,9 +733,8 @@ def find_escaping_api_calls(instr: parser.InstDef) -> dict[lexer.Token, tuple[le
721
733
continue
722
734
start = find_stmt_start (instr , idx )
723
735
end = find_stmt_end (instr , idx )
724
- result [start ] = tkn , end
725
- check_escaping_calls (instr , result )
726
- return result
736
+ result .append ((idx , start , end ))
737
+ return check_escaping_calls (instr , result )
727
738
728
739
729
740
EXITS = {
0 commit comments