Skip to content

Commit

Permalink
fix: asgi lifespan msg after lifespan context exception (#3315)
Browse files Browse the repository at this point in the history
An exception raised within an asgi lifespan context manager would result in a "lifespan.startup.failed" message being sent after we've already sent a "lifespan.startup.complete" message. This would cause uvicorn to raise a `STATE_TRANSITION_ERROR` assertion error due to their [check for that condition][1].

This PR modifies `ASGIRouter.lifespan()` so that it sends a shutdown failure message if we've already confirmed startup. This is consistent with [starlette's behavior][2] under the same conditions.

[1]: https://github.com/encode/uvicorn/blob/a2219eb2ed2bbda4143a0fb18c4b0578881b1ae8/uvicorn/lifespan/on.py#L115-L117
[2]: https://github.com/encode/starlette/blob/4e453ce91940cc7c995e6c728e3fdf341c039056/starlette/routing.py#L744-L745
  • Loading branch information
peterschutt authored Apr 6, 2024
1 parent b5d9c6f commit fac641a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
15 changes: 8 additions & 7 deletions litestar/_asgi/asgi_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,27 +158,28 @@ async def lifespan(self, receive: LifeSpanReceive, send: LifeSpanSend) -> None:
Returns:
None.
"""

message = await receive()
shutdown_event: LifeSpanShutdownCompleteEvent = {"type": "lifespan.shutdown.complete"}
startup_event: LifeSpanStartupCompleteEvent = {"type": "lifespan.startup.complete"}

await receive()

started = False
try:
async with self.app.lifespan():
await send(startup_event)
message = await receive()
started = True
await receive()

except BaseException as e:
formatted_exception = format_exc()
failure_message: LifeSpanStartupFailedEvent | LifeSpanShutdownFailedEvent

if message["type"] == "lifespan.startup":
failure_message = {"type": "lifespan.startup.failed", "message": formatted_exception}
else:
if started:
failure_message = {"type": "lifespan.shutdown.failed", "message": formatted_exception}
else:
failure_message = {"type": "lifespan.startup.failed", "message": formatted_exception}

await send(failure_message)

raise e

await send(shutdown_event)
33 changes: 32 additions & 1 deletion tests/unit/test_asgi_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncGenerator, Callable
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, call

import anyio
import pytest
from pytest_mock import MockerFixture

Expand Down Expand Up @@ -194,3 +195,33 @@ async def on_shutdown() -> None:

assert send.call_count == 2
assert send.call_args_list[1][0][0] == {"type": "lifespan.shutdown.failed", "message": mock_format_exc.return_value}


async def test_lifespan_context_exception_after_startup(mock_format_exc: MagicMock) -> None:
receive = AsyncMock()
receive.return_value = {"type": "lifespan.startup"}
send = AsyncMock()
mock_format_exc.return_value = "foo"

async def sleep_and_raise() -> None:
await anyio.sleep(0)
raise RuntimeError("An error occurred")

@asynccontextmanager
async def lifespan(_: Litestar) -> AsyncGenerator[None, None]:
async with anyio.create_task_group() as tg:
tg.start_soon(sleep_and_raise)
yield

router = ASGIRouter(app=Litestar(lifespan=[lifespan]))

with pytest.raises(_ExceptionGroup):
await router.lifespan(receive, send)

assert receive.call_count == 2
send.assert_has_calls(
[
call({"type": "lifespan.startup.complete"}),
call({"type": "lifespan.shutdown.failed", "message": mock_format_exc.return_value}),
]
)

0 comments on commit fac641a

Please sign in to comment.