1. 矩阵乘法
矩阵乘法是被广泛使用的算子,最简单的实现方法为:
for i in range(M):
for j in range(N):
C(i, j) = 0
for k in range(K):
C(i, j) += A(i, k) * B(k, j)
2. 在 Tensorflow 中定义矩阵乘法
Tensorflow中提供的线性代数和张量操作 API 允许对于一种计算有多种表达方式,推荐的方式往往是尽量利用已经存在的 API 原子性地表达一个运算,比如 tf.nn.conv2d, tf.linalg.matmul 等等,而不是将其拆开成多个操作。但是在深度学习网络的发展过程中,必然会有新的算子被提出,所以利用已有的运算拼凑新的运算是具有意义的。在本文里我们先尝试利用已有运算拼凑矩阵乘法,用来观测 XLA 对于拼凑出的运算(其实是计算图)会进行怎样的操作。2.1 直接调用 matmul API
直接使用预设好的 API
def gemm1(A, B):
return tf.linalg.matmul(A, B)
对于这种计算图,XLA 的优化在于消除冗余的 reshape 等节点。但优化前后都会调用 dot 这个运算 API 完成计算,dot 本身就在 HLO instruction 之中。2.2 利用升维降维计算矩阵乘法
先将输入矩阵升维到三维,进行逐点相乘后再降维累加得到输出矩阵
def gemm2(A, B):
return tf.reduce_sum(
tf.multiply(
tf.tile(tf.expand_dims(A, -1), [1, 1, B.shape[1]]),
tf.tile(tf.expand_dims(B, 0), [A.shape[0], 1, 1])
),
axis=1
)
打印出 XLA 的 IR,首先看图形:
初始计算图中有很多冗余的reshape以及broadcast。
优化后的计算图简洁了很多。
再来对比文本形式的 HLO IR(其BNF文法见https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/g3doc/hlo_parser.md)。
优化前的 IR:
HloModule cluster_15890044661264488385__.30
%Sum-reduction.21 (x.22: f32[], y.23: f32[]) -> f32[] {
%x.22 = f32[] parameter(0)
%y.23 = f32[] paramete