主要用于截取张量的部分内容。该函数的原型是:
strided_slice(
input_,
begin,
end,
strides=None,
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0,
var=None,
name=None
)
我们经常用的是前四个参数,它们的含义如下:
- input:输入的原始张量
- begin:每一阶的的起始位置,一阶张量,维度等于input的阶数
- end:每一阶的结束位置(是开区间),一阶张量,维度等于input的阶数
- strides:每一阶的步长,一阶张量,维度等于input的阶数
输出:阶数与input相同的张量
例:
import tensorflow as tf
data = [[[1, 2, 3], [2, 3, 4]],
[[3, 4,5], [4, 5, 6]],
[[6, 7, 8], [7, 8, 9]]]
x = tf.strided_slice(data,[0,0,0],[1,1,1])
y = tf.strided_slice(data,[0,0,0],[3,2,2],[1,1,1])
z = tf.strided_slice(data,[0,0,0],[2,2,2],[1,2,1])
with tf.Session() as sess:
print(sess.run(x))
print(sess.run(y))
print(sess.run(z))