import tensorflow as tf
import numpy as np
a = np.reshape(range(24),(4,2,3))
print(a)
b = tf.split(a,2,1)
print(b)
c = tf.squeeze(b[0])
print(c)
输出结果:
[[[ 0 1 2] [ 3 4 5]] [[ 6 7 8] [ 9 10 11]] [[12 13 14] [15 16 17]] [[18 19 20] [21 22 23]]] [<tf.Tensor: shape=(4, 1, 3), dtype=int64, numpy= array([[[ 0, 1, 2]], [[ 6, 7, 8]], [[12, 13, 14]], [[18, 19, 20]]])>, <tf.Tensor: shape=(4, 1, 3), dtype=int64, numpy= array([[[ 3, 4, 5]], [[ 9, 10, 11]], [[15, 16, 17]], [[21, 22, 23]]])>] tf.Tensor( [[ 0 1 2] [ 6 7 8] [12 13 14] [18 19 20]], shape=(4, 3), dtype=int64)
经过分割后的b是一个list,包含两个数组元素,每个元素的shape都是4 * 1 * 3,将第一个数组元素经过tf.squeeze()函数处理,可以看到维度变为4 * 3,去掉了维数为1的维度。
截取自知乎:https://zhuanlan.zhihu.com/p/52087724