Skip to content

Commit 723b6f1

Browse files
fix: improve timeout handling (#760)
1 parent bb1c72a commit 723b6f1

File tree

4 files changed

+8
-41
lines changed

4 files changed

+8
-41
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -227,18 +227,16 @@ async def connect_async(
227227
raise KeyError(f"Driver '{driver}' is not supported.")
228228

229229
ip_type = kwargs.pop("ip_type", self._ip_type)
230-
timeout = kwargs.pop("timeout", self._timeout)
231-
if "connect_timeout" in kwargs:
232-
timeout = kwargs.pop("connect_timeout")
230+
kwargs["timeout"] = kwargs.get("timeout", self._timeout)
233231

234232
# Host and ssl options come from the certificates and metadata, so we don't
235233
# want the user to specify them.
236234
kwargs.pop("host", None)
237235
kwargs.pop("ssl", None)
238236
kwargs.pop("port", None)
239237

240-
# helper function to wrap in timeout
241-
async def get_connection() -> Any:
238+
# attempt to make connection to Cloud SQL instance
239+
try:
242240
instance_data, ip_address = await instance.connect_info(ip_type)
243241

244242
# format `user` param for automatic IAM database authn
@@ -261,13 +259,8 @@ async def get_connection() -> Any:
261259
)
262260
return await self._loop.run_in_executor(None, connect_partial)
263261

264-
# attempt to make connection to Cloud SQL instance for given timeout
265-
try:
266-
return await asyncio.wait_for(get_connection(), timeout)
267-
except asyncio.TimeoutError:
268-
raise TimeoutError(f"Connection timed out after {timeout}s")
269262
except Exception:
270-
# with any other exception, we attempt a force refresh, then throw the error
263+
# with any exception, we attempt a force refresh, then throw the error
271264
instance.force_refresh()
272265
raise
273266

google/cloud/sql/connector/pymysql.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def connect(
5454
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
5555
server_hostname=ip_address,
5656
)
57-
57+
# pop timeout as timeout arg is called 'connect_timeout' for pymysql
58+
timeout = kwargs.pop("timeout")
59+
kwargs["connect_timeout"] = kwargs.get("connect_timeout", timeout)
5860
# Create pymysql connection object and hand in pre-made connection
5961
conn = pymysql.Connection(host=ip_address, defer_connect=True, **kwargs)
6062
conn.connect(sock)

tests/unit/test_connector.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
limitations under the License.
1515
"""
1616
import asyncio
17-
from typing import Any
1817

1918
from mock import patch
2019
from mocks import MockInstance
@@ -24,34 +23,6 @@
2423
from google.cloud.sql.connector.exceptions import ConnectorLoopError
2524

2625

27-
async def timeout_stub(*args: Any, **kwargs: Any) -> None:
28-
"""Timeout stub for Instance.connect()"""
29-
# sleep 10 seconds
30-
await asyncio.sleep(10)
31-
32-
33-
def test_connect_timeout() -> None:
34-
"""Test that connection times out after custom timeout."""
35-
connect_string = "test-project:test-region:test-instance"
36-
37-
instance = MockInstance()
38-
mock_instances = {}
39-
mock_instances[connect_string] = instance
40-
# stub instance to raise timeout
41-
setattr(instance, "connect_info", timeout_stub)
42-
# init connector
43-
connector = Connector()
44-
# attempt to connect with timeout set to 5s
45-
with patch.dict(connector._instances, mock_instances):
46-
pytest.raises(
47-
TimeoutError,
48-
connector.connect,
49-
connect_string,
50-
"pymysql",
51-
timeout=5,
52-
)
53-
54-
5526
def test_connect_enable_iam_auth_error() -> None:
5627
"""Test that calling connect() with different enable_iam_auth
5728
argument values throws error."""

tests/unit/test_pymysql.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ async def test_pymysql(kwargs: Any) -> None:
4545
"wrap_socket",
4646
partial(context.wrap_socket, do_handshake_on_connect=False),
4747
)
48+
kwargs["timeout"] = 30
4849
with patch("pymysql.Connection") as mock_connect:
4950
mock_connect.return_value = MockConnection
5051
pymysql_connect(ip_addr, context, **kwargs)

0 commit comments

Comments
 (0)