Skip to content

Commit 910e1c1

Browse files
Ashutosh0xwukath
authored andcommitted
fix: prevent ReDoS in code block extraction
Merge #6118 ## Summary - Replace regular expression-based code block extraction with a simple and safe string-find based search. This avoids exponential backtracking (ReDoS) when processing long or repeating inputs with missing trailing delimiters. - Add unit tests to verify standard behavior and test against ReDoS vulnerability. Co-authored-by: Kathy Wu <wukathy@google.com> PiperOrigin-RevId: 933834549
1 parent 5cfef01 commit 910e1c1

2 files changed

Lines changed: 204 additions & 16 deletions

File tree

src/google/adk/code_executors/code_execution_utils.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import binascii
2121
import copy
2222
import dataclasses
23-
import re
2423
from typing import List
2524
from typing import Optional
2625

@@ -145,31 +144,41 @@ def extract_code_and_truncate_content(
145144
first_text_part = copy.deepcopy(text_parts[0])
146145
response_text = '\n'.join([p.text for p in text_parts])
147146

148-
# Find the first code block.
149-
leading_delimiter_pattern = '|'.join(d[0] for d in code_block_delimiters)
150-
trailing_delimiter_pattern = '|'.join(d[1] for d in code_block_delimiters)
151-
pattern = re.compile(
152-
(
153-
rf'(?P<prefix>.*?)({leading_delimiter_pattern})(?P<code>.*?)({trailing_delimiter_pattern})(?P<suffix>.*?)$'
154-
).encode(),
155-
re.DOTALL,
156-
)
157-
pattern_match = pattern.search(response_text.encode())
158-
if pattern_match is None:
147+
# Find the first code block using simple string search
148+
best_start = -1
149+
best_end = -1
150+
best_lead_len = 0
151+
152+
for lead, trail in code_block_delimiters:
153+
start_idx = response_text.find(lead)
154+
if start_idx == -1:
155+
continue
156+
code_start = start_idx + len(lead)
157+
end_idx = response_text.find(trail, code_start)
158+
if end_idx == -1:
159+
continue
160+
# Pick the earliest occurring code block.
161+
if best_start == -1 or start_idx < best_start:
162+
best_start = start_idx
163+
best_end = end_idx
164+
best_lead_len = len(lead)
165+
166+
if best_start == -1:
159167
return
160168

161-
code_str = pattern_match.group('code').decode()
169+
code_str = response_text[best_start + best_lead_len : best_end]
162170
if not code_str:
163171
return
164172

165173
content.parts = []
166-
if pattern_match.group('prefix'):
167-
first_text_part.text = pattern_match.group('prefix').decode()
174+
prefix_text = response_text[:best_start]
175+
if prefix_text:
176+
first_text_part.text = prefix_text
168177
content.parts.append(first_text_part)
169178
content.parts.append(
170179
CodeExecutionUtils.build_executable_code_part(code_str)
171180
)
172-
return pattern_match.group('code').decode()
181+
return code_str
173182

174183
@staticmethod
175184
def build_executable_code_part(code: str) -> types.Part:
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import signal
16+
17+
from google.adk.code_executors import code_execution_utils
18+
from google.genai import types
19+
20+
21+
def test_extract_code_and_truncate_content_basic():
22+
"""Tests basic code extraction and content truncation."""
23+
content = types.Content(
24+
role="model",
25+
parts=[
26+
types.Part(
27+
text=(
28+
"Here is some code:\n```python\nx = 1\n```\nAnd some text"
29+
" after."
30+
)
31+
)
32+
],
33+
)
34+
delimiters = [("```python\n", "\n```")]
35+
code = (
36+
code_execution_utils.CodeExecutionUtils.extract_code_and_truncate_content(
37+
content, delimiters
38+
)
39+
)
40+
assert code == "x = 1"
41+
assert len(content.parts) == 2
42+
assert content.parts[0].text == "Here is some code:\n"
43+
assert content.parts[1].executable_code.code == "x = 1"
44+
45+
46+
def test_extract_code_and_truncate_content_multiple_blocks():
47+
"""Tests that the first code block is extracted when multiple exist."""
48+
content = types.Content(
49+
role="model",
50+
parts=[
51+
types.Part(
52+
text=(
53+
"First:\n"
54+
"```python\n"
55+
"x = 1\n"
56+
"```\n"
57+
"Second:\n"
58+
"```python\n"
59+
"y = 2\n"
60+
"```"
61+
)
62+
)
63+
],
64+
)
65+
delimiters = [("```python\n", "\n```")]
66+
code = (
67+
code_execution_utils.CodeExecutionUtils.extract_code_and_truncate_content(
68+
content, delimiters
69+
)
70+
)
71+
assert code == "x = 1"
72+
assert len(content.parts) == 2
73+
assert content.parts[0].text == "First:\n"
74+
assert content.parts[1].executable_code.code == "x = 1"
75+
76+
77+
def test_extract_code_and_truncate_content_no_delimiter():
78+
"""Tests when no delimiters are found in the content."""
79+
content = types.Content(
80+
role="model",
81+
parts=[types.Part(text="Just plain text without code.")],
82+
)
83+
delimiters = [("```python\n", "\n```")]
84+
code = (
85+
code_execution_utils.CodeExecutionUtils.extract_code_and_truncate_content(
86+
content, delimiters
87+
)
88+
)
89+
assert code is None
90+
# Content should be unmodified.
91+
assert len(content.parts) == 1
92+
assert content.parts[0].text == "Just plain text without code."
93+
94+
95+
def test_extract_code_and_truncate_content_redos_vulnerability():
96+
"""Tests that a string that would cause ReDoS behaves reasonably."""
97+
# Construct a long string that contains repeating patterns without matching delimiters.
98+
# The old regex pattern would backtrack exponentially.
99+
ticks = "`" * 3
100+
long_invalid_payload = ticks + "python\n" + "x = 1\n" * 5000 + "not_matching"
101+
content = types.Content(
102+
role="model",
103+
parts=[types.Part(text=long_invalid_payload)],
104+
)
105+
delimiters = [(ticks + "python\n", "\n" + ticks)]
106+
107+
def handler(_signum, _frame):
108+
raise TimeoutError("Test timed out (possible ReDoS regression)")
109+
110+
signal.signal(signal.SIGALRM, handler)
111+
signal.alarm(2)
112+
try:
113+
# If ReDoS vulnerability exists, this call will hang or take a very long time.
114+
code = code_execution_utils.CodeExecutionUtils.extract_code_and_truncate_content(
115+
content, delimiters
116+
)
117+
finally:
118+
signal.alarm(0)
119+
assert code is None
120+
121+
122+
def test_extract_code_and_truncate_content_multiple_delimiter_pairs():
123+
"""Tests code extraction when multiple different delimiter pairs are provided."""
124+
ticks = "`" * 3
125+
# Case 1: First delimiter pair matches first
126+
content = types.Content(
127+
role="model",
128+
parts=[
129+
types.Part(
130+
text="Here is tool code:\n"
131+
+ ticks
132+
+ "tool_code\nx = 1\n"
133+
+ ticks
134+
+ "\nAnd python code:\n"
135+
+ ticks
136+
+ "python\ny = 2\n"
137+
+ ticks
138+
)
139+
],
140+
)
141+
delimiters = [
142+
(ticks + "tool_code\n", "\n" + ticks),
143+
(ticks + "python\n", "\n" + ticks),
144+
]
145+
code = (
146+
code_execution_utils.CodeExecutionUtils.extract_code_and_truncate_content(
147+
content, delimiters
148+
)
149+
)
150+
assert code == "x = 1"
151+
assert len(content.parts) == 2
152+
assert content.parts[0].text == "Here is tool code:\n"
153+
assert content.parts[1].executable_code.code == "x = 1"
154+
155+
# Case 2: Second delimiter pair matches first
156+
content = types.Content(
157+
role="model",
158+
parts=[
159+
types.Part(
160+
text="Here is python code:\n"
161+
+ ticks
162+
+ "python\ny = 2\n"
163+
+ ticks
164+
+ "\nAnd tool code:\n"
165+
+ ticks
166+
+ "tool_code\nx = 1\n"
167+
+ ticks
168+
)
169+
],
170+
)
171+
code = (
172+
code_execution_utils.CodeExecutionUtils.extract_code_and_truncate_content(
173+
content, delimiters
174+
)
175+
)
176+
assert code == "y = 2"
177+
assert len(content.parts) == 2
178+
assert content.parts[0].text == "Here is python code:\n"
179+
assert content.parts[1].executable_code.code == "y = 2"

0 commit comments

Comments
 (0)