boost::beast reconnect websocket after close

508 Views Asked by At

I'm messing around with the beast async websocket client and want it to re-connect in the same session instance if it becomes disconnected. The issue is that everything works until I attempt the websocket handshake a second time following disconnect - the tcp connection works, the ssl handshake works (more on that in a bit) , but the websocket handshake fails with 'Error: unexpected message (SSL routines, ssl3_read_bytes)'. I'm not sure what's wrong with re-using the ws_. From this post it seems there may not be a good answer?

so far have modified websocket_client_async.cpp it as follows:

instead of the write then read then quit as is done in the example, I'd like the code to simulate pulling messages from a queue and sending them out via the websocket connection. The websocket connection should be secure, so an ssl_stream is used.

I got around this ws_ re-use issue by destroying the session if it becomes disconnected once a websocket connection has established, and then creating a new instance for the next connection attempt. This approach seems kind of ham-fisted and I'm hoping it's possible to re-connect in the same session.

I've hacked up websocket_client_async.cpp - apologies as it's a fair amount of code, but I don't see what could be left out without confusing the issue


// includes ...

namespace beast = boost::beast;         // from <boost/beast.hpp>
namespace http = beast::http;           // from <boost/beast/http.hpp>
namespace websocket = beast::websocket; // from <boost/beast/websocket.hpp>
namespace net = boost::asio;            // from <boost/asio.hpp>
namespace ssl = boost::asio::ssl;

using tcp = boost::asio::ip::tcp;       // from <boost/asio/ip/tcp.hpp>

namespace {
    std::atomic<bool> running_;
    std::atomic<bool> ws_connect_broken;
    std::atomic<bool> fill_q_done;

    std::queue<std::string> outgoingMessages;

    std::mutex mtx;
}

namespace detail {
    constexpr int64_t CONNECT_RETRY_INTERVAL_SEC = 2;
}

//------------------------------------------------------------------------------

// Report a failure
void
fail(beast::error_code ec, char const* what)
{
    std::cerr << what << ": " << ec.message() << "\n";
}

// Sends a WebSocket message and prints the response
class session : public std::enable_shared_from_this<session>
{
    net::strand< net::io_context::executor_type >& ex_;
    tcp::resolver resolver_;
    websocket::stream<beast::ssl_stream<beast::tcp_stream>> ws_;
    std::string api_key_{ "123456789012345678901234567890123456" };
    beast::flat_buffer buffer_;
    std::string host_;
    std::string text_;
    net::steady_timer& qtmr_;
    net::steady_timer& wtmr_;
    net::steady_timer& ctmr_;
    tcp::resolver::results_type results_;
    bool ws_initialized_ = false;
    bool retry_conn_ = false;
    bool valid_ws_conn_ = false;

public:
    // Resolver and socket require an io_context
    explicit
        session(net::strand< net::io_context::executor_type >& ex, ssl::context& ctx, net::steady_timer& qtmr,
            net::steady_timer& wtmr, net::steady_timer& ctmr)
        : ex_(ex)
        , resolver_(ex_)
        , ws_(ex_, ctx)
        , qtmr_(qtmr)
        , wtmr_(wtmr)
        , ctmr_(ctmr)
    {

        qtmr_.expires_after(std::chrono::seconds(1));
        qtmr_.async_wait(std::bind(&session::on_check_queue, this, std::placeholders::_1));

    }

    // Start the asynchronous operation
    void
        run(
            char const* host,
            char const* port)
    {
        // Save these for later
        host_ = host;

        // Set SNI Hostname (many hosts need this to handshake successfully)
        if (!SSL_set_tlsext_host_name(ws_.next_layer().native_handle(), host))
        {
            beast::error_code ec{ static_cast<int>(::ERR_get_error()), net::error::get_ssl_category() };
            std::cerr << ec.message() << "\n";
            return;
        }

        // Look up the domain name
        resolver_.async_resolve(host, port, beast::bind_front_handler(&session::on_resolve, shared_from_this()));
    }

    void
        on_resolve(beast::error_code ec, tcp::resolver::results_type results)
    {
        if (ec) {
            return fail(ec, "resolve");
        }

        results_ = results;

        boost::system::error_code ec2;
        on_try_connect(ec2);
    }

