效果:
将一个张量值value切分成子张量列表。
tf.split(
value, num_or_size_splits, axis=0, num=None, name='split'
)
如果num_or_size_splits为整数,则将张量value其沿维度axis拆分成大小为num_or_size_splits较小的张量。这就要求 value.shape[axis] 能被 num_or_size_splits整除。
如果num_or_size_splits为一维张量(或列表),则将value其拆分为 len(num_or_size_splits)个元素。第i个元素的形状与value的相同,除了沿维度轴的大小为num_or_size_splitting [i]。
参数含义
参数名称 | 具体含义 |
---|---|
value | 需要切分的张量 |
num_or_size_splits | 表示沿轴分割数的整数或一维整数张量或 包含沿轴每个输出张量大小的Python列表。如果是标量,则必须均匀分割value.shape[axis];否则,沿拆分轴的大小之和必须与值的大小之和匹配。 |
axis | 一个整数或int32类型张量。表示切分的维度。必须在[rank(value), rank(value)]范围内。默认值为0。 |
num | 可选,用于指定,当不能从size_split的形状推断输出的数量。 |
name | 操作的名称(可选)。 |
返回值
如果num_or_size_splitting是标量,则返回num_or_size_splitting个张量对象的列表;如果num_or_size_splitting是一维张量,则返回num_or_size_splitting.get_shape[0]个分割value得到的张量对象。
实例
value = tf.Variable(tf.random.uniform([5, 6], -1, 1))
#沿着axis=1将value切分为3个张量
s0, s1 = tf.split(value, num_or_size_splits=2, axis=1)
s0的值:
#将value按照尺寸[1,2,1]切分,在轴axis=1上
split0, split1, split2 = tf.split(value, [1, 2, 1], 1)
split0
split0:
split1: