# TensorFlow+TVM优化NMT神经机器翻译

TensorFlow+TVM优化NMT神经机器翻译

void BatchedGemm(input A, input B, output C, M, N, K, batch_dimension) {
for (int i = 0; i < batch_dimension; ++i) {
DoGemm(A[i],B[i],C[i],M,K,N)
}
}

cuBLAS批次batch matmul的性能问题

cuBLAS的批量matmul实施远非效率。因此，使用TVM为NMT工作负载生成有效的批处理matmul内核。

# computation representation

A = tvm.placeholder((batch, M, K), name=‘A’)
B = tvm.placeholder((batch, K, N), name=‘B’)
k = tvm.reduce_axis((0, K), ‘k’)
C = tvm.compute((batch, M, N),
lambda b, y, x: tvm.sum(A[b, y, k] * B[b, k, x], axis = k),
name = ‘C’)

# block partitioning

BB, FF, MM, PP = s[C].op.axis
BBFF = s[C].fuse(BB, FF)
MMPP = s[C].fuse(MM, PP)
s[C].bind(by, block_y)
s[C].bind(bx, block_x)
vty, ty = s[C].split(ty_block, nparts = vthread_y)
vtx, tx = s[C].split(tx_block, nparts = vthread_x)
s[C].reorder(by, bx, vty, vtx, ty, tx)
s[C].reorder(by, bx, ty, tx)

# computation representation

A = tvm.placeholder((batch_size, features, M, K), name=‘A’)

# the shape of B is (N, K) other than (K, N) is because B is transposed is this fusion pattern

B = tvm.placeholder((batch_size, features, N, K), name=‘B’)
ENTER = tvm.placeholder((batch_size, 1, M, N), name = ‘ENTER’)
k = tvm.reduce_axis((0, K), ‘k’)
C = tvm.compute(
(batch_size, features, M, N),
lambda yb, yf, m, x: tvm.sum(A[yb, yf, m, k] * B[yb, yf, x, k], axis = k),
name = ‘C’)

# computation representation

A = tvm.placeholder((batch_size, features, M, K), name=‘A’)
B = tvm.placeholder((batch_size, features, K, N), name=‘B’)
k = tvm.reduce_axis((0, K), ‘k’)
C = tvm.compute(
(batch_size, M, features, N),
lambda yb, m, yf, x: tvm.sum(A[yb, yf, m, k] * B[yb, yf, k, x], axis = k),
name = ‘C’)

• TF-R1.4 BatchMatmul：513.9
• TF-R1.4 BatchMatmul+ Transpose（另购）：541.9
• TVM BatchMatmul：37.62美元
• TVM BatchMatmul+ Transpose（融合）：38.39美元

• 点赞
• 评论
• 分享
x

海报分享

扫一扫，分享海报

• 收藏
• 打赏

打赏

wujianming_110117

你的鼓励将是我创作的最大动力

C币 余额
2C币 4C币 6C币 10C币 20C币 50C币
• 举报
• 一键三连

点赞Mark关注该博主, 随时了解TA的最新博文

07-02 2610
01-30 7296
03-02 1万+
02-02 5734
07-30 2867
03-01 8833
04-01 4万+
11-13
06-06 9563