tensorflow 函数matmul要求
a=tf.constant([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
b=tf.constant([[[1,2],[3,4],[5,6]],[[1,2],[3,4],[5,6]]])
print(a.shape) #2*2*3
print(b.shape) #2*3*2
c=tf.matmul(a,b)
print(c)
结论:要求 a、b的最后两维可乘,其他维度相等。但有时会有一些隐性的东西:
a=tf.constant([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
b=tf.constant([[1,2],[3,4],[5,6]])
print(a.shape)
print(b.shape) #3*2
c=tf.matmul(a,b)
print(c)
比如这个是成立的,tf内部实现时使用了expend_dim以及tile后stack在一起,所以发生了有意思的事情,就是:
a=tf.constant([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
b=tf.constant([[1,2],[3,4],[5,6]])#3*2
a=tf.constant([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
b=tf.constant([[[1,2],[3,4],[5,6]]]) #1*3*2
a=tf.constant([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
b=tf.constant([[[1,2],[3,4],[5,6]],[[1,2],[3,4],[5,6]]]) #2*3*2
这三种方式是等效的,但是当b的维度不可以通过tile达到a的维度时会报错:
a=tf.constant([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
b=tf.constant([[[1,2],[3,4],[5,6]],[[1,2],[3,4],[5,6]],,[[1,2],[3,4],[5,6]]]) #3*3*2
tensorflow.python.framework.errors_impl.InvalidArgumentError: In[0] and In[1] must have compatible batch dimensions: [2,2,3] vs. [3,3,2] [Op:BatchMatMulV2]