How to batch an IAsyncEnumerable<T>, enforcing a maximum interval policy between consecutive batches?

4.3k Views Asked by At

I have an asynchronous sequence (stream) of messages that are arriving sometimes numerously and sometimes sporadically, and I would like to process them in batches of 10 messages per batch. I also want to enforce an upper limit to the latency between receiving a message and processing it, so a batch with fewer than 10 messages should also be processed, if 5 seconds have passed after receiving the first message of the batch. I found that I can solve the first part of the problem by using the Buffer operator from the System.Interactive.Async package:

IAsyncEnumerable<Message> source = GetStreamOfMessages();
IAsyncEnumerable<IList<Message>> batches = source.Buffer(10);
await foreach (IList<Message> batch in batches)
{
    // Process batch
}

The signature of the Buffer operator:

public static IAsyncEnumerable<IList<TSource>> Buffer<TSource>(
    this IAsyncEnumerable<TSource> source, int count);

Unfortunately the Buffer operator has no overload with a TimeSpan parameter, so I can't solve the second part of the problem so easily. I'll have to implement somehow a batching operator with a timer myself. My question is: how can I implement a variant of the Buffer operator that has the signature below?

public static IAsyncEnumerable<IList<TSource>> Buffer<TSource>(
    this IAsyncEnumerable<TSource> source, TimeSpan timeSpan, int count);

The timeSpan parameter should affect the behavior of the Buffer operator like so:

  1. A batch must be emitted when the timeSpan has elapsed after emitting the previous batch (or initially after the invocation of the Buffer method).
  2. An empty batch must be emitted if the timeSpan has elapsed after emitting the previous batch, and no messages have been received during this time.
  3. Emitting batches more frequently than every timeSpan implies that the batches are full. Emitting a batch with less than count messages before the timeSpan has elapsed, is not desirable.

I am OK with adding external dependencies to my project if needed, like the System.Interactive.Async or the System.Linq.Async packages.

P.S. this question was inspired by a recent question related to channels and memory leaks.

2

There are 2 best solutions below

4
On BEST ANSWER

The solution below uses the PeriodicTimer class (.NET 6) for receiving timer notifications, and the Task.WhenAny method for coordinating the timer and enumeration tasks. The PeriodicTimer class is more convenient than the Task.Delay method for this purpose, because it can be disposed directly, instead of requiring an accompanying CancellationTokenSource.

