Skip to content

Commit 683ba33

Browse files
committed
move kms_ssl_contexts
1 parent 5807ba1 commit 683ba33

File tree

5 files changed

+26
-30
lines changed

5 files changed

+26
-30
lines changed

pymongo/asynchronous/encryption.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@
8787
from pymongo.results import BulkWriteResult, DeleteResult
8888
from pymongo.ssl_support import BLOCKING_IO_ERRORS, get_ssl_context
8989
from pymongo.typings import _DocumentType, _DocumentTypeArg
90-
from pymongo.uri_parser_shared import parse_host
90+
from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host
9191
from pymongo.write_concern import WriteConcern
9292

9393
if TYPE_CHECKING:
@@ -157,6 +157,7 @@ def __init__(
157157
self.mongocryptd_client = mongocryptd_client
158158
self.opts = opts
159159
self._spawned = False
160+
self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
160161

161162
async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
162163
"""Complete a KMS request.
@@ -165,11 +166,10 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
165166
166167
:return: None
167168
"""
168-
self.opts._parse_kms_tls_options(_IS_SYNC)
169169
endpoint = kms_context.endpoint
170170
message = kms_context.message
171171
provider = kms_context.kms_provider
172-
ctx = self.opts._kms_ssl_contexts.get(provider)
172+
ctx = self._kms_ssl_contexts.get(provider)
173173
if ctx is None:
174174
# Enable strict certificate verification, OCSP, match hostname, and
175175
# SNI using the system default CA certificates.
@@ -677,7 +677,7 @@ def __init__(
677677
kms_tls_options=kms_tls_options,
678678
key_expiration_ms=key_expiration_ms,
679679
)
680-
opts._parse_kms_tls_options(_IS_SYNC)
680+
self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
681681
self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO(
682682
None, key_vault_coll, None, opts
683683
)

pymongo/encryption_options.py

-6
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,8 @@
3232
from bson import int64
3333
from pymongo.common import validate_is_mapping
3434
from pymongo.errors import ConfigurationError
35-
from pymongo.uri_parser_shared import _parse_kms_tls_options
3635

3736
if TYPE_CHECKING:
38-
from pymongo.pyopenssl_context import SSLContext
3937
from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg
4038

4139

@@ -238,13 +236,9 @@ def __init__(
238236
self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60")
239237
# Maps KMS provider name to a SSLContext.
240238
self._kms_tls_options = kms_tls_options
241-
self._kms_ssl_contexts: dict[str, SSLContext] = {}
242239
self._bypass_query_analysis = bypass_query_analysis
243240
self._key_expiration_ms = key_expiration_ms
244241

245-
def _parse_kms_tls_options(self, is_sync: bool) -> None:
246-
self._kms_ssl_contexts = _parse_kms_tls_options(self._kms_tls_options, is_sync)
247-
248242

249243
class RangeOpts:
250244
"""Options to configure encrypted queries using the range algorithm."""

pymongo/synchronous/encryption.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@
8686
from pymongo.synchronous.database import Database
8787
from pymongo.synchronous.mongo_client import MongoClient
8888
from pymongo.typings import _DocumentType, _DocumentTypeArg
89-
from pymongo.uri_parser_shared import parse_host
89+
from pymongo.uri_parser_shared import _parse_kms_tls_options, parse_host
9090
from pymongo.write_concern import WriteConcern
9191

9292
if TYPE_CHECKING:
@@ -156,6 +156,7 @@ def __init__(
156156
self.mongocryptd_client = mongocryptd_client
157157
self.opts = opts
158158
self._spawned = False
159+
self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
159160

160161
def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
161162
"""Complete a KMS request.
@@ -164,11 +165,10 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
164165
165166
:return: None
166167
"""
167-
self.opts._parse_kms_tls_options(_IS_SYNC)
168168
endpoint = kms_context.endpoint
169169
message = kms_context.message
170170
provider = kms_context.kms_provider
171-
ctx = self.opts._kms_ssl_contexts.get(provider)
171+
ctx = self._kms_ssl_contexts.get(provider)
172172
if ctx is None:
173173
# Enable strict certificate verification, OCSP, match hostname, and
174174
# SNI using the system default CA certificates.
@@ -670,7 +670,7 @@ def __init__(
670670
kms_tls_options=kms_tls_options,
671671
key_expiration_ms=key_expiration_ms,
672672
)
673-
opts._parse_kms_tls_options(_IS_SYNC)
673+
self._kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
674674
self._io_callbacks: Optional[_EncryptionIO] = _EncryptionIO(
675675
None, key_vault_coll, None, opts
676676
)

test/asynchronous/test_encryption.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from pymongo.asynchronous.collection import AsyncCollection
4242
from pymongo.asynchronous.helpers import anext
4343
from pymongo.daemon import _spawn_daemon
44+
from pymongo.uri_parser import _parse_kms_tls_options
4445

4546
try:
4647
from pymongo.pyopenssl_context import IS_PYOPENSSL
@@ -141,7 +142,7 @@ def test_init(self):
141142
self.assertEqual(opts._mongocryptd_bypass_spawn, False)
142143
self.assertEqual(opts._mongocryptd_spawn_path, "mongocryptd")
143144
self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"])
144-
self.assertEqual(opts._kms_ssl_contexts, {})
145+
self.assertEqual(opts._kms_tls_options, {})
145146

146147
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
147148
def test_init_spawn_args(self):
@@ -189,22 +190,22 @@ def test_init_kms_tls_options(self):
189190
tls_opts: Any
190191
for tls_opts in [None, {}]:
191192
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
192-
self.assertEqual(opts._kms_ssl_contexts, {})
193+
self.assertEqual(opts._kms_tls_options, {})
193194
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}})
194-
opts._parse_kms_tls_options(_IS_SYNC)
195-
ctx = opts._kms_ssl_contexts["kmip"]
195+
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
196+
ctx = _kms_ssl_contexts["kmip"]
196197
self.assertEqual(ctx.check_hostname, True)
197198
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
198-
ctx = opts._kms_ssl_contexts["aws"]
199+
ctx = _kms_ssl_contexts["aws"]
199200
self.assertEqual(ctx.check_hostname, True)
200201
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
201202
opts = AutoEncryptionOpts(
202203
{},
203204
"k.d",
204205
kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}},
205206
)
206-
opts._parse_kms_tls_options(_IS_SYNC)
207-
ctx = opts._kms_ssl_contexts["kmip"]
207+
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
208+
ctx = _kms_ssl_contexts["kmip"]
208209
self.assertEqual(ctx.check_hostname, True)
209210
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
210211