    void
        on_try_connect(const boost::system::error_code& ec)
    {
        if (ec != boost::asio::error::operation_aborted && running_ && !ws_connect_broken) {
            std::cout << "on_try_connect() trying tcp connection to " << host_ << std::endl;    

            // Set the timeout for the operation
            beast::get_lowest_layer(ws_).expires_after(std::chrono::seconds(5));

            // Make the connection on the IP address we get from a lookup
            beast::get_lowest_layer(ws_).async_connect(
                results_, beast::bind_front_handler(&session::on_tcp_connect, shared_from_this()));
        }
        else {
            std::cout << "on_try_connect() closing" << std::endl;   
            ws_.async_close(websocket::close_code::normal,
                beast::bind_front_handler(&session::on_close, shared_from_this()));
            return;
        }
    }

    void handle_error_retry(beast::error_code ec, const char* str)
    {
        if (!ec) { return; }

        retry_conn_ = ec != boost::asio::error::operation_aborted && running_;

        std::cout << str << " closing" << std::endl;    
        ws_.async_close(websocket::close_code::normal,
            beast::bind_front_handler(&session::on_close, shared_from_this()));
    }

    void
        on_tcp_connect(beast::error_code ec, tcp::resolver::results_type::endpoint_type ep)
    {
        if (ec) {
            return handle_error_retry(ec, "on_tcp_connect");
        }

        // Turn off the timeout on the tcp_stream, because
        // the websocket stream has its own timeout system.
        beast::get_lowest_layer(ws_).expires_never();

        if (!ws_initialized_) {
            ws_initialized_ = true;

            // Set suggested timeout settings for the websocket
            ws_.set_option(websocket::stream_base::timeout::suggested(beast::role_type::client));

            // Set a decorator to change the User-Agent of the handshake
            ws_.set_option(websocket::stream_base::decorator(
                [this](websocket::request_type& req)
                {
                    req.set(http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client-async");
                    req.set("X-Api-Key", api_key_);
                }));

            ws_.binary(true);

            // Update the host_ string. This will provide the value of the
            // Host HTTP header during the WebSocket handshake.
            // See https://tools.ietf.org/html/rfc7230#section-5.4
            host_ += ':' + std::to_string(ep.port());
        }

        // Perform the ssl handshake
        ws_.next_layer().async_handshake(ssl::stream_base::client, beast::bind_front_handler(&session::on_ssl_handshake, shared_from_this()));
    }

    void
        on_ssl_handshake(beast::error_code ec)
    {
        if (ec) {
            return handle_error_retry(ec, "on_ssl_handshake");
        }

        std::cout << "ssl handshake succeeded" << std::endl;    

        // Perform the websocket handshake
        ws_.async_handshake(host_, "/", beast::bind_front_handler(&session::on_ws_handshake, shared_from_this()));
    }

    void
        on_ws_handshake(beast::error_code ec)
    {
        if (ec) {
            return handle_error_retry(ec, "on_ws_handshake");
        }

        std::cout << "ws handshake succeeded" << std::endl; 
        valid_ws_conn_ = true;

        wtmr_.expires_after(std::chrono::seconds(0));
        wtmr_.async_wait(std::bind(&session::do_write, this, std::placeholders::_1));
    }

    void
        do_write(const boost::system::error_code& ec)
    {
        if (ec) {
            return handle_error_retry(ec, "do_write");
        }

        std::string msg;
        bool has_msg = false;
        {
            std::lock_guard<std::mutex> lock(mtx);
            if (!outgoingMessages.empty()) {
                msg = outgoingMessages.front();
                outgoingMessages.pop();
                has_msg = true;
            }
        }
        if (has_msg) {
            ws_.async_write(net::buffer(msg), boost::beast::bind_front_handler(&session::on_write, shared_from_this()));
            std::cout << "wrote \"" << msg << "\"" << std::endl;    
        }
        else
        {
            // Repeat write
            if (running_)
            {
                //ex_.get_inner_executor().context().poll();

                wtmr_.expires_after(std::chrono::seconds(1));
                wtmr_.async_wait(std::bind(&session::do_write, this, std::placeholders::_1));
            }
            else {
                std::cout << "do_write() closing" << std::endl; 
                ws_.async_close(websocket::close_code::normal,
                    beast::bind_front_handler(&session::on_close, shared_from_this()));
            }
        }
    }

    void
        on_write(beast::error_code ec, std::size_t bytes_transferred)
    {
        boost::ignore_unused(bytes_transferred);

        if (ec) {
            return handle_error_retry(ec, "on_write");
        }

        // see if anything more in the queue to write
        wtmr_.expires_after(std::chrono::milliseconds(500));
        wtmr_.async_wait(std::bind(&session::do_write, this, std::placeholders::_1));
    }

    void
        on_check_queue(const boost::system::error_code& ec)
    {
        if (ec)
            return fail(ec, "on_check_queue");

        if (running_.load()) {
            std::cout << "on_check_queue" << std::endl; 

            // TODO do some stats, check if queue is backing up, etc

            qtmr_.expires_after(std::chrono::seconds(1));
            qtmr_.async_wait(std::bind(&session::on_check_queue, this, std::placeholders::_1));
        }
    }

    void
        on_close(beast::error_code ec)
    {
        if (ec) {
            std::cout << "on_close: " << ec.message() << std::endl;
        }
        else {
            std::cout << "on_close: no error"  << std::endl;
        }

        if (valid_ws_conn_) {
            ws_connect_broken = true;
=====> was trying async call to try_connect here, but ws handshake fails
//ctmr_.expires_after(std::chrono::seconds(detail::CONNECT_RETRY_INTERVAL_SEC));
//ctmr_.async_wait(std::bind(&session::on_try_connect, this, std::placeholders::_1));
            return;
        }

        if (retry_conn_) {
            ctmr_.expires_after(std::chrono::seconds(detail::CONNECT_RETRY_INTERVAL_SEC));
            ctmr_.async_wait(std::bind(&session::on_try_connect, this, std::placeholders::_1));
        }
    }

};

//------------------------------------------------------------------------------

void start(net::io_context& ioc, const std::string& host, const std::string& port)
{
    // The io_context is required for all I/O

    boost::asio::steady_timer           ctmr_{ ioc };   // connect timer
    boost::asio::steady_timer           wtmr_{ ioc };   // write timer
    boost::asio::steady_timer           qtmr_{ ioc };   // queue check timer
    boost::asio::ssl::context           ssl_ctx_{ boost::asio::ssl::context::tlsv12_client };

    // The SSL context is required, and holds certificates
    ssl::context ctx{ ssl::context::tlsv12_client };
    ctx.set_verify_mode(ssl::verify_none);

    // Launch the asynchronous operation
    auto strd = net::make_strand(ioc);
    auto s = std::make_shared<session>(strd, ctx, qtmr_, wtmr_, ctmr_);
    s->run(host.c_str(), port.c_str());

    // Run the I/O service. The call will return when
    // the socket is closed.
    ioc.run();
}

void fill_q()
{

    auto sleep_push = [&](int64_t t, std::string&& str) {
        std::this_thread::sleep_for(std::chrono::seconds(t));
        std::lock_guard<std::mutex> lock(mtx);
        outgoingMessages.push(str);
    };

    // push some fake messages on a queue periodically
    std::vector<std::string> msgs = { "yo", "ho",  "ho",  "and",  "a",  "bottle",  "of",  "rum",
        "navy", "grog", "for", "me"};
    for (auto& msg : msgs) {
        sleep_push(3, std::move(msg));
        if (!running_) { return; }
    }
    fill_q_done = true;

}

int main(int argc, char** argv)
{
    // Check command line arguments.
    if (argc != 3)
    {
        std::cerr <<
            "Usage: websocket-client-async <host> <port>\n" <<
            "Example:\n" << "    websocket-client-async echo.websocket.org 443\n";
        return EXIT_FAILURE;
    }
    auto const host = argv[1];
    auto const port = argv[2];

    net::io_context io_;

    using wgt = net::executor_work_guard<net::io_context::executor_type>;
    std::unique_ptr<wgt> work_guard_;
    std::unique_ptr<std::thread> thrd, thrd_q;
    
    outgoingMessages.push("hello"); // fake queue message

    auto destroy_conn = [&]() {
        // destroy existing connection
        std::cout << "destroy existing connection" << std::endl;    

        work_guard_.reset();
        running_ = false;
        thrd_q->join();
        thrd_q.reset();
        thrd->join();
        thrd.reset();
        io_.stop();
    };

    ws_connect_broken = true;       // true forces creation of initial session
    do {

        if (ws_connect_broken) {
            if (thrd) {
                destroy_conn();
            }

            // create new connection, set initial state and start threads
            std::cout << "create new connection" << std::endl;  

            io_.restart();
            work_guard_ = std::make_unique<wgt>(io_.get_executor());
            running_ = true;
            fill_q_done = false;
            ws_connect_broken = false;
            thrd_q.reset(new std::thread([&]() { fill_q(); }));
            thrd.reset(new std::thread([&]() { start(io_, host, port); }));
        }

        std::this_thread::sleep_for(std::chrono::seconds(1));
        if (!running_ || fill_q_done) { break; }

    } while (true);

    if (thrd) {
        destroy_conn();
    }

    std::cout << "done" << std::endl;   

    return EXIT_SUCCESS;
}

0

There are 0 best solutions below