import tensorflow as tf
x = tf.constant([1, 4])
print(x.get_shape()) # (2,)
y = tf.constant([2, 5])
z = tf.constant([3, 6])
axis_0 = tf.stack([x, y, z]) # [[1, 4], [2, 5], [3, 6]] (Pack along first dim.)
print(axis_0.get_shape()) # (3, 2)
axis_1 = tf.stack([x, y, z], axis=1) # [[1, 2, 3], [4, 5, 6]]
print(axis_1.get_shape()) # (2, 3)
tf.stack(values, axis=0, name="stack)
作用:将由秩为R的Tensor组成的列表堆叠成一个秩为R+1的Tensor
假如列表长度为N,也就是有N个秩为R且shape为(A, B,C)的Tensor,此时R其实是3,
如果axis=0,输出的Tensor的shape为(N, A, B, C);
如果axis=1,输出的Tensor的shape为(A,N,B,C)
在上述代码中 下x, y,z的秩为1,因此得到的axis_0和axis_1的秩为2,这里N=3,A=2,所以axis_0的shape为(3,2),axis_1的shape为(2, 3)
参数:
values:有相同形状和类型的Tensor对象的一个列表
axis: 堆叠沿着的轴
name:该操作的名字(可选)
返回值: 一个跟valuse有相同类型的堆叠的Tensor