How do I use an asio::strand in a library that provides both blocking and asynchronous functions

57 Views Asked by At

Im trying to learn asio by writing a library that connects to a specific device over a UDP connection.

When a device turns on it goes into broadcast mode where it broadcasts its identity to the network. So the host machine needs to listen on a udp port for packets that match the device format, and keeps track of devices that have already been detected as the device will continue to broadcast even after connection.

I did some prototyping in python, and an async generator lent itself quite well to this use case. But in asio, the natural analog is the asio::ip::tcp::acceptor.

I am expecting my custom acceptor to be run from within a multi-thread io_context because each device has a high data rate. So I think this means that I need to use a strand for the acceptor as well as each device controller to order the operations. So I have two question on how to use the strand with respect to this bit of example code:

template <typename Executor>
class Acceptor {
    std::unordered_set<asio::ip::udp::endpoint> accepted_connections;
    asio::basic_datagram_socket<asio::ip::udp, Executor> receive_sock;

    std::shared_ptr<Device<Executor>> accept(asio::error_code& ec) {
        std::array<std::byte, buffer_size> buffer;
        typename asio::basic_datagram_socket<asio::ip::udp, Executor>::endpoint_type remote_endpoint;
        receive_sock.receive_from(asio::buffer(buffer), remote_endpoint, {}, ec);
        if (!ec) {
            return {};
        }
        if (accepted_connections.contains(remote_endpoint)) {
            return {};
        }
        if (auto ret = std::start_lifetime_as<Header>(buffer.data())->validate(); !ret) {
            ec = ret.error();
            return {};
        }
        accepted_connections.emplace(remote_endpoint.remote_endpoint());
        return MakeDevice(remote_endpoint, buffer);
    }
};

#1) How do I modify my blocking accept function to run inside the strand so that any access to the accepted_connections set is serialized correctly? From what I can tell, all of the options for execution on the strand (defer, dispatch, execute, post) are allowed to push the function onto the strands queue and execute later. This makes checking the error_code impossible because if the executor runs the contents of this function later, when accept returns immediately the ec variable wont yet be set. Due to the indirection that supports different native sockets, looking at the source code for how socket Acceptor::accept(endpoint& peer_endpoint, error_code& ec) works has not been very informative.

#2) When implementing async_accept I think I need to use async_initiate to make it compatible with other completion tokens. But how do incorporate the strand into that so the accepted_connections set is kept in order? Do I await the strand inside my coroutine, or do I need to do something to my async_initiate to make the coroutine run in the strand.

1

There are 1 best solutions below

4
sehe On

First off:

I've beaten your sample into submission a self-contained example: https://coliru.stacked-crooked.com/a/b70927fe6853d5bf

That out of the way, let's answer your questions:

Question #1

#1) How do I modify my blocking accept function to run inside the strand so that any access to the accepted_connections set is serialized correctly?

Blocking operations should not be run on the service.

Even if so, you would require the user to switch to the appropriate context instead of doing that "transparently" for the user².

