tf.strided_slice是多维切片函数,网上给出了很多的说明,可是还是不容易理解,这里给出自己的理解。
直接在官网的示例上给出分析
strided_slice(
input_, begin, end, strides=None, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0, var=None, name=None
)
前四个参数分别为:输入数据,开始切片处,终止切片处,步长。区间为开区间
# 'input' is [[[1, 1, 1], [2, 2, 2]],
# [[3, 3, 3], [4, 4, 4]],
# [[5, 5, 5], [6, 6, 6]]]
tf.strided_slice(input, [1, 0, 0], [2, 1, 3], [1, 1, 1]) ==> [[[3, 3, 3]]]
tf.strided_slice(input, [1, 0, 0], [2, 2, 3], [1, 1, 1]) ==> [[[3, 3, 3], [4, 4, 4]]]
tf.strided_slice(input, [1, -1, 0], [2, -3, 3], [1, -1, 1]) ==>[[[4, 4, 4], [3, 3, 3]]]
第一个例子的第一维为(1,2)所以切出了[[3,3,3],[4,4,4]],第二维为(0,1),所以切出了[3,3,3],第三维为(0,3),所以最终切出了[[[3,3,3]]].
同理可得第二个例子,依次可以得到[[3,3,3],[4,4,4]],[[3,3,3],[4,4,4]],[[[3, 3, 3], [4, 4, 4]]]。
第三个例子大家自己试验一下吧。