How to force an Async Context Manager to Exit

214 Views Asked by At

I've been getting into Structured Concurrency recently and this is a pattern that keeps cropping up:

It's nice to use async context managers to access resource - say a some websocket. That's all great if the websocket stays open, but what about if it closes? well we expect our context to be forcefully exited - normally through an exception.

How can I write and implement a context manager that exhibits this behaviour? How can I throw an exception 'into' the calling codes open context? How can I forcefully exit a context?

Here's a simple setup, just for argument's sake:

# Let's pretend I'm implementing this:
class SomeServiceContextManager:
    def __init__(self, service):
        self.service = service

    async def __aenter__(self):
        await self.service.connect(self.connection_state_callback)
        return self.service

    async def __aexit__(self, exc_type, exc, tb):
        self.service.disconnect()
        return False

    def connection_state_callback(self, state):
        if state == "connection lost":
            print("WHAT DO I DO HERE? how do I inform my consumer and force the exit of their context manager?")

class Consumer:
    async def send_stuff(self):
        try:
            async with SomeServiceContextManager(self.service) as connected_service:
                while True:
                    await asyncio.sleep(1)
                    connected_service.send("hello")
        except ConnectionLostException: #<< how do I implement this from the ContextManager?
            print("Oh no my connection was lost!!")

How is this generally handled? It seems to be something I've run up into a couple of times when writing ContextManagers!

Here's a slightly more interesting example (hopefully) to demonstrate how things get a bit messy - say you are receiving through an async loop but want to close your connection if something downstream disconnects:

# Let's pretend I'm implementing this:
class SomeServiceContextManager:
    def __init__(self, service):
        self.service = service

    async def __aenter__(self):
        await self.service.connect(self.connection_state_callback)
        return self.service

    async def __aexit__(self, exc_type, exc, tb):
        self.service.disconnect()
        return False

    def connection_state_callback(self, state):
        if state == "connection lost":
            print("WHAT DO I DO HERE? how do I inform my consumer and force the exit of their context manager?")

class Consumer:
    async def translate_stuff_stuff(self):
        async with SomeOtherServiceContextManager(self.otherservice) as connected_other_service:
            try:
                async with SomeServiceContextManager(self.service) as connected_service:
                    for message in connected_other_service.messages():
                        connected_service.send("message received: " + message.text)
            except ConnectionLostException: #<< how do I implement this from the ContextManager?
                print("Oh no my connection was lost - I'll also drop out of the other service connection!!")
3

There are 3 best solutions below

5
On BEST ANSWER

Before we get started, let's replace manual __aenter__() and __aexit__() implementations with contextlib.asynccontextmanager. This takes care of handling exceptions properly and is especially useful when you have nested context managers, as we're going to have in this answer. Here's your snippet rewritten in this way.

from contextlib import asynccontextmanager

class SomeServiceConnection:
    def __init__(self, service):
        self.service = service

    async def _connect(self, connection_state_callback):
        await self.service.connect(connection_state_callback)

    async def _disconnect(self):
        self.service.disconnect()

@asynccontextmanager
async def get_service_connection(service):
    connection = SomeServiceConnection(service)
    await connection._connect(
        ...  # to be done
    )
    try:
        yield connection
    finally:
        await connection._disconnect()

OK, with that out of the way: The core of the answer here is that, if you want to stop running tasks in response to some event, then use a cancel scope.

@asynccontextmanager
async def get_service_connection(service):
    connection = SomeServiceConnection(service)
    with trio.CancelScope() as cancel_scope:
        await connection._connect(cancel_scope.cancel)
        try:
            yield connection
        finally:
            await connection._disconnect()
            if cancel_scope.called:
                raise RuntimeError("connection lost")

But wait... what if some other exception (or exceptions!) were thrown at roughly the same time that the connection was closed? That would be lost when you raise your own exception. This is handily dealt with by using a nursery instead. This has its own cancel scope doing the cancellation work, but it also deals with creating ExceptionGroup objects (formerly known as MultiErrors). Now your callback just needs to raise an exception inside the nursery. As a bonus, there is a good chance you needed to run a background task to make the callback happen anyway. (If not, e.g., your callback is called from another thread via a Trio token, then use a trio.Event as another answer suggested, and await it from within the nursery.)

async def check_connection(connection):
    await connection.wait_disconnected()
    raise RuntimeError("connection lost")

@asynccontextmanager
async def get_service_connection(service):
    connection = SomeServiceConnection(service)
    await connection._connect()
    try:
        async with trio.open_nursery() as nursery:
            nursery.start_soon(check_connection)
            yield connection
            nursery.cancel_scope.cancel()
    finally:
        await connection._disconnect()
1
On

You could use an asyncio.Event to communicate the disconnect from SomeServiceContextManager to Consumer.

Then instead of waiting for 1 second in each iteration of the loop by using asyncio.sleep, you wait for the disconnect event for up to 1 second and if it does not occur you continue as before. This can be achieved e.g. by using asyncio.wait_for.

class SomeServiceContextManager:
    def __init__(self, service):
        # … as before
        self._disconnect = asyncio.Event()  # added

    async def __aenter__(self):
        # in case the context manager is reusable and the event was set before
        self._disconnect.clear()  # added
        # … as before

    async def __aexit__(self, exc_type, exc, tb):
        # … as before

    def connection_state_callback(self, state):
        if state == "connection lost":
            self._disconnect.set()  # new

    async def wait_for_disconnect(self, timeout):  # added
        try:
            await asyncio.wait_for(self._disconnect.wait(), timeout)
        except asyncio.TimeoutError:  # just TimeoutError in newer Python versions
            pass  # OK, no disconnect within timeout
        else:
            raise ConnectionLostException

class Consumer:
    async def send_stuff(self):
        try:
            async with SomeServiceContextManager(self.service) as connected_service:
                while True:
                    # changed
                    await connected_service.wait_for_disconnect(timeout=1.0)
                    connected_service.send("hello")
        except ConnectionLostException:
            print("Oh no my connection was lost!!")
4
On

Why do you need to do this in the first place? Some sort of ConnectionLost exception will be raised by the current and/or next read or write call on the connection. (Alternately you'll get an empty reply, which you should use to raise an EOFError.)

There is no need for state-change callbacks. IMHO they're a design mistake. In fact, most of asyncio's connection handling is a mess, for historic reasons; you should consider using the anyio wrapper, which incidentally does the EOFError thing for you. Additional advantage: your code will run with Trio as well.