Torchserve streaming of inference responses with gRPC

517 Views Asked by At

I am trying to send a singular request to a Torchserve server and retrieve a stream of responses. The processing of the request takes some time and I would like to receive intermeddiate updates over the course of the run. I am quite new to torchserve and especially gRPC but I assume that I either need to write a custom endpoint plugin for torchserve or alter the source code directly as the current proto files of Torchserve support unary gRPC calls.

I have found examples of near real-time video which implemented a version of client-side streaming via request batching however that is not what I need.

Question: Is there a way to implement server-side response streaming in the latest Torchserve version? Or would I need to change the proto files and the Java source in order to allow for it?

1

There are 1 best solutions below

2
On

There looks to be support for streaming within the TorchServe framework.

TorchServe's gRPC API adds server-side streaming to the "StreamPredictions" inference API which allows for sequences of inference responses to be send over the same gRPC stream.

service InferenceAPIsService {
    // Check health status of the TorchServe server.
    rpc Ping(google.protobuf.Empty) returns (TorchServeHealthResponse) {}

    // Predictions entry point to get inference using default model version.
    rpc Predictions(PredictionsRequest) returns (PredictionResponse) {}

    // Streaming response for an inference request.
    rpc StreamPredictions(PredictionsRequest) returns (stream PredictionResponse) {}
}

NOTE: This API forces the batchSize to be one. Make sure to account for this in your custom handler logic.

The backend handler calls “send_intermediate_predict_response” to send one intermediate result to frontend, and return the last result as the existing style. For example

from ts.protocol.otf_message_handler import send_intermediate_predict_response

def handle(data, context):
    if type(data) is list:
        for i in range (3):
            send_intermediate_predict_response(["intermediate_response"], context.request_ids, "Intermediate Prediction success", 200, context)
        return ["hello world "]