Source code for redbot.core.utils

import asyncio
import warnings
from asyncio import AbstractEventLoop, as_completed, Semaphore
from asyncio.futures import isfuture
from itertools import chain
from typing import (
    Any,
    AsyncIterator,
    AsyncIterable,
    Awaitable,
    Callable,
    Iterable,
    Iterator,
    List,
    Optional,
    Tuple,
    TypeVar,
    Union,
    Set,
    TYPE_CHECKING,
)

__all__ = ("bounded_gather", "deduplicate_iterables")

_T = TypeVar("_T")


# Benchmarked to be the fastest method.
[docs]def deduplicate_iterables(*iterables): """ Returns a list of all unique items in ``iterables``, in the order they were first encountered. """ # dict insertion order is guaranteed to be preserved in 3.6+ return list(dict.fromkeys(chain.from_iterable(iterables)))
# https://github.com/PyCQA/pylint/issues/2717 class AsyncFilter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=duplicate-bases """Class returned by `async_filter`. See that function for details. We don't recommend instantiating this class directly. """ def __init__( self, func: Callable[[_T], Union[bool, Awaitable[bool]]], iterable: Union[AsyncIterable[_T], Iterable[_T]], ) -> None: self.__func: Callable[[_T], Union[bool, Awaitable[bool]]] = func self.__iterable: Union[AsyncIterable[_T], Iterable[_T]] = iterable # We assign the generator strategy based on the arguments' types if isinstance(iterable, AsyncIterable): if asyncio.iscoroutinefunction(func): self.__generator_instance = self.__async_generator_async_pred() else: self.__generator_instance = self.__async_generator_sync_pred() elif asyncio.iscoroutinefunction(func): self.__generator_instance = self.__sync_generator_async_pred() else: raise TypeError("Must be either an async predicate, an async iterable, or both.") async def __sync_generator_async_pred(self) -> AsyncIterator[_T]: for item in self.__iterable: if await self.__func(item): yield item async def __async_generator_sync_pred(self) -> AsyncIterator[_T]: async for item in self.__iterable: if self.__func(item): yield item async def __async_generator_async_pred(self) -> AsyncIterator[_T]: async for item in self.__iterable: if await self.__func(item): yield item async def __flatten(self) -> List[_T]: return [item async for item in self] def __aiter__(self): return self def __await__(self): # Simply return the generator filled into a list return self.__flatten().__await__() def __anext__(self) -> Awaitable[_T]: # This will use the generator strategy set in __init__ return self.__generator_instance.__anext__() def async_filter( func: Callable[[_T], Union[bool, Awaitable[bool]]], iterable: Union[AsyncIterable[_T], Iterable[_T]], ) -> AsyncFilter[_T]: """Filter an (optionally async) iterable with an (optionally async) predicate. At least one of the arguments must be async. Parameters ---------- func : Callable[[T], Union[bool, Awaitable[bool]]] A function or coroutine function which takes one item of ``iterable`` as an argument, and returns ``True`` or ``False``. iterable : Union[AsyncIterable[_T], Iterable[_T]] An iterable or async iterable which is to be filtered. Raises ------ TypeError If neither of the arguments are async. Returns ------- AsyncFilter[T] An object which can either be awaited to yield a list of the filtered items, or can also act as an async iterator to yield items one by one. """ return AsyncFilter(func, iterable) async def async_enumerate( async_iterable: AsyncIterable[_T], start: int = 0 ) -> AsyncIterator[Tuple[int, _T]]: """Async iterable version of `enumerate`. Parameters ---------- async_iterable : AsyncIterable[T] The iterable to enumerate. start : int The index to start from. Defaults to 0. Returns ------- AsyncIterator[Tuple[int, T]] An async iterator of tuples in the form of ``(index, item)``. """ async for item in async_iterable: yield start, item start += 1 async def _sem_wrapper(sem, task): async with sem: return await task
[docs]def bounded_gather_iter( *coros_or_futures, loop: Optional[AbstractEventLoop] = None, limit: int = 4, semaphore: Optional[Semaphore] = None, ) -> Iterator[Awaitable[Any]]: """ An iterator that returns tasks as they are ready, but limits the number of tasks running at a time. Parameters ---------- *coros_or_futures The awaitables to run in a bounded concurrent fashion. loop : asyncio.AbstractEventLoop The event loop to use for the semaphore and :meth:`asyncio.gather`. limit : Optional[`int`] The maximum number of concurrent tasks. Used when no ``semaphore`` is passed. semaphore : Optional[:class:`asyncio.Semaphore`] The semaphore to use for bounding tasks. If `None`, create one using ``loop`` and ``limit``. Raises ------ TypeError When invalid parameters are passed """ if loop is not None: warnings.warn( "Explicitly passing the loop will not work in Red 3.4+ and is currently ignored." "Call this from the related event loop.", DeprecationWarning, ) loop = asyncio.get_running_loop() if semaphore is None: if not isinstance(limit, int) or limit <= 0: raise TypeError("limit must be an int > 0") semaphore = Semaphore(limit) pending = [] for cof in coros_or_futures: if isfuture(cof) and cof._loop is not loop: raise ValueError("futures are tied to different event loops") cof = _sem_wrapper(semaphore, cof) pending.append(cof) return as_completed(pending)
[docs]def bounded_gather( *coros_or_futures, loop: Optional[AbstractEventLoop] = None, return_exceptions: bool = False, limit: int = 4, semaphore: Optional[Semaphore] = None, ) -> Awaitable[List[Any]]: """ A semaphore-bounded wrapper to :meth:`asyncio.gather`. Parameters ---------- *coros_or_futures The awaitables to run in a bounded concurrent fashion. loop : asyncio.AbstractEventLoop The event loop to use for the semaphore and :meth:`asyncio.gather`. return_exceptions : bool If true, gather exceptions in the result list instead of raising. limit : Optional[`int`] The maximum number of concurrent tasks. Used when no ``semaphore`` is passed. semaphore : Optional[:class:`asyncio.Semaphore`] The semaphore to use for bounding tasks. If `None`, create one using ``loop`` and ``limit``. Raises ------ TypeError When invalid parameters are passed """ if loop is not None: warnings.warn( "Explicitly passing the loop will not work in Red 3.4+ and is currently ignored." "Call this from the related event loop.", DeprecationWarning, ) loop = asyncio.get_running_loop() if semaphore is None: if not isinstance(limit, int) or limit <= 0: raise TypeError("limit must be an int > 0") semaphore = Semaphore(limit) tasks = (_sem_wrapper(semaphore, task) for task in coros_or_futures) return asyncio.gather(*tasks, return_exceptions=return_exceptions)