|
4 | 4 | from asyncio import Lock as AsyncLock |
5 | 5 | from asyncio import sleep as async_sleep |
6 | 6 | from typing import Optional, Tuple, Union |
| 7 | +from unittest.mock import AsyncMock, call |
7 | 8 |
|
8 | 9 | import pytest |
9 | 10 | import pytest_asyncio |
10 | 11 | import redis |
11 | | -from mock.mock import Mock, call |
12 | 12 | from redis import AuthenticationError, DataError, RedisError, ResponseError |
13 | 13 | from redis.asyncio import Connection, ConnectionPool, Redis |
14 | 14 | from redis.asyncio.retry import Retry |
@@ -340,10 +340,10 @@ class TestStreamingCredentialProvider: |
340 | 340 | indirect=True, |
341 | 341 | ) |
342 | 342 | async def test_async_re_auth_all_connections(self, credential_provider): |
343 | | - mock_connection = Mock(spec=Connection) |
| 343 | + mock_connection = AsyncMock(spec=Connection) |
344 | 344 | mock_connection.retry = Retry(NoBackoff(), 0) |
345 | | - mock_another_connection = Mock(spec=Connection) |
346 | | - mock_pool = Mock(spec=ConnectionPool) |
| 345 | + mock_another_connection = AsyncMock(spec=Connection) |
| 346 | + mock_pool = AsyncMock(spec=ConnectionPool) |
347 | 347 | mock_pool.connection_kwargs = { |
348 | 348 | "credential_provider": credential_provider, |
349 | 349 | } |
@@ -391,16 +391,16 @@ async def re_auth_callback(token): |
391 | 391 | indirect=True, |
392 | 392 | ) |
393 | 393 | async def test_async_re_auth_partial_connections(self, credential_provider): |
394 | | - mock_connection = Mock(spec=Connection) |
| 394 | + mock_connection = AsyncMock(spec=Connection) |
395 | 395 | mock_connection.retry = Retry(NoBackoff(), 3) |
396 | | - mock_another_connection = Mock(spec=Connection) |
| 396 | + mock_another_connection = AsyncMock(spec=Connection) |
397 | 397 | mock_another_connection.retry = Retry(NoBackoff(), 3) |
398 | | - mock_failed_connection = Mock(spec=Connection) |
| 398 | + mock_failed_connection = AsyncMock(spec=Connection) |
399 | 399 | mock_failed_connection.read_response.side_effect = ConnectionError( |
400 | 400 | "Failed auth" |
401 | 401 | ) |
402 | 402 | mock_failed_connection.retry = Retry(NoBackoff(), 3) |
403 | | - mock_pool = Mock(spec=ConnectionPool) |
| 403 | + mock_pool = AsyncMock(spec=ConnectionPool) |
404 | 404 | mock_pool.connection_kwargs = { |
405 | 405 | "credential_provider": credential_provider, |
406 | 406 | } |
@@ -454,21 +454,28 @@ async def re_auth_callback(token): |
454 | 454 | indirect=True, |
455 | 455 | ) |
456 | 456 | async def test_re_auth_pub_sub_in_resp3(self, credential_provider): |
457 | | - mock_pubsub_connection = Mock(spec=Connection) |
| 457 | + mock_pubsub_connection = AsyncMock(spec=Connection) |
458 | 458 | mock_pubsub_connection.get_protocol.return_value = 3 |
459 | 459 | mock_pubsub_connection.credential_provider = credential_provider |
460 | 460 | mock_pubsub_connection.retry = Retry(NoBackoff(), 3) |
461 | | - mock_another_connection = Mock(spec=Connection) |
| 461 | + mock_another_connection = AsyncMock(spec=Connection) |
462 | 462 | mock_another_connection.retry = Retry(NoBackoff(), 3) |
463 | 463 |
|
464 | | - mock_pool = Mock(spec=ConnectionPool) |
| 464 | + mock_pool = AsyncMock(spec=ConnectionPool) |
465 | 465 | mock_pool.connection_kwargs = { |
466 | 466 | "credential_provider": credential_provider, |
467 | 467 | } |
468 | | - mock_pool.get_connection.side_effect = [ |
469 | | - mock_pubsub_connection, |
470 | | - mock_another_connection, |
471 | | - ] |
| 468 | + |
| 469 | + async def get_connection_side_effect(): |
| 470 | + if not hasattr(get_connection_side_effect, "call_count"): |
| 471 | + get_connection_side_effect.call_count = 0 |
| 472 | + result = [mock_pubsub_connection, mock_another_connection][ |
| 473 | + get_connection_side_effect.call_count |
| 474 | + ] |
| 475 | + get_connection_side_effect.call_count += 1 |
| 476 | + return result |
| 477 | + |
| 478 | + mock_pool.get_connection = AsyncMock(side_effect=get_connection_side_effect) |
472 | 479 | mock_pool._available_connections = [mock_another_connection] |
473 | 480 | mock_pool._lock = AsyncLock() |
474 | 481 | auth_token = None |
@@ -516,21 +523,28 @@ async def re_auth_callback(token): |
516 | 523 | indirect=True, |
517 | 524 | ) |
518 | 525 | async def test_do_not_re_auth_pub_sub_in_resp2(self, credential_provider): |
519 | | - mock_pubsub_connection = Mock(spec=Connection) |
| 526 | + mock_pubsub_connection = AsyncMock(spec=Connection) |
520 | 527 | mock_pubsub_connection.get_protocol.return_value = 2 |
521 | 528 | mock_pubsub_connection.credential_provider = credential_provider |
522 | 529 | mock_pubsub_connection.retry = Retry(NoBackoff(), 3) |
523 | | - mock_another_connection = Mock(spec=Connection) |
| 530 | + mock_another_connection = AsyncMock(spec=Connection) |
524 | 531 | mock_another_connection.retry = Retry(NoBackoff(), 3) |
525 | 532 |
|
526 | | - mock_pool = Mock(spec=ConnectionPool) |
| 533 | + mock_pool = AsyncMock(spec=ConnectionPool) |
527 | 534 | mock_pool.connection_kwargs = { |
528 | 535 | "credential_provider": credential_provider, |
529 | 536 | } |
530 | | - mock_pool.get_connection.side_effect = [ |
531 | | - mock_pubsub_connection, |
532 | | - mock_another_connection, |
533 | | - ] |
| 537 | + |
| 538 | + async def get_connection_side_effect(): |
| 539 | + if not hasattr(get_connection_side_effect, "call_count"): |
| 540 | + get_connection_side_effect.call_count = 0 |
| 541 | + result = [mock_pubsub_connection, mock_another_connection][ |
| 542 | + get_connection_side_effect.call_count |
| 543 | + ] |
| 544 | + get_connection_side_effect.call_count += 1 |
| 545 | + return result |
| 546 | + |
| 547 | + mock_pool.get_connection = AsyncMock(side_effect=get_connection_side_effect) |
534 | 548 | mock_pool._available_connections = [mock_another_connection] |
535 | 549 | mock_pool._lock = AsyncLock() |
536 | 550 | auth_token = None |
@@ -583,10 +597,10 @@ async def test_fails_on_token_renewal(self, credential_provider): |
583 | 597 | RequestTokenErr, |
584 | 598 | RequestTokenErr, |
585 | 599 | ] |
586 | | - mock_connection = Mock(spec=Connection) |
| 600 | + mock_connection = AsyncMock(spec=Connection) |
587 | 601 | mock_connection.retry = Retry(NoBackoff(), 0) |
588 | | - mock_another_connection = Mock(spec=Connection) |
589 | | - mock_pool = Mock(spec=ConnectionPool) |
| 602 | + mock_another_connection = AsyncMock(spec=Connection) |
| 603 | + mock_pool = AsyncMock(spec=ConnectionPool) |
590 | 604 | mock_pool.connection_kwargs = { |
591 | 605 | "credential_provider": credential_provider, |
592 | 606 | } |
|
0 commit comments