发现一个简单,或者说有点意思的理解tf.stack的方法
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
c0=tf.stack([t1, t2], 0)
c1=tf.stack([t1, t2], 1)
c2=tf.stack([t1, t2], 2)
with tf.Session() as sess:
print('c0')
print(sess.run(c0))
print('c1')
print(sess.run(c1))
print('c2')
print(sess.run(c2))
//结果如下:
c0
[[[ 1 2 3]
[ 4 5 6]]
[[ 7 8 9]
[10 11 12]]]
c1
[[[ 1 2 3]
[ 7 8 9]]
[[ 4 5 6]
[10 11 12]]]
c2
[[[ 1 7]
[ 2 8]
[ 3 9]]
[[ 4 10]
[ 5 11]
[ 6 12]]]
思路如下:
我们发现上文中t1和t2都是shape是为(2,3)的矩阵,那么把tf.stack是把两个矩阵拼接起来。关键问题是如何拼接,拼接的逻辑是如何的。
首先我们发现axis这个参数,这个参数你可以理解成从左到又第几层方括号里的元素作为最小单元。最小单元的意思就是在合并时候不可分割。例如axis=0,就是从左到右第1个括号里面的单元不可分割。就得到c0。如果axis=1,那就是从左到右第二个括号里面的内容不可分割,然后两个一一对应的组装,注意这里,是两个需要stack的矩阵,按照最小单元进行一一对应的组装得到新的矩阵。如果axis=2,那就是从左到右第三