Context Cancels not triggering on a blocking Stream.Recv() in Go gRPC Bi-Directional Stream

50 Views Asked by At

I'm currently working on a gRPC implementation in Go where I have a server that needs to wait for a response from the client within a certain timeframe. If it doesn't receive a response in that timeframe (currently 35 seconds) it should exit the iteration and wait on the next <-app.driverChannels[driverID[0]]. To achieve this, I've tried using a context with a timeout and calling stream.Recv() within a loop to wait for the client's response. However, it seems that stream.Recv() is blocking the execution of the context cancellation with the timeout.

I've tried not using the wg.Wait() which causes the client to receive the prompt multiple times, and eventually it stops sending back to the client. I've also tried just waiting on the processDriverDecision() using a single lock instead of 2, which results in the same problem.

I've also tried doing a shared context with a cancel approach and calling cancel() in the processDriverDecision because that successfully times out after 35 seconds and triggers, and I thought the ctx.Done() channel would trigger in the receiveDriverDecision() function, but it did not.

Lastly, I've thought of doing 2 separate streaming functions instead of using a single bi-directional stream for this implementation, but that sounds like bad design.

I'm at odds here I have no idea what to try next, I've encountered blocking channels before, and timers and tickers trigger to break out. ChatGPT is also useless in this situation too.

I saw this about Recv() so I'm wondering if what I'm asking is even possible:

// RecvMsg blocks until it receives a message into m or the stream is
    // done. It returns io.EOF when the client has performed a CloseSend. On
    // any non-EOF error, the stream is aborted and the error contains the
    // RPC status.

This loop is the entry point for both of these functions:

app.driverChannelsLock.Lock()
    _, ok = app.driverChannels[driverID[0]]
    if !ok {
        app.driverChannels[driverID[0]] = make(chan interface{}, 2)
    }
    app.driverChannelsLock.Unlock()

for {
        select {
        case o := <-app.driverChannels[driverID[0]]:

            log.Printf("RECEIVED an ORDER Type [StreamDispatchOrders]: %v\n", o.(models.Order).ID)
            order := o.(models.Order)

            sortedRoutes, err := app.MapOrderToRoutes(driverID[0], &order)
            if err != nil {
                log.Fatal().Msgf("Failed to MapOrderToRoutes(): %v\n", err)
            }

            routingResponse := []*models.DriverRoute{}

            if len(sortedRoutes) > 1 {
                routingResponse = app.GoogleRoutesAPI(sortedRoutes)
                log.Printf("ROUTING RESPONSE: %v\n", len(routingResponse))
            }

            driverRoutes := mappers.DriverRoutesToRPC(routingResponse)

            //distance := float64(routingResponse.Routes[0].GetDistanceMeters()) / metersInMile
            //minutes := routingResponse.Routes[0].GetDuration().AsDuration().Minutes()
            //log.Printf("DISTANCE (IN METERS): %v | MINUTES (AS DURATION): %v\n", distance, minutes)

            // Implement a delivery calculation function
            log.Printf("ORDER (ESTIMATED EARNINGS) :: %v\n", order.DriverPool.EstimatedEarnings)
            estimatedEarnings := (order.DriverPool.EstimatedEarnings + order.DriverPool.TipAmount) / 2

            for _, od := range driverRoutes {
                if od.Order.Id == order.ID.Hex() {

                    orderPrompt := &pb.OrderPromptRequest{
                        Routes: driverRoutes,
                        // TODO: Calculate Driver FARE beforehand
                        EstimatedEarnings: estimatedEarnings,
                        Distance:          od.Miles,
                        Minutes:           od.Minutes,
                        NewOrder:          mappers.OrderToRPC(&order),
                    }

                    err = stream.Send(&pb.DispatchOrderResponse{Resp: &pb.DispatchOrderResponse_Response{Response: orderPrompt},
                    })

                    if err != nil {
                        log.Printf("Failed to send order: %v\n", err)
                        return status.Errorf(codes.Unavailable, "failed to send order")
                    }
                }
            }

            // Receives a drivers decision, whether by the driver or timeout
            ctx, cancel := context.WithCancel(context.Background())
            defer cancel()

            wg := new(sync.WaitGroup)
            wg.Add(2)

            received := make(chan Received)
            defer close(received)

            // Start a goroutine to perform the Recv operation
            go app.receiveDriverDecision(ctx, stream, wg, received, 1, order.ID, dID, order.DeliveryPhase)

            go app.processDriverDecision(ctx, cancel, stream, order.ID, dID, order.DeliveryPhase, wg, received, 1)

            wg.Wait()

            log.Printf("Receiver Goroutine: Channel closed. Exiting.")
        case <-stream.Context().Done():
            log.Printf("STREAM DISPATCH ORDERS :: CONTEXT DONE")
            return status.Errorf(codes.Canceled, "cancelled client")
        }
    }

