切分张量
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name='split'
)
value:待切分的张量
num_or_size_splits:切分的个数
axis:沿着哪个维度切分
其中分割方式分为两种
- 如果num_or_size_splits 传入的是一个整数,那直接在axis=D这个维度上把张量平均切分成几个小张量 ;
- 如果num_or_size_splits 传入的是一个向量(传入向量各个元素和与原本张量这个维度的数值相等),就根据传入的向量有几个元素将原来张量切分为几项 。
# 张量为(10, 40)
# 这个时候axis=0维度是10,axis=1维度是40,如果要在axis=1维度上把这个张量拆分成四个子张量
#传入向量时
split0, split1, split2, split3 = tf.split(value, [9, 15, 10, 6], 1) # 9+15+10+6=40
tf.shape(split0) # [10, 9]
tf.shape(split1) # [10, 15]
tf.shape(split2) # [10, 10]
tf.shape(split3) # [10, 6]
# 传入整数时
split0, split1, split2 ,split3 = tf.split(value, num_or_size_splits=4, axis=1)
tf.shape(split0) # [10, 10]