计算机一级电子表格TF函数,tf.slice()函数详解(极详细)

tf.slice()是TensorFlow库中分割张量的一个函数,其定义为def slice(input_, begin, size, name=None):。tf.slice()函数的那些参数设置实在是不好理解,查了好多资料才理解,所以这边记录一下。

1.官方注释

官方的注释如下:

"""Extracts a slice from a tensor.

This operation extracts a slice of size `size` from a tensor `input` starting

at the location specified by `begin`. The slice `size` is represented as a

tensor shape, where `size[i]` is the number of elements of the 'i'th dimension

of `input` that you want to slice. The starting location (`begin`) for the

slice is represented as an offset in each dimension of `input`. In other

words, `begin[i]` is the offset into the 'i'th dimension of `input` that you

want to slice from.

Note that @{tf.Tensor.__getitem__} is typically a more pythonic way to

perform slices, as it allows you to write `foo[3:7, :-2]` instead of

`tf.slice([3, 0], [4, foo.get_shape()[1]-2])`.

`begin` is zero-based; `size` is one-based. If `size[i]` is -1,

all remaining elements in dimension i are included in the

slice. In other words, this is equivalent to setting:

`size[i] = input.dim_size(i) - begin[i]`

This operation requires that:

`0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n]

翻译一下就是:

tf.slice()函数的作用就是从张量中提取想要的切片。此操作从由begin指定位置开始的张量input中提取一个尺寸size的切片.切片size被表示为张量形状,其中size[i]是你想要分割的input的第i维的元素的数量.切片的起始位置(begin)表示为每个input维度的偏移量.换句话说,begin[i]是你想从中分割出来的input的“第i个维度”的偏移量。

请注意,tf.Tensor.__getitem__通常是执行切片的python方式,因为它允许您写foo[3:7, :-2],而不是tf.slice([3, 0], [4, foo.get_shape()[1]-2]).

begin是基于零的;size是一个基础.如果size[i]是-1,则维度i中的所有其余元素都包含在切片中.

换句话说,这相当于设置:

size[i] = input.dim_size(i) - begin[i]

该操作要求:

0 <= begin[i] <= begin[i] + size[i] <= Di for i in [0, n]

看完注释还是挺懵的,下面看看解释。

2.参数解释

def slice(input_, begin, size, name=None):

...

return gen_array_ops._slice(input_, begin, size, name=name)

input_

input_类型为一个tensor,表示的是输入的tensor,也就是被切的那个

begin

begin是一个int32或int64类型的tensor,表示的是每一个维度的起始位置

size

size是一个int32或int64类型的tensor,表示的是每个维度要拿的元素数

name=None

name是操作的名称,可写可不写

return

返回一个和输入类型一样的tensor

3.例子

还是通过例子来讲解会比较容易理解

例1

t = tf.constant([[[1, 1, 1], [2, 2, 2]],

[[3, 3, 3], [4, 4, 4]],

[[5, 5, 5], [6, 6, 6]]])

tf.slice(t, [1, 0, 0], [1, 1, 3]) # 输出[[[3, 3, 3]]]

首先作为一个3维数组t,要先明白他的shape是[3,2,3].

这个shape是怎么来的呢?咱们把这个t分解一下看就好理解了。那一大堆有括号的t,只看它最外面的括号的话,可以看成是:

t = [A, B, C] #这是第一维度

然后每一个里面有两个东西,可以写成:

A = [i, j], B = [k, l], C = [m, n] #这是第二维度

最后,这i, j, k, l, m, n里面分别是:

i = [1, 1, 1], j = [2, 2, 2], k = [3, 3 ,3], l = [4, 4, 4], m = [5, 5, 5], n = [6, 6, 6] # 这是第三维度

所以shape就是中括号 [ ] 的层级里单位的数量。

对于t来说,最外面括号里有3个东西,分别是A, B, C。这三个东西每个里面有两个玩意儿, i和j, k和l, m和n。他们里面每一个又有3个数字。所以t的shape是[3,2,3]。

有了这个基础,我们再来看例子:

t = tf.constant([[[1, 1, 1], [2, 2, 2]],

[[3, 3, 3], [4, 4, 4]],

[[5, 5, 5], [6, 6, 6]]])

tf.slice(t, [1, 0, 0], [1, 1, 3]) # begin = [1, 0, 0]

有了这个基础,我们再来看例子:

tf.slice(t, [1, 0, 0], [1, 1, 3]) # begin = [1, 0, 0]

注意一下,python的数组index是从0开始的。

这里根据顺序我们知道,begin是[1, 0, 0], size是[1, 1, 3]. 他们两个数组的意义是从左至右,每一个数字代表一个维度。上面说了begin的意思是起始位置,那么[1, 0, 0]的意思是在3个维度中,每个维度从哪里算起。

第一维度是[A, B, C]。 begin里[1, 0, 0]是1,也就是从B算起。其次第二维度里B = [k, l](注意啊,我这里只写了B = [k, l],可不代表只有B有用,如果size里第一个数字是2的话,B和C都会被取的),begin里第二个数是0,也就是从k算起。第三维度k = [3, 3 ,3],begin里第三个数是0,就是从第一个3算起。

到现在都能看懂吧?知道了这三个起始点之后,再来看size。

size的意思是每个维度的大小,也就是每个维度取几个元素。size的应该是最后输出的tensor的shape。

例子里面:

tf.slice(t, [1, 0, 0], [1, 1, 3]) # size = [1, 1, 3]

size里第一个是1,意思是在第一个维度取1个元素。t = [A, B, C] begin是起算是B,取一个那就是B了呗。那么第一维度结果就是[B]

size第二个也是1,第二维度B = [k, l], begin里起算是k,取一个是k。那么第二维度结果是[[k]]。

size第三个是3,第三维度k = [3, 3 ,3],begin里起算是第一个3。三个3取3个数,那就要把三个3都取了,所以是

[[[3, 3, 3]]]

看懂了吗?是不是有点像代数?[B]里把B换成[k], 再把k换成[3, 3 ,3]。最后注意中括号的数量,和size一样是[1, 1, 3].

例2

t = tf.constant([[[1, 1, 1], [2, 2, 2]], [[3, 3, 3], [4, 4, 4]], [[5, 5, 5], [6, 6, 6]]])

tf.slice(t, [1, 0, 0], [-1, -1, -1])

对于这种情况,源代码注释中有一句话:

If `size[i]` is -1, all remaining elements in dimension i are included in the slice. In other words, this is equivalent to setting: `size[i] = input.dim_size(i) - begin[i]`

也就是说,如果size输入值是-1的话,在那个维度剩下的数都会slice走。上面的例子中,begin是[1, 0, 0]。三个维度都是-1的话,那么结果: 第一维度是[B,C];第二维度是[[k, l], [m, n]]; 第三维度是[[[3,3,3], [4,4,4]], [[5,5,5], [6,6,6]]]

例3

import tensorflow as tf

sess = tf.Session()

input = tf.constant([[[[1, 1, 1], [2, 2, 2]],

[[3, 3, 3], [4, 4, 4]],

[[5, 5, 5], [6, 6, 6]]],

[[[1, 1, 1], [2, 2, 2]],

[[3, 3, 3], [4, 4, 4]],

[[5, 5, 5], [6, 6, 6]]]])

print(input)

output=tf.slice(input,[0,0,0,1],[2,3,2,1])

print(output)

print(sess.run(output))

参考

码字不易,如果您觉得有帮助,麻烦点个赞再走呗~

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
假设我们有一个包含100个样本的数据集,每个样本有两个特征,一个是图像数据,一个是标签。我们希望使用TensorFlow的队列机制异步读取这些数据,并进行训练。 首先,我们可以使用tf.train.slice_input_producer函数将数据集切分成若干个batch,然后每个batch通过多个线程异步读取数据: ```python import tensorflow as tf # 构造数据集 data = [] for i in range(100): image = ... # 加载图像数据 label = ... # 加载标签数据 data.append((image, label)) # 定义batch大小和线程数 batch_size = 32 num_threads = 4 # 使用slice_input_producer函数将数据集切分成若干个batch image_batch, label_batch = tf.train.slice_input_producer(data, batch_size=batch_size, num_threads=num_threads) # 定义数据预处理函数 def preprocess(image, label): # 对图像数据进行预处理 image = ... # 对标签数据进行预处理 label = ... return image, label # 使用map函数将数据预处理函数应用到每个batch中的每个样本 image_batch, label_batch = tf.map_fn(preprocess, (image_batch, label_batch)) # 定义模型 ... # 定义损失函数 ... # 定义优化器 ... # 定义训练操作 train_op = ... # 启动会话 with tf.Session() as sess: # 初始化变量 sess.run(tf.global_variables_initializer()) # 启动多线程读取数据 coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) # 训练模型 for i in range(num_steps): _, loss_val = sess.run([train_op, loss]) # 关闭多线程 coord.request_stop() coord.join(threads) ``` 在上面的代码中,我们首先定义了一个包含100个样本的数据集。然后,使用tf.train.slice_input_producer函数将数据集切分成若干个batch,并通过多个线程异步读取数据。接着,我们定义了一个数据预处理函数,并使用tf.map_fn函数将其应用到每个batch中的每个样本。最后,我们定义了模型、损失函数和优化器,并使用tf.Session启动会话进行训练。在训练过程中,我们启动多线程读取数据,并在训练完成后关闭多线程。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值