tf计算矩阵维度_XLA 探究:矩阵乘法

1. 矩阵乘法
矩阵乘法是被广泛使用的算子,最简单的实现方法为:

for i in range(M):
    for j in range(N):
         C(i, j) = 0
         for k in range(K):
                C(i, j) += A(i, k) * B(k, j)

2. 在 Tensorflow 中定义矩阵乘法
Tensorflow中提供的线性代数和张量操作 API 允许对于一种计算有多种表达方式,推荐的方式往往是尽量利用已经存在的 API 原子性地表达一个运算,比如 tf.nn.conv2d, tf.linalg.matmul 等等,而不是将其拆开成多个操作。但是在深度学习网络的发展过程中,必然会有新的算子被提出,所以利用已有的运算拼凑新的运算是具有意义的。在本文里我们先尝试利用已有运算拼凑矩阵乘法,用来观测 XLA 对于拼凑出的运算(其实是计算图)会进行怎样的操作。2.1 直接调用 matmul API
直接使用预设好的 API

def gemm1(A, B):
     return tf.linalg.matmul(A, B)


对于这种计算图,XLA 的优化在于消除冗余的 reshape 等节点。但优化前后都会调用 dot 这个运算 API 完成计算,dot 本身就在 HLO instruction 之中。2.2 利用升维降维计算矩阵乘法
先将输入矩阵升维到三维,进行逐点相乘后再降维累加得到输出矩阵

def gemm2(A, B):
     return tf.reduce_sum(
            tf.multiply(
            tf.tile(tf.expand_dims(A, -1), [1, 1, B.shape[1]]), 
            tf.tile(tf.expand_dims(B, 0), [A.shape[0], 1, 1])
        ),
 axis=1
    )


打印出 XLA 的 IR,首先看图形:
初始计算图中有很多冗余的reshape以及broadcast。

841f73d2c2b6d6f6da5403c209e32fc8.png
优化前

优化后的计算图简洁了很多。

ce8598a352f28addcc99ac06e5282b28.png

再来对比文本形式的 HLO IR(其BNF文法见https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/g3doc/hlo_parser.md)。


优化前的 IR:

HloModule cluster_15890044661264488385__.30

%Sum-reduction.21 (x.22: f32[], y.23: f32[]) -> f32[] {
  %x.22 = f32[] parameter(0)
  %y.23 = f32[] paramete
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值