.NET Task 揭秘(3)async 与 AsyncMethodBuilder

目录

前言

本文为系列博客

  1. 什么是 Task
  2. Task 的回调执行与 await
  3. async 与 AsyncMethodBuilder(本文)
  4. 总结与常见误区(TODO)

上文我们学习了 await 这个语法糖背后的实现,了解了 await 这个关键词是如何去等待 Task 的完成并获取 Task 执行结果。并且我们还实现了一个简单的 awaitable 类型,它可以让我们自定义 await 的行为。

class FooAwaitable<TResult>

{

    // 回调,简化起见,未将其包裹到 TaskContinuation 这样的容器里

    private Action _continuation;



    private TResult _result;

    

    private Exception _exception;



    private volatile bool _completed;



    public bool IsCompleted => _completed;



    // Awaitable 中的关键部分,提供 GetAwaiter 方法

    public FooAwaiter<TResult> GetAwaiter() => new FooAwaiter<TResult>(this);



    public void Run(Func<TResult> func)

    {

        new Thread(() =>

        {

            var result = func();

            TrySetResult(result);

        })

        {

            IsBackground = true

        }.Start();

    }



    private bool AddFooContinuation(Action action)

    {

        if (_completed)

        {

            return false;

        }

        _continuation += action;

        return true;

    }



    internal void TrySetResult(TResult result)

    {

        _result = result;

        _completed = true;

        _continuation?.Invoke();

    }

    

    internal void TrySetException(Exception exception)

    {

        _exception = exception;

        _completed = true;

        _continuation?.Invoke();

    }



    // 1 实现 ICriticalNotifyCompletion

    public struct FooAwaiter<TResult> : ICriticalNotifyCompletion

    {

        private readonly FooAwaitable<TResult> _fooAwaitable;

        

        // 2 实现 IsCompleted 属性

        public bool IsCompleted => _fooAwaitable.IsCompleted;



        public FooAwaiter(FooAwaitable<TResult> fooAwaitable)

        {

            _fooAwaitable = fooAwaitable;

        }



        public void OnCompleted(Action continuation)

        {

            Console.WriteLine("FooAwaiter.OnCompleted");

            if (_fooAwaitable.AddFooContinuation(continuation))

            {

                Console.WriteLine("FooAwaiter.OnCompleted: added continuation");

            }

            else

            {

                Console.WriteLine("FooAwaiter.OnCompleted: already completed, invoking continuation");

                continuation();

            }

        }



        public void UnsafeOnCompleted(Action continuation)

        {

            Console.WriteLine("FooAwaiter.UnsafeOnCompleted");

            if (_fooAwaitable.AddFooContinuation(continuation))

            {

                Console.WriteLine("FooAwaiter.UnsafeOnCompleted: added continuation");

            }

            else

            {

                Console.WriteLine("FooAwaiter.UnsafeOnCompleted: already completed, invoking continuation");

                continuation();

            }

        }



        // 3. 实现 GetResult 方法

        public TResult GetResult()

        {

            if (_fooAwaitable._exception != null)

            {

                // 4. 如果 awaitable 中有异常,则抛出

                throw _fooAwaitable._exception;

            }

            Console.WriteLine("FooAwaiter.GetResult");

            return _fooAwaitable._result;

        }

    }

}

如果在一个方法中使用了 await,那么这个方法就必须添加 async 修饰符。并且这个方法的返回类型通常是 Task 或者 其它 runtime 里定义的 awaitable 类型。

int foo = await FooAsync();

Console.WriteLine(foo); // 1



async Task<int> FooAsync()

{

    await Task.Delay(1000);

    return 1;

}

问题1: 上面的代码中,FooAsync 方法是一个异步方法,它的返回类型是 Task。但代码中的 await FooAsync() 并不会返回 Task,而是返回 int。这是为什么呢?

如果我们把 FooAsync 的返回值改成我们自己实现的 awaitable 类型,编译器会报错:

正在上传…重新上传取消

问题2: 明明我们可以在 FooAwaitable 实例上使用 await 关键词,为什么把它作为 FooAsync 的返回类型就会报错呢?且提示它不是一个 task-like 类型?

实际上我们在上篇文章实现的 awaitable 类型 FooAwaitable,只是支持了 await 关键词,并不是一个完整的 task-like 类型。

而上面两个问题的答案就是本文要讲的内容:AsyncMethodBuilder

AsyncMethodBuilder 介绍

AsyncMethodBuilder 是状态机的重要组成部分#

引用上一篇文章介绍状态机的代码:

class Program

{

    static async Task Main(string[] args)

