文章同步更新在公众号 AIPlayer,欢迎扫码关注,共同进步
目录
一、原理
1、函数原型
tf.strided_slice(input_, begin, end, strides=None,
begin_mask=0, end_mask=0, ellipsis_mask=0,
new_axis_mask=0, shrink_axis_mask=0,
var=None, name=None)
这个运算从给定的 input_ 张量中提取一个尺寸 (end-begin)/stride 的片段,从 begin 片段指定的位置开始,以步长 stride 添加索引,直到所有维度都不小于 end,这里的 stride 可以是负值,表示反向切片。
这里需要注意的两点是:
-
取值区间为 [ begin, end ),左闭右开
-
返回的张量各维度大小的计算:abs(end - begin) / stride
2、举例说明
假设需要输入的张量为:
input = tf.constant([[[1,2,3], [4,5,6]],
[[7,8,9], [10,11,12]],
[[13,14,15], [16,17,18]]])
下面列举了在不同维度上使用不同步长的分片结果:
(1)示例1
tf.strided_slice(input,
begin=[1, 0, 0],
end=[2, 1, 3],
strides=[1, 1, 1])
-
对于第0维,begin是1,end是2,begin+stride=2,2大于等于end的第0维,所以不用继续加stride,取值区间为[1,2),第0维返回索引为1的元素,即[[7,8,9], [10,11,12]]
-
对于第1维,取值区间为[0,1),在第0维结果的基础上,第1维返回索引为0的元素,即[7,8,9]
-
对于第2维,取值区间为[0,3),在第1维结果的基础上,第2维返回索引为0,1,2的元素,即[7,8,9]
所以最终结果为[[[7,8,9]]]
(2)示例2
tf.strided_slice(input,
begin=[1, 0, 0],
end=[2, 2, 3],
strides=[1, 1, 1])
-
对于第0维,取值区间为[1,2),返回第0维索引为1的元素,即[[7,8,9], [10,11,12]]
-
对于第1维,取值区间为[0,2),在第0维结果的基础上,返回第1维索引为0,1的元素,即[7, 8, 9], [10, 11, 12]
-
对于第2维,取值区间为[0,3),在第1维结果的基础上,返回第2维索引为0,1,2的元素,即[7,8,9], [10,11,12]
所以最终结果[[[7,8,9], [10,11,12]]]
(3)示例3
tf.strided_slice(input,
begin=[1, -1, 0],
end=[2, -3, 3],
strides=[1, -1, 1])
-
对于第0维,取值区间[1,2),第0维返回索引为1的元素,即[[7,8,9], [10,11,12]]
-
对于第1维,取值区间[-1,-3),在第0维结果的基础上,返回第1维索引为-1,-2的元素,表示倒序取值,即[10,11,12],[7,8,9]
-
对于第2维,取值区间为[0,3),在第1维结果的基础上,返回第2维索引为0,1,2的元素,即[10,11,12],[7,8,9]
所以最终结果为[[[10,11,12],[7,8,9]]]
二、使用场景
tf.strided_slice()通常会使用在序列化模型的数据处理中,比如在 seq2seq 模型中,需要把目标数据切片再送入到decoder网络中:
ending = tf.strided_slice(targets, [0, 0], [batch_size, -1], [1, 1])
dec_input = tf.concat([tf.fill([batch_size, 1], target_letter_to_int['<s>']),
ending], 1)
文章同步更新在公众号 AIPlayer,欢迎扫码关注,共同进步
往期内容: