[C#] Burst编译通用算法

27 篇文章 4 订阅
1 篇文章 0 订阅

英文原文:https://www.jacksondunstan.com/articles/5380

许多算法被反复使用:搜索、排序、过滤等。C# 通过 LINQ 和 Array.Sort 等函数提供这些算法,但这些算法无法由 Burst 编译,因为不支持接口和委托。那么我们如何实现这些通用算法以避免一遍又一遍地重写它们呢?今天我们将探索一种解决这个问题的技术。请继续阅读以了解具体方法!

正常方式

假设我们想要创建一个函数,将通过测试的 NativeArray 的所有元素复制到另一个 NativeArray 中。这样的函数可能称为 Filter,如下所示:

public static int Filter<T>(
    NativeArray<T> input,
    NativeArray<T> output,
    Func<T, bool> predicate)
    where T : struct
{
    int destIndex = 0;
    for (int i = 0; i < input.Length; ++i)
    {
        if (predicate(input[i]))
        {
            output[destIndex] = input[i];
            destIndex++;
        }
    }
    return destIndex;
}

然后我们像这样调用这个函数:

// Input is 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
NativeArray<int> input = new NativeArray<int>(10, Allocator.TempJob);
NativeArray<int> output = new NativeArray<int>(10, Allocator.TempJob);
for (int i = 0; i < input.Length; ++i)
{
    input[i] = i;
}
 
// Filter odds
int resultCount = Filter(input, output, val => (val & 1) != 0);
 
// Print results: 1, 3, 5, 7, 9
for (int i = 0; i < resultCount; ++i)
{
    print(output[i]);
}
 
// Cleanup
input.Dispose();
output.Dispose();

如果我们想删除只能过滤 NativeArray 的限制,以便我们可以使用任何类型的集合,我们可以将 Filter 重写为如下所示:

public static int Filter<T>(
    Func<int, bool> isDone,
    Func<int, bool> predicate,
    Action<int, int> addResult)
    where T : struct
{
    int destIndex = 0;
    for (int i = 0; !isDone(i); ++i)
    {
        if (predicate(i))
        {
            addResult(destIndex, i);
            destIndex++;
        }
    }
    return destIndex;
}

以下是该版本的使用方法:

// Input is 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
NativeArray<int> input = new NativeArray<int>(10, Allocator.TempJob);
NativeArray<int> output = new NativeArray<int>(10, Allocator.TempJob);
for (int i = 0; i < input.Length; ++i)
{
    input[i] = i;
}
 
// Filter odds
int resultCount = Filter(
    i => i >= input.Length,
    val => (val & 1) != 0,
    (dest, src) => output[dest] = input[src]);
 
// Print results: 1, 3, 5, 7, 9
for (int i = 0; i < resultCount; ++i)
{
    print(output[i]);
}
 
// Cleanup
input.Dispose();
output.Dispose();

不成功的方法

让我们尝试将第一个版本直接移植到 Burst 编译的作业中:

[BurstCompile]
struct UnsuccessfulFilterArrayJob<T> : IJob
    where T : struct
{
    public Func<T, bool> Predicate;
    [ReadOnly] public NativeArray<T> Input;
    [WriteOnly] public NativeArray<T> Output;
    [WriteOnly] public NativeArray<int> ResultCount;
 
    public void Execute()
    {
        int destIndex = 0;
        for (int i = 0; i < Input.Length; ++i)
        {
            if (Predicate(Input[i]))
            {
                Output[destIndex] = Input[i];
                destIndex++;
            }
        }
        ResultCount[0] = destIndex;
    }
}

我们在这里所做的只是创建一个结构体,将参数和返回值移动到字段,并重命名函数以符合 IJob 的要求。不幸的是,这不会编译:

/Users/builduser/buildslave/unity/build/Runtime/Jobs/Managed/IJob.cs(29,13): error: The delegate managed type System.Func2<System.Int32,System.Boolean> is not supported by burst at Unity.Jobs.IJobExtensions.JobStruct1<UnsuccessfulFilterArrayJob1<System.Int32>>.Execute(ref UnsuccessfulFilterArrayJob1<int> data, System.IntPtr additionalPtr, System.IntPtr bufferRangePatchData, ref Unity.Jobs.LowLevel.Unsafe.JobRanges ranges, int jobIndex) (at /Users/builduser/buildslave/unity/build/Runtime/Jobs/Managed/IJob.cs:29)

While compiling job: System.Void Unity.Jobs.IJobExtensions/JobStruct1<UnsuccessfulFilterArrayJob1<System.Int32>>::Execute(T&,System.IntPtr,System.IntPtr,Unity.Jobs.LowLevel.Unsafe.JobRanges&,System.Int32)

原因是我们使用了委托,这是 Burst 禁止的。任何依赖 Action 和 Func 的代码都无法编译。

不成功的方法:第 2 部分

让我们尝试通过使用结构体实现的接口来避免使用委托。 Burst 肯定允许使用结构体,因此有理由相信这可能有效:

interface IFilterPredicate<T>
{
    bool Test(in T val);
}
 
struct IsOddFilterPredicate : IFilterPredicate<int>
{
    public bool Test(in int val)
    {
        return (val & 1) != 0;
    }
}
 
[BurstCompile]
struct UnsuccessfulFilterArrayJob2<T> : IJob
    where T : struct
{
    public IFilterPredicate<T> Predicate;
    [ReadOnly] public NativeArray<T> Input;
    [WriteOnly] public NativeArray<T> Output;
    [WriteOnly] public NativeArray<int> ResultCount;
 
    public void Execute()
    {
        int destIndex = 0;
        for (int i = 0; i < Input.Length; ++i)
        {
            if (Predicate.Test(Input[i]))
            {
                Output[destIndex] = Input[i];
                destIndex++;
            }
        }
        ResultCount[0] = destIndex;
    }
}

这也无法编译:

/Users/builduser/buildslave/unity/build/Runtime/Export/NativeArray/NativeArray.cs(141,13): error: Unexpected error while processing function IFilterPredicate1.Test(IFilterPredicate1<int>* this, System.Int32& modreq(System.Runtime.InteropServices.InAttribute) val). Reason: Object reference not set to an instance of an object
at Unity.Collections.NativeArray1<System.Int32>.get_Item(Unity.Collections.NativeArray1* this, int index) (at /Users/builduser/buildslave/unity/build/Runtime/Export/NativeArray/NativeArray.cs:138)
at UnsuccessfulFilterArrayJob21<System.Int32>.Execute(UnsuccessfulFilterArrayJob21* this) (at /Users/jackson/Code/UnityPlayground/Assets/TestScript.cs:109)
at Unity.Jobs.IJobExtensions.JobStruct1<UnsuccessfulFilterArrayJob21<System.Int32>>.Execute(ref UnsuccessfulFilterArrayJob2`1 data, System.IntPtr additionalPtr, System.IntPtr bufferRangePatchData, ref Unity.Jobs.LowLevel.Unsafe.JobRanges ranges, int jobIndex) (at /Users/builduser/buildslave/unity/build/Runtime/Jobs/Managed/IJob.cs:30)

Internal compiler exception: System.NullReferenceException: Object reference not set to an instance of an object
at Burst.Compiler.IL.Helpers.CecilExtensions.IsDelegate (Mono.Cecil.TypeDefinition typeDef) [0x0002b] in <3179d4839c86430ca331f2949f40ede5>:0
at Burst.Compiler.IL.Intrinsics.FunctionPointerProcessor.IsDelegateInvoke (Burst.Compiler.IL.Syntax.ILFunctionReference method) [0x0000b] in <3179d4839c86430ca331f2949f40ede5>:0
at Burst.Compiler.IL.Intrinsics.FunctionPointerProcessor.IsHandling (Burst.Compiler.IL.Syntax.ILFunctionReference method) [0x00009] in <3179d4839c86430ca331f2949f40ede5>:0
at Burst.Compiler.IL.Syntax.ILBuilder.IsIntrinsicCall (Burst.Compiler.IL.Syntax.ILFunctionReference binding, System.Boolean& dontVisitImpl) [0x00019] in <3179d4839c86430ca331f2949f40ede5>:0
at Burst.Compiler.IL.Syntax.ILBuilder.CreateFunctionFromRef (Burst.Compiler.IL.Syntax.ILFunctionReference funcRef) [0x0003d] in <3179d4839c86430ca331f2949f40ede5>:0
at Burst.Compiler.IL.Syntax.ILBuilder.VisitPendingFunctionReferences () [0x000c1] in <3179d4839c86430ca331f2949f40ede5>:0

While compiling job: System.Void Unity.Jobs.IJobExtensions/JobStruct1<UnsuccessfulFilterArrayJob21<System.Int32>>::Execute(T&,System.IntPtr,System.IntPtr,Unity.Jobs.LowLevel.Unsafe.JobRanges&,System.Int32)

这里的直接原因是“内部编译器异常”,但实际上这是由于我们在 Job 结构中使用了接口类型。

手动方式

此时,我们可能会放弃并简单地为每个想要使用它的 Burst 编译作业复制过滤算法。这使我们能够删除委托或接口,但代价是在想要使用它的任何地方重复算法代码以及难以维护代码和可读性的相关缺点。看起来是这样的:

[BurstCompile]
struct ManualFilterOddJob : IJob
{
    [ReadOnly] public NativeArray<int> Input;
    [WriteOnly] public NativeArray<int> Output;
    [WriteOnly] public NativeArray<int> ResultCount;
 
    public void Execute()
    {
        int destIndex = 0;
        for (int i = 0; i < Input.Length; ++i)
        {
            if ((Input[i] & 1) != 0)
            {
                Output[destIndex] = Input[i];
                destIndex++;
            }
        }
        ResultCount[0] = destIndex;
    }
}

请注意,if 不再依赖于 predicate 委托,而是直接包含“is odd”检查。

这是编译后的汇编代码,我们稍后会参考:

movsxd  rcx, dword ptr [rdi + 8]
        test    rcx, rcx
        jle     .LBB0_1
        mov     rdx, qword ptr [rdi]
        xor     eax, eax
        .p2align        4, 0x90
.LBB0_6:
        mov     esi, dword ptr [rdx]
        test    sil, 1
        je      .LBB0_3
        mov     r8, qword ptr [rdi + 56]
        cdqe
        mov     dword ptr [r8 + 4*rax], esi
        inc     eax
.LBB0_3:
        add     rdx, 4
        dec     rcx
        jne     .LBB0_6
        mov     rcx, qword ptr [rdi + 112]
        mov     dword ptr [rcx], eax
        ret
.LBB0_1:
        xor     eax, eax
        mov     rcx, qword ptr [rdi + 112]
        mov     dword ptr [rcx], eax
        ret

通用方式

值得庆幸的是,这实际上是不必要的。事实证明,我们可以通过更多地使用 C# 泛型来创建可重用的算法。就是这样:

[BurstCompile]
struct FilterArrayJob<T, TPredicate> : IJob
    where T : struct
    where TPredicate : IFilterPredicate<T>
{
    public TPredicate Predicate;
    [ReadOnly] public NativeArray<T> Input;
    [WriteOnly] public NativeArray<T> Output;
    [WriteOnly] public NativeArray<int> ResultCount;
 
    public void Execute()
    {
        int destIndex = 0;
        for (int i = 0; i < Input.Length; ++i)
        {
            if (Predicate.Test(Input[i]))
            {
                Output[destIndex] = Input[i];
                destIndex++;
            }
        }
        ResultCount[0] = destIndex;
    }
}

这里的关键区别在于我们采用两个类型参数:元素类型 T 和谓词类型 TPredicate。这允许我们明确地告诉 Burst 我们正在使用哪种谓词结构,以便在编译时知道一切。

这确实可以编译并且现在可以像这样使用:

// Input is 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
NativeArray<int> input = new NativeArray<int>(10, Allocator.TempJob);
NativeArray<int> output = new NativeArray<int>(10, Allocator.TempJob);
NativeArray<int> resultCount = new NativeArray<int>(1, Allocator.TempJob);
for (int i = 0; i < input.Length; ++i)
{
    input[i] = i;
}
 
// Filter odds
new FilterArrayJob<int, IsOddFilterPredicate>
{
    Predicate = new IsOddFilterPredicate(),
    Input = input,
    Output = output,
    ResultCount = resultCount
}.Run();
 
// Print results: 1, 3, 5, 7, 9
for (int i = 0; i < resultCount[0]; ++i)
{
    print(output[i]);
}
 
// Cleanup
input.Dispose();
output.Dispose();
resultCount.Dispose();

这编译成与手动版本几乎相同的程序集。唯一的区别是偏移量为 8 个字节:

movsxd  rcx, dword ptr [rdi + 16]
        test    rcx, rcx
        jle     .LBB0_1
        mov     rdx, qword ptr [rdi + 8]
        xor     eax, eax
        .p2align        4, 0x90
.LBB0_6:
        mov     esi, dword ptr [rdx]
        test    sil, 1
        je      .LBB0_3
        mov     r8, qword ptr [rdi + 64]
        cdqe
        mov     dword ptr [r8 + 4*rax], esi
        inc     eax
.LBB0_3:
        add     rdx, 4
        dec     rcx
        jne     .LBB0_6
        mov     rcx, qword ptr [rdi + 120]
        mov     dword ptr [rcx], eax
        ret
.LBB0_1:
        xor     eax, eax
        mov     rcx, qword ptr [rdi + 120]
        mov     dword ptr [rcx], eax
        ret

真正通用的方法

一开始,我们又采取了一步,甚至抽象了 NativeArray,这样我们就可以使用其他类型的集合,例如 NativeChunkedList。让我们在这里应用相同的技术,看看是否可以在 Burst 编译的代码中实现这一目标。首先,这是新接口:

interface IFilterUser<T>
{
    bool IsDone(int index);
    bool Predicate(int index);
    void AddResult(int destIndex, int srcIndex);
}

就像以前一样,我们抽象了输入和输出集合,以便算法不会直接操作它们,甚至不知道它们的类型。相反,它只依赖于始终是 int 的索引。

接下来,这是使用该接口的作业:

[BurstCompile]
struct FilterJob<T, TUser> : IJob
    where T : struct
    where TUser : IFilterUser<T>
{
    public TUser User;
    [WriteOnly] public NativeArray<int> ResultCount;
 
    public void Execute()
    {
        int destIndex = 0;
        for (int i = 0; !User.IsDone(i); ++i)
        {
            if (User.Predicate(i))
            {
                User.AddResult(destIndex, i);
                destIndex++;
            }
        }
        ResultCount[0] = destIndex;
    }
}

这也和以前一样,因为我们所做的只是用对 IsDone 的调用替换循环条件,并用对 AddResult 的调用替换输出写入。

现在让我们看看如何为奇数整数创建“user”结构:

struct IsOddFilterUser : IFilterUser<int>
{
    [ReadOnly] public NativeArray<int> Input;
    [WriteOnly] public NativeArray<int> Output;
 
    public bool IsDone(int index)
    {
        return index >= Input.Length;
    }
 
    public bool Predicate(int index)
    {
        return (Input[index] & 1) != 0;
    }
 
    public void AddResult(int destIndex, int srcIndex)
    {
        Output[destIndex] = Input[srcIndex];
    }
}

以下是如何将它们组合在一起来运行过滤奇数整数的作业:

// Input is 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
NativeArray<int> input = new NativeArray<int>(10, Allocator.TempJob);
NativeArray<int> output = new NativeArray<int>(10, Allocator.TempJob);
NativeArray<int> resultCount = new NativeArray<int>(1, Allocator.TempJob);
for (int i = 0; i < input.Length; ++i)
{
    input[i] = i;
}
 
// Filter odds
new FilterJob<int, IsOddFilterUser>
{
    User = new IsOddFilterUser
    {
        Input = input,
        Output = output
    },
    ResultCount = resultCount
}.Run();
 
// Print results: 1, 3, 5, 7, 9
for (int i = 0; i < resultCount[0]; ++i)
{
    print(output[i]);
}
 
// Cleanup
input.Dispose();
output.Dispose();
resultCount.Dispose();

最后,我们看一下 Burst 编译后的程序集:

movsxd  rcx, dword ptr [rdi + 8]
        test    rcx, rcx
        jle     .LBB0_1
        mov     rdx, qword ptr [rdi]
        xor     eax, eax
        .p2align        4, 0x90
.LBB0_6:
        mov     esi, dword ptr [rdx]
        test    sil, 1
        je      .LBB0_3
        mov     r8, qword ptr [rdi + 56]
        cdqe
        mov     dword ptr [r8 + 4*rax], esi
        inc     eax
.LBB0_3:
        add     rdx, 4
        dec     rcx
        jne     .LBB0_6
        mov     rcx, qword ptr [rdi + 112]
        mov     dword ptr [rcx], eax
        ret
.LBB0_1:
        xor     eax, eax
        mov     rcx, qword ptr [rdi + 112]
        mov     dword ptr [rcx], eax
        ret

令人惊讶的是,这与手动版本的组件完全相同!

结束语

不仅可以用 Burst 编译的代码编写通用算法,而且从性能角度来看,所有这些抽象最终都完全自由。它不像使用 lambda 那样友好,但与手动将算法复制到每个作业类型相比,可读性和代码重用性得到了很大提高。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值