Add task pipeline to asyncio with capped parallelism and lazy input reading (original) (raw)

Motivation

I would like to propose adding a standard way to implement async data processing pipelines that are memory-friendly, enable input data stream being generated lazily or dynamically, and enable easy capping of parallelism.

I have noticed that some people struggle with implementing this in a way that does not create a task per input item at the start (and ultimately crash the interpreter for large batches) - and neither OpenAI o1 nor Github Copilot could generate an implementation I would consider good and robust either, unfortunatelly.

Using asyncio.as_completed or asyncio.gather requires creating all tasks beforehand, which prevents generating them dynamically on-the-fly as new information becomes available (like in the case of a web crawler), and is very memory-intensive or impractical if all the input data cannot be fit into the memory. The lack of proper support for backpressure in custom simple implementations also - in some cases - leads to gradually increasing the application’s memory footprint until the interpreter crashes.

Such utility would be useful for example for communicating with rate-limited remote servers when processing a huge or dynamically created stream of input data - e.g. web crawling, generating LLM completions, etc.

I have seen this attempted in many Jupyter notebooks evaluating downstream tasks on LLMs or LLMs themselves.

Proposed API

I would like to propose adding the following function to the asyncio module:

async def pipeline[T, R](
    iterable: AsyncIterable[T] | Iterable[T],
    func: Callable[[T], R | Awaitable[R]],
    *args, # passed to func
    parallelism: int = 4,
    **kwargs, # passed to func
) -> AsyncIterable[R]:
    ...

I have no strong feelings about the function’s name, maybe map_concurrent or map_parallel would be better.

To make the usage as convenient as possible, the implementation should also handle backpressure to adjust the speed of reading its input and avoid consuming too much memory using internal buffers. Additionally, the results should be re-ordered internally to match the input order, so that there is no need for item identifiers and re-ordering on the calling code’s part.

Usage demonstration

async def generate_completion(item: str) -> str | None:
    # Call some 3rd party API to generate output from item
    try:
        result = await some_api_client.do_stuff(item)
        return result
    except BaseException as e:
        logging.error(e)
        return None

async def main():
    async with aiofiles.open("input.txt", "r") as in_f, aiofiles.open("output.txt", "w") as out_f:
        async for result in asyncio.pipeline(in_f, generate_completion, parallelism=16):
            await out_f.write(f"{result}\n")

asyncio.run(main())

Design considerations

The proposed API could be at a later point extended by allowing custom callback / object passed to the parallelism parameter as a custom dynamic parallelism controller, but I did not want to overcomplicate things. Another open question is whether overriding the executor should be supported, per-pipeline or per-task.

I would also like to recognize that result reordering might not always be necessary, and as it can under some situations result in lower overall throughput, maybe it should be configurable - but I would strongly recommend to make it enabled by default.

I suggest that any errors raised from func would cancel the pipeline and be raised from the returned AsyncIterable.

An alternative approach to address this use case could be to allow passing AsynIterable[Awaitable[T]] to asyncio.gather and asyncio.as_completed, along with a parallelism cap.

Existing work

I am currently aware of the following libraries related to this:

Example implementation

Here I attach an example implementation, licensed as Public Domain.

It does not handle errors, I would add that in a hypothetical PR should there be interest in this (adding an error | None item to the tuples in output_queue seems like the most straightforward way).

# SPDX-License-Identifier: CC-PDM-1.0

