在tensorflow 的代码里经常看到tf.split()这个函数,今天来扒一扒这个API的用法
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
1
2
3
4
5
6
7
8
Splits a tensor into sub tensors.
If num_or_size_splits is an integer type, num_split, then splits value along dimension axis into num_split smaller tensors. Requires that num_split evenly divides value.shape[axis].
If num_or_size_splits is not an integer type, it is presumed to be a Tensor size_splits, then splits value into len(size_splits) pieces. The shape of the i-th piece has the same size as the value except along dimension axis where the size is size_splits[i].
根据官方文档的说法这个函数的用途简单说就是把一个张量划分成几个子张量。
value:准备切分的张量
num_or_size_splits:准备切成几份
axis : 准备在第几个维度上进行切割
其中分割方式分为两种
1. 如果num_or_size_splits 传入的 是一个整数,那直接在axis=D这个维度上把张量平均切分成几个小张量
2. 如果num_or_size_splits 传入的是一个向量(这里向量各个元素的和要跟原本这个维度的数值相等)就根据这个向量有几个元素分为几项)
举个例子
# 张量为(5, 30)
# 这个时候5是axis=0, 30是axis=1,如果要在axis=1这个维度上把这个张量拆分成三个子张量
#传入向量时
split0, split1, split2 = tf.split(value, [4, 15, 11], 1)
tf.shape(split0) # [5, 4]
tf.shape(split1) # [5, 15]
tf.shape(split2) # [5, 11]
# 传入整数时
split0, split1, split2 = tf.split(value, num_or_size_splits=3, axis=1)
tf.shape(split0) # [5, 10]