tf.tensordot是tensorflow中tensor矩阵相乘的API,可以进行任意维度的矩阵相乘
(1).tf.tensordot函数详细介绍如下:
tf.tensordot(
a,
b,
axes,
name=None
)
"""
Args:
a:类型为float32或者float64的tensor
b:和a有相同的type,即张量同类型,但不要求同维度
axes:可以为int32,也可以是list,为int32,表示取a的最后几个维度,与b的前面几个维度相乘,再累加求和,消去(收缩)相乘维度
为list,则是指定a的哪几个维度与b的哪几个维度相乘,消去(收缩)这些相乘的维度
name:操作命名
"""
(2).代码演示(举四维Tensor与三维Tensor相乘的例子)
1.获取一个shape=(2,1,3,2)的随机数矩阵a,以及一个shape=(2,3,1)的矩阵b
import tensorflow as tf
a = tf.constant([0,1,2,1,3,4,5,2,3,4,5,0],shape=[2,1,3,2])
b =tf.constant([1,3,2,3,1,2],shape=[2,3,1])
with tf.Session() as sess:
print("a的shape:",a.shape)
print("b的shape:",b.shape)
print("a的值:",sess.run(a))
print("b的值:",sess.run(b))