Execute middleware before and after request in Rust warp

729 Views Asked by At

I would like to track in-flight connections in warp such that a metrics counter is incremented before the request is handled and decremented after it was processed.

I attempted to solve this by using a "no-op" filter in the start of the chain and a custom logging filter in the end of the chain; something like that:

/// Increment the request count metric before the requests starts.
fn with_start_call_metrics() -> impl Filter<Extract = (), Error = Infallible> + Clone {
    warp::any()
        .and(path::full())
        .map(|path: FullPath| {
            HttpMetrics::inc_in_flight(path.as_str());
        })
        .untuple_one()
}

/// Decrement the request count metric after the request ended.
fn with_end_call_metrics() -> Log<fn(Info<'_>)> {
    warp::log::custom(|info| {
        HttpMetrics::dec_in_flight(info.path());
        // ... track more metrics, e.g. info.elapsed() ...
    })
}

The problem arises when a long-running request (/slow in the code below) is started and the connection is dropped before the request could be processed completely (e.g. CTRL-C on curl).

In this case, the slow route is simply aborted by warp and the with_end_call_metrics filter below is never reached:

#[tokio::main]
async fn main() {
    let hello = warp::path!("hello" / String).and_then(hello);
    let slow = warp::path!("slow").and_then(slow);

    warp::serve(
        with_start_call_metrics()
            .and(
                hello.or(slow), // ... and more ...
            )
            // If the call (e.g. of `slow`) is cancelled, this is never reached.
            .with(with_end_call_metrics()),
    )
    .run(([127, 0, 0, 1], 8080))
    .await;
}

async fn hello(name: String) -> Result<impl warp::Reply, warp::Rejection> {
    Ok(format!("Hello, {}!", name))
}

async fn slow() -> Result<impl warp::Reply, warp::Rejection> {
    tokio::time::sleep(Duration::from_secs(5)).await;
    Ok(format!("That was slow."))
}

I understand this is normal behavior and the recommended way is to rely on the Drop implementation of a type in the request, as that would always be called, so something like:

async fn in_theory<F, T, E>(filter: F) -> Result<T, E>
where
    F: Filter<Extract = T, Error = E>
{
    let guard = TrackingGuard::new();
    filter.await
}

But that doesn't work. I tried using wrap_fn like so:

pub fn in_theory<F>(filter: F) -> Result<F::Extract, F::Error>
where
    F: Filter + Clone,
{
    warp::any()
        .and(filter)
        .wrap_fn(|f| async { 
             // ... magic here ...
             f.await 
        })
}

but regardless of what I try, it always ends up with an error like this:

error[E0277]: the trait bound `<F as warp::filter::FilterBase>::Error: reject::sealed::CombineRejection<Infallible>` is not satisfied
   --> src/metrics.rs:255:25
    |
255 |         warp::any().and(filter).wrap_fn(|f| async { f.await })
    |                     --- ^^^^^^ the trait `reject::sealed::CombineRejection<Infallible>` is not implemented for `<F as warp::filter::FilterBase>::Error`
    |                     |
    |                     required by a bound introduced by this call

And that cannot be specified, because reject::sealed is not a public module. Any help is appreciated!

1

There are 1 best solutions below

0
sunside On BEST ANSWER

As was suggested in the comments, moving away from warp and using Tower for building the middleware helped. I had to rewrite the code for hosting the server to use hyper::Server directly but this was only a mild inconvenience.


I started off with an HttpCallMetrics service wrapping an inner service S. Since I am tracking HTTP responses, I need that service to ultimately produce a hyper::Response, which is indicated here by type argument O.

The phantom data is here such that I can indicate O on the struct; not adding O here would prevent the Service implementation to fail due to missing trait bounds.

#[derive(Clone)]
pub struct HttpCallMetrics<S, O> {
    inner: T,
    _phantom: PhantomData<O>,
}

impl<T, O> HttpCallMetrics<S, O> {
    pub fn new(inner: S) -> Self {
        Self {
            inner,
            _phantom: PhantomData::default(),
        }
    }
}

Because it is about HTTP metrics, the service also specifically deals with HTTP requests and hence implements Service<Request<B>> for any body type B. Likewise, the wrapped service needs to be the same and its output needs to be convertible to a Response<O>.

The HttpCallMetrics service will produce a custom future HttpCallMetricsFuture that takes care of the metrics tracking; this is to avoid boxing here. Apart from that, since metrics never block, it forwards its poll_ready call to the wrapped inner service.

When called, a HttpCallMetricTracker instance is created from the request. This is a struct that holds basic request information (HTTP method, version, path, start time instance) and implements Drop - when dropped, it will register that the request terminated. This will work regardless of cancellation or finishing a request successfully.

