def slice(input_, begin, size, name=None):
# pylint: disable=redefined-builtin
"""Extracts a slice from a tensor.
This operation extracts a slice of size `size` from a tensor `input` starting
at the location specified by `begin`. The slice `size` is represented as a
tensor shape, where `size[i]` is the number of elements of the 'i'th dimension
of `input` that you want to slice. The starting location (`begin`) for the
slice is represented as an offset in each dimension of `input`. In other
words, `begin[i]` is the offset into the 'i'th dimension of `input` that you
want to slice from.
Note that `tf.Tensor.__getitem__` is typically a more pythonic way to
perform slices, as it allows you to write `foo[3:7, :-2]` instead of
`tf.slice(foo, [3, 0], [4, foo.get_shape()[1]-2])`.
`begin` is zero-based; `size` is one-based. If `size[i]` is -1,
all remaining elements in dimension i are included in the
slice. In other words, this is equivalent to setting:
`size[i] = input.dim_size(i) - begin[i]`
This operation requires that:
`0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n]`
For example:
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
tf.slice(t, [1, 0, 0], [1, 1, 3]) # [[[3, 3, 3]]]
tf.slice(t, [1, 0, 0], [1, 2, 3]) # [[[3, 3, 3],
# [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [2, 1, 3]) # [[[3, 3, 3]],
# [[5, 5, 5]]]
Args:
input_: A `Tensor`.
begin: An `int32` or `int64` `Tensor`.
size: An `int32` or `int64` `Tensor`.
name: A name for the operation (optional).
Returns:
A `Tensor` the same type as `input`.
"""
这个函数式tensorflow中的一个张量切片函数,三个参数:input, begen, size.
简单来说:input是输入张量,begin切片每个维度开始的索引,size是每个维度切片的大小。
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
tf.slice(t, [1, 0, 0], [1, 1, 3]) # [[[3, 3, 3]]]
tf.slice(t, [1, 0, 0], [1, 2, 3]) # [[[3, 3, 3],
# [4, 4, 4]]]
tf.slice(t, [1, 0, 0], [2, 1, 3]) # [[[3, 3, 3]],
# [[5, 5, 5]]]
begin: 1 0 0 表示从[[3, 3, 3], [4, 4, 4]] 开始,size : [1, 1, 3] 表示切一个,得到结果:[[3, 3, 3], [4, 4, 4]]
然后 从0 开始切,切1 个, 得到: [3,3,3]
然后从0 开始切,切3 个,得到 : [3,3,3]
最终结果就是:[[[3,3,3]]]
如果稍微改变下: begin :1,1,1, size:[2,1,2] 结果如何?
从1 开始切,切2个,得到:
[[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]]
从1 开始切,切1个:得到:
[[ [4, 4, 4]],
[ [6, 6, 6]]]
从1开始切,切2个,得到:
[[[4,4]],[[6,6]]] 形状就是: 2,1,3
代码验证:
t = tf.constant([[[1, 1, 1], [2, 2, 2]],
[[3, 3, 3], [4, 4, 4]],
[[5, 5, 5], [6, 6, 6]]])
x = tf.slice(t, [1, 1, 1], [2, 1, 2])
sess = tf.Session()
sess.run(x)
Out[21]:
array([[[4, 4]],
[[6, 6]]])
可以看到个之前的解读是一样的,然后如果将size 的某个参数改成-1的话,则默认全部输出该维度的值。
pytorch 中好像没有这么方便的函数,不过可以手动进行切:
t[1:3,1:2,1:3]
Out[22]: <tf.Tensor 'strided_slice:0' shape=(2, 1, 2) dtype=int32>
sess.run(t[1:3,1:2,1:3])
Out[23]:
array([[[4, 4]],
[[6, 6]]])
结果是一样的。