Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion aetcd/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,18 @@ class Client:
:param str password:
Password to be used for authentication.

:param str ca_cert:
Trusted CA certificates path

:param str cert_key:
Client certificate private key path

:param str cert_cert:
Client certificate path

:param bool ssl:
Force use ssl

:param int timeout:
Connection timeout in seconds.

Expand All @@ -122,13 +134,21 @@ def __init__(
port: int = 2379,
username: typing.Optional[str] = None,
password: typing.Optional[str] = None,
ca_cert: typing.Optional[str] = None,
cert_key: typing.Optional[str] = None,
cert_cert: typing.Optional[str] = None,
ssl: typing.Optional[bool] = None,
timeout: typing.Optional[int] = None,
options: typing.Optional[typing.Dict[str, typing.Any]] = None,
):
self._host = host
self._port = port
self._username = username
self._password = password
self._ca_cert = ca_cert
self._cert_key = cert_key
self._cert_cert = cert_cert
self._ssl = ssl
self._timeout = timeout
self._options = options or {}

Expand All @@ -138,6 +158,15 @@ def __init__(
'if using authentication credentials both username and password '
'must be provided')

client_cert_params = (self._cert_cert, self._cert_key)
if any(client_cert_params) and None in client_cert_params:
raise Exception(
'if use client certificates both cert_key and cert_key must be provided')

if self._ssl is None:
if any((ca_cert, cert_key, cert_cert)):
self._ssl = True

self._init_channel_attrs()

self.transactions = Transactions()
Expand All @@ -160,7 +189,30 @@ async def connect(self) -> None:
return

target = f'{self._host}:{self._port}'
self.channel = rpc.insecure_channel(target, options=self._options.items())
if self._ssl:
if self._ca_cert:
with open(self._ca_cert, 'rb') as ca_fd:
ca_cert = ca_fd.read()
else:
ca_cert = None
if self._cert_key:
with open(self._cert_key, 'rb') as cert_key_fd:
cert_key = cert_key_fd.read()
else:
cert_key = None
if self._cert_cert:
with open(self._cert_cert, 'rb') as cert_cert_fd:
cert_cert = cert_cert_fd.read()
else:
cert_cert = None
credentials = rpc.ssl_channel_credentials(
root_certificates=ca_cert,
private_key=cert_key,
certificate_chain=cert_cert,
)
self.channel = rpc.secure_channel(target, credentials, options=self._options.items())
else:
self.channel = rpc.insecure_channel(target, options=self._options.items())

cred_params = [c is not None for c in (self._username, self._password)]
if all(cred_params):
Expand Down
1 change: 1 addition & 0 deletions aetcd/rpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from grpc import RpcError # noqa: F401
from grpc import Status # noqa: F401
from grpc import StatusCode # noqa: F401
from grpc import ssl_channel_credentials # noqa: F401
from grpc.aio import AbortError # noqa: F401
from grpc.aio import AioRpcError # noqa: F401
from grpc.aio import BaseError # noqa: F401
Expand Down