髙维张量tensor之间的乘法,均类似于二维矩阵的乘法,但需要满足最后两个维度符合矩阵乘法要求的维度,且其他维度均必须要相等的原则
import tensorflow as tf
import numpy as np
a = tf.constant(np.arange(1, 13, dtype=np.float32), shape=[2, 2, 3])
b = tf.constant(np.arange(1, 13, dtype=np.float32), shape=[2, 3, 2])
a1 = tf.constant(np.arange(1, 25, dtype=np.float32), shape=[2, 2, 2, 3])
b1 = tf.constant(np.arange(1, 25, dtype=np.float32), shape=[2, 2, 3, 2])
c1 = tf.matmul(a, b)
c2 = tf.matmul(a[0, :, :], b[0, :, :])
c3 = tf.matmul(a1, b1)
c4 = tf.matmul(a1[0, 0, :, :], b1[0, 0, :, :])
with tf.Session() as sess:
print(c1.shape)
print(sess.run(c1))
print(sess.run(c2))
print(np.all(sess.run(c1)[0] == sess.run(c2)))
print("-"*20)
print(c3.shape)
print(sess.run(c3))
print(sess.run(c4))
print(np.all(sess.run(c3)[0, 0] == sess.run(c4)))