    {

        var a = 1;

        Console.WriteLine(await FooAsync(a));

    }



    static async Task<int> FooAsync(int a)

    {

        int b = 2;

        int c = await BarAsync();

        return a + b + c;

    }



    static async Task<int> BarAsync()

    {

        await Task.Delay(100);

        return 3;

    }

}

由 FooAsync 编译成的 IL 代码经整理后的等效 C# 代码如下:

using System;

using System.Runtime.CompilerServices;

using System.Threading.Tasks;



class Program

{

    static async Task Main(string[] args)

    {

        var a = 1;

        Console.WriteLine(await FooAsync(a));

    }



    static Task<int> FooAsync(int a)

    {

        var stateMachine = new FooStateMachine

        {

            _asyncTaskMethodBuilder = AsyncTaskMethodBuilder<int>.Create(),

    

            _state = -1, // 初始化状态

            _a = a // 将实参拷贝到状态机字段

        };

        // 开始执行状态机

        stateMachine._asyncTaskMethodBuilder.Start(ref stateMachine);

        return stateMachine._asyncTaskMethodBuilder.Task;

    }



    static async Task<int> BarAsync()

    {

        await Task.Delay(100);

        return 3;

    }



    public class FooStateMachine : IAsyncStateMachine

    {

        // 方法的参数和局部变量被编译会字段

        public int _a;

        public AsyncTaskMethodBuilder<int> _asyncTaskMethodBuilder;

        private int _b;



        private int _c;



        // -1: 初始化状态

        // 0: 等到 Task 执行完成

        // -2: 状态机执行完成

        public int _state;



        private TaskAwaiter<int> _taskAwaiter;



        public void MoveNext()

        {

            var result = 0;

            TaskAwaiter<int> taskAwaiter;

            try

            {

                // 状态不是0,代表 Task 未完成

                if (_state != 0)

                {

                    // 初始化局部变量

                    _b = 2;



                    taskAwaiter = Program.BarAsync().GetAwaiter();

                    if (!taskAwaiter.IsCompleted)

                    {

                        // state: -1 => 0,异步等待 Task 完成

                        _state = 0;

                        _taskAwaiter = taskAwaiter;

                        var stateMachine = this;

                        // 内部会调用 将 stateMachine.MoveNext 注册为 Task 的回调

                        _asyncTaskMethodBuilder.AwaitUnsafeOnCompleted(ref taskAwaiter, ref stateMachine);

                        return;

                    }

                }

                else

                {

                    taskAwaiter = _taskAwaiter;

                    // TaskAwaiter 是个结构体,这边相当于是个清空 _taskAwaiter 字段的操作

                    _taskAwaiter = new TaskAwaiter<int>();

                    // state: 0 => -1,状态机恢复到初始化状态

                    _state = -1;

                }



                _c = taskAwaiter.GetResult();

                result = _a + _b + _c;

            }

            catch (Exception e)

            {

                // state: any => -2,状态机执行完成

                _state = -2;

                _asyncTaskMethodBuilder.SetException(e);

                return;

            }



            // state: -1 => -2,状态机执行完成

            _state = -2;

            // 将 result 设置为 FooAsync 方法的返回值

            _asyncTaskMethodBuilder.SetResult(result);

        }



        public void SetStateMachine(IAsyncStateMachine stateMachine)

        {

        }

    }

}

在编译器生成的状态机类中,我们可以看到一个名为 _asyncTaskMethodBuilder 的字段,它的类型是 AsyncTaskMethodBuilder<int>。
这个 AsyncTaskMethodBuilder 就是 Task所绑定的 AsyncMethodBuilder。

AsyncMethodBuilder 的结构#

以 AsyncTaskMethodBuilder<TResult> 为例,我们来看下 AsyncMethodBuilder 的结构:

public struct AsyncTaskMethodBuilder<TResult>

{

    // 保存最后作为返回值的 Task

    private Task<TResult>? m_task;



    // 创建一个 AsyncTaskMethodBuilder

    public static AsyncTaskMethodBuilder<TResult> Create() => default;



    // 开始执行 AsyncTaskMethodBuilder 及其绑定的状态机 

    public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine =>

        AsyncMethodBuilderCore.Start(ref stateMachine);



    // 绑定状态机,但编译器的编译结果不会调用

    public void SetStateMachine(IAsyncStateMachine stateMachine) =>

        AsyncMethodBuilderCore.SetStateMachine(stateMachine, m_task);



    // 将状态机的 MoveNext 方法注册为 async方法 内 await 的 Task 的回调