/// <summary>
/// Splits the elements of a sequence into chunks that are emitted when either
/// they are full, or a given amount of time has elapsed after requesting the
/// previous chunk.
/// </summary>
public static async IAsyncEnumerable<IList<TSource>> Buffer<TSource>(
    this IAsyncEnumerable<TSource> source, TimeSpan timeSpan, int count,
    [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
    ArgumentNullException.ThrowIfNull(source);
    if (timeSpan < TimeSpan.FromMilliseconds(1.0))
        throw new ArgumentOutOfRangeException(nameof(timeSpan));
    if (count < 1) throw new ArgumentOutOfRangeException(nameof(count));

    using CancellationTokenSource linkedCts = CancellationTokenSource
        .CreateLinkedTokenSource(cancellationToken);
    PeriodicTimer timer = null;
    Task<bool> StartTimer()
    {
        timer = new(timeSpan);
        return timer.WaitForNextTickAsync().AsTask();
    }
    IAsyncEnumerator<TSource> enumerator = source
        .GetAsyncEnumerator(linkedCts.Token);
    Task<bool> moveNext = null;
    try
    {
        List<TSource> buffer = new();
        TSource[] ConsumeBuffer()
        {
            timer?.Dispose();
            TSource[] array = buffer.ToArray();
            buffer.Clear();
            if (buffer.Capacity > count) buffer.Capacity = count;
            return array;
        }
        Task<bool> timerTickTask = StartTimer();
        while (true)
        {
            if (moveNext is null)
            {
                if (timerTickTask.IsCompleted)
                {
                    Debug.Assert(timerTickTask.Result);
                    yield return ConsumeBuffer();
                    timerTickTask = StartTimer();
                }
                moveNext = enumerator.MoveNextAsync().AsTask();
            }
            if (!moveNext.IsCompleted)
            {
                Task completedTask = await Task.WhenAny(moveNext, timerTickTask)
                    .ConfigureAwait(false);
                if (ReferenceEquals(completedTask, timerTickTask))
                {
                    Debug.Assert(timerTickTask.IsCompleted);
                    Debug.Assert(timerTickTask.Result);
                    yield return ConsumeBuffer();
                    timerTickTask = StartTimer();
                    continue;
                }
            }
            Debug.Assert(moveNext.IsCompleted);
            bool moved = await moveNext.ConfigureAwait(false);
            moveNext = null;
            if (!moved) break;
            TSource item = enumerator.Current;
            buffer.Add(item);
            if (buffer.Count == count)
            {
                yield return ConsumeBuffer();
                timerTickTask = StartTimer();
            }
        }
        if (buffer.Count > 0) yield return ConsumeBuffer();
    }
    finally
    {
        // Cancel the enumerator, for more responsive completion.
        try { linkedCts.Cancel(); }
        finally
        {
            // The last moveNext must be completed before disposing.
            if (moveNext is not null && !moveNext.IsCompleted)
                await Task.WhenAny(moveNext).ConfigureAwait(false);
            await enumerator.DisposeAsync().ConfigureAwait(false);
            timer?.Dispose();
        }
    }
}

The timer is restarted each time a chunk is emitted, after the consumer has finished consuming the chunk.

Online demo.

This implementation is destructive, meaning that in case the source sequence fails or the enumeration is canceled, any elements that have been consumed previously from the source and are buffered, will be lost. See this question for ideas about how to inject a non-destructive behavior.

Care has been taken to avoid leaking fire-and-forget MoveNextAsync operations or timers.

For an implementation that uses the Task.Delay method instead of the PeriodicTimer class, and so it can be used by .NET versions previous than 6.0, you can look at the 7th revision of this answer. That revision includes also a tempting but flawed Rx-based implementation.

8
On

What about using a Channel to achieve the required functionality? Is there any flaw if using something like this extension method to read from a queue until a timeout has expired?

public static async Task<List<T>> ReadWithTimeoutAsync<T>(this ChannelReader<T> reader, TimeSpan readTOut, CancellationToken cancellationToken)
{
    var timeoutTokenSrc = new CancellationTokenSource();
    timeoutTokenSrc.CancelAfter(readTOut);

    var messages = new List<T>();

    using (CancellationTokenSource linkedCts =
        CancellationTokenSource.CreateLinkedTokenSource(timeoutTokenSrc.Token, cancellationToken))
    {
        try
        {
            await foreach (var item in reader.ReadAllAsync(linkedCts.Token))
            {
                messages.Add(item);
                linkedCts.Token.ThrowIfCancellationRequested();
            }

            Console.WriteLine("All messages read.");
        }
        catch (OperationCanceledException)
        {
            if (timeoutTokenSrc.Token.IsCancellationRequested)
            {
                Console.WriteLine($"Delay ({readTOut.Milliseconds} msec) for reading items from message channel has expired.");
            }
            else if (cancellationToken.IsCancellationRequested)
            {
                Console.WriteLine("Cancelling per user request.");
                cancellationToken.ThrowIfCancellationRequested();
            }
        }
    }
    timeoutTokenSrc.Dispose();

    return messages;
}

To combine the timeout with the max. batch size, one more token source could be added:

public static async Task<List<T>> ReadBatchWithTimeoutAsync<T>(this ChannelReader<T> reader, int maxBatchSize, TimeSpan readTOut, CancellationToken cancellationToken)
{
    var timeoutTokenSrc = new CancellationTokenSource();
    timeoutTokenSrc.CancelAfter(readTOut);
    var maxSizeTokenSrc = new CancellationTokenSource();

    var messages = new List<T>();

    using (CancellationTokenSource linkedCts =
        CancellationTokenSource.CreateLinkedTokenSource(timeoutTokenSrc.Token, maxSizeTokenSrc.Token, cancellationToken))
    {
        try
        {
            await foreach (var item in reader.ReadAllAsync(linkedCts.Token))
            {
                messages.Add(item);
                if (messages.Count >= maxBatchSize)
                {
                    maxSizeTokenSrc.Cancel();
                }
                linkedCts.Token.ThrowIfCancellationRequested();
            }....