链接:https://www.tensorflow.org/api_docs/python/tf/strided_slice
说明:对传入的tensor执行切片操作,返回切片后的tensor。主要参数
input_, start, end, strides
,strides
代表切片步长。例子:
参考链接:https://github.com/NELSONZHAO/zhihu/tree/master/basic_seq2seq?1521452873816# 'input' is [[[1, 1, 1], [2, 2, 2]], # [[3, 3, 3], [4, 4, 4]], # [[5, 5, 5], [6, 6, 6]]] tf.strided_slice(input, [1, 0, 0], [2, 1, 3], [1, 1, 1]) ==> [[[3, 3, 3]]] # 上面一行代码中[1,0,0]分别代表原数组三个维度的切片起始位置,[2,1,3]代表结束位置。 [1,1,1]代表切片步长,表示在三个维度上切片步长都为1。我们的原始输入数据为3 x 2 x 3, 通过参数我们可以得到,第一个维度上切片start=1,end=2, 第二个维度start=0, end=1,第三个维度start=0, end=3。 我们从里面的维度来看,原始数据的第三个维度有三个元素,切片操作start=0,end=3,stride=1,代表第三个维度上的元素我们全部保留。 同理,在第二个维度上,start=0, end=1, stride=1,代表第二个维度上只保留第一个切片,这样我们就只剩下[[[1,1,1]],[[3,3,3]],[[5,5,5]]]。 接着我们看第一个维度,start=1, end=2, stride=1代表只取第二个切片,因此得到[[[3,3,3]]。以下两个例子同理。 tf.strided_slice(input, [1, 0, 0], [2, 2, 3], [1, 1, 1]) ==> [[[3, 3, 3], [4, 4, 4]]] tf.strided_slice(input, [1, -1, 0], [2, -3, 3], [1, -1, 1]) ==>[[[4, 4, 4], [3, 3, 3]]]
- 注意:如何看numpy数组的维度:以input数组为例,input维度是(3,2,3),即第一维有3个元素,第2维有2个元素,第3维有3个元素。技巧:第一维是看第一个方括号里包了几个元素,第二维是看第二个方括号里包了几个元素。。。。以此类推。
tensorflow学习笔记--tf.strided_slice
最新推荐文章于 2022-03-03 22:30:47 发布