Implementing a simple ForEachAsync, part 2

Stephen Toub - MSFT

After my previous post, I received several emails and comments from folks asking why I chose to implement ForEachAsync the way I did.  My goal with that post wasn’t to prescribe a particular approach to iteration, but rather to answer a question I’d received… obviously, however, I didn’t provide enough background. Let me take a step back then so as to put the post in context.

Iteration is a common development task, and there are many different variations on how iteration might be implemented.  For example, a basic synchronous ForEach might be implemented as follows:

public static void ForEach<T>(this IEnumerable<T> source, Action<T> body)
{
    foreach(var item in source) body(item);
}

That, however, encapsulates just one particular semantic, that of looping through the source, executing the action one element at a time, and stopping if an exception is thrown.  Here’s another implementation, this time continuing the processing even if an exception is thrown, propagating any exceptions only once we’re done with the whole loop:

public static void ForEach<T>(this IEnumerable<T> source, Action<T> body)
{
    List<Exception> exceptions = null;
    foreach(var item in source)
    {
        try {
body(item); }
        catch(Exception exc)
        {
            if (exceptions == null) exceptions = new List<Exception>();
            exceptions.Add(exc);
        }

    }
    if (exceptions != null)
        throw new AggregateException(exceptions);
}

These are both synchronous examples.  Once asynchrony is introduced, additional variations are possible.  We can of course create asynchronous versions that match the two examples just shown, e.g.

public static async Task ForEachAsync<T>(this IEnumerable<T> source, Func<T,Task> body)
{
    foreach(var item in source) await body(item);
}

and:

public static async Task ForEachAsync<T>(this IEnumerable<T> source, Func<T,Task> body)
{
   
List<Exception> exceptions = null;
    foreach(var item in source)
    { 
        try { await
body(item); }
        catch(Exception exc)
        {
            if (exceptions == null) exceptions = new List<Exception>();
            exceptions.Add(exc);
        }

    }
    if (exceptions != null)  
        throw new AggregateException(exceptions);
}

respectively. But we can also go beyond this.  Once we’re able to launch work asynchronously, we can achieve concurrency and parallelism, invoking the body for each element and waiting on them all at the end, rather than waiting for each in turn, e.g.

public static Task ForEachAsync<T>(this IEnumerable<T> source, Func<T,Task> body)
{
    return Task.WhenAll(
        from item in source
        select body(item));
}

This serially invokes all of the body delegates, but it allows any continuations used in the bodies to run concurrently (depending on whether we’re in a serializing SynchronizationContextand whether the code in the body delegate is forcing continuations back to that context).  We could force more parallelism by wrapping each body invocation in a Task:

public static Task ForEachAsync<T>(this IEnumerable<T> source, Func<T, Task> body)
{
    return Task.WhenAll(
        from item in source
        select Task.Run(() => body(item)));
}

This will schedule a Task to invoke the body for each item and will then asynchronously wait for all of the async invocations to complete.  Note that this also means that the code run by the body delegate won’t be forced back to the current SynchronizationContext, even if there is one, since the async invocations are occurring on ThreadPool threads where there is no SynchronizationContext set.

We could further expand on this if we wanted to limit the number of operations that are able to run in parallel.  One way to achieve that is to partition the input data set into N partitions, where N is the desired maximum degree of parallelism, and schedule a separate task to begin the execution for each partition (this uses the Partitioner class from the System.Collections.Concurrent namespace):

public static Task ForEachAsync<T>(this IEnumerable<T> source, int dop, Func<T, Task> body)
{
    return Task.WhenAll(
        from partition in Partitioner.Create(source).GetPartitions(dop)
        select Task.Run(async delegate {
            using (partition)
                while (partition.MoveNext())
                    await body(partition.Current);
        }));
}

This last example is similar in nature to Parallel.ForEach, with the primary difference being that Parallel.ForEach is a synchronous method and uses synchronous delegates.

The point is that there are many different semantics possible for iteration, and each will result in different design choices and implementations.  The ForEachAsync example from my previous post was just one more such variation, accounting for the behavior that I’d been asked about.  As should now hopefully be obvious from this post, it is in no way the only way to iterate asynchronously.

Thanks for all the interest.

2 comments

Discussion is closed. Login to edit/delete existing comments.

  • Marc Selman 0

    I’ve added an optional CancellationToken so the loop can be cancelled when, for example, an exception occurs.

    public static Task ForEachAsync(this IEnumerable source, int dop, Func body, CancellationToken? ct = null)
    {
    	return Task.WhenAll(
    		from partition in Partitioner.Create(source).GetPartitions(dop)
    		select Task.Run(async delegate
    		{
    			using (partition)
    				while ((ct == null || !ct.Value.IsCancellationRequested) && partition.MoveNext())
    					await body(partition.Current);
    		}));
    }

    Then you can use it like this:

    var cts = new CancellationTokenSource();
    await items.ForEachAsync(20, async item =>
    {
    	try
    	{
    		await PerformAsyncTask(item);
    	}
    	catch (Exception ex)
    	{
    		cts.Cancel();
    	}
    }, cts.Token);
    
    if (cts.IsCancellationRequested)
    {
    	// An exception occurred and the ForEach loop was cancelled
    }
  • Alastair Crabtree 0

    Thanks for this post, useful. Another variant I found helpful was when I needed all exceptions aggregated, not just the first one which is the default, as shown below.

            
    public static Task ForEachAsync(
                this IEnumerable source, int dop, Func body)
            {
                var exceptions = new ConcurrentBag();
    
                void ObserveException(Task task)
                {
                    if (task.Exception != null)
                    {
                        exceptions.Add(task.Exception);
                    }
                }
    
                void RaiseExceptions(Task _)
                {
                    if (exceptions.Any())
                        throw (exceptions.Count == 1 ? exceptions.Single() : new AggregateException(exceptions))
                            .Flatten();
                }
    
                return Task.WhenAll(
                        from partition in Partitioner.Create(source).GetPartitions(dop)
                        select Task.Run(async delegate
                        {
                            using (partition)
                                while (partition.MoveNext())
                                    await body(partition.Current)
                                        .ContinueWith(ObserveException);
                        }))
                    .ContinueWith(ObserveException)
                    .ContinueWith(RaiseExceptions);
            }

Feedback usabilla icon