Skip to content

Commit

Permalink
feat: add support for DNS names with Connector (#1204)
Browse files Browse the repository at this point in the history
The Connector may be configured to use a DNS name to look up the instance
name instead of configuring the connector with the instance connection name directly.

Add a DNS TXT record for the Cloud SQL instance to a private DNS server
or a private Google Cloud DNS Zone used by your application. For example:

Record type: TXT
Name: prod-db.mycompany.example.com – This is the domain name used by the application
Value: my-project:my-region:my-instance – This is the instance connection name

Configure the Connector to use a DNS name via setting resolver=DnsResolver
  • Loading branch information
jackwotherspoon authored Dec 4, 2024
1 parent 11f9fe9 commit 1a8f274
Show file tree
Hide file tree
Showing 14 changed files with 309 additions and 26 deletions.
63 changes: 63 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,69 @@ conn = connector.connect(
)
```

### Using DNS domain names to identify instances

The connector can be configured to use DNS to look up an instance. This would
allow you to configure your application to connect to a database instance, and
centrally configure which instance in your DNS zone.

#### Configure your DNS Records

Add a DNS TXT record for the Cloud SQL instance to a **private** DNS server
or a private Google Cloud DNS Zone used by your application.

> [!NOTE]
>
> You are strongly discouraged from adding DNS records for your
> Cloud SQL instances to a public DNS server. This would allow anyone on the
> internet to discover the Cloud SQL instance name.
For example: suppose you wanted to use the domain name
`prod-db.mycompany.example.com` to connect to your database instance
`my-project:region:my-instance`. You would create the following DNS record:

* Record type: `TXT`
* Name: `prod-db.mycompany.example.com` – This is the domain name used by the application
* Value: `my-project:my-region:my-instance` – This is the Cloud SQL instance connection name

#### Configure the connector

Configure the connector to resolve DNS names by initializing it with
`resolver=DnsResolver` and replacing the instance connection name with the DNS
name in `connector.connect`:

```python
from google.cloud.sql.connector import Connector, DnsResolver
import pymysql
import sqlalchemy

# helper function to return SQLAlchemy connection pool
def init_connection_pool(connector: Connector) -> sqlalchemy.engine.Engine:
# function used to generate database connection
def getconn() -> pymysql.connections.Connection:
conn = connector.connect(
"prod-db.mycompany.example.com", # using DNS name
"pymysql",
user="my-user",
password="my-password",
db="my-db-name"
)
return conn

# create connection pool
pool = sqlalchemy.create_engine(
"mysql+pymysql://",
creator=getconn,
)
return pool

# initialize Cloud SQL Python Connector with `resolver=DnsResolver`
with Connector(resolver=DnsResolver) as connector:
# initialize connection pool
pool = init_connection_pool(connector)
# ... use SQLAlchemy engine normally
```

### Using the Python Connector with Python Web Frameworks

The Python Connector can be used alongside popular Python web frameworks such
Expand Down
4 changes: 4 additions & 0 deletions google/cloud/sql/connector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
from google.cloud.sql.connector.connector import create_async_connector
from google.cloud.sql.connector.enums import IPTypes
from google.cloud.sql.connector.enums import RefreshStrategy
from google.cloud.sql.connector.resolver import DefaultResolver
from google.cloud.sql.connector.resolver import DnsResolver
from google.cloud.sql.connector.version import __version__

__all__ = [
"__version__",
"create_async_connector",
"Connector",
"DefaultResolver",
"DnsResolver",
"IPTypes",
"RefreshStrategy",
]
16 changes: 14 additions & 2 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import google.cloud.sql.connector.pg8000 as pg8000
import google.cloud.sql.connector.pymysql as pymysql
import google.cloud.sql.connector.pytds as pytds
from google.cloud.sql.connector.resolver import DefaultResolver
from google.cloud.sql.connector.resolver import DnsResolver
from google.cloud.sql.connector.utils import format_database_user
from google.cloud.sql.connector.utils import generate_keys

Expand All @@ -63,6 +65,7 @@ def __init__(
user_agent: Optional[str] = None,
universe_domain: Optional[str] = None,
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
resolver: Type[DefaultResolver] | Type[DnsResolver] = DefaultResolver,
) -> None:
"""Initializes a Connector instance.
Expand Down Expand Up @@ -104,6 +107,13 @@ def __init__(
of the following: RefreshStrategy.LAZY ("LAZY") or
RefreshStrategy.BACKGROUND ("BACKGROUND").
Default: RefreshStrategy.BACKGROUND
resolver (DefaultResolver | DnsResolver): The class name of the
resolver to use for resolving the Cloud SQL instance connection
name. To resolve a DNS record to an instance connection name, use
DnsResolver.
Default: DefaultResolver
"""
# if refresh_strategy is str, convert to RefreshStrategy enum
if isinstance(refresh_strategy, str):
Expand Down Expand Up @@ -157,6 +167,7 @@ def __init__(
self._enable_iam_auth = enable_iam_auth
self._quota_project = quota_project
self._user_agent = user_agent
self._resolver = resolver()
# if ip_type is str, convert to IPTypes enum
if isinstance(ip_type, str):
ip_type = IPTypes._from_str(ip_type)
Expand Down Expand Up @@ -269,13 +280,14 @@ async def connect_async(
if (instance_connection_string, enable_iam_auth) in self._cache:
cache = self._cache[(instance_connection_string, enable_iam_auth)]
else:
conn_name = await self._resolver.resolve(instance_connection_string)
if self._refresh_strategy == RefreshStrategy.LAZY:
logger.debug(
f"['{instance_connection_string}']: Refresh strategy is set"
" to lazy refresh"
)
cache = LazyRefreshCache(
instance_connection_string,
conn_name,
self._client,
self._keys,
enable_iam_auth,
Expand All @@ -286,7 +298,7 @@ async def connect_async(
" to backgound refresh"
)
cache = RefreshAheadCache(
instance_connection_string,
conn_name,
self._client,
self._keys,
enable_iam_auth,
Expand Down
7 changes: 7 additions & 0 deletions google/cloud/sql/connector/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,10 @@ class IncompatibleDriverError(Exception):
Exception to be raised when the database driver given is for the wrong
database engine. (i.e. asyncpg for a MySQL database)
"""


class DnsResolutionError(Exception):
"""
Exception to be raised when an instance connection name can not be resolved
from a DNS record.
"""
10 changes: 4 additions & 6 deletions google/cloud/sql/connector/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import RefreshNotValidError
from google.cloud.sql.connector.rate_limiter import AsyncRateLimiter
from google.cloud.sql.connector.refresh_utils import _is_valid
Expand All @@ -45,25 +45,23 @@ class RefreshAheadCache:

def __init__(
self,
instance_connection_string: str,
conn_name: ConnectionName,
client: CloudSQLClient,
keys: asyncio.Future,
enable_iam_auth: bool = False,
) -> None:
"""Initializes a RefreshAheadCache instance.
Args:
instance_connection_string (str): The Cloud SQL Instance's
connection string (also known as an instance connection name).
conn_name (ConnectionName): The Cloud SQL instance's
connection name.
client (CloudSQLClient): The Cloud SQL Client instance.
keys (asyncio.Future): A future to the client's public-private key
pair.
enable_iam_auth (bool): Enables automatic IAM database authentication
(Postgres and MySQL) as the default authentication method for all
connections.
"""
# validate and parse instance connection name
conn_name = _parse_instance_connection_name(instance_connection_string)
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
Expand Down
10 changes: 4 additions & 6 deletions google/cloud/sql/connector/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_info import ConnectionInfo
from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.refresh_utils import _refresh_buffer

logger = logging.getLogger(name=__name__)
Expand All @@ -38,25 +38,23 @@ class LazyRefreshCache:

def __init__(
self,
instance_connection_string: str,
conn_name: ConnectionName,
client: CloudSQLClient,
keys: asyncio.Future,
enable_iam_auth: bool = False,
) -> None:
"""Initializes a LazyRefreshCache instance.
Args:
instance_connection_string (str): The Cloud SQL Instance's
connection string (also known as an instance connection name).
conn_name (ConnectionName): The Cloud SQL instance's
connection name.
client (CloudSQLClient): The Cloud SQL Client instance.
keys (asyncio.Future): A future to the client's public-private key
pair.
enable_iam_auth (bool): Enables automatic IAM database authentication
(Postgres and MySQL) as the default authentication method for all
connections.
"""
# validate and parse instance connection name
conn_name = _parse_instance_connection_name(instance_connection_string)
self._project, self._region, self._instance = (
conn_name.project,
conn_name.region,
Expand Down
67 changes: 67 additions & 0 deletions google/cloud/sql/connector/resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import dns.asyncresolver

from google.cloud.sql.connector.connection_name import _parse_instance_connection_name
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import DnsResolutionError


class DefaultResolver:
"""DefaultResolver simply validates and parses instance connection name."""

async def resolve(self, connection_name: str) -> ConnectionName:
return _parse_instance_connection_name(connection_name)


class DnsResolver(dns.asyncresolver.Resolver):
"""
DnsResolver resolves domain names into instance connection names using
TXT records in DNS.
"""

async def resolve(self, dns: str) -> ConnectionName: # type: ignore
try:
conn_name = _parse_instance_connection_name(dns)
except ValueError:
# The connection name was not project:region:instance format.
# Attempt to query a TXT record to get connection name.
conn_name = await self.query_dns(dns)
return conn_name

async def query_dns(self, dns: str) -> ConnectionName:
try:
# Attempt to query the TXT records.
records = await super().resolve(dns, "TXT", raise_on_no_answer=True)
# Sort the TXT record values alphabetically, strip quotes as record
# values can be returned as raw strings
rdata = [record.to_text().strip('"') for record in records]
rdata.sort()
# Attempt to parse records, returning the first valid record.
for record in rdata:
try:
conn_name = _parse_instance_connection_name(record)
return conn_name
except Exception:
continue
# If all records failed to parse, throw error
raise DnsResolutionError(
f"Unable to parse TXT record for `{dns}` -> `{rdata[0]}`"
)
# Don't override above DnsResolutionError
except DnsResolutionError:
raise
except Exception as e:
raise DnsResolutionError(f"Unable to resolve TXT record for `{dns}`") from e
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
aiofiles==24.1.0
aiohttp==3.11.9
cryptography==44.0.0
dnspython==2.7.0
Requests==2.32.3
google-auth==2.36.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"aiofiles",
"aiohttp",
"cryptography>=42.0.0",
"dnspython>=2.0.0",
"Requests",
"google-auth>=2.28.0",
]
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from unit.mocks import FakeCSQLInstance # type: ignore

from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.instance import RefreshAheadCache
from google.cloud.sql.connector.utils import generate_keys

Expand Down Expand Up @@ -144,7 +145,7 @@ async def fake_client(
async def cache(fake_client: CloudSQLClient) -> AsyncGenerator[RefreshAheadCache, None]:
keys = asyncio.create_task(generate_keys())
cache = RefreshAheadCache(
"test-project:test-region:test-instance",
ConnectionName("test-project", "test-region", "test-instance"),
client=fake_client,
keys=keys,
)
Expand Down
17 changes: 9 additions & 8 deletions tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from google.cloud.sql.connector import create_async_connector
from google.cloud.sql.connector import IPTypes
from google.cloud.sql.connector.client import CloudSQLClient
from google.cloud.sql.connector.connection_name import ConnectionName
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
from google.cloud.sql.connector.exceptions import IncompatibleDriverError
from google.cloud.sql.connector.instance import RefreshAheadCache
Expand Down Expand Up @@ -322,18 +323,18 @@ async def test_Connector_remove_cached_bad_instance(
async with Connector(
credentials=fake_credentials, loop=asyncio.get_running_loop()
) as connector:
conn_name = "bad-project:bad-region:bad-inst"
conn_name = ConnectionName("bad-project", "bad-region", "bad-inst")
# populate cache
cache = RefreshAheadCache(conn_name, fake_client, connector._keys)
connector._cache[(conn_name, False)] = cache
connector._cache[(str(conn_name), False)] = cache
# aiohttp client should throw a 404 ClientResponseError
with pytest.raises(ClientResponseError):
await connector.connect_async(
conn_name,
str(conn_name),
"pg8000",
)
# check that cache has been removed from dict
assert (conn_name, False) not in connector._cache
assert (str(conn_name), False) not in connector._cache


async def test_Connector_remove_cached_no_ip_type(
Expand All @@ -348,21 +349,21 @@ async def test_Connector_remove_cached_no_ip_type(
async with Connector(
credentials=fake_credentials, loop=asyncio.get_running_loop()
) as connector:
conn_name = "test-project:test-region:test-instance"
conn_name = ConnectionName("test-project", "test-region", "test-instance")
# populate cache
cache = RefreshAheadCache(conn_name, fake_client, connector._keys)
connector._cache[(conn_name, False)] = cache
connector._cache[(str(conn_name), False)] = cache
# test instance does not have Private IP, thus should invalidate cache
with pytest.raises(CloudSQLIPTypeError):
await connector.connect_async(
conn_name,
str(conn_name),
"pg8000",
user="my-user",
password="my-pass",
ip_type="private",
)
# check that cache has been removed from dict
assert (conn_name, False) not in connector._cache
assert (str(conn_name), False) not in connector._cache


def test_default_universe_domain(fake_credentials: Credentials) -> None:
Expand Down
Loading

0 comments on commit 1a8f274

Please sign in to comment.