tf.matmul()是常见的矩阵相乘运算,而tf.multiply()则是两个矩阵中对应元素的相乘运算。
具体用法:
-
multiply 等同与* ,用于计算矩阵之间的element-wise 乘法,要求矩阵的形状必须一致(或者是其中一个维度为1),否则会报错.
-
matmul 是tensor的矩阵乘法, 参与运算的两个tensor维度、数据类型必须一致,
参与运算的是最后两维形成的矩阵,如果tensor是二维矩阵,则等同于矩阵乘法.
代码示例
import tensorflow as tf
a = tf.constant([1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12], shape=[2, 3, 2])
b = tf.constant([1, 2, 3, 4, 5, 6], shape=[2, 3, 1])
c = a*b
with tf.Session():
print(a.eval())
print(b.eval())
print(c.eval())
输出为:
>> a
[[[ 1 2]
[ 3 4]
[ 5 6]]
[[ 7 8]
[ 9 10]
[11 12]]]
>>b
[[[1]
[2]
[3]]
[[4]
[5]
[6]]]
>>c
[[[ 1 2]
[ 6 8]
[15 18]]
[[28 32]
[45 50]
[66 72]]]
可见c的shape是[2,3,2]。
既然提到这了,就顺便说一下怎么理解三维数组的shape吧。
像shape=[a,b,c],就可以理解为底面积是[b,c],高是a,更准确的理解这个数组共是a页,每一页都有一个b行c列的二维数组,可以说a个b*c的二维数组构成了这个[a,b,c]的三维数组。