How to throw cancellation exception in an async for each implementation?

738 Views Asked by At

I have a custom async for each implementation that is defined and used as follows:

public static Task ForEachAsync<T>(this IEnumerable<T> source, int partitionCount, Func<T, Task> body)
{
    return Task.WhenAll(
        from partition in Partitioner.Create(source).GetPartitions(partitionCount)
        select Task.Run(async delegate
        {
            using (partition)
            {
                while (partition.MoveNext())
                {
                    await body(partition.Current).ConfigureAwait(false);
                }
            }
        })
    );
}

...

List<long> ids = new List...

await ids.ForEachAsync(8,
    async (id) =>
    {
        await myTask(id);
    }
);

This works great, but now I need to modify this to allow for a cancellation token to be passed in. I have tried something as simple as this:

List<long> ids = new List...

await ids.ForEachAsync(8,
    async (id) =>
    {
        myToken.ThrowIfCancellationRequested();
        await myTask(id);
    }
);

But this fails ungracefully. Rather than an OperationCanceledException bubbling up, as I would have expected, I am receiving an exception that is being thrown by one of the threads as a result of the cancellation. I have also tried passing the token into the async extension method but that didn't seem to work either. Can someone please provide guidance on how this should be done? Thanks.

1

There are 1 best solutions below

4
On BEST ANSWER

To get the exception to bubble up you need to pass the token in to the Task.Run it will just take a small modification to your code.

public static Task ForEachAsync<T>(this IEnumerable<T> source, int partitionCount, Func<T, Task> body, CancellationToken token = default(CancellationToken))
{
    return Task.WhenAll(
        from partition in Partitioner.Create(source).GetPartitions(partitionCount)
        select Task.Run(async delegate
        {
            using (partition)
            {
                while (partition.MoveNext())
                {
                    await body(partition.Current).ConfigureAwait(false);
                }
            }
        }, token) //token passed in
    );
}

used like

await ids.ForEachAsync(8,
    async (id) =>
    {
        myToken.ThrowIfCancellationRequested();
        await myTask(id);
    },
    myToken //token passed in
);