split(value, num_or_size_splits, axis=0, num=None, name=’split’)
Splits a tensor into sub tensors.
If `num_or_size_splits` is an integer type, `num_split`, then splits `value`
along dimension `axis` into `num_split` smaller tensors.
Requires that `num_split` evenly divides `value.shape[axis]`.
如果参数num_or_size_splits是整数,则把value切片为该整数个
If num_or_size_splits
is not an integer type, it is presumed to be a Tensor size_splits, then splits
valueinto len(size_splits)
pieces.
否则给出的是期望在axis上切片下来维度list。
The shape of the i
-th piece has the same size as the value
except along dimension axis
where the size is size_splits[i]
.
除了axis的维度不一样,切片过后其他维度保持不变。
例子:
# ‘value’ is a tensor with shape [5, 30]
# Split ‘value’ into 3 tensors with sizes [4, 15, 11] along dimension 1
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0) # [5, 4]
tf.shape(split1) # [5, 15]
tf.shape(split2) # [5, 11]
# Split ‘value’ into 3 tensors along dimension 1
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0) # [5, 10]
“`
我来写个例子:
tensor = [[1,2,3],
[4,5,6],
[7,8,9]]
with tf.Session() as sess:
"""沿着1轴切片"""
tensor1,tensor2,tensor3 = tf.split(tensor,num_or_size_splits=3,axis=1)
print(tensor1.eval())
print('--------------')
"""沿着0轴切片"""
tensor1,tensor2,tensor3 = tf.split(tensor,num_or_size_splits=3,axis=0)
print('--------------')
print(tensor1.eval())
"""给出切片list"""
tensor1, tensor2 = tf.split(tensor,num_or_size_splits=[1,2],axis=0)
print('--------------')
print(tensor2.eval())
"""关于num参数还不清楚,这里略去了"""
输出:
[[1]
[4]
[7]]
--------------
--------------
[[1 2 3]]
--------------
[[4 5 6]
[7 8 9]]