How can I get grpc status code in Rust tonic middleware

686 Views Asked by At

I've tried to implement a logging middleware for tonic Server that records elapse time, grpc status code, etc. I referred to https://tokio.rs/blog/2021-05-14-inventing-the-service-trait and https://github.com/tower-rs/tower/blob/master/guides/building-a-middleware-from-scratch.md to get familiar with how to build a middleware. Here is my implementation:

pub struct GrpcLoggingMiddleware<S> {
    service_name: &'static str,
    inner: S,
}

impl<S> GrpcLoggingMiddleware<S> {
    pub fn new(inner: S, service_name: &'static str) -> Self {
        Self {
            inner,
            service_name,
        }
    }
}

impl<S, Res> Service<http::request::Request<tonic::transport::Body>> for GrpcLoggingMiddleware<S>
where
    S: Service<
        http::request::Request<tonic::transport::Body>,
        Response = http::response::Response<Res>,
    >,
    S::Future: Future<Output = Result<http::response::Response<Res>, S::Error>>,
    S::Error: std::fmt::Debug,
    Res: std::fmt::Debug,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = ResponseFuture<S::Future>;

    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, request: http::request::Request<tonic::transport::Body>) -> Self::Future {
        // wrap response future to avoid Pin<Box<_>> overhead
        let path = String::from(request.uri().path());
        ResponseFuture {
            response_future: self.inner.call(request),
            path,
            start: Instant::now(),
        }
    }
}

#[pin_project]
pub struct ResponseFuture<F> {
    #[pin]
    response_future: F,
    path: String,
    start: Instant,
}

impl<F, Res, E> Future for ResponseFuture<F>
where
    F: Future<Output = Result<http::response::Response<Res>, E>>,
    E: std::fmt::Debug,
    Res: std::fmt::Debug,
{
    type Output = Result<http::response::Response<Res>, E>;

    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        let this = self.project();
        match this.response_future.poll(cx) {
            Poll::Ready(result) => {
                match &result {
                    Err(_) => error!("grpc response error"),
                    Ok(response) => debug!(
                        "finish grpc request, path: {}, elapse: {}us, status code: {}",
                        this.path,
                        this.start.elapsed().as_micros(),
                        response.status(),
                    ),
                }
                return Poll::Ready(result);
            }
            Poll::Pending => {
                return Poll::Pending;
            }
        }
    }
}

and layer it:

let layer = tower::layer::layer_fn(|service| GrpcLoggingMiddleware::new(service, "default"));

Server::builder()
    .layer(layer)
    .add_service(...)
    .serve(addr).await?;

However, we can only get a http::response::Response object in middleware implementation, which only has http status code. But I want to capture the grpc status code, which is wrapped in http body.

So I'm wondering how can I capture the grpc status code in middleware? The only way I can figure out is to deserialize http body but that's an extra overhead that I don't want.

2

There are 2 best solutions below

0
On BEST ANSWER

The grpc status code is returned in the http header map keyed on grpc-status. In your ResponseFuture struct and inside its poll method, you will need a mechanism to extract the value from the http response headers in the Ok arm of the match statement.

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
    let this = self.project();
    match this.response_future.poll(cx) {
        Poll::Ready(result) => {
            match &result {
                Err(_) => error!("grpc response error"),
                Ok(response) => {
                    let grpc_status_code = status_code_from(response.headers())
                } 
            }
            return Poll::Ready(result);
        }
        Poll::Pending => {
            return Poll::Pending;
        }
    }
}

A naive implementation of the status_code_from above could look like the following:

pub fn status_code_from(headers: &hyper::HeaderMap) -> &str {
    // tonic/hyper will not add a grpc-status code to headers
    // in the event of success
    match headers.get("grpc-status") {
        Some(value) => {
            // if gRPC status code cannot be unwrapped to string
            // instead of panicing we return status code INTERNAL
            let code: &str = value.to_str().unwrap_or("13");
            println!("got grpc status code {}", code);
            match code {
                "0" => "OK",
                "1" => "CANCELLED",
                "2" => "UNKNOWN",
                "3" => "INVALID_ARGUMENT",
                "4" => "DEADLINE_EXCEEDED",
                "5" => "NOT_FOUND",
                "6" => "ALREADY_EXISTS",
                "7" => "PERMISSION_DENIED",
                "8" => "RESOURCE_EXHAUSTED",
                "9" => "FAILED_PRECONDITION",
                "10" => "ABORTED",
                "11" => "OUT_OF_RANGE",
                "12" => "UNIMPLEMENTED",
                "13" => "INTERNAL",
                "14" => "UNAVAILABLE",
                "15" => "DATA_LOSS",
                "16" => "UNAUTHENTICATED",
                _ => "UNKNOWN",
            }
        }
        None => "OK",
    }
}
0
On

you can try tower_http, the following code can get you the status code of the failed grpc request

    Server::builder()
        .layer(tower_http::trace::TraceLayer::new_for_grpc().on_failure(
            |error: tower_http::classify::GrpcFailureClass, _latency: Duration, _span: &Span| {
                tracing::error!("something went wrong: {:?}", error);
            },
        ))