diff --git a/CHANGELOG.md b/CHANGELOG.md index fdfdc6d896..5f8c23040a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,9 @@ These changes are available on the `master` branch, but have not yet been releas ([#2714](https://github.com/Pycord-Development/pycord/pull/2714)) - Added the ability to pass a `datetime.time` object to `format_dt`. ([#2747](https://github.com/Pycord-Development/pycord/pull/2747)) +- Added the ability to pass an `overlap` parameter to the `loop` decorator and `Loop` + class, allowing concurrent iterations if enabled. + ([#2765](https://github.com/Pycord-Development/pycord/pull/2765)) - Added `discord.Interaction.created_at`. ([#2801](https://github.com/Pycord-Development/pycord/pull/2801)) diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index af34cc6844..5276427e7d 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -26,6 +26,7 @@ from __future__ import annotations import asyncio +import contextvars import datetime import inspect import sys @@ -46,6 +47,9 @@ LF = TypeVar("LF", bound=_func) FT = TypeVar("FT", bound=_func) ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]]) +_current_loop_ctx: contextvars.ContextVar[int] = contextvars.ContextVar( + "_current_loop_ctx", default=None +) class SleepHandle: @@ -59,10 +63,14 @@ def __init__( relative_delta = discord.utils.compute_timedelta(dt) self.handle = loop.call_later(relative_delta, future.set_result, True) + def _set_result_safe(self): + if not self.future.done(): + self.future.set_result(True) + def recalculate(self, dt: datetime.datetime) -> None: self.handle.cancel() relative_delta = discord.utils.compute_timedelta(dt) - self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) + self.handle = self.loop.call_later(relative_delta, self._set_result_safe) def wait(self) -> asyncio.Future[Any]: return self.future @@ -91,10 +99,12 @@ def __init__( count: int | None, reconnect: bool, loop: asyncio.AbstractEventLoop, + overlap: bool | int, ) -> None: self.coro: LF = coro self.reconnect: bool = reconnect self.loop: asyncio.AbstractEventLoop = loop + self.overlap: bool | int = overlap self.count: int | None = count self._current_loop = 0 self._handle: SleepHandle = MISSING @@ -115,6 +125,7 @@ def __init__( self._is_being_cancelled = False self._has_failed = False self._stop_next_iteration = False + self._tasks: list[asyncio.Task[Any]] = [] if self.count is not None and self.count <= 0: raise ValueError("count must be greater than 0 or None.") @@ -128,6 +139,12 @@ def __init__( raise TypeError( f"Expected coroutine function, not {type(self.coro).__name__!r}." ) + if not isinstance(overlap, (bool, int)): + raise TypeError("overlap must be a bool or a positive integer.") + elif not isinstance(overlap, bool) and isinstance(overlap, int): + if overlap <= 1: + raise ValueError("overlap as an integer must be greater than 1.") + self._semaphore = asyncio.Semaphore(overlap) async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: coro = getattr(self, f"_{name}") @@ -166,7 +183,27 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None: self._last_iteration = self._next_iteration self._next_iteration = self._get_next_sleep_time() try: - await self.coro(*args, **kwargs) + token = _current_loop_ctx.set(self._current_loop) + if not self.overlap: + await self.coro(*args, **kwargs) + else: + + async def run_with_semaphore(): + async with self._semaphore: + await self.coro(*args, **kwargs) + + task = asyncio.create_task( + ( + self.coro(*args, **kwargs) + if self.overlap is True + else run_with_semaphore() + ), + name=f"pycord-loop-{self.coro.__name__}-{self._current_loop}", + ) + task.add_done_callback(self._tasks.remove) + self._tasks.append(task) + + _current_loop_ctx.reset(token) self._last_iteration_failed = False backoff = ExponentialBackoff() except self._valid_exception: @@ -192,6 +229,9 @@ async def _loop(self, *args: Any, **kwargs: Any) -> None: except asyncio.CancelledError: self._is_being_cancelled = True + for task in self._tasks: + task.cancel() + await asyncio.gather(*self._tasks, return_exceptions=True) raise except Exception as exc: self._has_failed = True @@ -218,6 +258,7 @@ def __get__(self, obj: T, objtype: type[T]) -> Loop[LF]: count=self.count, reconnect=self.reconnect, loop=self.loop, + overlap=self.overlap, ) copy._injected = obj copy._before_loop = self._before_loop @@ -269,7 +310,11 @@ def time(self) -> list[datetime.time] | None: @property def current_loop(self) -> int: """The current iteration of the loop.""" - return self._current_loop + return ( + _current_loop_ctx.get() + if _current_loop_ctx.get() is not None + else self._current_loop + ) @property def next_iteration(self) -> datetime.datetime | None: @@ -738,6 +783,7 @@ def loop( count: int | None = None, reconnect: bool = True, loop: asyncio.AbstractEventLoop = MISSING, + overlap: bool | int = False, ) -> Callable[[LF], Loop[LF]]: """A decorator that schedules a task in the background for you with optional reconnect logic. The decorator returns a :class:`Loop`. @@ -773,6 +819,11 @@ def loop( loop: :class:`asyncio.AbstractEventLoop` The loop to use to register the task, if not given defaults to :func:`asyncio.get_event_loop`. + overlap: Union[:class:`bool`, :class:`int`] + Controls whether overlapping executions of the task loop are allowed. + Set to False (default) to run iterations one at a time, True for unlimited overlap, or an int to cap the number of concurrent runs. + + .. versionadded:: 2.7 Raises ------ @@ -793,6 +844,7 @@ def decorator(func: LF) -> Loop[LF]: time=time, reconnect=reconnect, loop=loop, + overlap=overlap, ) return decorator