tf.strided_slice() 用法
strided_slice(
input_,
begin,
end,
strides=None,
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0,
var=None,
name=None
)
简单来看这个函数重点在于前四个参数:
- 第一个是输入数据
- 第二个是开始切片的地方
- 第三个是终止切片的地方
- 第四个是步长
这里的关键在于如何在多维的情况切片,官网给的例子是
tf.strided_slice( input_, begin,end,strides )
input = [[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]]
tf.strided_slice(input, [1, 0, 0], [2, 1, 3], [1, 1, 1]) ==> [[[3, 3, 3]]]
tf.strided_slice(input, [1, 0, 0], [2, 2, 3], [1, 1, 1]) ==> [[[3, 3, 3],
[4, 4, 4]]]
tf.strided_slice(input, [1, -1, 0], [2, -3, 3], [1, -1, 1]) ==>[[[4, 4, 4],
[3, 3, 3]]]
[切片的时候是左闭右开)
示例1
tf.strided_slice(t, [1, 0, 0], [2, 1, 3], [1, 1, 1])
- 一个维度一个维度地看.
- begin的第0维是1,end的第0维是2,取值区间为[1,2),
- 第0维返回索引为1的元素,即[[3, 3, 3], [4, 4, 4]]
- 第1维,取值区间为[0,1),在第0维结果的基础上,第1维返回索引为0的元素,即[3, 3, 3]
- 第2维,取值区间为[0,3),在第1维结果的基础上,第2维返回索引为0,1,2的元素,即[3, 3, 3]
- 最终结果为[3, 3, 3]
示例2
tf.strided_slice(t, [1, 0, 0], [2, 2, 3], [1, 1, 1])
- 第0维取值区间为[1,2),返回第0维索引为1的元素,即[[3, 3, 3], [4, 4, 4]]
- 第1维取值区间为[0,2),在第0维结果的基础上,返回第1维索引为0,1的元素,即[3, 3, 3], [4, 4, 4]
- 第2维取值区间为[0,3],在第1维结果的基础上,返回第2维索引为0,1,2的元素,即[3, 3, 3], [4, 4, 4],这里注意,因- 为第1维结果是两个list,0,1,2这三个索引分别作用于这两个list
- 最终结果[[3, 3, 3], [4, 4, 4]]
示例3
tf.strided_slice(t, [1, -1, 0], [2, -3, 3], [1, -1, 1])
- 第0维取值区间[1,2),第0维返回索引为1的元素,即[[3, 3, 3], [4, 4, 4]]
- 第1维取值区间[-1,-3),在第0维结果的基础上,返回第1维索引为-1,-2的元素,即[4, 4, 4],[3, 3, 3]
- 第2维取值区间为[0,3),在第1维结果的基础上,返回第2维索引为0,1,2的元素,即[4, 4, 4],[3, 3, 3]
- 最终结果[[4, 4, 4],[3, 3, 3]]