tensorflow学习1.1
1.关于tf.stack()的使用方法
stack在英文中有堆积,堆成垛的意思。不难理解,tf.stack 是将一个tensor按维数堆积的函数。
tf.stack(values, axis=0, name=‘stack’)
values : [tensor1, tensor2, …] 一个tensor列表
axis :维数(堆积第几维)
举例:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.stack([t1, t2]) => [ [[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]] ]
#默认axis为0维, stack后shape仍为(2, 2, 3)
tf.stack([t1, t2], axis = 1) => [ [[1, 2, 3], [7, 8, 9]], [[4, 5, 6], [10, 11, 12]] ]
#按1维stack后shape仍为(2, 2, 3)
tf.stack([t1, t2], axis = 2) => [ [[1, 7], [2, 8], [3, 9]], [[4, 10], [5, 11], [6, 12]] ]
#按2维stack后shape为(2, 3, 2)
2.关于tf.concat()的使用方法
concat 即 concatenate, 连接的意思。指将两个tensor按维数连接起来。有将原tensor降维的过程。
tf.concat(values, axis, name=‘concat’)
values: [tensor1, tensor2, …] 注意:A list of Tensor objects or a single Tensor
axis: 维数(确定第几维被连接)
举例:
t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
#[t1, t2]的shape 为(2, 2, 3), 按0维concat后变为(4, 3)
tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
#[t1, t2]的shape 为(2, 2, 3), 按1维concat后变为(2, 6)
注意:此时axis不能为2,因为2维已经是最底层了,不能再被降维了。
#tensor t3 with shape [2, 3]
#tensor t4 with shape [2, 3]
tf.shape(tf.concat([t3, t4], 0)) ==> [4, 3]
tf.shape(tf.concat([t3, t4], 1)) ==> [2, 6]