定义
tf.strided_slice(
input_,
begin,
end,
strides=None,
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0,
var=None,
name=None,
)
理解
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
#第1维[0,1)的所有数组
>tf.strided_slice(t,[0],[1],[1])
array([[[1, 1, 1],[2, 2, 2]]])
#第1维[0,2)的所有数组
>tf.strided_slice(t,[0],[2],[1])
array([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3],[4, 4, 4]]])
#第1维[0,2),第2维[0,1)的所有数组
>tf.strided_slice(t,[0,0],[2,1],[1,1])
array([[[1, 1, 1]],
[[3, 3, 3]]])
#第1维[0,2),第2维[0,2)的所有数组
>tf.strided_slice(t,[0,0],[2,2],[1,1])
array([[[1, 1, 1],
[2, 2, 2]],
[[3, 3, 3],
[4, 4, 4]]])
#第1维[0,2),第2维[0,2),第3维[0,1)的所有数组
>tf.strided_slice(t,[0,0,0],[2,2,1],[1,1,1])
array([[[1],
[2]],
[[3],
[4]]])
#第1维[0,2),第2维[0,2),第3维[0,2)的所有数组
>tf.strided_slice(t,[0,0,0],[2,2,2],[1,1,1])
array([[[1, 1],
[2, 2]],
[[3, 3],
[4, 4]]])