impl<S, B, O> Service<Request<B>> for HttpCallMetrics<S, O>
where
    S: Service<Request<B>>,
    S::Response: Into<hyper::Response<O>>,
{
    type Response = hyper::Response<O>;
    type Error = S::Error;
    type Future = HttpCallMetricsFuture<S::Future, O, Self::Error>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: Request<B>) -> Self::Future {
        let tracker = HttpCallMetricTracker::start(&request);
        HttpCallMetricsFuture::new(self.inner.call(request), tracker)
    }
}

The implemented future again requires a phantom data hack for keeping track of the success variant O and error variant E of the service's future.

#[pin_project]
pub struct HttpCallMetricsFuture<F, O, E> {
    #[pin]
    future: F,
    tracker: HttpCallMetricTracker,
    _phantom: PhantomData<(O, E)>,
}

impl<F, O, E> HttpCallMetricsFuture<F, O, E> {
    fn new(future: F, tracker: HttpCallMetricTracker) -> Self {
        Self {
            future,
            tracker,
            _phantom: PhantomData::default(),
        }
    }
}

The implementation is then comparatively simple: In essence, the poll call is forwarded to the wrapped inner future, and the method exits if that future is still Poll::Pending.

The moment the future returns Poll::Ready it will be inspected for its result variant and if it is an Ok(result) the result is converted into a hyper::Response. Metrics are then updated and the response is returned.

In case of an error variant, the error is essentially returned as is.

impl<F, R, O, E> Future for HttpCallMetricsFuture<F, O, E>
where
    F: Future<Output = Result<R, E>>,
    R: Into<hyper::Response<O>>,
{
    type Output = Result<hyper::Response<O>, E>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        let response = match this.future.poll(cx) {
            Poll::Pending => return Poll::Pending,
            Poll::Ready(reply) => reply,
        };

        let result = match response {
            Ok(reply) => {
                let response = reply.into();
                this.tracker
                    .set_state(ResultState::Result(response.status(), response.version()));
                Ok(response)
            }
            Err(e) => {
                this.tracker.set_state(ResultState::Failed);
                Err(e)
            }
        };
        Poll::Ready(result)
    }
}

The HttpCallMetricTracker is more or less trivial, it increments call metrics when constructed and decrements call metrics when dropped.

The only interesting point here would be the state: Cell<ResultState> field. This allows the Drop implementation to infer whether something should be logged or not. It's not strictly required here

struct HttpCallMetricTracker {
    version: Version,
    method: hyper::Method,
    path: String,
    start: Instant,
    state: Cell<ResultState>,
}

pub enum ResultState {
    /// The result was already processed.
    None,
    /// Request was started.
    Started,
    /// The result failed with an error.
    Failed,
    /// The result is an actual HTTP response.
    Result(StatusCode, Version),
}

impl HttpCallMetricTracker {
    fn start<B>(request: &Request<B>) -> Self {
        // increase "requests in flight" metric
        Self {
            // ...
            state: Cell::new(ResultState::None),
        }
    }

    fn set_state(&self, state: ResultState) {
        self.state.set(state)
    }

    fn duration(&self) -> Duration {
        Instant::now() - self.start
    }
}

impl Drop for HttpCallMetricTracker {
    fn drop(&mut self) {
        match self.state.replace(ResultState::None) {
            ResultState::None => {
                // This was already handled; don't decrement metrics again.
                return;
            }
            ResultState::Started => {
                // no request was actually performed.
            }
            ResultState::Failed => {
                // handle "fail" state
            }
            ResultState::Result(status, version) => {
                // handle "meaningful result" state
            }
        }

        // decrease "requests in flight" metric
    }
}

As far as hosting goes, the code now looks something like that:

let make_svc = make_service_fn(|_conn| {
    let tx = shutdown_tx.clone();

    async move {
        // Convert the warp filter into a Tower service.
        let svc = warp::service(
            hello
                .or(slow)
                .or(filters::metrics_endpoint())
                .or(filters::health_endpoints())
                .or(filters::shutdown_endpoint(tx)),
        );

        // Wrap it into the metrics service.
        let svc = services::HttpCallMetrics::new(svc);

        Ok::<_, Infallible>(svc)
    }
});

let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
let listener = TcpListener::bind(addr).unwrap();

// Using a ServiceBuilder is not strictly required.
let builder = ServiceBuilder::new().service(make_svc);

Server::from_tcp(listener)
    .unwrap()
    .serve(builder)
    .with_graceful_shutdown(async move {
        shutdown_rx.recv().await.ok();
    })
    .await?;

That said, there also exists tower_http::trace which indeed seems to support all of the above. I will likely migrate to that later on, but this exercise helped me tremendously in understanding Tower in the first place.