diff --git a/litestar/_asgi/asgi_router.py b/litestar/_asgi/asgi_router.py index ebafaf049f..e90011d192 100644 --- a/litestar/_asgi/asgi_router.py +++ b/litestar/_asgi/asgi_router.py @@ -158,27 +158,28 @@ async def lifespan(self, receive: LifeSpanReceive, send: LifeSpanSend) -> None: Returns: None. """ - - message = await receive() shutdown_event: LifeSpanShutdownCompleteEvent = {"type": "lifespan.shutdown.complete"} startup_event: LifeSpanStartupCompleteEvent = {"type": "lifespan.startup.complete"} + await receive() + + started = False try: async with self.app.lifespan(): await send(startup_event) - message = await receive() + started = True + await receive() except BaseException as e: formatted_exception = format_exc() failure_message: LifeSpanStartupFailedEvent | LifeSpanShutdownFailedEvent - if message["type"] == "lifespan.startup": - failure_message = {"type": "lifespan.startup.failed", "message": formatted_exception} - else: + if started: failure_message = {"type": "lifespan.shutdown.failed", "message": formatted_exception} + else: + failure_message = {"type": "lifespan.startup.failed", "message": formatted_exception} await send(failure_message) - raise e await send(shutdown_event) diff --git a/tests/unit/test_asgi_router.py b/tests/unit/test_asgi_router.py index bc9b084a23..b913823cbb 100644 --- a/tests/unit/test_asgi_router.py +++ b/tests/unit/test_asgi_router.py @@ -2,8 +2,9 @@ from contextlib import asynccontextmanager from typing import TYPE_CHECKING, AsyncGenerator, Callable -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock, MagicMock, call +import anyio import pytest from pytest_mock import MockerFixture @@ -194,3 +195,33 @@ async def on_shutdown() -> None: assert send.call_count == 2 assert send.call_args_list[1][0][0] == {"type": "lifespan.shutdown.failed", "message": mock_format_exc.return_value} + + +async def test_lifespan_context_exception_after_startup(mock_format_exc: MagicMock) -> None: + receive = AsyncMock() + receive.return_value = {"type": "lifespan.startup"} + send = AsyncMock() + mock_format_exc.return_value = "foo" + + async def sleep_and_raise() -> None: + await anyio.sleep(0) + raise RuntimeError("An error occurred") + + @asynccontextmanager + async def lifespan(_: Litestar) -> AsyncGenerator[None, None]: + async with anyio.create_task_group() as tg: + tg.start_soon(sleep_and_raise) + yield + + router = ASGIRouter(app=Litestar(lifespan=[lifespan])) + + with pytest.raises(_ExceptionGroup): + await router.lifespan(receive, send) + + assert receive.call_count == 2 + send.assert_has_calls( + [ + call({"type": "lifespan.startup.complete"}), + call({"type": "lifespan.shutdown.failed", "message": mock_format_exc.return_value}), + ] + )