diff --git a/aetcd/client.py b/aetcd/client.py index db6279f..a2d2f52 100644 --- a/aetcd/client.py +++ b/aetcd/client.py @@ -106,6 +106,18 @@ class Client: :param str password: Password to be used for authentication. + :param str ca_cert: + Trusted CA certificates path + + :param str cert_key: + Client certificate private key path + + :param str cert_cert: + Client certificate path + + :param bool ssl: + Force use ssl + :param int timeout: Connection timeout in seconds. @@ -122,6 +134,10 @@ def __init__( port: int = 2379, username: typing.Optional[str] = None, password: typing.Optional[str] = None, + ca_cert: typing.Optional[str] = None, + cert_key: typing.Optional[str] = None, + cert_cert: typing.Optional[str] = None, + ssl: typing.Optional[bool] = None, timeout: typing.Optional[int] = None, options: typing.Optional[typing.Dict[str, typing.Any]] = None, ): @@ -129,6 +145,10 @@ def __init__( self._port = port self._username = username self._password = password + self._ca_cert = ca_cert + self._cert_key = cert_key + self._cert_cert = cert_cert + self._ssl = ssl self._timeout = timeout self._options = options or {} @@ -138,6 +158,15 @@ def __init__( 'if using authentication credentials both username and password ' 'must be provided') + client_cert_params = (self._cert_cert, self._cert_key) + if any(client_cert_params) and None in client_cert_params: + raise Exception( + 'if use client certificates both cert_key and cert_key must be provided') + + if self._ssl is None: + if any((ca_cert, cert_key, cert_cert)): + self._ssl = True + self._init_channel_attrs() self.transactions = Transactions() @@ -160,7 +189,30 @@ async def connect(self) -> None: return target = f'{self._host}:{self._port}' - self.channel = rpc.insecure_channel(target, options=self._options.items()) + if self._ssl: + if self._ca_cert: + with open(self._ca_cert, 'rb') as ca_fd: + ca_cert = ca_fd.read() + else: + ca_cert = None + if self._cert_key: + with open(self._cert_key, 'rb') as cert_key_fd: + cert_key = cert_key_fd.read() + else: + cert_key = None + if self._cert_cert: + with open(self._cert_cert, 'rb') as cert_cert_fd: + cert_cert = cert_cert_fd.read() + else: + cert_cert = None + credentials = rpc.ssl_channel_credentials( + root_certificates=ca_cert, + private_key=cert_key, + certificate_chain=cert_cert, + ) + self.channel = rpc.secure_channel(target, credentials, options=self._options.items()) + else: + self.channel = rpc.insecure_channel(target, options=self._options.items()) cred_params = [c is not None for c in (self._username, self._password)] if all(cred_params): diff --git a/aetcd/rpc/__init__.py b/aetcd/rpc/__init__.py index 08e73c4..04ae2a1 100644 --- a/aetcd/rpc/__init__.py +++ b/aetcd/rpc/__init__.py @@ -1,6 +1,7 @@ from grpc import RpcError # noqa: F401 from grpc import Status # noqa: F401 from grpc import StatusCode # noqa: F401 +from grpc import ssl_channel_credentials # noqa: F401 from grpc.aio import AbortError # noqa: F401 from grpc.aio import AioRpcError # noqa: F401 from grpc.aio import BaseError # noqa: F401