一、合并
合并:将多个张量在某个维度上合并为一个张量。
以某学校班级成绩册数据为例,设张量A 保存了某学校1-5 号班级的成绩册,每个班级30 个学生,共8 门科目,则张量A的shape 为:[5,30,8];同样的方式,张量B 保存了剩下的6 个班级的成绩册,shape 为[6,30,8]。通过合并2 个成绩册,便可得到学校所有班级的成绩册张量C,shape 应为[11,30,8]。
张量的合并可以使用拼接(Concatenate)和堆叠(Stack)操作实现,拼接并不会产生新的维度,而堆叠会创建新维度。选择使用拼接还是堆叠操作来合并张量,取决于具体的场景是否需要创建新维度。
1. 拼接
通过tf.concat(tensors, axis)
完成拼接操作,其中tensors 保存了所有需要合并的张量List,axis 指定需要合并的维度。
"""
拼接
tf.concat(tensors, axis)
tensors:数据
axis:需要进行拼接的维度值,此时需要进行拼接的axis上的值是可以不同的,但其他维度需要是相同的
例如,axis=0表示为在第0维度进行拼接,axis=1表示为在第1维度进行拼接
"""
a = tf.ones([4, 25, 8])
b = tf.ones([2, 25, 8])
a_b = tf.concat([a, b], axis=0)
a_b_shape = a_b.shape # a_b_shape = [6, 25, 8]
a1 = tf.ones([4, 32, 8])
b1 = tf.ones([4, 3, 8])
a1_b1 = tf.concat([a1, b1], axis = 1)
a1_b1_shape = a1_b1.shape # a1_b1_shape = [4, 35, 8]
需要进行拼接的axis上的值是可以不同的,但其他维度需要是相同的
2. 堆叠
如果在合并数据时,希望创建一个新的维度,则需要使用tf.stack
操作
使用 tf.stack(tensors, axis)
可以合并多个张量tensors,其中axis 指定插入新维度的位置,axis 的用法与tf.expand_dims
的一致,当axis ≥ 0时,在axis 之前插入;当axis < 0时,在axis 之后插入新维度。例如shape 为[𝑏, 𝑐, ℎ, 𝑤]的张量,在不同位置通过stack 操作插入新维度,axis 参数对应的插入位置设置如图
"""
堆叠
如果在合并数据时,希望创建一个新的维度,则需要使用tf.stack 操作
tf.stack(tensors,axis)
tensors:数据
axis 指定插入新维度的位置,axis 的用法与tf.expand_dims 的一致,当axis ≥ 0时,在axis 之前插入;当axis < 0时,在axis 之后插入新维度
约束条件:要求被拼接的所有维度相同。
"""
c = tf.random.normal([5, 8])
d = tf.random.normal([5, 8])
c_d = tf.stack([c, d], axis=0)
c_d_shape = c_d.shape # c_d_shape = [2, 5, 8]
二、分割
合并操作的逆过程就是分割,将一个张量分拆为多个张量。
继续考虑成绩册的例子,我们得到整个学校的成绩册张量,shape 为[10,35,8],现在需要将数据在班级维度切割为10 个张量,每个张量保存了对应班级的成绩册。
通过 tf.split(x, axis, num_or_size_splits)
可以完成张量的分割操作,其中
x:待分割张量
axis:分割的维度索引号
num_or_size_splits:切割方案。当num_or_size_splits 为单个数值时,如10,表示切割为10 份;当num_or_size_splits 为List 时,每个元素表示每份的长度,如[2,4,2,2]表示切割为4 份,每份的长度分别为2,4,2,2
"""
分割
将一个张量拆分成多个张量
tf.split(tensor, axis, num_or_size_splits)
tensor:待分割的张量
axis:分割的维度索引号
num_or_size_splits:切割方案。当num_or_size_splits 为单个数值时,如10,表示切割为10 份;当num_or_size_splits 为List 时,每个元素表示每份的长度,如[2,4,2,2]表示切割为4 份,每份的长度分别为2,4,2,2
"""
e = tf.random.normal([5, 10, 10])
result = tf.split(e, axis=1, num_or_size_splits=[1, 2, 3, 4])
result_length = len(result)
print("result_length:", result_length)