    public void AwaitOnCompleted<TAwaiter, TStateMachine>(

        ref TAwaiter awaiter, ref TStateMachine stateMachine)

        where TAwaiter : INotifyCompletion

        where TStateMachine : IAsyncStateMachine =>

        AwaitOnCompleted(ref awaiter, ref stateMachine, ref m_task);



    // 同上,参考前一篇文章讲 UnsafeOnCompleted 和 OnCompleted 的区别

    [MethodImpl(MethodImplOptions.AggressiveInlining)]

    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(

        ref TAwaiter awaiter, ref TStateMachine stateMachine)

        where TAwaiter : ICriticalNotifyCompletion

        where TStateMachine : IAsyncStateMachine =>

        AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine, ref m_task);



    public Task<TResult> Task

    {

        get => m_task ?? InitializeTaskAsPromise();

    }



    public void SetResult(TResult result)

    {

        if (m_task is null)

        {

            m_task = Threading.Tasks.Task.FromResult(result);

        }

        else

        {

            SetExistingTaskResult(m_task, result);

        }

    }



    public void SetException(Exception exception) => SetException(exception, ref m_task);

}

非泛型的 Task 对应的 AsyncMethodBuilder 是 AsyncTaskMethodBuilder,它的结构与泛型的 AsyncTaskMethodBuilder<TResult> 类似,但因为最终返回的 Task 没有执行结果,它的 SetResult 只是为了标记 Task 的完成状态并触发 Task 的回调。

public struct AsyncTaskMethodBuilder

{

    private Task<VoidTaskResult>? m_task;



    public static AsyncTaskMethodBuilder Create() => default;





    public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine =>

        AsyncMethodBuilderCore.Start(ref stateMachine);



    public void SetStateMachine(IAsyncStateMachine stateMachine) =>

        AsyncMethodBuilderCore.SetStateMachine(stateMachine, task: null);



    public void AwaitOnCompleted<TAwaiter, TStateMachine>(

        ref TAwaiter awaiter, ref TStateMachine stateMachine)

        where TAwaiter : INotifyCompletion

        where TStateMachine : IAsyncStateMachine =>

        AsyncTaskMethodBuilder<VoidTaskResult>.AwaitOnCompleted(ref awaiter, ref stateMachine, ref m_task);



    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(

        ref TAwaiter awaiter, ref TStateMachine stateMachine)

        where TAwaiter : ICriticalNotifyCompletion

        where TStateMachine : IAsyncStateMachine =>

        AsyncTaskMethodBuilder<VoidTaskResult>.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine, ref m_task);



    public Task Task

    {

        get => m_task ?? InitializeTaskAsPromise();

    }



    public void SetResult()

    {

        if (m_task is null)

        {

            m_task = Task.s_cachedCompleted;

        }

        else

        {

            AsyncTaskMethodBuilder<VoidTaskResult>.SetExistingTaskResult(m_task, default!);

        }

    }



    public void SetException(Exception exception) =>

        AsyncTaskMethodBuilder<VoidTaskResult>.SetException(exception, ref m_task);

}            

AsyncMethodBuilder 功能分析#

AsyncTaskMethodBuilder 在 FooAsync 方法的执行过程中,起到了以下作用:

  1. 对内:关联状态机和状态机执行的上下文,管理状态机的生命周期。
  2. 对外:构建一个 Task 对象,作为异步方法的返回值,并会触发该 Task 执行的完成或异常。

为了方便说明,下文我们将 FooAsync 方法返回的 Task 称为 FooTask,BarAsync 方法返回的 Task 称为 BarTask。

对状态机的生命周期进行管理#

状态机通过 _asyncTaskMethodBuilder.Start 方法来启动且其 MoveNext 方式是通过 _asyncTaskMethodBuilder.AwaitUnsafeOnCompleted 方法来注册为 BarTask 的回调的。

对 async 方法的返回值进行包装#

_asyncTaskMethodBuilder 是用来构建一个 Task 对象,_asyncTaskMethodBuilder 的 Task 属性就是 FooAsync 方法返回的 FooTask。通过 _asyncTaskMethodBuilder 的 SetResult 方法,我们可以设置 FooTask 的执行结果, 通过 SetException 方法,我们可以设置 FooTask 的异常。

小结#

一个 AsyncMethodBuilder 是由下面几个部分组成的:

  1. 一个 Task 对象,作为异步方法的返回值。
  2. Create 方法,用来创建 AsyncMethodBuilder。
  3. Start 方法,用来启动状态机。
  4. AwaitOnCompleted/AwaitUnsafeOnCompleted 方法,用来将状态机的 MoveNext 方法注册为 async方法 内 await 的 Task 的回调。
  5. SetResult/SetException 方法,用来标记 Task 的完成状态并触发 Task 的回调。
  6. SetStateMachine 方法,用来关联状态机,不常用,编译结果也不会调用。

