TensorFlow 中张量切分操作 tf.split 使用实例

一、环境

TensorFlow API r1.14(rc)

CUDA 9.0 V9.0.176

Python 3.6.3

二、官方说明

把张量分解成子张量

https://tensorflow.google.cn/versions/r1.14/api_docs/python/tf/split

tf.split(
value,
num_or_size_splits,
axis=0,
num=None,
name=‘split’
)

参数:
value:要分割的张量
num_or_size_splits:可以是整数(指定把张量划分为几分,需要注意的是必须能整除 value.shape[axis])、一维张量或Python 列表(指定划分输出的每一个子张量的大小,需要注意的是张量或列表中的元素之和需要等于 value 要划分的维度的大小)
axis:整数或标量 int32 张量,指定沿那个维度分割张量,数值必须在 [-rank(value), rank(value)] 之间,默认是 0。
num:可选参数。当不能从 num_or_size_splits 的形状推断输出的数量时,通过该参数指定
name:可选参数。操作的名称

三、实例

>>> import tensorflow as tf
>>> import numpy as np
>>> tf.enable_eager_execution()
>>> data = np.random.random((5,10))
>>> data_tensor = tf.constant(data)

# 标量
>>> splited_interger = 5
>>> split0_0, split0_1, split0_2, split0_3, split0_4 = tf.split(data_tensor, splited_interger, -1)

>>> tf.shape(split0_0)
<tf.Tensor: id=9, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_1)
<tf.Tensor: id=11, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_2)
<tf.Tensor: id=13, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_3)
<tf.Tensor: id=15, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split0_4)
<tf.Tensor: id=17, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>

# Python 列表
>>> splited_list = [2,3,5]
>>> split1_0, split1_1, split1_2 = tf.split(data_tensor, splited_list, -1)
>>> tf.shape(split1_0)

<tf.Tensor: id=24, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split1_1)
<tf.Tensor: id=26, shape=(2,), dtype=int32, numpy=array([5, 3], dtype=int32)>
>>> tf.shape(split1_2)
<tf.Tensor: id=28, shape=(2,), dtype=int32, numpy=array([5, 5], dtype=int32)>

# 1 维张量
>>> splited_tensor = tf.constant(splited_list)
>>> split2_0, split2_1, split2_2 = tf.split(data_tensor, splited_tensor, -1)
>>> tf.shape(split2_0)

<tf.Tensor: id=38, shape=(2,), dtype=int32, numpy=array([5, 2], dtype=int32)>
>>> tf.shape(split2_1)
<tf.Tensor: id=40, shape=(2,), dtype=int32, numpy=array([5, 3], dtype=int32)>
>>> tf.shape(split2_2)
<tf.Tensor: id=42, shape=(2,), dtype=int32, numpy=array([5, 5], dtype=int32)>
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

csdn-WJW

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值