一、环境
TensorFlow API r1.14(rc)
CUDA 9.0 V9.0.176
Python 3.6.3
二、官方说明
把张量分解成子张量
https://tensorflow.google.cn/versions/r1.14/api_docs/python/tf/split
tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name=‘split’
)
参数:
value:要分割的张量
num_or_size_splits:可以是整数(指定把张量划分为几分,需要注意的是必须能整除 value.shape[axis])、一维张量或Python 列表(指定划分输出的每一个子张量的大小,需要注意的是张量或列表中的元素之和需要等于 value 要划分的维度的大小)
axis:整数或标量 int32 张量,指定沿那个维度分割张量,数值必须在 [-rank(value), rank(value)] 之间,默认是 0。
num:可选参数。当不能从 num_or_size_splits 的形状推断输出的数量时,通过该参数指定
name:可选参数。操作的名称
三、实例
>>> import tensorflow as tf
>>> import numpy as np
>>> tf.enable_eager_execution()
>>> data = np.random.random((5,10))
>>> data_tensor = tf.constant(data)
# 标量
>>> splited_interger = 5
>>> split0_0, split0_1, split0_2, split0_3, split0_4 = tf.split(data_tensor, splited_interger, -1)
>>> tf.shape(split0_0)
<tf.Tensor: id=9, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_1)
<tf.Tensor: id=11, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_2)
<tf.Tensor: id=13, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_3)
<tf.Tensor: id=15, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_4)
<tf.Tensor: id=17, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
# Python 列表
>>> splited_list = [2,3,5]
>>> split1_0, split1_1, split1_2 = tf.split(data_tensor, splited_list, -1)
>>> tf.shape(split1_0)
<tf.Tensor: id=24, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split1_1)
<tf.Tensor: id=26, shape=(2,), dtype=int32, numpy=array([5, 3], dtype=int32)>
>>> tf.shape(split1_2)
<tf.Tensor: id=28, shape=(2,), dtype=int32, numpy=array([5, 5], dtype=int32)>
# 1 维张量
>>> splited_tensor = tf.constant(splited_list)
>>> split2_0, split2_1, split2_2 = tf.split(data_tensor, splited_tensor, -1)
>>> tf.shape(split2_0)
<tf.Tensor: id=38, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split2_1)
<tf.Tensor: id=40, shape=(2,), dtype=int32, numpy=array([5, 3], dtype=int32)>
>>> tf.shape(split2_2)
<tf.Tensor: id=42, shape=(2,), dtype=int32, numpy=array([5, 5], dtype=int32)>