Why does step-wise async iteration in separate tasks behave differently than without tasks?

62 Views Asked by At

I just isolated a strange behavior I've observed, but now I'm out of ideas.

My script asynchronously iterates over the stdout stream of a process which has been spawned via ssh (asyncssh) in a kind of complex way. And now sometimes that stream gets truncated, i.e. I get StopAsyncIteration when I expect more to come.

Now I found the culprit - the following lines read from the stream line by line with separate tasks (in order to be able to implement a timeout, but I removed the according code):

async def add_next(generator, collector) -> bool:
    with suppress(StopAsyncIteration):
        collector.append(await anext(generator))
        return True
    return False

async def thrtl(generator):
    bucket = []
    while True:
        if not await asyncio.create_task(add_next(generator, bucket)):
            break
    yield bucket

(please note this code does not make much sense - it just reproduces my observation)

The code above will truncate generator, because awaiting the next element takes place in separate tasks. If I only await add_next everything works as expected:

async def thrtl(generator):
    bucket = []
    while True:
        if not await add_next(generator, bucket):
            break
    yield bucket

This problem does not occur with every async iterable, e.g. streams created with create_subprocess_exec work fine. If I connect to a remote machine with slow internet I even get more lines, and the log shows

INFO:asyncssh:[conn=0, chan=1] Received exit status 0
INFO:asyncssh:[conn=0, chan=1] Received channel close
INFO:asyncssh:[conn=0, chan=1] Channel closed

directly before I get the StopAsyncIteration. So it looks like reading from the stream after the process has terminated seems to be part of the problem.

I could accept this and try to read from the stream differently, but why does it work without the tasks being created?

Here is a full script for reference - please note that the effects are better observable with a remote connection (rather than connecting to 'localhost'):

import asyncio, asyncssh, logging
from contextlib import suppress

async def add_next(generator, collector) -> bool:
    with suppress(StopAsyncIteration):
        collector.append(await anext(generator))
        return True
    return False

async def thrtl(generator):
    bucket = []
    while True:
        # good:
        #if not await add_next(generator, bucket):
        # bad:
        if not await asyncio.create_task(add_next(generator, bucket)):
            break
    yield bucket

async def streamhandler(stream):
    result = []
    async for rawlines in thrtl(aiter(stream)):
        result += rawlines
    return result
    
async def main():
    ssh_connection = await asyncssh.connect("localhost")

    while True:
        ssh_process = await ssh_connection.create_process("df -P")
        stdout, stderr, completed = await asyncio.gather(
            streamhandler(ssh_process.stdout),
            streamhandler(ssh_process.stderr),
            asyncio.ensure_future(ssh_process.wait()),
        )
        print(stdout)
        print(stderr)
        await asyncio.sleep(2)

logging.basicConfig(level=logging.DEBUG)
asyncio.run(main())
1

There are 1 best solutions below

0
On

As jsbueno pointed out this behavior can be avoided by having only one task (async context) which actually reads from generator and use a asyncio.Queue to access read items from different tasks/contexts. One caveat of this approach is exception handling, especially StopAsyncIteration since reading from Queue will never let you know your generator has been eaten up already. Still hoping there is a more straight forward approach, here is how I did it:

async def thrtl(generator):
    async def iterate(generator, queue) -> None:
        while True:
            try:
                queue.put_nowait(await anext(generator))
            except Exception as exc:
                queue.put_nowait(exc)
                break

    async def add_next(queue, collector) -> None:
        elem = await queue.get()
        if isinstance(elem, Exception):
            raise elem
        collector.append(elem)

    bucket = []
    while True:
        try:
            await asyncio.create_task(add_next(generator, bucket)):
        except StopAsyncIteration:
            break

    yield bucket

The full (more meaningful but also more verbose) function with type hints looks like this now:

async def collect_chunks(
    generator: AsyncIterator[T],
    *,
    postpone: bool = False,
    min_interval: float = 2,
    bucket_size: int = 0,
) -> AsyncIterator[Sequence[T]]:
    """Collect elements read from @generator and wait for a given condition before yielding them
    in chunks.
    Condition is met only after @min_interval seconds have passed since
    [1] first element received since last bucket if @postpone is set to False or
    [2] since last received element if @postpone is set to True.
    If @bucket_size > 0 the chunk will be returned immediately regardless of @postpone if the
    number of collected elements has reached @bucket_size.
    """

    async def iterate(generator: AsyncIterator[T], queue: asyncio.Queue[T | Exception]) -> None:
        """Writes elements read from @generator to @queue in order to not access @generator
        from more than one context
        see https://stackoverflow.com/questions/77245398"""
        while True:
            try:
                queue.put_nowait(await anext(generator))
            except Exception as exc:  # pylint: disable=broad-except
                queue.put_nowait(exc)
                break

    async def add_next(queue: asyncio.Queue[T | Exception], collector: MutableSequence[T]) -> None:
        """Reads one element from @queue and puts it into @collector. Together with `iterate`
        this gives us an awaitable read-only-one-element-with-timeout semantic"""
        elem = await queue.get()
        if isinstance(elem, Exception):
            raise elem
        collector.append(elem)

    event_tunnel: asyncio.Queue[T | Exception] = asyncio.Queue()
    collected_events: MutableSequence[T] = []
    fuse_task = None
    tasks = {
        asyncio.create_task(add_next(event_tunnel, collected_events), name="add_next"),
        asyncio.create_task(iterate(generator, event_tunnel), name="iterate"),
    }

    with suppress(asyncio.CancelledError):
        while True:
            finished, tasks = await asyncio.wait(fs=tasks, return_when=asyncio.FIRST_COMPLETED)

            for finished_task in finished:
                if (event_name := finished_task.get_name()) == "add_next":
                    # in case we're postponing we 'reset' the timeout fuse by removing it
                    if postpone and fuse_task:
                        tasks.remove(fuse_task)
                        fuse_task.cancel()
                        with suppress(asyncio.CancelledError):
                            await fuse_task
                        del fuse_task
                        fuse_task = None

                    if (exception := finished_task.exception()) or (
                        bucket_size and len(collected_events) >= bucket_size
                    ):
                        if collected_events:
                            yield collected_events
                            collected_events.clear()
                        if exception:
                            if isinstance(exception, StopAsyncIteration):
                                return
                            raise exception

                    tasks.add(
                        asyncio.create_task(
                            add_next(event_tunnel, collected_events), name="add_next"
                        )
                    )
                elif event_name == "fuse":
                    if collected_events:
                        yield collected_events
                        collected_events.clear()
                    del fuse_task
                    fuse_task = None
                else:
                    assert event_name == "iterate"

            # we've had a new event - start the timeout fuse
            if not fuse_task and min_interval > 0:
                tasks.add(
                    fuse_task := asyncio.create_task(asyncio.sleep(min_interval), name="fuse")
                )

Yet, I haven't read anywhere I shouldn't access an async generator from more than one context. And taking into account how exceptions have to be smuggled the way I did, I consider it a (design) bug.

hth