import numpy as np
import tensorflow as tf
x = np.arange(0,50)
x = x.reshape((5, 10))
print(x.shape) #(5, 10)
split1, split2, split3 = tf.split(x, num_or_size_splits=[2, 3, 5],
axis = 1)
print(split1.shape, split2.shape, split3.shape, sep ='\n ')
(5, 10)
(5, 2)
(5, 3)
(5, 5)
参数:
num_or_size_split: 向量或数值——裁剪的份数
axis :0——行 1——列