How to write a custom awaiter

[This post is part of a series How to await a storyboard, and other things]

 

The normal behavior of the "await" operator on a task is to suspend execution of the method; then, when the task operand has finished, to resume execution on the same SynchronizationContext and with the same ExecutionContext.

You can replace the part in italics with your own logic by awaiting an operand of your own type (not a task) and writing a custom awaiter for it. It's not a common need. Stephen Toub has already written an excellent article that covers this topic ("Await anything"). He explains that a custom awaiter is for when "you need full control over how (rather than when) the method resumes". I wanted to build on his article from a compiler-writer's perspective and with some more example code. Here are some examples I've seen of custom awaiters.

  • Await task1.ConfigureAwait(false) . Here execution doesn't make the effort to go back to the prior synchronization context, and often resumes on whatever thread the task happened to signal its completion upon. This is an efficiency boost, since it's faster than having to post back onto the prior SynchronizationContext.
  • Await Task.Yield() . Here there isn't an operand that finishes, as such. Instead this awaiter schedules execution to resume shortly on the same SynchronizationContext for some vague meaning of "shortly"... (Actually, the meaning of "shortly" is so vague that the only predictable use of this API is to force a suspension of an async method).

The above two scenarios were thought common enough that they're provided in the .NET45 framework. The rest aren't...

  • Await HibernateAsync("a.hib") . I wrote an awaiter like this which didn't have any useful resumption. Instead it used the fact that a custom awaiter can (undocumentedly) use reflection to get at the current state of the async method, and thereby serialize it to disk.
  • Await task1.OnCoroutine(crm1) . I saw some code which bypassed the SynchronizationContext mechanism completely. Instead it used its own co-routine managers: a set of lists-of-actions, which they executed in round-robin fashion.
  • Await SwitchToThreadpool() / SwitchToDispatcher(dispatcher). I saw some code which used these two primitives to switch between threadpool and UI thread, as an alternative to the more common "Await Task.Run(...)" and "Await Dispatcher.RunAsync(...)" which both involve lambdas. It came from a wish to expunge all lambdas from a user's code.
  • Await SwitchPriority(High) . Similar to the above example, it might be useful to start executing at a higher priority.
  • Dim x = Await thread1.Enter() : ... : Await x.Leave() . I thought of this as a synchronization primitive. The idea is that most of my code runs happily on multiple threads, but I block off certain critical sections which must be executed on a particular thread maybe because they use thread-local-storage.
  • Await task1.Log() . Here I was trying to track down a bug where Await wasn't resuming. I wrote a custom awaiter which wrote to a debug log at every stage along the way of suspending and resuming an async method.
  • Await task1.WithCulture() . Normally, CurrentThread.CurrentCulture doesn't flow across await points. You can write a custom awaiter which does make it flow.
  • Await task1.TemporarilyConfigureAwait(false) . Maybe you might want the efficiency-boost of task.ConfigureAwait(false), but you still want it to restore SynchronizationContext.Current so that subsequent awaits will return to it.

Here's some background reading to explain the concepts...

What the compiler requires from a custom awaiter

When you use the Await operator, the compiler expands it out into something a bit like this:

' ORIGINAL CODE

Dim r = Await t

 

' APPROXIMATELY WHAT THE COMPILER GENERATES

Dim tmp = t.GetAwaiter()

If Not tmp.IsCompleted Then

Dim ec = ExecutionContext .Capture()

    CType(tmp, INotifyCompletion).OnCompleted( Sub () ExecutionContext .Run(ec, K1 ))

    Return

End If

K1:

Dim r = tmp.GetResult()

 

' Some things not shown in the above simplification:

' The two lines in italics are executed from within a routine inside mscorlib

' The argument "K1" represents a delegate that, when executed, resumes execution below.

' If "tmp" is a struct, any mutations after IsCompleted before GetResult may be lost.

' If tmp implements ICriticalNotifyCompletion, then it calls tmp.UnsafeOnCompleted instead.

' The variable "ec" gets disposed at the right time inside the Sub()

' If ec is null, then it invokes K1 directly instead of within ExecutionContext.Run

' There are optimizations for the case where ExecutionContext is unmodified.

 

From the compiler perspective, here's the bare minimum that you need to implement to be able to compile "Await t":

  1. Dim tmp = t.GetAwaiter() must compile, and GetAwaiter must be either an instance method which takes no parameters, or an extension method which takes just one "this" parameter. (From this you can deduce that GetAwaiter cannot take optional parameters, and can take generic type parameters only if it's an extension method).
      
  2. Dim b = tmp.IsCompleted must compile, and IsCompleted must be a property which takes no parameters and returns a Boolean. (From this you can deduce that IsCompleted must be an instance property).
      
  3. tmp must implement INotifyCompletion. It can optionally also implement ICriticalNotifyCompletion. These must be "CLR implements" – the compiler will not use dynamic casts or user-defined conversions. If the delegate passed on OnCompleted/UnsafeOnCompleted is executed zero times then the async method will never resume. If it is executed more than once, then behavior of the async method is unspecified.
      
  4. tmp.GetResult() must compile, and GetResult must be an instance method which takes no parameters. If GetResult is a function, then this function's return-value is used as the result of the Await operator. If GetResult is a sub (i.e. void-returning), then the Await operator returns void.
      

If you do a late-bound await then the rules are a little different. A late-bound await is "await t" where t has type dynamic (C#) or Object (VB with Option Strict Off).

1. Dim tmp = t.GetAwaiter() must compile as a late-bound invocation, and it will use the normal late-binder rules: it works with optional parameters, and IDynamicMetaObjectProvider, and even with a delegate field named GetAwaiter.

2. Dim b = tmp.IsCompleted must compile as a normal late-bound property access.

3. tmp must implement INotifyCompletion or ICriticalNotification. As before, it must implement them through CLR interface implementation; not through dynamic casts.

4. tmp.GetResult() must compile as a normal late-bound invocation, and the result of the late-bound invocation is the result of the late-bound Await operator.
  

These are only the bare-minimum requirements for the code to compile and have kind of specified behavior at runtime. The rest of it, of how you implement GetAwaiter/IsCompleted/OnCompleted/GetResult, is entirely up to you and convention.
  

What conventions the user expects from a custom awaiter

Here is code for a custom awaiter that satisfies all the compiler's requirements, and behaves approximately similarly to how the default TaskAwaiter behaves ("suspend execution of the method; then, when the operand has finished, resume execution on the same SynchronizationContext and with the same ExecutionContext").

This code is pointless! The only reason you'd write a custom awaiter is because you want something different from the default TaskAwaiter. (After all, if you were happy with the default TaskAwaiter behavior, then it'd be more efficient and more robust to just create one via TaskCompletionSource.Task.GetAwaiter). I'm putting this code here just as a starting point, to diverge from it. I've also made fields public to keep the code simple.

Public Class MyTask(Of T)

    Public result As T

    Public isCompleted As Boolean

    Public continuations As New List(Of Action)

    Public mutex As New Object

 

    Public Function GetAwaiter() As MyTaskAwaiter(Of T)

        Return New MyTaskAwaiter(Of T) With {.task = Me}

    End Function

 

    Public Sub SetResult(value As T)

        Dim cc As List(Of Action)

        SyncLock mutex

            result = value

            isCompleted = True

            cc = continuations

            continuations = Nothing

        End SyncLock

        For Each c In cc : c() : Next

    End Sub

 

    Public Sub AddContinuation(continuation As action)

        SyncLock mutex

            If Not isCompleted Then continuations.Add(continuation) : Return

        End SyncLock

        Task.Run(continuation)

    End Sub

End Class

 

 

Public Class MyTaskAwaiter(Of T)
    Implements System.Runtime.CompilerServices.INotifyCompletion

 

    Public task As MyTask(Of T)

 

    Public ReadOnly Property IsCompleted As Boolean

        Get

            SyncLock task.mutex

                Return task.isCompleted

            End SyncLock

        End Get

    End Property

 

    Public Function GetResult() As T

        Return task.result

    End Function

 

    Public Sub OnCompleted(continuation As Action) _
    Implements INotifyCompletion.OnCompleted

        Dim sc = If(SynchronizationContext.Current, New SynchronizationContext)

        task.AddContinuation(Sub() sc.Post(Sub() continuation(), Nothing))

    End Sub

End Class

 

Alternatively, if your code allows partially trusted callers, then the above code would be a security hole. Here's how we could write MyTaskAwaiter instead for this case:

 

<Assembly: Security.AllowPartiallyTrustedCallers>

 

Public Class MyTaskAwaiter(Of T) : Implements ICriticalNotifyCompletion

    Public task As MyTask(Of T)

 

    Public ReadOnly Property IsCompleted As Boolean

        Get

            SyncLock task.mutex

                Return task.isCompleted

            End SyncLock

        End Get

    End Property

 

    Public Function GetResult() As T

        Return task.result

    End Function

 

    <Security.SecurityCritical>

    Public Sub UnsafeOnCompleted(continuation As Action) _
    Implements ICriticalNotifyCompletion.UnsafeOnCompleted

        Dim sc = If(SynchronizationContext.Current, New SynchronizationContext)

        task.AddContinuation(Sub() sc.Post(Sub() continuation(), Nothing))

    End Sub

 

    Public Sub OnCompleted(continuation As Action) _
    Implements INotifyCompletion.OnCompleted

        Throw New NotSupportedException()

    End Sub

End Class

 

 

AllowPartiallyTrustedCallers. Part of the .NET security model is that, if a method in your assembly can be invoked from a partially trusted caller, then none of your APIs can allow that caller to break free of the ExecutionContext they started in. A partially-trusted attacker might call the first MyTaskAwaiter.OnCompleted() directly (i.e. bypassing the Await operator) and hence execute an arbitrary lambda in the ExecutionContext of whoever invoked MyTask.SetResult(). The second version works around this by implementing ICriticalNotifyCompletion and putting <SecurityCritical> on the UnsafeOnCompleted method to prevent partially-trusted callers from invoking it. We already saw that the compiler prefers to use the ICriticalNotifyCompletion interface if it's available, over INotifyCompletion. We also saw that the generated code's call to UnsafeOnCompleted is made indirectly via a routine in mscorlib, which is why it's able to invoke <SecurityCritical> code even if your assembly is only partially trusted. (The code for MyTask would also have to be tightened up, so that continuations/AddContinuation can't be touched by partially trusted callers).

 

Race conditions. I believe the above code is free of race conditions and other concurrency bugs but I'm not sure. My first version suffered from several concurrency bugs. First bug: I forgot the lock on MyTask.AddContinuation(), allowing this interleaving:

  1. Thread1 invokes MyTask.AddContinuation()
  2. Thread1 passes the "Not isCompleted" test
  3. Thread2 invokes MyTask.SetResult()
  4. Thread2 sets isCompleted to true and executes all delegates in "continuations"
  5. Thread1 adds its continuation to "continuations"
  6. RESULT: the continuation from Thread1 never gets executed.

Second bug: I forgot the lock on MyTaskAwaiter.IsCompleted, so the code was vulnerable to memory reordering:

  1.  Thread1 invokes SetResult
    WRITE:result
        WRITE:isCompleted = true
  2. Thread2 invokes IsCompleted followed by GetResult
    READ:isCompleted returns true
        READ:result
  3. If the hardware re-ordered the two WRITE operations (allowed by ECMA but not CLR), then we might get this interleaving:
        WRITE:isCompleted = true
        READ:isCompleted returns true
        READ:result returns an invalid value
        WRITE:result
  4. If the hardware did speculative execution and re-ordered the two READ operations (I don't know if this is even allowed!) then we might get this interleaving:
        READ:result returns an invalid value
        WRITE:result
        WRITE:isCompleted = true
        READ:isCompleted returns true
  5. RESULT: in both cases the GetResult() call retrieved an invalid value

Third bug: In MyTask.SetResult(), the first version of my code looped over the continuations and executed them from inside the SyncLock. But this would have led to deadlock if any of the continuations themselves tried to await mytask from a different thread. The solution was to execute the continuations outside the SyncLock.

Fourth bug: In AddContinuation, the first version of my code executed the continuation directly. But this would have blown the stack if I had a loop which awaited a MyTask that had already completed. The solution was to schedule the continuation for later execution with Task.Run(continuation).

 

NotSupportedException. The AllowPartiallyTrustedCallers version of the code throws a NotSupportedException for anyone who calls its implementation of INotifyCompletion.OnCompleted. I figured there was no need, and it's hard to get it right so I'm better off just throwing the exception. After all, I'm only doing my custom awaiter to support the Await operator, and the compiler specs spell out the fact that the compiler won't invoke INotifyCompletion.OnCompleted on an awaiter that implements ICriticalNotifyCompletion.

 

What to do if SynchronizationContext.Current is null? My custom awaiter uses this code:
    Dim sc = If(SynchronizationContext.Current, New SynchronizationContext)   ' VB

    var sc = SynchronizationContext.Current ?? new SynchronizationContext();  // C#

I did as part of implementing the common convention, that the "Await" operator will resume execution on whatever SynchronizationContext was current before it started. The question arises of what to do if Synchronization.Current was null, as is the case in a Console application or a unit test, or in the body of Task.Run(): where should the continuation be posted? In my code, the new synchronization context will post the continuation-delegate to the threadpool, which is decent-enough behavior.
  

Conclusions

The custom awaiter code is difficult – difficult to follow the conventions, and difficult to write without concurrency bugs. If at all possible it's better to stick to the standard TaskAwaiter, or at least to build your own custom awaiter as a wrapper around TaskAwaiter.