@@ -2233,7 +2234,7 @@ async def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self):
22332234
encryption = self.create_client_encryption(
22342235
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
22352236
)
2236-
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
2237+
ctx = encryption._io_callbacks._kms_ssl_contexts["aws"]
22372238
if not hasattr(ctx, "check_ocsp_endpoint"):
22382239
raise self.skipTest("OCSP not enabled")
22392240
self.assertFalse(ctx.check_ocsp_endpoint)

test/test_encryption.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from pymongo.daemon import _spawn_daemon
4242
from pymongo.synchronous.collection import Collection
4343
from pymongo.synchronous.helpers import next
44+
from pymongo.uri_parser import _parse_kms_tls_options
4445

4546
try:
4647
from pymongo.pyopenssl_context import IS_PYOPENSSL
@@ -141,7 +142,7 @@ def test_init(self):
141142
self.assertEqual(opts._mongocryptd_bypass_spawn, False)
142143
self.assertEqual(opts._mongocryptd_spawn_path, "mongocryptd")
143144
self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"])
144-
self.assertEqual(opts._kms_ssl_contexts, {})
145+
self.assertEqual(opts._kms_tls_options, {})
145146

146147
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
147148
def test_init_spawn_args(self):
@@ -189,22 +190,22 @@ def test_init_kms_tls_options(self):
189190
tls_opts: Any
190191
for tls_opts in [None, {}]:
191192
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts)
192-
self.assertEqual(opts._kms_ssl_contexts, {})
193+
self.assertEqual(opts._kms_tls_options, {})
193194
opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}})
194-
opts._parse_kms_tls_options(_IS_SYNC)
195-
ctx = opts._kms_ssl_contexts["kmip"]
195+
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
196+
ctx = _kms_ssl_contexts["kmip"]
196197
self.assertEqual(ctx.check_hostname, True)
197198
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
198-
ctx = opts._kms_ssl_contexts["aws"]
199+
ctx = _kms_ssl_contexts["aws"]
199200
self.assertEqual(ctx.check_hostname, True)
200201
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
201202
opts = AutoEncryptionOpts(
202203
{},
203204
"k.d",
204205
kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}},
205206
)
206-
opts._parse_kms_tls_options(_IS_SYNC)
207-
ctx = opts._kms_ssl_contexts["kmip"]
207+
_kms_ssl_contexts = _parse_kms_tls_options(opts._kms_tls_options, _IS_SYNC)
208+
ctx = _kms_ssl_contexts["kmip"]
208209
self.assertEqual(ctx.check_hostname, True)
209210
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
210211

@@ -2225,7 +2226,7 @@ def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self):
22252226
encryption = self.create_client_encryption(
22262227
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options
22272228
)
2228-
ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"]
2229+
ctx = encryption._io_callbacks._kms_ssl_contexts["aws"]
22292230
if not hasattr(ctx, "check_ocsp_endpoint"):
22302231
raise self.skipTest("OCSP not enabled")
22312232
self.assertFalse(ctx.check_ocsp_endpoint)

0 commit comments

Comments
 (0)