async void

为了让 async 方法适配传统的事件回调,C# 引入了 async void 的概念。

var foo = new Foo();

foo.OnSayHello += FooAsync;

foo.SayHello();



Console.ReadLine();



async void FooAsync(object sender, EventArgs e)

{

    var args = e as SayHelloEventArgs;

    await Task.Delay(1000);

    Console.WriteLine(args.Message);

}



class Foo

{

    public event EventHandler OnSayHello;



    public void SayHello()

    {

        OnSayHello.Invoke(this, new SayHelloEventArgs { Message = "Hello" });

    }

}



class SayHelloEventArgs : EventArgs

{

    public string Message { get; set; }

}

async void 也有一个对应的 AsyncVoidMethodBuilder。

    public struct AsyncVoidMethodBuilder

    {

        // AsyncVoidMethodBuilder 是对 AsyncTaskMethodBuilder 的封装

        private AsyncTaskMethodBuilder _builder;



        public static AsyncVoidMethodBuilder Create()

        {

            // ...

        }



        public void Start<TStateMachine>(ref TStateMachine stateMachine) where TStateMachine : IAsyncStateMachine =>

            AsyncMethodBuilderCore.Start(ref stateMachine);



        public void SetStateMachine(IAsyncStateMachine stateMachine) =>

            _builder.SetStateMachine(stateMachine);



        public void AwaitOnCompleted<TAwaiter, TStateMachine>(

            ref TAwaiter awaiter, ref TStateMachine stateMachine)

            where TAwaiter : INotifyCompletion

            where TStateMachine : IAsyncStateMachine =>

            _builder.AwaitOnCompleted(ref awaiter, ref stateMachine);



        public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(

            ref TAwaiter awaiter, ref TStateMachine stateMachine)

            where TAwaiter : ICriticalNotifyCompletion

            where TStateMachine : IAsyncStateMachine =>

            _builder.AwaitUnsafeOnCompleted(ref awaiter, ref stateMachine);



        public void SetResult()

        {

            // 仅仅是做 runtime 的一些状态标记

        }



        public void SetException(Exception exception)

        {

            // 这个异常只能通过 TaskScheduler.UnobservedTaskException 事件来捕获

        }



        // 因为没有返回值,这个 Task 不对外暴露

        private Task Task => _builder.Task;

    }

自定义 AsyncMethodBuilder

自定义一个 AsyncMethodBuilder,不需要实现任意接口,只需要实现上面说的那 6 个主要组成部分,编译器就能够正常编译。

awaitable 绑定 AsyncMethodBuilder 的方式有两种:

  1. 在 awaitable 类型上添加 AsyncMethodBuilderAttribute 来绑定 AsyncMethodBuilder。
  2. 在 async 方法上添加 AsyncMethodBuilderAttribute 来绑定 AsyncMethodBuilder,用来覆盖 awaitable 类型上的 AsyncMethodBuilderAttribute(前提是 awaitable 类型上有 AsyncMethodBuilderAttribute)。
struct FooAsyncMethodBuilder<TResult>

{

    private FooAwaitable<TResult> _awaitable;



    // 1. 定义 Task 属性

    public FooAwaitable<TResult> Task

    {

        get

        {

            Console.WriteLine("FooAsyncMethodBuilder.Task");

            return _awaitable;

        }

    }

    

    // 2. 定义 Create 方法

    public static FooAsyncMethodBuilder<TResult> Create()

    {

        Console.WriteLine("FooAsyncMethodBuilder.Create");

        var awaitable = new FooAwaitable<TResult>();

        var builder = new FooAsyncMethodBuilder<TResult>

        {

            _awaitable = awaitable,

        };

        return builder;

    }

    

    // 3. 定义 Start 方法

    public void Start<TStateMachine>(ref TStateMachine stateMachine)

        where TStateMachine : IAsyncStateMachine

    {

        Console.WriteLine("FooAsyncMethodBuilder.Start");

        stateMachine.MoveNext();

    }



    

    // 4. 定义 AwaitOnCompleted/AwaitUnsafeOnCompleted 方法

    

    // 如果 awaiter 实现了 INotifyCompletion 接口,就调用 AwaitOnCompleted 方法

    public void AwaitOnCompleted<TAwaiter, TStateMachine>(ref TAwaiter awaiter, ref TStateMachine stateMachine)