[#1...] From what I can tell, all of the options for execution on the strand (defer, dispatch, execute, post) are allowed to push the function onto the strands queue and execute later. This makes checking the error_code impossible because if the executor runs the contents of this function later, when accept returns immediately the ec variable wont yet be set.

Yes. If you must shoehorn the sync accept into a blocking handler¹ the typical pattern would look like (see How can I get a future from boost::asio::post?):

return asio::post(ex, asio::use_future([] {
             // ...
           }))
    .get();

So in your example:

Handle accept(error_code& ec) {
    return asio::post( //
               receive_sock.get_executor(), asio::use_future([this, &ec]() -> Handle {
                   Buffer        buffer;
                   udp::endpoint remote_endpoint;
                   receive_sock.receive_from(asio::buffer(buffer), remote_endpoint, {}, ec);
                   if (ec) {
                       return {};
                   }
                   if (accepted_connections.contains(remote_endpoint)) {
                       return {};
                   }
                   if (auto ret = std::start_lifetime_as<Header>(buffer.data())->validate(); !ret) {
                       ec = ret.error();
                       return {};
                   }
                   accepted_connections.emplace(remote_endpoint);
                   return MakeDevice(remote_endpoint, buffer);
               }))
        .get();
}

I'm not sure what it buys you, but it does address your question.

Question #2

#2) When implementing async_accept I think I need to use async_initiate to make it compatible with other completion tokens.

That's one way to make composed operations, yes: https://www.boost.org/doc/libs/1_84_0/doc/html/boost_asio/examples/cpp20_examples.html#boost_asio.examples.cpp20_examples.operations

[But] how do incorporate the strand into that so the accepted_connections set is kept in order? Do I await the strand inside my coroutine, or do I need to do something to my async_initiate to make the coroutine run in the strand.

The latter. Again, like under #1 the caller is responsible for calling the initiation on the strand, IFF the implementation requires that (it's up to document whether that is required). However, you will at least want the map access to happen on the strand, and also honor the callers executor specification, if any.

To this end Asio service objects have associated executors which will be default if the completion token has no executor associated. Composed operations must use the associated executor for any (intermediate) handlers it invokes.

Note how this is at odds with your goal of forcing specific intermediate step(s) to be on the strand. I think in this case the best you can do is dispatch to resume the intermediate completion to the strand:

asio::dispatch(bind_executor(get_executor(), continuation));

Here's a tested example which exercises different completion token types, including bound ones and c++20 coroutines:

Live On Coliru

#include <boost/asio.hpp>
#include <expected>
#include <iostream>
#include <memory>
#include <unordered_set>
// standalone asio compat
namespace asio = boost::asio;
using boost::system::error_code;

// c++23 compat hack
namespace std {
    template <typename T> decltype(auto) start_lifetime_as(void* p) { return reinterpret_cast<T*>(p); }
} // namespace std
using asio::ip::udp;

static constexpr size_t buffer_size = 65520;
using Buffer                        = std::array<std::byte, buffer_size>;

static std::string AsPayload(Buffer const& b, size_t n) {
    std::string s(reinterpret_cast<char const*>(b.data()), std::min(n, b.size()));
    if (auto EOL = s.find_last_not_of("\n") + 1)
        s.resize(EOL);
    return s;
}

struct Header {
    char                                      raw[24];
    constexpr std::expected<bool, error_code> validate() const { return true; }
};

static_assert(std::is_standard_layout_v<Header>);
static_assert(std::is_trivial_v<Header>);

template <typename Executor> struct Device : std::enable_shared_from_this<Device<Executor>> {
    using Endpoint = udp::endpoint;
    Device(Executor, Endpoint ep, std::string name) : ep_(std::move(ep)), name_(std::move(name)) {}

    Endpoint    endpoint() const { return ep_;   }
    std::string name()     const { return name_; }

  private:
    Endpoint    ep_;
    std::string name_;

    friend std::ostream& operator<<(std::ostream& os, Device const& d) {
        return os << "{Device " << d.endpoint() << ", " << quoted(d.name()) << "}";
    }
};

template <typename Executor> class Acceptor {
  public:
    using executor_type = Executor;
    Executor get_executor() { return receive_sock.get_executor(); }
    Executor get_executor() const { return receive_sock.get_executor(); }

    Acceptor(Executor ex, uint16_t port) : receive_sock(ex, {{}, port}) {}

    using Handle = std::shared_ptr<Device<Executor>>;
    Handle accept(error_code& ec) {
// #define DONT_DO_THIS
#ifdef DONT_DO_THIS
        return asio::post( //
                   receive_sock.get_executor(), asio::use_future([this, &ec]() -> Handle {
#endif
                       Buffer        buffer;
                       udp::endpoint remote_endpoint;
                       auto n = receive_sock.receive_from(asio::buffer(buffer), remote_endpoint, {}, ec);

                       if (ec || accepted_connections.contains(remote_endpoint)) {
                           return {};
                       }
                       if (auto ret = std::start_lifetime_as<Header>(buffer.data())->validate(); !ret) {
                           ec = ret.error();
                           return {};
                       }
                       accepted_connections.emplace(remote_endpoint);
                       return MakeDevice(remote_endpoint, AsPayload(buffer, n));
#ifdef DONT_DO_THIS
                   }))
            .get();
#endif
    }

    using AcceptSig = void(error_code, Handle);

    template <asio::completion_token_for<AcceptSig> Token> //
    auto async_accept(Token&& token) {
        struct StableState {
            Buffer        buffer_;
            size_t        num_bytes_ = {};
            udp::endpoint remote_endpoint_;

            std::string Payload() const { return AsPayload(buffer_, num_bytes_); }
        };

        auto impl = [this, state = std::make_unique<StableState>(), coro = asio::coroutine{}] //
            (auto& self, error_code ec = {}, size_t n = {}) mutable {
                BOOST_ASIO_CORO_REENTER(coro) {
                    BOOST_ASIO_CORO_YIELD receive_sock.async_receive_from(
                        asio::buffer(state->buffer_), state->remote_endpoint_, std::move(self));

                    if (ec)
                        return std::move(self).complete(ec, Handle{});

                    state->num_bytes_ = n;

                    // switch over to strand
                    BOOST_ASIO_CORO_YIELD asio::dispatch(bind_executor(get_executor(), std::move(self)));

                    assert_strand();
                    if (accepted_connections.contains(state->remote_endpoint_))
                        return std::move(self).complete(ec, Handle{});

                    if (auto ret = std::start_lifetime_as<Header>(state->buffer_.data())->validate(); !ret)
                        return std::move(self).complete(error_code{}, Handle{});

                    assert_strand();
                    accepted_connections.emplace(state->remote_endpoint_);

                    //// forcibly leave strand, keeping caller's restrictions
                    // BOOST_ASIO_CORO_YIELD asio::post(std::move(self));

                    return std::move(self).complete(error_code{},
                                                    MakeDevice(state->remote_endpoint_, state->Payload()));
                }
            };
        return asio::async_compose<Token, AcceptSig>(std::move(impl), token, receive_sock);
    }

  private:
    std::unordered_set<udp::endpoint>          accepted_connections;
    asio::basic_datagram_socket<udp, Executor> receive_sock;

    void assert_strand() {
        // HACKY ASSERT, assuming concrete strand executor type
        assert(get_executor().running_in_this_thread());
    }
    auto MakeDevice(udp::endpoint ep, std::string s) {
        return std::make_shared<Device<Executor>>(receive_sock.get_executor(), std::move(ep), std::move(s));
    }
};

int main() {
    asio::thread_pool ioc;

    auto A_strand     = make_strand(ioc);
    auto Other_strand = make_strand(ioc);
    assert(A_strand != Other_strand);

    auto report = [&](std::string caption, auto h, error_code const& ec = {}) {
        static std::atomic_int tid_gen {};
        thread_local int const tid = tid_gen++;
        std::cout << "#" << tid << " " << std::setw(23) << std::left << (caption + ":");
        if (h)
            std::cout << *h;
        else
            std::cout << "NULL";

        std::cout << " (" << ec.message() << ")";

        if (A_strand.running_in_this_thread())
            std::cout << " [on A_strand]" ;
        if (Other_strand.running_in_this_thread())
            std::cout << " [on Other_strand]" ;
        std::cout << std::endl;
    };

    Acceptor a(A_strand, 7878);
    using Handle = decltype(a)::Handle;

    {   // synchronous
        error_code ec;
        report("Synchronous", a.accept(ec), ec);
    }

    {
        // use_future
        std::future<Handle> dev = a.async_accept(asio::use_future);
        try {
            report("use_future", dev.get());
        } catch (boost::system::system_error const& se) {
            report("use_future", Handle{}, se.code());
        }
    }

    // callback
    auto callback = [report](std::string caption) {
        return [=](error_code ec, Handle h) { report(caption + " callback", h, ec); };
    };

    a.async_accept(callback("plain"));
    a.async_accept(bind_executor(ioc.get_executor(), callback("bound")));
    a.async_accept(bind_executor(Other_strand, callback("Other_strand")));
    a.async_accept(bind_executor(A_strand, callback("A_strand")));

    // c++20 coro
    auto coro = [&a, report](auto caption) -> asio::awaitable<void> {
        try {
            report(caption, co_await a.async_accept(asio::deferred));
        } catch (boost::system::system_error const& se) {
            report(caption, Handle{}, se.code());
        }
    };

    co_spawn(A_strand,     coro("A_strand coro"),     asio::detached);
    co_spawn(Other_strand, coro("Other_strand coro"), asio::detached);
    co_spawn(ioc,          coro("plain coro"),        asio::detached);

    ioc.join();
}

Online output:

g++ -std=c++23 -O2 -Wall -pedantic -pthread main.cpp
./a.out &
sleep 2;
for a in foo{1..7}; do sleep .5; nc -u -w0 127.0.0.1 7878 <<<"$a"; done
#0 Synchronous:           {Device 127.0.0.1:36447, "foo1"} (Success)
#0 use_future:            {Device 127.0.0.1:51578, "foo2"} (Success)
#1 plain callback:        {Device 127.0.0.1:44773, "foo3"} (Success) [on A_strand]
#2 bound callback:        {Device 127.0.0.1:36262, "foo4"} (Success) [on A_strand]
#3 Other_strand callback: {Device 127.0.0.1:52930, "foo5"} (Success) [on A_strand] [on Other_strand]
#4 A_strand callback:     {Device 127.0.0.1:58049, "foo6"} (Success) [on A_strand]
#5 plain coro:            {Device 127.0.0.1:46936, "foo7"} (Success) [on A_strand]

Locally, with a steady stream of faux devices:

grep -v "'" /etc/dictionaries-common/words | sort -R | while read w
do
    sleep .5
    (set -x ; nc -u -s 127.0.0.$((1 + $RANDOM % 254)) localhost 7878 -w 0 <<< "$w")
done

We get:

enter image description here


¹ (hint: don't; it will easily cause your threads to soft lock)

² The latter more closely approximates Active Object Pattern (see e.g. boost::asio and Active Object for inspiration)