本文都是参考一个开源的项目:Unity3dAsyncAwaitUtil/releases
之前写的文章可以参考:Unity 中的 async-await 关键字解析
本文要实现的就是:await 可以直接接 协程 yield return 的东西
此处就拿 WaitForSecond 举例子:
void Start()
{
ForIEnumator();
}
async void ForIEnumator()
{
await WaitThenThrow();
}
IEnumerator WaitThenThrow()
{
yield return WaitThenThrowNested();
}
IEnumerator WaitThenThrowNested()
{
Debug.Log("Waiting 1 second...");
yield return new WaitForSeconds(1.0f);
throw new Exception("zxcv");
}
那么具体是怎么实现的呢?
首先 await 后面跟着的是 一个 INotifyCompletion 类型,很显然 WaitForSeconds 不是,那我们就可以写一个扩展方法,叫做 GetAwaiter 兼容一下
public static IEnumatorAwaiter<object> GetAwaiter(this IEnumerator coroutine)
{
IEnumatorAwaiter<object> awaiter = new IEnumatorAwaiter<object>();
SyncContextUtil.RunOnUnityScheduler(() =>
{
CoroutineRunner.Instance.StartCoroutine(
new IEnumeratorWrapper<object>(coroutine, awaiter).Run());
});
return awaiter;
}
public static IEnumatorAwaiter<T> GetAwaiter<T>(this IEnumerator<T> coroutine)
{
IEnumatorAwaiter<T> awaiter = new IEnumatorAwaiter<T>();
SyncContextUtil.RunOnUnityScheduler(() =>
{
CoroutineRunner.Instance.StartCoroutine(
new IEnumeratorWrapper<T>(coroutine, awaiter).Run());
});
return awaiter;
}
跟之前一样的就不说了哈,看一下之前的文章:Unity 中的 async-await 支持YieldInstruction
新创建了一个反省的 Awaiter 类:IEnumatorAwaiter
public class IEnumatorAwaiter<T> : INotifyCompletion
{
bool _isDone;
Exception _exception;
Action _continuation;
T _result;
public bool IsCompleted
{
get { return _isDone; }
}
public T GetResult()
{
AssertUtil.Assert(_isDone);
if (_exception != null)
{
//当跑起了一个异步线程,并用 await 异步等待时,通过这个 可以在主线程捕获异步线程的异常
// ExceptionDispatchInfo.Capture 可以重新抛出被捕获时的调用栈(StackTrace)
ExceptionDispatchInfo.Capture(_exception).Throw();
}
return _result;
}
public void Complete(T result, Exception e)
{
AssertUtil.Assert(!_isDone);
_isDone = true;
_exception = e;
_result = result;
// 当 await unity yield 时候,所有都在 Unity主线程里统一触发
if (_continuation != null)
{
SyncContextUtil.RunOnUnityScheduler(_continuation);
}
}
void INotifyCompletion.OnCompleted(Action continuation)
{
AssertUtil.Assert(_continuation == null);
AssertUtil.Assert(!_isDone);
_continuation = continuation;
}
}
跟之前不一样得是 多了一个 IEnumeratorWrapper这个类
那么他具体是干啥的呢?
1.封装一层 IEnumerator 就是对应的 Run 方法
2.进行异常捕获处理,嵌套 IEnumerator 使用栈记录
public class IEnumeratorWrapper<T>
{
//保存 IEnumatorAwaiter,就是一个带泛型的 Awaiter
readonly IEnumatorAwaiter<T> _awaiter;
readonly Stack<IEnumerator> _processStack;
public IEnumeratorWrapper(IEnumerator coroutine, IEnumatorAwaiter<T> awaiter)
{
_processStack = new Stack<IEnumerator>();
_processStack.Push(coroutine);
_awaiter = awaiter;
}
public IEnumerator Run()
{
while (true)
{
IEnumerator topWorker = _processStack.Peek();
bool isDone;
try
{
isDone = !topWorker.MoveNext();
}
catch (Exception e)
{
//通过反射获取 IEnumerators 协程方法的实际名称,把它添加到异常输出
var objectTrace = GenerateObjectTrace(_processStack);
if (objectTrace.Any())
{
_awaiter.Complete(default(T), new Exception(GenerateObjectTraceMessage(objectTrace), e));
}
else
{
_awaiter.Complete(default(T), e);
}
yield break;
}
if (isDone)
{
_processStack.Pop();
if (_processStack.Count == 0)
{
_awaiter.Complete((T)topWorker.Current, null);
yield break;
}
}
//在这里管理嵌套异常的异常捕获
if (topWorker.Current is IEnumerator)
{
_processStack.Push((IEnumerator)topWorker.Current);
}
else
{
//将当前值返回到unity引擎,以便它可以处理 WaitForSeconds,WaitToEndOfFrame 等
yield return topWorker.Current;
}
}
}
string GenerateObjectTraceMessage(List<Type> objTrace)
{
var result = new StringBuilder();
foreach (var objType in objTrace)
{
if (result.Length != 0)
{
result.Append(" -> ");
}
result.Append(objType.ToString());
}
result.AppendLine();
return "Unity Coroutine Object Trace: " + result.ToString();
}
static List<Type> GenerateObjectTrace(IEnumerable<IEnumerator> enumerators)
{
var objTrace = new List<Type>();
foreach (var enumerator in enumerators)
{
var field = enumerator.GetType().GetField("$this", BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance);
if (field == null)
{
continue;
}
var obj = field.GetValue(enumerator);
if (obj == null)
{
continue;
}
var objType = obj.GetType();
//Any() 如果源序列包含任何元素,则为 true;否则为 false。
if (!objTrace.Any() || objType != objTrace.Last())
{
objTrace.Add(objType);
}
}
objTrace.Reverse();
return objTrace;
}
}