【五分钟学习Tensorflow系列】tf.strided_slice()原理及应用

文章同步更新在公众号 AIPlayer,欢迎扫码关注,共同进步

目录

一、原理

二、使用场景


一、原理

1、函数原型

tf.strided_slice(input_, begin, end, strides=None,
                 begin_mask=0, end_mask=0, ellipsis_mask=0,
                 new_axis_mask=0, shrink_axis_mask=0, 
                 var=None, name=None)

    这个运算从给定的 input_ 张量中提取一个尺寸 (end-begin)/stride 的片段,从 begin 片段指定的位置开始,以步长 stride 添加索引,直到所有维度都不小于 end,这里的 stride 可以是负值,表示反向切片。

这里需要注意的两点是:

  • 取值区间为 [ begin, end ),左闭右开

  • 返回的张量各维度大小的计算:abs(end - begin) / stride

2、举例说明

    假设需要输入的张量为:​​​​​​​

input = tf.constant([[[1,2,3], [4,5,6]],
                     [[7,8,9], [10,11,12]],
                     [[13,14,15], [16,17,18]]])

    下面列举了在不同维度上使用不同步长的分片结果:

(1)示例1​​​​​​​

tf.strided_slice(input, 
                 begin=[1, 0, 0], 
                 end=[2, 1, 3], 
                 strides=[1, 1, 1])

  • 对于第0维,begin是1,end是2,begin+stride=2,2大于等于end的第0维,所以不用继续加stride,取值区间为[1,2),第0维返回索引为1的元素,即[[7,8,9], [10,11,12]]

  • 对于第1维,取值区间为[0,1),在第0维结果的基础上,第1维返回索引为0的元素,即[7,8,9]

  • 对于第2维,取值区间为[0,3),在第1维结果的基础上,第2维返回索引为0,1,2的元素,即[7,8,9]

所以最终结果为[[[7,8,9]]]

(2)示例2​​​​​​​

tf.strided_slice(input, 
                 begin=[1, 0, 0], 
                 end=[2, 2, 3], 
                 strides=[1, 1, 1])

  • 对于第0维,取值区间为[1,2),返回第0维索引为1的元素,即[[7,8,9], [10,11,12]]

  • 对于第1维,取值区间为[0,2),在第0维结果的基础上,返回第1维索引为0,1的元素,即[7, 8, 9], [10, 11, 12]

  • 对于第2维,取值区间为[0,3),在第1维结果的基础上,返回第2维索引为0,1,2的元素,即[7,8,9], [10,11,12]

所以最终结果[[[7,8,9], [10,11,12]]]

(3)示例3​​​​​​​

tf.strided_slice(input, 
                 begin=[1, -1, 0], 
                 end=[2, -3, 3], 
                 strides=[1, -1, 1])

  • 对于第0维,取值区间[1,2),第0维返回索引为1的元素,即[[7,8,9], [10,11,12]]

  • 对于第1维,取值区间[-1,-3),在第0维结果的基础上,返回第1维索引为-1,-2的元素,表示倒序取值,即[10,11,12],[7,8,9]

  • 对于第2维,取值区间为[0,3),在第1维结果的基础上,返回第2维索引为0,1,2的元素,即[10,11,12],[7,8,9]

所以最终结果为[[[10,11,12],[7,8,9]]]

二、使用场景

    tf.strided_slice()通常会使用在序列化模型的数据处理中,比如在 seq2seq 模型中,需要把目标数据切片再送入到decoder网络中:​​​​​​​

ending = tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1])
dec_input = tf.concat([tf.fill([batch_size, 1], target_letter_to_int['<s>']), 
                       ending], 1)

 

文章同步更新在公众号 AIPlayer,欢迎扫码关注,共同进步

往期内容:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值