proper use of contextvars with asyncio.create_server

67 Views Asked by At

The asyncio.create_server listens on the specified address:port and calls the methods of the supplied callable that implements BaseProtocol. Everything seems fine with the examples provided like this:

import asyncio


# Define a protocol class that inherits from asyncio.Protocol
class EchoProtocol(asyncio.Protocol):
    # This method is called when a new client connection is established
    def connection_made(self, transport):
        # Save a reference to the transport object
        self.transport = transport
        # Get the peer name of the client
        peername = transport.get_extra_info("peername")
        # Print a message
        print(f"Connection from {peername}")

    # This method is called when data is received from the client
    def data_received(self, data):
        # Decode the data from bytes to string
        message = data.decode()
        # Print a message
        print(f"Data received: {message}")
        # Send back the same data to the client
        self.transport.write(data)
        # Print a message
        print(f"Data sent: {message}")

    # This method is called when the client connection is closed
    def connection_lost(self, exc):
        # Print a message
        print("Connection closed")
        # Close the transport
        self.transport.close()


# Define an asynchronous function that creates and runs the server
async def main():
    # Get the current event loop
    loop = asyncio.get_running_loop()
    # Create a TCP server using the loop and the protocol class
    server = await loop.create_server(EchoProtocol, "127.0.0.1", 8888)
    # Get the server address and port
    addr = server.sockets[0].getsockname()
    # Print a message
    print(f"Serving on {addr}")
    # Run the server until it is stopped
    async with server:
        await server.serve_forever()


# Run the main function using asyncio.run()
asyncio.run(main())

The processing in this example is very simple and one doesn't need to keep track of the state of the objects between the invocations of data_received method.

However, in the real app I'm trying to write I do need to keep the state of the dialog with the connected client. I double checked and found that the same asyncio.Protocol object is used for every new connecting client - so I can't put my state there.

What is the right way of keeping the data per connected client? Say, I want to output the sequential number with every reply, counting them in the example above. But I want to count for every connected client separately, from 0 onwards without any misses.

I tried to use contextvars module, but apparently I don't understand the way it is connected with asyncio, create_server etc.

I can create a contextvars.ContextVar and assign it to some global var - but how this is different from using the global var directly?

I searched high and low for the concept of keeping the state of the server session in asyncio but could find nothing that I could understand and use.

I tried to implement something similar to this question

import asyncio
import contextvars
test = contextvars.ContextVar('test')


# Define a protocol class that inherits from asyncio.Protocol
class EchoProtocol(asyncio.Protocol):
    count = contextvars.ContextVar('count')
    # This method is called when a new client connection is established
    def connection_made(self, transport):
        # Save a reference to the transport object
        self.transport = transport
        # Get the peer name of the client
        peername = transport.get_extra_info("peername")
        # Print a message
        self.count.set(0)
        print(f"Connection from {peername}")
        print(self)

    # This method is called when data is received from the client
    def data_received(self, data):
        # Decode the data from bytes to string
        message = data.decode()
        # Print a message
        print(f"Data received: {message}")
        # Send back the same data to the client
        self.transport.write(data)
        # Print a message
        print(f"Data sent:{self.count.get()}: {message}")
        self.count.set(self.count.get()+1)

    # This method is called when the client connection is closed
    def connection_lost(self, exc):
        # Print a message
        print("Connection closed")
        # Close the transport
        self.transport.close()


# Define an asynchronous function that creates and runs the server
async def main():
    # Get the current event loop
    loop = asyncio.get_running_loop()
    # Create a TCP server using the loop and the protocol class
    server = await loop.create_server(EchoProtocol, "127.0.0.1", 8888)
    # Get the server address and port
    addr = server.sockets[0].getsockname()

    test.set("this is a check of context var propagation")
    # Print a message
    print(f"Serving on {addr}")
    # Run the server until it is stopped
    async with server:
        await server.serve_forever()


# Run the main function using asyncio.run()
asyncio.run(main(), debug=True)

however this snippet doesn't work:

Fatal error: protocol.data_received() call failed.
handle_traceback: Handle created at (most recent call last):
  File "/home/anton/.config/JetBrains/PyCharmCE2023.2/scratches/scratch.py", line 61, in <module>
    asyncio.run(main(), debug=True)
  File "/usr/lib/python3.12/asyncio/runners.py", line 194, in run
    return runner.run(main)
  File "/usr/lib/python3.12/asyncio/runners.py", line 118, in run
    return self._loop.run_until_complete(task)
  File "/usr/lib/python3.12/asyncio/base_events.py", line 651, in run_until_complete
    self.run_forever()
  File "/usr/lib/python3.12/asyncio/base_events.py", line 618, in run_forever
    self._run_once()
  File "/usr/lib/python3.12/asyncio/base_events.py", line 1943, in _run_once
    handle._run()
  File "/usr/lib/python3.12/asyncio/events.py", line 84, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/lib/python3.12/asyncio/selector_events.py", line 924, in _add_reader
    self._loop._add_reader(fd, callback, *args)
  File "/usr/lib/python3.12/asyncio/selector_events.py", line 276, in _add_reader
    handle = events.Handle(callback, args, self, None)
protocol: <__main__.EchoProtocol object at 0x7fe46a0a76b0>
transport: <_SelectorSocketTransport fd=7 read=polling write=<idle, bufsize=0>>
Traceback (most recent call last):
  File "/usr/lib/python3.12/asyncio/selector_events.py", line 1023, in _read_ready__data_received
    self._protocol.data_received(data)
  File "/home/anton/.config/JetBrains/PyCharmCE2023.2/scratches/scratch.py", line 32, in data_received
    print(f"Data sent:{self.count.get()}: {message}")
                       ^^^^^^^^^^^^^^^^
LookupError: <ContextVar name='count' at 0x7fe46a3993a0>
1

There are 1 best solutions below

0
Simon Kocurek On

I can create a contextvars.ContextVar and assign it to some global var but how this is different from using the global var directly?

ContextVar is a low level construct.

Just like thread.local() creates a seemingly global variable... the contents will be different based on which thread accesses it.

Contextvars are similar in that regard. Except, instead of having different contents based on thread, they have different content based on asynchronous context and executing thread.

So to answer your question. While it seems like a global variable, each asynchronous context will have its own content there. For this reason, it is ideal for storing things like request_id.

It's all explained also in the official contextvars documentation.

If you are familiar with NodeJS, they are basically Python alternative to AsyncLocalStorage .