以下内容翻译自:Optimize Deep Learning GPU Operators with TVM: A Depthwise Convolution Example
高效的深度学习算子是深度学习系统的核心。通常这些算子很难优化,并且需要高性能计算专家的努力。TVM,端到端张量IR/DSL堆栈,使得这项任务更容易。
这个博客教你如何在TVM的帮助下编写高性能GPU运算核心。我们使用深度卷积(即topi.nn.depthwise_conv2d_nchw)作为示例,并演示如何在tensorflow中优化手动调优过的CUDA内核。在不同的工作负载下,我们的最终版本比tf-1.2中的优化内核快2到4倍,启用算子融合时速度快了3x-7倍。以下是在GTX1080上,filter size= [1,256,3,3],stride = [1,1],padding ='SAME’的测试结果:
Depthwise Convolution介绍
深度卷积是现代架构的重要组成部分,如Xception和MobileNet。这是一种降低深度神经网络计算复杂度的有效方法。
source: http://machinethink.net/blog/googles-mobile-net-architecture-on-iphone/
在TVM中,深度卷积可以被声明为:
# padding stage
PaddedInput = tvm.compute(
(batch, in_channel, height_after_pad, width_after_pad),
lambda b, c, i, j: tvm.select(
tvm.all(i >= pad_top, i - pad_top < in_height, j >= pad_left, j - pad_left < in_width),
Input[b, c, i - pad_top, j - pad_left], tvm.const(0.0)),
name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
(batch, out_channel, out_height, out_width),
lambda b, c, i, j: tvm.sum(
PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * Filter[c/channel_multiplier, c%channel_multiplier, di, dj],
axis=[di, dj]),
name='DepthwiseConv2d')
通用GPU优化指南
本部分简要介绍了优化CUDA代码时应该了解的三个概念:数据重用,共享内存和存储体冲突。如果你已了解它们,很好,那么你可以跳过这部分。
数据重用
在现代计算体系结构中,从存储器加载数据的成本远高于进行单个浮点计算。因此,我们总是希望在输入数据加载到寄存器或共享内存(缓存)后重新使用输入数据。
深度卷积有两种形式的数据重用:
- 滤波器重用
- 输入重用
滤波器重用发生在滤波器滑动窗口并进行多次计算时;输入重用是通过平铺来实现的,我们以3x3深度转换为例:
如果没有平铺,每个线程加载3x3输入数据并计算1个输出元素。16个线程一起有9x16负载。
通过平铺,每个线程加载4x4输入数据并计算2x2输出元素。4个线程一起有16x4负载。
共享内存和Bank Conflicts
共享内存可以被看作是GPU中的缓存。它是片上的,比全局存储器要快得多。
共享内存按块分配。通常的做法是将全局内存中的数据加载到共享内存中,然后块中的所有线程都从共享内存中读取数据。
共享内存的大小是有限的(通常是48K),所以我们必须注意共享内存溢出。此外,分配给一个块的共享内存太多会限制每个多处理器的活动块数量。
共享内存的另一个性能问题是Bank Conflicts。共享内存被分成可以同时访问的大小相同的内存模块(bank),但是,如果多个线程访问相同的存储体(导致bank冲突),访问将被串行化,从而降低有效带宽。
共享存储体的组织方式使得连续的地址被分配给连续的存储体。为了避免存储体冲突,最好连续的线程访问连续的内存地址,如下所示(每种颜色代表一个共享内存组):
有关共享内存和存储体冲突的更多详细信息,请参阅Nvidia的博客。
好吧,现在让我们开始优化TVM中的深度卷积。
Schedule优化
内联计算PaddedInput以节省内存分配
正如我们从第1部分看到的那样,填充被明确地声明为一个单独的阶段。我们在线计算它以避免冗余内存分配:
s = tvm.create_schedule(Output.op)
s[PaddedInput].compute_inline()
将一个大通道分成较小的块
深度卷积的一个简单的调度是一个cuda块负责一个输入通道和相应的滤波器,将它们加载到共享内存中,然后计算:
IS = s.cache_read(PaddedInput, "shared", [DepthwiseConv2d])
FS = s.cache_read(Filter, "shared", [DepthwiseConv2d])
block_y = tvm.thread_axis("blockIdx.y")
block_x = tvm.thread_axis("blockIdx.x")
# bind the dimension of batch (N in NCHW) with block_y
s[Output].bind(Output.op.axis[0], block_y)
# bind the dimension of channel (C in NCHW) with block_x
s[Output].bind(Output.op.axis[1], block_x)
我们在GTX 1080上测试1000次运行的平均时间成本,并与tensorflow中的depthwise conv2d进行比较。结果如下:
Input | Filter | stride | tf-1.2 SAME pad (us) | TVM SAME pad (us) |
---|---|---|---|---|
[1, 256, 21, 21] | [256, 1, 3, 3] | [1, 1] | 16.1 | 9.1 |
[1, 256, 32, 32] | [256, 1, 3, 3] | [1, 1] | 34.8 | 14.5 |
[1, 256, 64, 64] | [256, 1, 3, 3] | [1, 1] | 130.9 | 98.9 |
[1, 256, 96, 96] | [256, 1, 3, 3] | [1, 1] | 251.6 | 387.4 |
正如我们所看到的,这个调度表在21x21或32x32这样的小特征图下表现良好,然而,随着特征图增加到大于64x64,其性能严重下降。一个主要原因是分配的共享内存过多 一个块限制每个多处理器的活动块数量。
我们修改调度表将一个大通道分成更小的块。例如,一个通道(64x64或96x96)被分成32x32的块,一个cuda块处理一个32x32的块:
blocking_h = 32
blocking_w = 32
# split the dimension of height (H in NCHW)
bx1, _ = s[Output].split(Output.op.axis[2], factor=blocking_h)
# split the dimension of width (W in NCHW)
bx2, _ = s[Output].split(Output.op.axis[3], factor=blocking_w)
# assign one 32 x 32 block to one cuda block
by = s[Output].fuse(Output.op.axis[0], Output.op.axis[1])
s[Output].bind(by, block_y)
bx = s[Output].fuse(bx1, bx2)
s[Output].bind(bx, block_x)
这是新的结果:
Input | [blocking_h, blocking_w] | tf-1.2 SAME pad (us) | TVM SAME pad (us) |
---|---|---|---|
[1, 256, 64, 64] | [32, 32] | 130.9 | 63.4 |
[1, 256, 96, 96] | [32, 32] | 251.6 | 132.5 |
我们的分块策略有效!对于64x64尺寸通道,它带来1.6倍的加速(98.9us->63.4us); 对于96x96尺寸通道,它带来了2.9倍的加速(387.4us->132.5us)。
调整线程号参数
如何在一个cuda块中安排32x32线程的工作负载?直观地说,它应该是这样的:
num_thread_y = 8
num_thread_x = 8
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
ty, yi = s[Output].split(h_dim, nparts=num_thread_y)
tx, xi = s[Output].split(w_dim, nparts=num_thread_x)
s[Output].reorder(ty, tx, yi, xi)
s[Output].bind(ty, thread_y)
s[Output].bind(tx, thread_x)
调度中有两个参数:num_thread_y
和num_thread_x
。如何确定它们的最佳组合? 那么,我们先做一些实验。以下是Filter = [256,1,3,3]和stride = [1,1]的结果:
Case | Input | num_thread_y | num_thread_x | TVM SAME pad (us) |
---|---|---|---|---|
1 | [1, 256, 32, 32] | 8 | 32 | 9.7 |
2 | [1, 256, 32, 32] | 4 | 32 | 8.8 |
3 | [1, 256, 32, 32] | 1 | 32 | 17.7 |
4 | [1, 256, 32, 32] | 32 | 1 | 32.5 |
上面一些有趣的观察结果:
- 情况2比情况1快。在情况2中,每个线程计算输出中的8×1分片,其对应于输入中的10×3分片。它比情况1的4x1分片具有更好的数据重用性。
- 情况3比情况2慢。这是因为在情况3中,每个线程的工作量太大并且导致本地存储器读取的很多成本。
- 情况4比情况3慢。这是因为
num_thread_x=32
确保没有存储体冲突,而num_thread_y=32
不能。
总结我们从以上观察得出的结论:
- 大块分片有利于数据重用,但对本地内存读取不利。
num_thread_y
和num_thread_x
对存储体冲突的影响是不同的。- 要找到
num_thread_y
和num_thread_x
的最佳组合,可以实现有效的共享内存访问(避免存储库冲突),数据重用和本地内存读取之间的平衡。
非常棘手。那么,我们应该做些什么才能找到最佳组合?答案是蛮力搜索。我们可以将num_thread_y
和num_thread_x
作为参数传递给schedule函数,并尝试所有可能的组合以找到最优的一个。这可以在TVM中轻松完成:
def schedule_depthwise_conv2d(..., num_thread_y=8, num_thread_x=8):
num_thread_y = num_thread_y
num_thread_x = num_thread_x
do_schedule_as_usual
return schedule
min_time_cost = inf
for num_thread_y, num_thread_x in all_possible_combinations:
schedule = schedule_depthwise_conv2d(..., num_thread_y=num_thread_y, num_thread_x=num_thread_x)
time_cost = test_depthwise_conv2d(..., schedule)
if time_cost < min_time_cost:
min_time_cost = time_cost
optimal_combination = [num_thread_y, num_thread_x]
实际上,它可以被看作是一个简单的自动调度程序。
Vthread和Stripped模式
引入TVM中的Vthread(虚拟线程)以支持分步模式。我们可以这样使用它:
num_vthread_y = 2
num_vthread_x = 2
num_thread_y = 8
num_thread_x = 8
thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
# split the dimension of height (H in NCHW) twice
tvy, vyi = s[Output].split(h_dim, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread_y)
# split the dimension of width (W in NCHW) twice
tvx, vxi = s[Output].split(w_dim, nparts=num_vthread_x)
tx, xi = s[Output].split(vxi, nparts=num_thread_x)
# bind thread and vthread respectively
s[Output].bind(tvy, thread_vy)
s[Output].bind(tvx, thread_vx)
s[Output].bind(ty, thread_y)
s[Output].bind(tx, thread_x)
s[Output].reorder(tvy, tvx, ty, tx, yi, xi)
让我们打印IR以查看vthread的作用:
/* Input = [1, 1, 32, 32], Filter = [1, 1, 3, 3], stride = [1, 1], padding = 'SAME' */
produce DepthwiseConv2d {
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
// attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
// attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
for (i.inner.inner.inner, 0, 2) {
for (j.inner.inner.inner, 0, 2) {
DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = 0.000000f
DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 512)] = 0.000000f
DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 16)] = 0.000000f
DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 528)] = 0.000000f
for (di, 0, 3) {
for (dj, 0, 3) {
DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = (DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] + (tvm_if_then_else(((((((1 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((33 - di) - i.inner.inner.inner))) && (((1 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((33 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + -33)], 0.000000f)*Filter[((di*3) + dj)]))
DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 512)] = (DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 512)] + (tvm_if_then_else(((((((-15 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((17 - di) - i.inner.inner.inner))) && (((1 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((33 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + 479)], 0.000000f)*Filter[((di*3) + dj)]))
DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 16)] = (DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 16)] + (tvm_if_then_else(((((((1 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((33 - di) - i.inner.inner.inner))) && (((-15 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((17 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + -17)], 0.000000f)*Filter[((di*3) + dj)]))
DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 528)] = (DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 528)] + (tvm_if_then_else(((((((-15 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((17 - di) - i.inner.inner.inner))) && (((-15 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((17 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + 495)], 0.000000f)*Filter[((di*3) + dj)]))
}
}
}
}
}
没有vthread(只设置为1),IR是:
/* Input = [1, 1, 32, 32], Filter = [1, 1, 3, 3], stride = [1, 1], padding = 'SAME' */
produce DepthwiseConv2d {
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
// attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
// attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
for (i.inner.inner.inner, 0, 4) {
for (j.inner.inner.inner, 0, 4) {
DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = 0.000000f
for (di, 0, 3) {
for (dj, 0, 3) {
DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = (DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] + (tvm_if_then_else(((((((1 - di) - i.inner.inner.inner) <= (((blockIdx.x*8) + threadIdx.y)*4)) && ((((blockIdx.x*8) + threadIdx.y)*4) < ((33 - di) - i.inner.inner.inner))) && (((1 - dj) - j.inner.inner.inner) <= (threadIdx.x*4))) && ((threadIdx.x*4) < ((33 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + -33)], 0.000000f)*Filter[((di*3) + dj)]))
}
}
}
}
}
正如我们所看到的,当num_vthread_y = 2
和num_vthread_x = 2
时,32 x 32通道被分成四个16 x 16的子通道。每个线程一次计算四个输出元素,一个子通道中有一个元素。
以下是Filter = [256,1,3,3],stride = [1,1],blocking_h = 32,blocking_w = 32的结果:
|Case | Input | num_thread_y, num_thread_x | num_vthread_y, num_vthread_x | TVM SAME pad (us)|
|—|---|—|---|
|1 | [1, 256, 96, 96] | 8, 8 | 1, 1 | 132.5|
|2 | [1, 256, 96, 96] | 8, 8 | 1, 4 | 103.1|
|3 | [1, 256, 96, 96] | 4, 32| 1, 1 | 95.9 |
|4 | [1, 256, 96, 96] | 8, 16| 1, 2 | 90.9 |
情况2比情况1更快。这是因为在情况2中num_thread_x = 8
和num_vthread_x = 4
一起确保连续线程访问连续内存地址,从而避免存储库冲突(如下所示)(每种颜色表示一个线程的工作负载):
理论上,情况3和4应该是相样快,因为它们每个线程具有相同的工作量,并且都享有高效的共享内存访问。不知怎的,案例4就是更快一点。
还记得tensorflow的速度吗?是251.6us,现在TVM速度提高了2.8倍。387.4 -> 132.5 -> 95.9 -> 90.9,分块帮助最大; 调整线程号节约37us; vthread节约额外的5us。
事实上,在更大或更多通道的卷积上,TVM比tensorflow更快(因为更多的滤波器重用):
Input | Filter | stride | tf-1.2 SAME pad (us) | TVM SAME pad (us) | How faster is TVM |
---|---|---|---|---|---|
[1, 256, 96, 96] | [256, 1, 3, 3] | [1, 1] | 251.6 | 90.9 | 2.8x |
[1, 256, 96, 96] | [256, 1, 5, 5] | [1, 1] | 597.6 | 128.9 | 4.6x |
[1, 256, 96, 96] | [256, 2, 3, 3] | [1, 1] | 659.9 | 143.7 | 4.6x |
[1, 256, 96, 96] | [256, 2, 5, 5] | [1, 1] | 1203.9 170.5 | 7.1x |
算子融合
我们可以在深度学习中进行的一种典型优化是运算符融合,即在单个内核中将多个运算符一起计算,而不将中间结果保存回全局内存。TVM支持开箱即用。
考虑神经网络中的常见模式:depthwise_conv2d + scale_shift + relu。 我们可以通过稍微修改原始调度表将三个算子融合为一个:
DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, stride, padding)
ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)
Output = Relu # is no longer DepthwiseConv2d
s[ScaleShift].compute_inline() # this line fuses ScaleShift, explicitly
s[DepthwiseConv2d].set_scope("local") # this line fuses DepthwiseConv2d, implicitly
schedule(Output) # schedule for Output the same way we schedule for DepthwiseConv2d as discussed above
s[DepthwiseConv2d].compute_at(s[Output], tx) # tx is the inner most axis, bound to threadIdx.x
它会产生像这样的IR:
/* Input = [1, 1, 32, 32], Filter = [1, 1, 3, 3], stride = [1, 1], padding = 'SAME' */
produce Relu {
// attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
// attr [DepthwiseConv2d] storage_scope = "local"
allocate DepthwiseConv2d[float32 * 1 * 1 * 4 * 4]
// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
// attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
// attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
produce DepthwiseConv2d {
for (i, 0, 4) {
for (j, 0, 4) {
DepthwiseConv2d[((i*4) + j)] = 0.000000f
for (di, 0, 3) {
for (dj, 0, 3) {
DepthwiseConv2d[((i*4) + j)] = (DepthwiseConv2d[((i*4) + j)] + (tvm_if_then_else(((((((1 - di) - i) <= (((blockIdx.x*8) + threadIdx.y)*4)) && ((((blockIdx.x*8) + threadIdx.y)*4) < ((33 - di) - i))) && (((1 - dj) - j) <= (threadIdx.x*4))) && ((threadIdx.x*4) < ((33 - dj) - j))), Input[(((((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i*32)) + j) + (di*32)) + dj) + -33)], 0.000000f)*Filter[((di*3) + dj)]))
}
}
}
}
}
for (i2.inner.inner.inner, 0, 4) {
for (i3.inner.inner.inner, 0, 4) {
Relu[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i2.inner.inner.inner*32)) + i3.inner.inner.inner)] = max(((DepthwiseConv2d[((i2.inner.inner.inner*4) + i3.inner.inner.inner)]*Scale[0]) + Shift[0]), 0.000000f)
}
}
}
正如我们所看到的,每个线程在将depthwise_conv2d的结果写入全局内存之前计算scale_shift和relu。融合的运算符与单个depthwise_conv2d一样快。以下是Input = [1,256,96,96],Filter = [256,1,3,3],stride = [1,1],padding ='SAME’的结果:
- tf-1.2 depthwise_conv2d:251.6 us
- tf-1.2 depthwise_conv2d + scale_shift + relu(单独):419.9 us
- TVM depthwise_conv2d:90.9 us
- TVM depthwise_conv2d + scale_shift + relu(融合):91.5 us
算子融合的优势是显而易见的。
这不是终点,TVM可以以更智能的方式进行算子融合。你可以参考这个并阅读下面提供的源代码。
让我们看看代码
- Declare: https://github.com/dmlc/tvm/blob/master/topi/python/topi/nn/convolution.py
- Schedule: https://github.com/dmlc/tvm/blob/master/topi/python/topi/cuda/depthwise_conv2d.py
- Test: https://github.com/dmlc/tvm/blob/master/topi/recipe/conv/depthwise_conv2d_test.py
致谢
作者非常感谢陈天奇的有益建议和鼓舞人心的讨论。