最近在看CIFAR10的代码,其中,在cifar10_input.py里面,出现了
- code1:
# The first bytes represent the label, which we convert from uint8->int32
result.label = tf.cast(
tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32
)
于是,对tf.strided_slice()和tf.cast()有了认识。于是,整理如下:
- tf.cast()的认识
def cast(x, dtype, name=None):
官方解释:Casts a tensor to a new type.
其中,
x:是一个Tensor或者SparseTensor
dtype:目标类型
name:这个op的名字(可选)
举例说明:
# tensor `a` is [6.3, 7.4], dtype=tf.float
tf.cast(a, tf.int32) ==> [6, 7] # dtype=tf.int32
因此,在code1代码中,result.label的type由tf.cast()从uint8变成了int32.
- tf.strided_slice()的认识
def strided_slice(input_,
begin,
end,
strides=None,
begin_mask=0,
end_mask=0,
ellipsis_mask=0,
new_axis_mask=0,
shrink_axis_mask=0,
var=None,
name=None):
官方解释:
To a first order, this operation extracts a slice of size `end - begin`
from a tensor `input`
starting at the location specified by `begin`. The slice continues by adding
`stride` to the `begin` index until all dimensions are not less than `end`.
Note that components of stride can be negative, which causes a reverse
slice.
简而言之,就是:
从输入tensor ‘input’中提取一个从‘begin’位置开始,长度为’end - begin’的片段。片段增加步长为’stride’,直到所有的维度不小于‘end’.但是,如果stride的中有负数,那么,会产生一个顺序相反的slice.
举例说明:
# 'input' is [[[1, 1, 1], [2, 2, 2]],
# [[3, 3, 3], [4, 4, 4]],
# [[5, 5, 5], [6, 6, 6]]]
tf.strided_slice(input, [1, 0, 0], [2, 1, 3], [1, 1, 1]) ==> [[[3, 3, 3]]]
tf.strided_slice(input, [1, 0, 0], [2, 2, 3], [1, 1, 1]) ==> [[[3, 3, 3],
[4, 4, 4]]]
tf.strided_slice(input, [1, -1, 0], [2, -3, 3], [1, -1, 1]) ==>[[[4, 4, 4],
[3, 3, 3]]]
在上面例子第三个tf.strided_slice()中,
begin = [1, -1, 0]
end = [2, -3, 3]
strides = [1, -1, 1]
其中,begin中的 -1 表示要从第二维最后一个元素开始,strides中的 -1 表示第二维中每次增长步长为-1,于是,取出的元素下标是-1, -2, -3,… ,且因为 strides中的第二维步长为负数,所以,第二维元素取出后是反方向,而end中的 -3 表示截至于第二维中倒数第二个元素(包括倒数第二个元素,下标为-2),所以,最终,输出结果为[[[4, 4, 4], [3, 3, 3]]]