合并与分割
tf.concat #拼接
tf.split # 分割
tf.stack # 堆叠
tf.unstack #分割,unstack是split中的一种
concat 拼接
a = tf.ones([4, 35, 8])
b = tf.ones([2, 35, 8])
c = tf.concat([a, b], axis=0)
c.shape
a = tf.ones([4, 32, 8])
b = tf.ones([4, 3, 8])
tf.concat([a, b], axis=1).shape
[6,35,8]其实可以理解为[4+2,35,8],下一段代码结果同理
这里需要注意的是,为什么可以理解为相加,一是因为代码中的axis=0指定了轴,二是除指定轴之外的其他轴数字都相等
从而可以得到concat的一个条件:
需要拼接的维度可以数字不等,但其他维度的数字必须相等
concat合并,但不增加维度
这可以与下面的stack有所区分
stack堆叠
a = tf.ones([4, 35, 8])
b = tf.ones([4, 35, 8])
tf.concat([a, b], axis=-1).shape # out:TensorShape([4, 35, 16])
tf.stack([a, b], axis=-1).shape #out:TensorShape([4, 35, 8, 2])
tf.stack([a, b],axis=0).shape #out:TensorShape([2, 4, 35, 8])
在这三个代码的输出结果,可以很直观感受出代码的作用不同
concat实际是指定轴数字相加,其余维度数字不变(但必须相同),但维度的数目没有增加或减少
stack实际是在指定处加上一个维度,增加了一个维度
由于是复习,此时我又有了一个疑问,为什么可以直接增加,有没有条件?
翻开笔记,发现stack是有一个条件的,比concat更加严苛的条件
stack要求所有维度相等
最后总结一下两个“合并”代码总结一下:
concat需要拼接的维度可以数字不等,但其他维度的数字必须相等
concat合并,但不增加维度
stack要求所有维度数字都相等,才可以堆叠
stack堆叠,增加维度
unstack分割
a = tf.ones([4, 35, 8])
b = tf.ones([4, 35, 8])
c = tf.stack([a, b])
c.shape #TensorShape([2, 4, 35, 8])
aa,bb=tf.unstack(c,axis=0)
aa.shape,bb.shape #(TensorShape([4, 35, 8]), TensorShape([4, 35, 8]))
如果不明显,不太看得出来,换一个
res = tf.unstack(c,axis=3)
res[0].shape # TensorShape([2, 4, 35])
len(res) # 8
[2, 4, 35, 8]
在unstack中axis=0时为[4, 35, 8]
在unstack中axis=3时为[2, 4, 35]
可以很容易看出,当指定维度时,该维度直接“没了”,其实就是被分割掉了
axis=0时此时维度上数为2,就是被分割成了两段,每一段为[4, 35, 8],下一个同理(res长度可以看出为8)
split分割
res=tf.split(c,axis=3,num_or_size_splits=2)
len(res) # 2
与unstack相比其实就是多了一个num_or_size_splits
即
split可以指定打散的数量
res[0].shape #TensorShape([2, 4, 35, 4])
res=tf.split(c,axis=3,num_or_size_splits=[2,2,4])
res[0].shape #TensorShape([2, 4, 35, 2])
res[2].shape # TensorShape([2, 4, 35, 4])
res=tf.split(c,axis=3,num_or_size_splits=[2,2,4])
是指在第4个维度上,分成三个,分别为2,2,4
即[2, 4, 35, 8]被分成[2, 4, 35, 2],[2, 4, 35, 2],[2, 4, 35, 4]