        where TAwaiter : INotifyCompletion

        where TStateMachine : IAsyncStateMachine

    {

        Console.WriteLine("FooAsyncMethodBuilder.AwaitOnCompleted");

        awaiter.OnCompleted(stateMachine.MoveNext);

    }



    [SecuritySafeCritical]

    public void AwaitUnsafeOnCompleted<TAwaiter, TStateMachine>(

        ref TAwaiter awaiter,

        ref TStateMachine stateMachine)

        where TAwaiter : ICriticalNotifyCompletion

        where TStateMachine : IAsyncStateMachine

    {

        Console.WriteLine("FooAsyncMethodBuilder.AwaitUnsafeOnCompleted");

        awaiter.UnsafeOnCompleted(stateMachine.MoveNext);

    }



    // 5. 定义 SetResult/SetException 方法

    public void SetResult(TResult result)

    {

        Console.WriteLine("FooAsyncMethodBuilder.SetResult");

        _awaitable.TrySetResult(result);

    }

    

    public void SetException(Exception exception)

    {

        Console.WriteLine("FooAsyncMethodBuilder.SetException");

        _awaitable.TrySetException(exception);

    }

    

    // 6. 定义 SetStateMachine 方法,虽然编译器不会调用,但是编译器要求必须有这个方法

    public void SetStateMachine(IAsyncStateMachine stateMachine)

    {

        Console.WriteLine("FooAsyncMethodBuilder.SetStateMachine");

    }

}



// 7. 通过 AsyncMethodBuilderAttribute 绑定 FooAsyncMethodBuilder

[AsyncMethodBuilder(typeof(FooAsyncMethodBuilder<>))]

class FooAwaitable<TResult>

{

    // ...

}
Console.WriteLine("await Foo1Async()");

int foo1= await Foo1Async();

Console.WriteLine("Foo1Async() result: " + foo1);

Console.WriteLine();



Console.WriteLine("await Foo2Async()");



int foo2 = await Foo2Async();

Console.WriteLine("Foo2Async() result: " + foo2);

Console.WriteLine();



Console.WriteLine("await FooExceptionAsync()");

try

{

    await FooExceptionAsync();

}

catch (Exception e)

{

    Console.WriteLine(e.Message);

}



async FooAwaitable<int> Foo1Async()

{

    await Task.Delay(1000);

    return 1;

}



// 覆盖默认的 AsyncMethodBuilder,使用 FooAsyncMethodBuilder2

// 本文省略了 FooAsyncMethodBuilder2 的定义,可以参考上面的 FooAsyncMethodBuilder

[AsyncMethodBuilder(typeof(FooAsyncMethodBuilder2<>))]

async FooAwaitable<int> Foo2Async()

{

    await Task.Delay(1000);

    return 2;

}

执行结果:

await Foo1Async()

FooAsyncMethodBuilder.Create

FooAsyncMethodBuilder.Start

FooAsyncMethodBuilder.AwaitUnsafeOnCompleted

FooAsyncMethodBuilder.Task

FooAwaiter.UnsafeOnCompleted

FooAwaiter.UnsafeOnCompleted: added continuation

FooAsyncMethodBuilder.SetResult

FooAwaiter.GetResult

Foo1Async() result: 1



await Foo2Async()

FooAsyncMethodBuilder2.Create

FooAsyncMethodBuilder2.Start

FooAsyncMethodBuilder2.AwaitUnsafeOnCompleted

FooAsyncMethodBuilder2.Task

FooAwaiter.UnsafeOnCompleted

FooAwaiter.UnsafeOnCompleted: added continuation

FooAsyncMethodBuilder2.SetResult

FooAwaiter.GetResult

Foo2Async() result: 2



await FooExceptionAsync()

FooAsyncMethodBuilder.Create

FooAsyncMethodBuilder.Start

FooAsyncMethodBuilder.AwaitUnsafeOnCompleted

FooAsyncMethodBuilder.Task

FooAwaiter.UnsafeOnCompleted

FooAwaiter.UnsafeOnCompleted: added continuation

FooAsyncMethodBuilder.SetException

Exception from FooExceptionAsync

在方法上添加 AsyncMethodBuilderAttribute 的功能是后来才添加的,通过这个功能,可以覆盖 awaitable 类型上的 AsyncMethodBuilderAttribute,以便进行性能优化。例如 .NET 6 开始提供的 PoolingAsyncValueTaskMethodBuilder,对原始的 AsyncValueTaskMethodBuilder 进行了池化处理,可以通过在方法上添加 AsyncMethodBuilderAttribute 来使用。

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值