diff --git a/google/api/field_behavior_pb2.py b/google/api/field_behavior_pb2.py new file mode 100644 index 000000000..5496a3016 --- /dev/null +++ b/google/api/field_behavior_pb2.py @@ -0,0 +1,52 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# type: ignore +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/api/field_behavior.proto +# isort: skip_file +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b"\n\x1fgoogle/api/field_behavior.proto\x12\ngoogle.api\x1a google/protobuf/descriptor.proto*\xb6\x01\n\rFieldBehavior\x12\x1e\n\x1a\x46IELD_BEHAVIOR_UNSPECIFIED\x10\x00\x12\x0c\n\x08OPTIONAL\x10\x01\x12\x0c\n\x08REQUIRED\x10\x02\x12\x0f\n\x0bOUTPUT_ONLY\x10\x03\x12\x0e\n\nINPUT_ONLY\x10\x04\x12\r\n\tIMMUTABLE\x10\x05\x12\x12\n\x0eUNORDERED_LIST\x10\x06\x12\x15\n\x11NON_EMPTY_DEFAULT\x10\x07\x12\x0e\n\nIDENTIFIER\x10\x08:Q\n\x0e\x66ield_behavior\x12\x1d.google.protobuf.FieldOptions\x18\x9c\x08 \x03(\x0e\x32\x19.google.api.FieldBehaviorBp\n\x0e\x63om.google.apiB\x12\x46ieldBehaviorProtoP\x01ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\xa2\x02\x04GAPIb\x06proto3" +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "google.api.field_behavior_pb2", _globals +) +if _descriptor._USE_C_DESCRIPTORS == False: + google_dot_protobuf_dot_descriptor__pb2.FieldOptions.RegisterExtension( + field_behavior + ) + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\n\016com.google.apiB\022FieldBehaviorProtoP\001ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\242\002\004GAPI" + _globals["_FIELDBEHAVIOR"]._serialized_start = 82 + _globals["_FIELDBEHAVIOR"]._serialized_end = 264 +# @@protoc_insertion_point(module_scope) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index 11508ce17..90cc13884 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -91,6 +91,9 @@ def __init__( self._sqladmin_api_endpoint = DEFAULT_SERVICE_ENDPOINT else: self._sqladmin_api_endpoint = sqladmin_api_endpoint + # asyncpg does not currently support using metadata exchange + # only use metadata exchange for sync drivers + self._use_metadata = False if driver == "asyncpg" else True self._user_agent = user_agent async def _get_metadata( @@ -204,7 +207,10 @@ async def _get_ephemeral( url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}:generateEphemeralCert" - data = {"public_key": pub_key} + data = { + "public_key": pub_key, + "use_metadata_exchange": self._use_metadata, + } if enable_iam_auth: # down-scope credentials with only IAM login scope (refreshes them too) diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index c76092a40..c00b11617 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -21,9 +21,10 @@ import logging import os import socket +import struct from threading import Thread from types import TracebackType -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import google.auth from google.auth.credentials import Credentials @@ -44,11 +45,17 @@ from google.cloud.sql.connector.resolver import DnsResolver from google.cloud.sql.connector.utils import format_database_user from google.cloud.sql.connector.utils import generate_keys +import google.cloud.sql.proto.cloud_sql_metadata_exchange_pb2 as connectorspb + +if TYPE_CHECKING: + import ssl logger = logging.getLogger(name=__name__) ASYNC_DRIVERS = ["asyncpg"] SERVER_PROXY_PORT = 3307 +# the maximum amount of time to wait before aborting a metadata exchange +IO_TIMEOUT = 30 _DEFAULT_SCHEME = "https://" _DEFAULT_UNIVERSE_DOMAIN = "googleapis.com" _SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}" @@ -391,6 +398,9 @@ async def connect_async( socket.create_connection((ip_address, SERVER_PROXY_PORT)), server_hostname=ip_address, ) + # Perform Metadata Exchange Protocol + metadata_partial = partial(self.metadata_exchange, sock) + sock = await self._loop.run_in_executor(None, metadata_partial) # If this connection was opened using a domain name, then store it # for later in case we need to forcibly close it on failover. if conn_info.conn_name.domain_name: @@ -409,6 +419,86 @@ async def connect_async( await monitored_cache.force_refresh() raise + def metadata_exchange(self, sock: ssl.SSLSocket) -> ssl.SSLSocket: + """ + Sends metadata about the connection prior to the database + protocol taking over. + The exchange consists of four steps: + 1. Prepare a CloudSQLConnectRequest including the socket protocol and + the user agent. + 2. Write the size of the message as a big endian uint32 (4 bytes) to + the server followed by the serialized message. The length does not + include the initial four bytes. + 3. Read a big endian uint32 (4 bytes) from the server. This is the + CloudSQLConnectResponse message length and does not include the + initial four bytes. + 4. Parse the response using the message length in step 3. If the + response is not OK, return the response's error. If there is no error, + the metadata exchange has succeeded and the connection is complete. + Args: + sock (ssl.SSLSocket): The mTLS/SSL socket to perform metadata + exchange on. + Returns: + sock (ssl.SSLSocket): mTLS/SSL socket connected to Cloud SQL Proxy + server. + """ + # form metadata exchange request + req = connectorspb.CloudSQLConnectRequest( + user_agent=f"{self._client._user_agent}", # type: ignore + protocol_type=connectorspb.CloudSQLConnectRequest.TCP, + ) + + # set I/O timeout + sock.settimeout(IO_TIMEOUT) + + # pack big-endian unsigned integer (4 bytes) + packed_len = struct.pack(">I", req.ByteSize()) + + # send metadata message length and request message + sock.sendall(packed_len + req.SerializeToString()) + + # form metadata exchange response + resp = connectorspb.CloudSQLConnectResponse() + + # read metadata message length (4 bytes) + message_len_buffer_size = struct.Struct(">I").size + message_len_buffer = b"" + while message_len_buffer_size > 0: + chunk = sock.recv(message_len_buffer_size) + if not chunk: + raise RuntimeError( + "Connection closed while getting metadata exchange length!" + ) + message_len_buffer += chunk + message_len_buffer_size -= len(chunk) + + (message_len,) = struct.unpack(">I", message_len_buffer) + + # read metadata exchange message + buffer = b"" + while message_len > 0: + chunk = sock.recv(message_len) + if not chunk: + raise RuntimeError( + "Connection closed while performing metadata exchange!" + ) + buffer += chunk + message_len -= len(chunk) + + # parse metadata exchange response from buffer + resp.ParseFromString(buffer) + + # reset socket back to blocking mode + sock.setblocking(True) + + # validate metadata exchange response + if resp.response_code != connectorspb.CloudSQLConnectResponse.OK: + raise ValueError( + f"Metadata Exchange request has failed with error: {resp.error}" + ) + + return sock + async def _remove_cached( self, instance_connection_string: str, enable_iam_auth: bool ) -> None: diff --git a/google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.py b/google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.py new file mode 100644 index 000000000..2c6f3e7b0 --- /dev/null +++ b/google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.py @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/cloud/sql/v1beta4/cloud_sql_metadata_exchange.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2 + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n:google/cloud/sql/v1beta4/cloud_sql_metadata_exchange.proto\x12\x18google.cloud.sql.v1beta4\x1a\x1fgoogle/api/field_behavior.proto"\xc8\x01\n\x16\x43loudSQLConnectRequest\x12\x17\n\nuser_agent\x18\x01 \x01(\tB\x03\xe0\x41\x01\x12T\n\rprotocol_type\x18\x02 \x01(\x0e\x32=.google.cloud.sql.v1beta4.CloudSQLConnectRequest.ProtocolType"?\n\x0cProtocolType\x12\x1d\n\x19PROTOCOL_TYPE_UNSPECIFIED\x10\x00\x12\x07\n\x03TCP\x10\x01\x12\x07\n\x03UDS\x10\x02"\xc6\x01\n\x17\x43loudSQLConnectResponse\x12U\n\rresponse_code\x18\x01 \x01(\x0e\x32>.google.cloud.sql.v1beta4.CloudSQLConnectResponse.ResponseCode\x12\x12\n\x05\x65rror\x18\x02 \x01(\tB\x03\xe0\x41\x01"@\n\x0cResponseCode\x12\x1d\n\x19RESPONSE_CODE_UNSPECIFIED\x10\x00\x12\x06\n\x02OK\x10\x01\x12\t\n\x05\x45RROR\x10\x02\x62\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "google.cloud.sql.v1beta4.cloud_sql_metadata_exchange_pb2", globals() +) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + _CLOUDSQLCONNECTREQUEST.fields_by_name["user_agent"]._options = None + _CLOUDSQLCONNECTREQUEST.fields_by_name["user_agent"]._serialized_options = ( + b"\340A\001" + ) + _CLOUDSQLCONNECTRESPONSE.fields_by_name["error"]._options = None + _CLOUDSQLCONNECTRESPONSE.fields_by_name["error"]._serialized_options = b"\340A\001" + _CLOUDSQLCONNECTREQUEST._serialized_start = 122 + _CLOUDSQLCONNECTREQUEST._serialized_end = 322 + _CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_start = 259 + _CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_end = 322 + _CLOUDSQLCONNECTRESPONSE._serialized_start = 325 + _CLOUDSQLCONNECTRESPONSE._serialized_end = 523 + _CLOUDSQLCONNECTRESPONSE_RESPONSECODE._serialized_start = 459 + _CLOUDSQLCONNECTRESPONSE_RESPONSECODE._serialized_end = 523 +# @@protoc_insertion_point(module_scope) diff --git a/google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.pyi b/google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.pyi new file mode 100644 index 000000000..d70285f86 --- /dev/null +++ b/google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.pyi @@ -0,0 +1,67 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ClassVar as _ClassVar +from typing import Optional as _Optional +from typing import Union as _Union + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper + +from google.api import field_behavior_pb2 as _field_behavior_pb2 + +DESCRIPTOR: _descriptor.FileDescriptor + +class CloudSQLConnectRequest(_message.Message): + __slots__ = ["protocol_type", "user_agent"] + + class ProtocolType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] # type: ignore + + PROTOCOL_TYPE_FIELD_NUMBER: _ClassVar[int] + PROTOCOL_TYPE_UNSPECIFIED: CloudSQLConnectRequest.ProtocolType + TCP: CloudSQLConnectRequest.ProtocolType + UDS: CloudSQLConnectRequest.ProtocolType + USER_AGENT_FIELD_NUMBER: _ClassVar[int] + protocol_type: CloudSQLConnectRequest.ProtocolType + user_agent: str + def __init__( + self, + user_agent: _Optional[str] = ..., + protocol_type: _Optional[ + _Union[CloudSQLConnectRequest.ProtocolType, str] + ] = ..., + ) -> None: ... + +class CloudSQLConnectResponse(_message.Message): + __slots__ = ["error", "response_code"] + + class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] # type: ignore + + ERROR: CloudSQLConnectResponse.ResponseCode + ERROR_FIELD_NUMBER: _ClassVar[int] + OK: CloudSQLConnectResponse.ResponseCode + RESPONSE_CODE_FIELD_NUMBER: _ClassVar[int] + RESPONSE_CODE_UNSPECIFIED: CloudSQLConnectResponse.ResponseCode + error: str + response_code: CloudSQLConnectResponse.ResponseCode + def __init__( + self, + response_code: _Optional[ + _Union[CloudSQLConnectResponse.ResponseCode, str] + ] = ..., + error: _Optional[str] = ..., + ) -> None: ... diff --git a/pyproject.toml b/pyproject.toml index 8a694369b..e0a3c5746 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ "dnspython>=2.0.0", "Requests", "google-auth>=2.28.0", + "protobuf", ] dynamic = ["version"] diff --git a/requirements.txt b/requirements.txt index 1dc6bc047..61d786fc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ cryptography==44.0.2 dnspython==2.7.0 Requests==2.32.3 google-auth==2.38.0 +protobuf==6.30.0 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 83d7a78f3..e1fa0326a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,7 @@ from unit.mocks import create_ssl_context # type: ignore from unit.mocks import FakeCredentials # type: ignore from unit.mocks import FakeCSQLInstance # type: ignore +from unit.mocks import metadata_exchange from google.cloud.sql.connector.client import CloudSQLClient from google.cloud.sql.connector.connection_name import ConnectionName @@ -84,10 +85,11 @@ def fake_credentials() -> FakeCredentials: return FakeCredentials() -async def start_proxy_server(instance: FakeCSQLInstance) -> None: - """Run local proxy server capable of performing mTLS""" +async def start_proxy_server( + instance: FakeCSQLInstance, port: int = 3307, use_metadata: bool = True +) -> None: + """Run local proxy server capable of performing mTLS and metadata exchange""" ip_address = "127.0.0.1" - port = 3307 # create socket with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: # create SSL/TLS context @@ -116,12 +118,15 @@ async def start_proxy_server(instance: FakeCSQLInstance) -> None: with context.wrap_socket(sock, server_side=True) as ssock: while True: conn, _ = ssock.accept() + if use_metadata: + metadata_exchange(conn) + conn.sendall(instance.name.encode("utf-8")) conn.close() @pytest.fixture(scope="session") -def proxy_server(fake_instance: FakeCSQLInstance) -> None: - """Run local proxy server capable of performing mTLS""" +def proxy_server_with_metadata(fake_instance: FakeCSQLInstance) -> None: + """Run local proxy server capable of performing metadata exchange""" thread = Thread( target=asyncio.run, args=( @@ -135,6 +140,18 @@ def proxy_server(fake_instance: FakeCSQLInstance) -> None: thread.join(1.0) # add a delay to allow the proxy server to start +@pytest.fixture(scope="session") +def proxy_server(fake_instance: FakeCSQLInstance) -> None: + """Run local proxy server capable of performing mTLS""" + thread = Thread( + target=asyncio.run, + args=(start_proxy_server(fake_instance, 3308, False),), + daemon=True, + ) + thread.start() + thread.join(1.0) # add a delay to allow the proxy server to start + + @pytest.fixture async def context(fake_instance: FakeCSQLInstance) -> ssl.SSLContext: return await create_ssl_context(fake_instance) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 66bf64a32..e7dd7dcea 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -21,6 +21,7 @@ import datetime import json import ssl +import struct from typing import Any, Callable, Literal, Optional from aiofiles.tempfile import TemporaryDirectory @@ -38,6 +39,7 @@ from google.cloud.sql.connector.connector import _DEFAULT_UNIVERSE_DOMAIN from google.cloud.sql.connector.utils import generate_keys from google.cloud.sql.connector.utils import write_to_file +import google.cloud.sql.proto.cloud_sql_metadata_exchange_pb2 as connectorspb class FakeCredentials: @@ -298,3 +300,65 @@ async def generate_ephemeral(self, request: Any) -> web.Response: } } return web.Response(content_type="application/json", body=json.dumps(response)) + + +def metadata_exchange(sock: ssl.SSLSocket) -> None: + """ + Mimics server side metadata exchange behavior in four steps: + + 1. Read a big endian uint32 (4 bytes) from the client. This is the number of + bytes the message consumes. The length does not include the initial four + bytes. + + 2. Read the message from the client using the message length and serialize + it into a MetadataExchangeResponse message. + + The real server implementation will then validate the client has connection + permissions using the provided OAuth2 token based on the auth type. Here in + the test implementation, the server does nothing. + + 3. Prepare a response and write the size of the response as a big endian + uint32 (4 bytes) + + 4. Parse the response to bytes and write those to the client as well. + + Subsequent interactions with the test server use the database protocol. + """ + # read metadata message length (4 bytes) + message_len_buffer_size = struct.Struct("I").size + message_len_buffer = b"" + while message_len_buffer_size > 0: + chunk = sock.recv(message_len_buffer_size) + if not chunk: + raise RuntimeError( + "Connection closed while getting metadata exchange length!" + ) + message_len_buffer += chunk + message_len_buffer_size -= len(chunk) + + (message_len,) = struct.unpack(">I", message_len_buffer) + + # read metadata exchange message + buffer = b"" + while message_len > 0: + chunk = sock.recv(message_len) + if not chunk: + raise RuntimeError("Connection closed while performing metadata exchange!") + buffer += chunk + message_len -= len(chunk) + + # form metadata exchange request to be received from client + message = connectorspb.CloudSQLConnectRequest() + # parse metadata exchange request from buffer + message.ParseFromString(buffer) + + # form metadata exchange response to send to client + resp = connectorspb.CloudSQLConnectResponse( + response_code=connectorspb.CloudSQLConnectResponse.OK + ) + + # pack big-endian unsigned integer (4 bytes) + resp_len = struct.pack(">I", resp.ByteSize()) + + # send metadata response length and response message + sock.sendall(resp_len + resp.SerializeToString()) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 498c947cc..81e55fb9c 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -257,6 +257,32 @@ def test_Connector_connect_bad_ip_type( ) +@pytest.mark.usefixtures("proxy_server_with_metadata") +async def test_connect( + fake_credentials: Credentials, fake_client: CloudSQLClient +) -> None: + """ + Test that connector.connect returns connection object. + """ + client = fake_client + async with Connector( + credentials=fake_credentials, loop=asyncio.get_running_loop() + ) as connector: + connector._client = client + # patch db connection creation + with patch("google.cloud.sql.connector.pg8000.connect") as mock_connect: + mock_connect.return_value = True + connection = await connector.connect_async( + "test-project:test-region:test-instance", + "pg8000", + user="my-user", + password="my-pass", + db="my-db", + ) + # check connection is returned + assert connection is True + + @pytest.mark.asyncio async def test_Connector_connect_async( fake_credentials: Credentials, fake_client: CloudSQLClient diff --git a/tests/unit/test_monitored_cache.py b/tests/unit/test_monitored_cache.py index 1c1f1df86..3589ac9ff 100644 --- a/tests/unit/test_monitored_cache.py +++ b/tests/unit/test_monitored_cache.py @@ -180,7 +180,7 @@ async def test_MonitoredCache_check_domain_name( # configure a local socket ip_addr = "127.0.0.1" sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), + socket.create_connection((ip_addr, 3308)), server_hostname=ip_addr, ) # verify socket is open @@ -218,7 +218,7 @@ async def test_MonitoredCache_purge_closed_sockets( # configure a local socket ip_addr = "127.0.0.1" sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), + socket.create_connection((ip_addr, 3308)), server_hostname=ip_addr, ) diff --git a/tests/unit/test_pg8000.py b/tests/unit/test_pg8000.py index 2c003b8a9..2ace9d056 100644 --- a/tests/unit/test_pg8000.py +++ b/tests/unit/test_pg8000.py @@ -29,7 +29,7 @@ async def test_pg8000(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pg8000 gets to proper connection call.""" ip_addr = "127.0.0.1" sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), + socket.create_connection((ip_addr, 3308)), server_hostname=ip_addr, ) with patch("pg8000.dbapi.connect") as mock_connect: diff --git a/tests/unit/test_pymysql.py b/tests/unit/test_pymysql.py index 13cd8e98a..aacbf17c1 100644 --- a/tests/unit/test_pymysql.py +++ b/tests/unit/test_pymysql.py @@ -38,7 +38,7 @@ async def test_pymysql(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pymysql gets to proper connection call.""" ip_addr = "127.0.0.1" sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), + socket.create_connection((ip_addr, 3308)), server_hostname=ip_addr, ) kwargs["timeout"] = 30 diff --git a/tests/unit/test_pytds.py b/tests/unit/test_pytds.py index faa20ad8c..207a3e89d 100644 --- a/tests/unit/test_pytds.py +++ b/tests/unit/test_pytds.py @@ -41,7 +41,7 @@ async def test_pytds(context: ssl.SSLContext, kwargs: Any) -> None: """Test to verify that pytds gets to proper connection call.""" ip_addr = "127.0.0.1" sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), + socket.create_connection((ip_addr, 3308)), server_hostname=ip_addr, ) @@ -61,7 +61,7 @@ async def test_pytds_platform_error(context: ssl.SSLContext, kwargs: Any) -> Non setattr(platform, "system", stub_platform_linux) assert platform.system() == "Linux" sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), + socket.create_connection((ip_addr, 3308)), server_hostname=ip_addr, ) # add active_directory_auth to kwargs @@ -84,7 +84,7 @@ async def test_pytds_windows_active_directory_auth( setattr(platform, "system", stub_platform_windows) assert platform.system() == "Windows" sock = context.wrap_socket( - socket.create_connection((ip_addr, 3307)), + socket.create_connection((ip_addr, 3308)), server_hostname=ip_addr, ) # add active_directory_auth and server_name to kwargs