tf.split()在keras中切割张量

TensorFlow中的split函数

1. tf.split()函数

tf.split(
    value,
    num_or_size_splits,
    axis=0,
    num=None,
    name='split'
)

value:传入的tensor,就是传入的矩阵或高维矩阵
num_or_size_splits:切成几份,比如输入的是2,那么会在传入的axis上切成2份。举个栗子:2 x 32 x 32的tensor,在axis = 0上切2份,就是两个1 x 32 x 32,具体顺序如何,一会再讨论。
axis:在哪个轴上进行切割,比如2 x 32 x 32,依次对应axis = 0 , 1 和 2

2. tf.split()在keras中切割张量

Lambda层

keras.layers.Lambda(function, output_shape=None, mask=None, arguments=None)

如果split()想要用到keras中,就必须套入Lambda,作为神经网络的一层出现。具体写法如下:

x = Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 4})(input_tensor)

Lambda层的第一个参数是要作为层出现的函数,第二个参数形式为字典,指的是要传入前面函数的参数,其中key是函数API中定义的参数名称,value是要传入的参数。
这里需要注意,tensor这个参数作为Lambda的层的输入,写在最后的(input_tensor)里

Lambda层的详细介绍见keras中文文档

切割张量的排列顺序

以 input_tensor.shape = (2, 32, 32) 为例,输入到网络的 shape 应该是
(?, 2, 32, 32)
用如下 split() 进行切割

x = Lambda(tf.split, arguments={'axis': 2, 'num_or_size_splits': 4})(input_tensor)

x 的形状应该是

# x.shape = (4,?,2,8,32)

即 split() 会将切割的片段放到axis = 0的位置

我的理解是split()先将 ? x 2 x 32 x 32 切成 ? x 2 x 4 x 8 x 32,再转置成4 x ? x 2 x 8 x32
为什么有如上的理解,和下面要说的还原有关

它们具体是如何排列的,只需要 print 一下 tensor 就可以了!

调整顺序

tf.transpose() 转置函数,第一个参数是tensor,第二个参数是axis的顺序

x = Lambda(tf.transpose, arguments={'perm': [1,2,0,3,4]})(x)
# x.shape = (?,2,4,8,32)

还原

还原成最初的样子以及顺序

# 直接reshape
x = Reshape((2, 32, 32,))(x)

这里由于调整顺序部分已经将tensor的顺序调整为2 x 4 x 8 x 32了,即我理解的split()切割的第一步,因此只需要reshape就可以还原回去了。

注:不知道是不是可以写成 2 x 8 x 4 x 32(也许可以…)

### TensorFlow `tf.split` 的使用方法 在 TensorFlow 中,`tf.split` 是用于将张量沿指定轴分割成多个子张量的函数。以下是关于该函数的具体说明以及示例。 #### 函数定义 `tf.split(value, num_or_size_splits, axis=0, num=None, name='split')` - **value**: 要被分割的输入张量。 - **num_or_size_splits**: 如果是一个整数,则表示要将张量均匀分成多少份;如果是一个列表或一维数组,则表示每一份的大小[^1]。 - **axis**: 表示沿着哪个维度进行分割,默认为 0。 - **name**: 运算名称(可选参数),默认为 'split'。 #### 示例代码 下面提供几个具体的例子来展示如何使用 `tf.split`: ```python import tensorflow as tf # 创建一个简单的二维张量 tensor = tf.constant([[1, 2, 3], [4, 5, 6]]) # 将张量按照第0轴分为两部分 result_1 = tf.split(tensor, num_or_size_splits=2, axis=0) # 打印结果 print(result_1) # 输出: [<tf.Tensor: shape=(1, 3), ...>, <tf.Tensor: shape=(1, 3), ...>] # 将张量按照第1轴分为三部分 result_2 = tf.split(tensor, num_or_size_splits=[1, 1, 1], axis=1) # 打印结果 print(result_2) # 输出: [<tf.Tensor: shape=(2, 1), ...>, ..., <tf.Tensor: shape=(2, 1), ...>] ``` 上述代码展示了两种不同的分割方式:一种是通过指定份数让 TensorFlow 自动计算每一部分的尺寸;另一种则是手动设定各部分的尺寸。 #### 注意事项 当使用 `num_or_size_splits` 参数时需要注意: - 若传入的是整数值 n,则原张量会被平均划分为 n 个形状相同的子张量; - 若传入的是长度为 m 的序列 s,则会依据此序列中的每一个元素所代表的数量依次切分原始数据,并且这些数量之总和应等于目标维度上的全部条目数目[^3]。 此外,在实际应用过程中还需确保待切割方向上能够满足所需划分的要求,否则将会抛出错误提示。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值