Here's the current state of the receive function:

// receiveDriverDecision listens on the StreamDispatchOrders function stream for a decision (true or false) and sends it along to the
// processDriverDecision through the receivedChannel.
func (app *driverServer) receiveDriverDecision(ctx context.Context, stream pb.DriverService_StreamDispatchOrdersServer, group *sync.WaitGroup, receivedChannel chan Received, readWrite int, orderID primitive.ObjectID, driverID primitive.ObjectID, phase models.OrderPhase) {
    defer group.Done() // Ensure group.Done() is called to signal completion
    //for {
    // Create a context with a deadline of 35 seconds
    ctxchild, cancel := context.WithTimeout(stream.Context(), 35*time.Second)
    defer cancel()

        select {
        case <-ctxchild.Done():
            log.Printf("CANCELLATION RECEIVED (RECEIVE DRIVER DECISION) :: EXITING")
            return
        default:
            // Create a child context with a timeout of 35 seconds
            //childCtx, cancel := context.WithTimeout(ctx, 35*time.Second)
            //defer cancel() // Cancel the child context to release resources

            //log.Printf("BEFORE RECEIVE STREAM :: %v", childCtx)

            // Receive data from the stream with the timeout
            ordr, err := stream.Recv()
            if err == io.EOF {
                log.Printf("End of stream")
                return
            } else if err != nil {
                log.Printf("Error receiving from StreamDispatchOrder: %v", err)
                return
            }

            // Process the received order (your logic here)
            recvd := &Received{
                Order:    mappers.RPCToOrder(ordr.GetRequest().GetOrder()),
                DriverID: ordr.GetRequest().GetDriverId(),
                Decision: ordr.GetRequest().GetAcceptDecision(),
            }
            for i := 1; i <= readWrite; i++ {
                receivedChannel <- *recvd
            }
            log.Printf("RECEIVED WAS SENT")
        }
    //}

}

Here is the processDriverDecision

