diff --git a/aws_embedded_metrics/metric_scope/__init__.py b/aws_embedded_metrics/metric_scope/__init__.py index 3389945..4d0ff34 100644 --- a/aws_embedded_metrics/metric_scope/__init__.py +++ b/aws_embedded_metrics/metric_scope/__init__.py @@ -29,15 +29,10 @@ async def async_gen_wrapper(*args, **kwargs): # type: ignore kwargs["metrics"] = logger try: - fn_gen = fn(*args, **kwargs) - while True: - result = await fn_gen.__anext__() - await logger.flush() + async for result in fn(*args, **kwargs): yield result - except Exception as ex: + finally: await logger.flush() - if not isinstance(ex, StopIteration): - raise return cast(F, async_gen_wrapper) @@ -49,15 +44,10 @@ def gen_wrapper(*args, **kwargs): # type: ignore kwargs["metrics"] = logger try: - fn_gen = fn(*args, **kwargs) - while True: - result = next(fn_gen) - asyncio.run(logger.flush()) + for result in fn(*args, **kwargs): yield result - except Exception as ex: + finally: asyncio.run(logger.flush()) - if not isinstance(ex, StopIteration): - raise return cast(F, gen_wrapper) diff --git a/tests/metric_scope/test_metric_scope.py b/tests/metric_scope/test_metric_scope.py index 9ebd1f1..31fe4d2 100644 --- a/tests/metric_scope/test_metric_scope.py +++ b/tests/metric_scope/test_metric_scope.py @@ -169,7 +169,38 @@ def my_handler(metrics): assert expected_timestamp_second == actual_timestamp_second -def test_sync_scope_iterates_generator(mock_logger): +@pytest.mark.asyncio +async def test_async_generator_completes_successfully(mock_logger): + expected_results = [1, 2, 3] + + @metric_scope + async def my_handler(): + for item in expected_results: + yield item + + actual_results = [] + async for result in my_handler(): + actual_results.append(result) + + assert actual_results == expected_results + assert InvocationTracker.invocations == 1 + + +def test_sync_generator_completes_successfully(mock_logger): + expected_results = [1, 2, 3] + + @metric_scope + def my_handler(): + yield from expected_results + + actual_results = [] + for result in my_handler(): + actual_results.append(result) + + assert actual_results == expected_results + assert InvocationTracker.invocations == 1 + +def test_sync_generator_handles_exception(mock_logger): expected_results = [1, 2] @metric_scope @@ -183,11 +214,11 @@ def my_handler(): actual_results.append(result) assert actual_results == expected_results - assert InvocationTracker.invocations == 3 + assert InvocationTracker.invocations == 1 @pytest.mark.asyncio -async def test_async_scope_iterates_async_generator(mock_logger): +async def test_async_generator_handles_exception(mock_logger): expected_results = [1, 2] @metric_scope @@ -203,7 +234,7 @@ async def my_handler(): actual_results.append(result) assert actual_results == expected_results - assert InvocationTracker.invocations == 3 + assert InvocationTracker.invocations == 1 # Test helpers