async def pipeline[T, R](
    iterable: AsyncIterable[T] | Iterable[T],
    func: Callable[[T], R | Awaitable[R]],
    *args,
    parallelism: int = 4,
    **kwargs,
) -> AsyncIterable[R]:
    """
    Asynchronously map a function over an (a)synchronous iterable with bounded
    parallelism.

    Allows concurrent processing of items from the given iterable, up to a
    specified parallelism level. Takes a function that can be either synchronous
    (returning R) or asynchronous (returning Awaitable[R]) and returns an async
    iterable producing results in input order as soon as they become available.
    Internal queues are bounded, preventing consumption of the entire iterable
    at once in memory.

    Args:
        iterable:
            The source of items to process. Can be a synchronous or asynchronous
            iterable.
        func:
            The mapping function to apply to each item. May be synchronous
            (returning R) or asynchronous (returning Awaitable[R]). All *args
            and **kwargs are forwarded to this function.
        parallelism (int):
            Maximum number of concurrent worker tasks. Defaults to 4.
        *args:
            Extra positional arguments passed on to `func`.
        **kwargs:
            Extra keyword arguments passed on to `func`.

    Yields:
        R: The result of applying `func` to each item, in the same order as
        their corresponding inputs.  

    Notes:
        - If the callback is synchronous, it will be invoked directly in the
          event loop coroutine, so consider wrapping it with asyncio.to_thread()
          if blocking is significant.
        - This implementation uses internal queues to avoid reading from
          `iterable` too far ahead, controlling memory usage.
        - Once an item finishes processing, its result is enqueued and will be
          yielded as soon as all previous results have also been yielded.
        - If the consumer of this async iterable stops consuming early, workers
          may block while attempting to enqueue subsequent results. It is
          recommended to cancel this coroutine in such case to clean up
          resources if it is no longer needed.
        - If the work for some items is very slow, intermediate results are
          accumulated in an internal buffer until those slow results become
          available, preventing out-of-order yielding.
    """

    input_terminator = cast(T, object())
    output_terminator = cast(R, object())
    input_queue = asyncio.Queue[tuple[int, T]](parallelism)
    output_queue = asyncio.Queue[tuple[int, R]](parallelism)
    feeding_stop = asyncio.Event()
    last_fed = -1
    next_to_yield = 0
    early_results: dict[int, R] = {}

    async def _worker() -> None:
        while True:
            index, item = await input_queue.get()
            if item is input_terminator:
                input_queue.task_done()
                break
            result = func(item, *args, **kwargs)
            if isinstance(result, Awaitable):
                result = cast(R, await result)
            await output_queue.put((index, result))
            input_queue.task_done()

        await output_queue.put((-1, output_terminator))

    def _as_async_iterable(
        iterable: AsyncIterable[T] | Iterable[T],
    ) -> AsyncIterable[T]:
        if isinstance(iterable, AsyncIterable):
            return iterable

        async def _sync_to_async_iterable() -> AsyncIterable[T]:
            for item in iterable:
                yield item

        return _sync_to_async_iterable()

    async def _feeder() -> None:
        nonlocal last_fed
        async for item in _as_async_iterable(iterable):
            if len(early_results) >= parallelism:
                # There is an item that is taking very long to process. We need
                # to wait for it to finish to avoid blowing up memory.
                await feeding_stop.wait()
                feeding_stop.clear()

            last_fed += 1
            await input_queue.put((last_fed, item))

        for _ in range(parallelism):
            await input_queue.put((-1, input_terminator))

    async def _consumer() -> AsyncIterable[R]:
        nonlocal next_to_yield
        remaining_workers = parallelism
        while remaining_workers:
            index, result = await output_queue.get()
            if result is output_terminator:
                remaining_workers -= 1
                output_queue.task_done()
                continue

            early_results[index] = result
            while next_to_yield in early_results:
                # The feeding lock is set only when the results can be yielded
                # to prevent the early results from growing too much.
                feeding_stop.set()

                yield early_results.pop(next_to_yield)
                next_to_yield += 1
            output_queue.task_done()

    tasks = [
        asyncio.create_task(_worker()) for _ in range(parallelism)
    ] + [asyncio.create_task(_feeder())]

    try:
        async for result in _consumer():
            yield result
    finally:
        for task in tasks:
            task.cancel()

        await asyncio.gather(*tasks, return_exceptions=True)