tf.split
split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
参数:
value
:要分割的 Tensor。num_or_size_splits
:如果为一个标量,那么被分割的Tensor的第axis维度的值必须能被num_or_size_splits整除;否则沿分割维度的大小总和必须与该 value 的第axis维度的值相匹配。axis
:A 0-D int32 Tensor;表示分割的尺寸;必须在[-rank(value), rank(value))范围内;默认为0。num
:可选的,用于指定无法从 size_splits 的形状推断出的输出数。name
:操作的名称(可选)。
返回值:
- 如果 num_or_size_splits 是标量,返回 num_or_size_splits个Tensor对象;
- 如果 num_or_size_splits 是一维张量,则返回由 value 分割产生的 num_or_size_splits.get_shape[0] 个Tensor对象。
import tensorflow as tf
import numpy as np
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
a = tf.constant(x, shape=[2,3, 5])
b1 = tf.split(a, 3, axis=1) # num_or_size_splits为一个标量
b2 = tf.split(a, [1,1,3], axis=2) #num_or_size_splits是一个一维的tensor
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(a))
print()
print(sess.run(b1))
print()
print(sess.run(b2))
输出
[[[ 1 2 3 4 5]
[ 6 7 8 9 10]
[11 12 13 14 15]]
[[16 17 18 19 20]
[20 20 20 20 20]
[20 20 20 20 20]]]
[array([[[ 1, 2, 3, 4, 5]],
[[16, 17, 18, 19, 20]]]),
array([[[ 6, 7, 8, 9, 10]],
[[20, 20, 20, 20, 20]]]),
array([[[11, 12, 13, 14, 15]],
[[20, 20, 20, 20 20]]])]
[array([[[ 1],
[ 6],
[11]],
[[16],
[20],
[20]]]),
array([[[ 2],
[ 7],
[12]],
[[17],
[20],
[20]]]),
array([[[ 3, 4, 5],
[ 8, 9, 10],
[13, 14, 15]],
[[18, 19, 20],
[20, 20, 20],
[20, 20, 20]]])]
对于b1来说,a的第1个维度是3,这个维度被均匀分割成3份,故每个分割后的tensor对象的维度是
[
2
,
1
,
5
]
[2, 1, 5]
[2,1,5]
对于b2来说,a的第2个维度是5,这个维度按照[1,1,3]的方式分割(1+1+3=5)成3个tensor,每个tensor的shape分别是
[
2
,
3
,
1
]
,
[
2
,
3
,
1
]
,
[
2
,
3
,
3
]
[2, 3, 1], [2, 3, 1],[2, 3, 3]
[2,3,1],[2,3,1],[2,3,3]。