写代码时候遇到了带batch的矩阵乘法,需要用numpy实现。即A=[batch,height,width], B=[batch,width,height], C=AB=[batch,height,height]。在tensorflow中是可以直接使用tf.matmul得到结果的,但是在numpy中没有现成的三维矩阵乘法。
三维矩阵乘法的思路就是:C[0]=A[0]B[0], C[1]=A[1]B[1],也就是分别将A和B的batch的每个样本进行矩阵乘法,然后构成C。在numpy中可以使用np.einsum一行代码实现。
简单介绍:np.einsum('ij, jk', A, B)是矩阵A乘以矩阵B,等价于np.dot(A,B),这是二维的
对于三维的AB,则 np.einsum("ijk,ikn->ijn", A, B) ijk表示A的索引,ikn表示B的索引,定义输出的维度是ijn
import numpy as np
import tensorflow as tf
if __name__ == "__main__":
batch_size = 2
height = 4
width = 2
a = np.random.rand(batch_size, height, width)
b = np.random.rand(batch_size, width, height)
print("*************** a 输入 {} ****************".format(a.shape))
print(a)
print("*************** b 输入 {} ****************".format(b.shape))
print(b)
aa = tf.placeholder(tf.float32, [batch_size, height, width])
bb = tf.placeholder(tf.float32, [batch_size, width, height])
cc = tf.matmul(a, b)
with tf.Session() as sess:
out = sess.run(cc, feed_dict={aa:a, bb:b})
print("*************** tf 输出 {} ****************".format(out.shape))
print(out)
xx = np.einsum("ijk,ikn->ijn", a, b)
print("\n*************** numpy 输出 {} ****************".format(xx.shape))
print(xx)
err_max2 = np.amax(np.absolute(np.subtract(out, xx)))
print("\ntf与numpy误差:{}".format(err_max2))
# *************** a 输入 (2, 4, 2) ****************
# [[[0.48151815 0.59571173]
# [0.54950679 0.07559809]
# [0.54483139 0.49344093]
# [0.66313407 0.7736222 ]]
#
# [[0.71144517 0.25567787]
# [0.82224508 0.87165079]
# [0.27935693 0.10498713]
# [0.39752717 0.62073428]]]
# *************** b 输入 (2, 2, 4) ****************
# [[[0.43953543 0.51880854 0.35398745 0.59761315]
# [0.1339994 0.46699152 0.9858384 0.67810861]]
#
# [[0.91692865 0.63337183 0.52427425 0.77657735]
# [0.14262564 0.92203296 0.27971297 0.95416443]]]
# *************** tf 输出 (2, 4, 4) ****************
# [[[0.2914693 0.52800806 0.75772688 0.69171883]
# [0.2516578 0.32039248 0.26904601 0.3796562 ]
# [0.30559349 0.51309591 0.67931649 0.66020494]
# [0.39513583 0.70531463 0.99740761 0.92089751]]
#
# [[0.68881068 0.68635275 0.4445088 0.79645093]
# [0.87825983 1.32447763 0.67489395 1.47023509]
# [0.27112423 0.2737384 0.1758259 0.31711724]
# [0.45303667 0.82411997 0.38204068 0.90099316]]]
#
# *************** numpy 输出 (2, 4, 4) ****************
# [[[0.2914693 0.52800806 0.75772688 0.69171883]
# [0.2516578 0.32039248 0.26904601 0.3796562 ]
# [0.30559349 0.51309591 0.67931649 0.66020494]
# [0.39513583 0.70531463 0.99740761 0.92089751]]
#
# [[0.68881068 0.68635275 0.4445088 0.79645093]
# [0.87825983 1.32447763 0.67489395 1.47023509]
# [0.27112423 0.2737384 0.1758259 0.31711724]
# [0.45303667 0.82411997 0.38204068 0.90099316]]]
#
# tf与numpy误差:0.0