合并和分割的接口:
- tf.concat 拼接
- tf.split 分割
- tf.stack 堆叠
- tf.unstack 分割
tf.concat
例如:
a 收集 1-4 班的成绩: [class1-4, students, scores]
b 收集 5-8 班的成绩: [class5-6, students, scores]
# concat
a = tf.ones([4,35,8])
b = tf.ones([2,35,8])
c = tf.concat([a,b],axis=0)
print(c.shape)#(6, 35, 8)
a = tf.ones([4,35,8])
b = tf.ones([4,3,8])
print(tf.concat([a, b], axis=1).shape)#(4, 38, 8)
tf.stack 可以增加维度
School1:[classes, students, scores]
School2:[classes, students, scores]
[schools, classes, students, scores]
# stack
a = tf.ones([4,35,8])
b = tf.ones([4,35,8])
print(tf.stack([a, b], axis=0).shape)#(2, 4, 35, 8)
合并总结
合并的两种方式 : concat \ stack
对于concat :
- 维度要相同
- 某一维度的数值可以不相同,但是其他的维度的数值必须相同
- 总的来说,是把两个维度的某一个维度增大
对于stack :
- 维度要完全相同
- 总的来说,是在原有的维度上添加一个维度
tf.unstack
# unstack
print(c.shape)#(2, 4, 35, 8)
aa,bb = tf.unstack(c,axis=0)
print(aa.shape,bb.shape)
res = tf.unstack(c,axis=3)#(4, 35, 8) (4, 35, 8)
print(res[0].shape,res[7].shape)#(2, 4, 35) (2, 4, 35)
tf.split
#split
print(c.shape)#(2, 4, 35, 8)
res = tf.unstack(c,axis=3)#(4, 35, 8) (4, 35, 8)
print(len(res))
res = tf.split(c,axis=3,num_or_size_splits=2)
print(len(res))
res = tf.split(c,axis=3,num_or_size_splits=[2,2,4])
print(res[0].shape)#(2, 4, 35, 2)
print(res[2].shape)#(2, 4, 35, 4)
总结分割
分割的方式有两种: unstack \ split
对于 unstack :
- 指定某一维度,会自动将某一维度分割为(这个维度对应的值n)n个
- unstack是stack的逆
对于 split:
- 可以通过参数 num_or_size_splits 来指定分割的块,它会自动均分
- 也可以给参数 num_or_size_splits 传递一个序列,使得矩阵分割成指定的块