关于tf.strided_slice函数tensorflow官方文档介绍链接https://www.tensorflow.org/api_docs/python/tf/strided_slice过于简单,本文使用简短的测试用例补充说明strided_slice的详细用法。首先给出接口定义:
tf.strided_slice(
input_,
begin,
end,
strides=None,
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0,
var=None,
name=None
)
其中input表示一个tensor,tf.strided_slice完成对该tensor的切片操作,begin,end 以及strides参数是长度为N的向量(N等于input的维度)。begin[i]、end[i]表示对应的维度的切片开始下标和结束下标。使用python测试用例说明如下:
import tensorflow as tf
import numpy as np
data_a = [
[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]
]
print np.shape(data_a)
c = tf.strided_slice(data_a, [0,0,0],[3,2,3],[1,1,1], 0, 0, 0, 0, 0)
c1 = tf.strided_slice(data_a, [1,0,0],[3,2,3],[1,1,1], 0, 0, 0, 0, 0)
c2 = tf.strided_slice(data_a, [0,0,0],[2,2,3],[1,1,1], 0, 0, 0, 0)
c3 = tf.strided_slice(data_a, [0,0,0],[3,2,3],[2,1,1], 0, 0, 0, 0)
with tf.Session() as session:
session.run(tf.global_variables_initializer())
out = session.run(c)
out1 = session.run(c1)
out2 = session.run(c2)
out3 = session.run(c3)
print 'c shape:', np.shape(out)
print out
print 'c1 shape:', np.shape(out1)
print out1
print 'c2 shape:', np.shape(out2)
print out2
print 'c3 shape:', np.shape(out3)
print out3
结果分析:
#c作为参考结果,begin、end均跟原始shape一致,stride步长为1,结果为原始输入不做切片:
c shape: (3, 2, 3)
[
[[1 1 1] [2 2 2]]
[[3 3 3] [4 4 4]]
[[5 5 5] [6 6 6]]
]
#c1的begin[0] = 1表示第一个维度从下标1开始切分数据:
c1 shape: (2, 2, 3)
[
[[3 3 3] [4 4 4]]
[[5 5 5] [6 6 6]]
]
#c2的end[0]=2,表示第一个维度切分出索引为0,1的数据,索引为2及以后的数据均被切除:
c2 shape: (2, 2, 3)
[
[[1 1 1] [2 2 2]]
[[3 3 3] [4 4 4]]
]
#c3的stride[0]=2,表示第一个维度按步长为2进行切分,所以中间的[[3 3 3] [4 4 4]]给切掉了:
c3 shape: (1, 2, 3)
[
[[1 1 1] [2 2 2]]
[[5 5 5] [6 6 6]]
]
1.begin_mask掩码:使用二进制flag对input tensor不同维度进行标志,第i位设置为1则begin[i]参数对应的第i维度设置无效,表示该维度的起始索引从0开始。
参考:c = tf.strided_slice(data_a, [0,0,2],[3,2,3],[1,1,1], 0, 0, 0, 0, 0)
结果:shape: (3, 2, 1)
[
[[1][2]]
[[3][4]]
[[5][6]]
]
测试1:c = tf.strided_slice(data_a, [0,0,2],[3,2,3],[1,1,1], 0b100, 0, 0, 0, 0)
结果:shape: (3, 2, 3)
[
[[1 1 1] [2 2 2]]
[[3 3 3] [4 4 4]]
[[5 5 5] [6 6 6]]
]
测试2:c = tf.strided_slice(data_a, [0,0,2],[3,2,3],[1,1,1], 0b110, 0, 0, 0, 0)
结果:shape: (3, 2, 3)
[
[[1 1 1] [2 2 2]]
[[3 3 3] [4 4 4]]
[[5 5 5] [6 6 6]]
]
2.end_mask掩码:功能类似begin_mask。使用二进制flag对input tensor不同维度进行标志,第i位设置为1则end参数对应的该维度设置无效,表示该维度切分的结束索引到列表最后即切分到尽可能大的范围。
参考:c = tf.strided_slice(data_a, [0,0,0],[3,2,1],[1,1,1], 0, 0, 0, 0, 0)
结果:shape: (3, 2, 2)
[
[[1] [2]]
[[3] [4]]
[[5] [6]]
]
测试:c = tf.strided_slice(data_a, [0,0,0],[3,2,1],[1,1,1], 0, 0b100, 0, 0, 0)
结果:shape: (3, 2, 3)
[
[[1 1 1] [2 2 2]]
[[3 3 3][4 4 4]]
[[5 5 5] [6 6 6]]
]
3.ellipsis_mask 掩码: 不为零的维度不需要进行切分操作(只允许出现一个非零位)
举例:
c = tf.strided_slice(data_a, [0,0,2],[3,2,3],[1,1,1], 0, 0, 0, 0, 0)
结果:shape: (3, 2, 2)
[
[[1] [2]]
[[3] [4]]
[[5] [6]]
]
c = tf.strided_slice(data_a, [0,0,2],[3,2,3],[1,1,1], 0, 0, 0b100, 0, 0)
结果:shape: (3, 2, 3)
[
[[1 1 1] [2 2 2]]
[[3 3 3] [4 4 4]]
[[5 5 5] [6 6 6]]
]
4. new_axis_mask掩码:如果第i位出现1,则begin, end, and stride对所有维度参数无效,并在第1位上增加一个大小为1的维度。
参考:c = tf.strided_slice(data_a, [1,0,0],[3,2,3],[1,1,1], 0, 0, 0, 0b000, 0)
结果:shape: (2, 2, 3)
[
[[3 3 3] [4 4 4]]
[[5 5 5] [6 6 6]]
]
测试:c = tf.strided_slice(data_a, [1,0,0],[3,2,3],[1,1,1], 0, 0, 0, 0b001, 0)
结果:shape: (1, 2, 2, 3)
[
[
[[1 1 1] [2 2 2]]
[[3 3 3] [4 4 4]]
]
]
5. shrink_axis_mask掩码:第i位设置为1则意味着第i维度缩小为1。
参考:c1 = tf.strided_slice(data_a, [0,0,0],[3,2,3],[1,1,1], 0, 0, 0, 0, 0)
结果:shape: (3, 2, 3)
[
[[1 1 1] [2 2 2]]
[[3 3 3] [4 4 4]]
[[5 5 5] [6 6 6]]
]
测试:c1 = tf.strided_slice(data_a, [0,0,0],[3,2,3],[1,1,1], 0, 0, 0, 0, 0b100)
结果:shape: (3, 2)
[
[1 2]
[3 4]
[5 6]
]