tf.slice的理解

1 源代码注释的解释

  • This operation extracts a slice of size size from a tensor input starting at the location specified by begin. The slice size is represented as tensor shape, where size[i] is the number of elements of the 'i’th dimension of input that you want to slice. The starting location (begin) for the slice is represented as an offset in each dimension of input. In other words, begin[i] is the offset into the 'i’th dimension of input that you want to slice from.

2 tf.slice函数的参数

slice(input, begin, size, name)
input:输入的tensor变量。
begin:每个维度的起始位置。
size:从每个维度切取的大小。

3 实例解释

测试代码如下:

import tensorflow as tf

data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [2, 1, 0], [1, 1, 3])
with tf.Session() as sess:
    data, slice_ = sess.run([data, slice_])
    print('原始大小:\n', data.shape)
    print('原始数据:\n',data)
    print('切取后大小:\n', slice_.shape)
    print('切取后数据:\n',slice_)

结果输出如下:

原始大小:
 (4, 2, 3)
原始数据:
 [[[ 1  2  3]
  [11 12 13]]

 [[21 22 23]
  [31 32 33]]

 [[41 42 43]
  [51 52 53]]

 [[61 62 63]
  [71 72 73]]]
切取后大小:
 (1, 1, 3)
切取后数据:
 [[[51 52 53]]]

可以看到,原始的列表 data 是三维的,即[4, 2, 3],最后切取到的大小也是三维的[1 ,1, 3],切取到的是 [[[51 52 53]]]。

每个维度切取的起始位置分别是[2, 1, 0],并分别在每个维度切取[1,1,3]的大小。(注意python列表的下标是从0开始的)

我们先把最外层的中括号去掉,并以逗号分隔,得到原 data 列表第一维。

第一维是三个元素,分别是:

  • [[1, 2, 3], [11, 12, 13]]
  • [[21, 22, 23], [31, 32, 33]]
  • [[41, 42, 43], [51, 52, 53]]
  • [[61, 62, 63], [71, 72, 73]]

如上介绍,第一维切取的起始位置是3,切取1的大小。所以只得到第一维的第三个元素,即[[41, 42, 43], [51, 52, 53]] 。

接下来我们继续去掉最外层的中括号,以逗号分隔得到data列表第一维切取后的第二维如下:

  • [41, 42, 43]
  • [51, 52, 53]

如上介绍,第二维切取的起始位置是2,切取1的大小。所以得到第二维的第2个元素,即 [51, 52, 53] 。

第三维从第1个位置开始切,切取3个元素,得到 51, 52, 53。

4 进阶

如何在不知道列表各个维度大小的情况下进行切取?

举例如下:

import tensorflow as tf

data = tf.constant([[[1, 2, 3], [11, 12, 13]], [[21, 22, 23], [31, 32, 33]], [[41, 42, 43], [51, 52, 53]], [[61, 62, 63], [71, 72, 73]]])
slice_ = tf.slice(data, [0, 1, 0], [-1, -1, 2])
with tf.Session() as sess:
    data, slice_ = sess.run([data, slice_])
    print('原始大小:\n', data.shape)
    print('原始数据:\n',data)
    print('切取后大小:\n', slice_.shape)
    print('切取后数据:\n',slice_)

输出为:

原始大小:
 (4, 2, 3)
原始数据:
 [[[ 1  2  3]
  [11 12 13]]

 [[21 22 23]
  [31 32 33]]

 [[41 42 43]
  [51 52 53]]

 [[61 62 63]
  [71 72 73]]]
切取后大小:
 (4, 1, 2)
切取后数据:
 [[[11 12]]

 [[31 32]]

 [[51 52]]

 [[71 72]]]

-1可指代最后的一位。

具体请读者自行揣摩。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值