# (None, 49, 49, 16, 16, 3) --->49个(None,49,16,16,3) 拆分成list
x = Lambda(tf.split, arguments={'axis': 1, 'num_or_size_splits': 49})(x)
如果split()想要用到keras中,就必须套入Lambda,作为神经网络的一层出现。
lambda中文文档
例:? x 2 x 32 x 32 切成 ? x 2 x 4 x 8 x 32,再转置成4 x ? x 2 x 8 x32
1 用如下 split() 进行切割
x = Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 4})(input_tensor)
# x.shape = (4,?,2,8,32)
2 tf.transpose() 转置函数,第一个参数是tensor,第二个参数是axis的顺序
x = Lambda(tf.transpose, arguments={'perm': [1,2,0,3,4]})(x)
# x.shape = (?,2,4,8,32)
3 还原成最初的样子以及顺序
# 直接reshape
x = Reshape((2, 32, 32,))(x)