// processDriverDecision listens on the receivedChannel channel and waits for the driver decision from receiveDriverDecision, or it will
// time out after 35 seconds and send an empty message to the client to cancel the order prompt from the screen.
// If the message comes, we will assign the driver or not based on their decision, and insert their decision into driver stats in the assignment collection
func (app *driverServer) processDriverDecision(ctx context.Context, cancel context.CancelFunc, stream pb.DriverService_StreamDispatchOrdersServer, orderID primitive.ObjectID, driverID primitive.ObjectID, phase models.OrderPhase, group *sync.WaitGroup, receivedChannel chan Received, readWrite int) {
    defer group.Done()

    conn, err := app.GetOrderClient()
    if err != nil {
        return
    }

    orderClient := pb.NewOrderServiceClient(conn)

    select {
    // TODO: Implement a oneof approach that the client will send us back a isInBackground type //
        // response and we can instead send an APN/Google push notification
    case msg := <-receivedChannel:
        estimatedEarnings := (msg.Order.DriverPool.EstimatedEarnings + msg.Order.DriverPool.TipAmount) / 2
        log.Printf("Received RECEIVED MESSAGE: %v", msg)

        orderAssigned, err := orderClient.GetOrderFromAssignedRecords(stream.Context(), &pb.GetOrderFromAssignedRequest{
            OrderId:       msg.Order.ID.Hex(),
            DeliveryPhase: pb.OrderPhase(msg.Order.DeliveryPhase),
        })

        if err != nil {
            log.Printf("Failed to get order from OrderService: %v\n", err)
            return

        }

        if !orderAssigned.GetAssigned() {
            if msg.Decision {
                updatedOrder, err := app.handleAssignDriverToOrder(stream.Context(), msg.DriverID, msg.Order)
                if err != nil {
                    return
                }

                routes, err := app.driver.AddRoutesForOrder(stream.Context(), updatedOrder, msg.DriverID)
                if err != nil {
                    log.Printf("Failed to add order to routes: %v\n", err)
                    return
                }

                //log.Printf("ROUTES: %v\n", routes)
                routes = app.GetSortedRoutes(routes[0].DriverLocation, routes)
                log.Printf("LENGTH OF ROUTES (accepted): %d", len(routes))

                

                //if len(routes) > 1 {
                //  routingResp = app.GoogleRoutesAPI(routes)
                //  log.Printf("ROUTING RESPONSE: %v\n", len(routingResp))
                //}

                newDriverRoutes := mappers.DriverRoutesToRPC(routes)

                //distance := float64(routingResponse.Routes[0].GetDistanceMeters()) / metersInMile
                //minutes := routingResponse.Routes[0].GetDuration().AsDuration().Minutes()

                for _, od := range newDriverRoutes {
                    if od.Order.Id == updatedOrder.Id {
                        log.Printf("MATCHED DRIVER ROUTE (Calculating order prompt)")
                        //estimatedEarnings := (od.Miles * 0.67) + (od.Minutes * 0.225) + order.Cost.DeliveryFee

                        err = stream.Send(&pb.DispatchOrderResponse{Resp: &pb.DispatchOrderResponse_AcceptedResponse{AcceptedResponse: &pb.AcceptedOrderResponse{
                            AssignedSuccessfully: true,
                            Response: &pb.OrderPromptResponse{
                                Routes:            newDriverRoutes,
                                EstimatedEarnings: estimatedEarnings,
                                Distance:          od.Miles,
                                Minutes:           od.Minutes,
                            },
                        }}})

                        if err != nil {
                            log.Printf("Failed to send order: %v\n", err)
                        }
                    }

                    log.Printf("DRIVER ROUTE (LOOP): %v | %v\n", od.Minutes, od.Miles)
                }

                select {
                case app.orderAssignedChan[msg.Order.ID.Hex()] <- true:
                    log.Printf("[StreamDispatchOrder] Sent accepted to channel to break out of dispatch loop")
                default:
                    log.Printf("[StreamDispatchOrder] Failed to send to break out of loop")
                }

            }
        } else {

            err = stream.Send(&pb.DispatchOrderResponse{Resp: &pb.DispatchOrderResponse_AcceptedResponse{AcceptedResponse: &pb.AcceptedOrderResponse{
                AssignedSuccessfully: false,
                Response: &pb.OrderPromptResponse{
                    Routes:            []*pb.Route{},
                    EstimatedEarnings: 0,
                    Distance:          0,
                    Minutes:           0,
                },
            }}})
            if err != nil {
                log.Printf("Failed to SEND (ASSIGNED UNSUCCESSFULLY): %v\n", err)
            }
        }
        //wg.Done()
        log.Printf("Recv operation completed successfully.")
    case <-time.After(35*time.Second):

        log.Printf("SESSION'S CONTEXT RAN OUT OF TIME (PROCESS DRIVER DECISION)")
        cancel()

    }

}
0

There are 0 best solutions below