1 源代码注释的解释
- This operation extracts a slice of size
size
from a tensorinput
starting at the location specified bybegin
. The slicesize
is represented as tensor shape, wheresize[i]
is the number of elements of the 'i’th dimension ofinput
that you want to slice. The starting location (begin
) for the slice is represented as an offset in each dimension ofinput
. In other words,begin[i]
is the offset into the 'i’th dimension ofinput
that you want to slice from.
2 tf.slice函数的参数
slice(input, begin, size, name)
input:输入的tensor变量。
begin:每个维度的起始位置。
size:从每个维度切取的大小。
3 实例解释
测试代码如下:
import tensorflow as tf
data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [2, 1, 0], [1, 1, 3])
with tf.Session() as sess:
data, slice_ = sess.run([data, slice_])
print('原始大小:\n', data.shape)
print('原始数据:\n',data)
print('切取后大小:\n', slice_.shape)
print('切取后数据:\n',slice_)
结果输出如下:
原始大小:
(4, 2, 3)
原始数据:
[[[ 1 2 3]
[11 12 13]]
[[21 22 23]
[31 32 33]]
[[41 42 43]
[51 52 53]]
[[61 62 63]
[71 72 73]]]
切取后大小:
(1, 1, 3)
切取后数据:
[[[51 52 53]]]
可以看到,原始的列表 data 是三维的,即[4, 2, 3],最后切取到的大小也是三维的[1 ,1, 3],切取到的是 [[[51 52 53]]]。
每个维度切取的起始位置分别是[2, 1, 0],并分别在每个维度切取[1,1,3]的大小。(注意python列表的下标是从0开始的)
我们先把最外层的中括号去掉,并以逗号分隔,得到原 data 列表第一维。
第一维是三个元素,分别是:
- [[1, 2, 3], [11, 12, 13]]
- [[21, 22, 23], [31, 32, 33]]
- [[41, 42, 43], [51, 52, 53]]
- [[61, 62, 63], [71, 72, 73]]
如上介绍,第一维切取的起始位置是3,切取1的大小。所以只得到第一维的第三个元素,即[[41, 42, 43], [51, 52, 53]] 。
接下来我们继续去掉最外层的中括号,以逗号分隔得到data列表第一维切取后的第二维如下:
- [41, 42, 43]
- [51, 52, 53]
如上介绍,第二维切取的起始位置是2,切取1的大小。所以得到第二维的第2个元素,即 [51, 52, 53] 。
第三维从第1个位置开始切,切取3个元素,得到 51, 52, 53。
4 进阶
如何在不知道列表各个维度大小的情况下进行切取?
举例如下:
import tensorflow as tf
data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [0, 1, 0], [-1, -1, 2])
with tf.Session() as sess:
data, slice_ = sess.run([data, slice_])
print('原始大小:\n', data.shape)
print('原始数据:\n',data)
print('切取后大小:\n', slice_.shape)
print('切取后数据:\n',slice_)
输出为:
原始大小:
(4, 2, 3)
原始数据:
[[[ 1 2 3]
[11 12 13]]
[[21 22 23]
[31 32 33]]
[[41 42 43]
[51 52 53]]
[[61 62 63]
[71 72 73]]]
切取后大小:
(4, 1, 2)
切取后数据:
[[[11 12]]
[[31 32]]
[[51 52]]
[[71 72]]]
-1可指代最后的一位。
具体请读者自行揣摩。