Tensor高阶用法:快速掌握Tensor切分、变形等方法

要想在实际的应用中更灵活地用好 Tensor,Tensor 的连接、切分等操作也是必不可少的。我们就通过一些例子和图片来一块学习下。虽然这几个操作比较有难度,但只要耐心分析,然后上手练习,还是可以拿下的。

Tensor 的连接操作

在项目开发中,深度学习某一层神经元的数据可能有多个不同的来源,那么就需要将数据
进行组合,这个组合的操作,我们称之为连接。

cat

连接的操作函数如下:

torch.cat(tensors, dim = 0, out = None)

cat 是 concatnate 的意思,也就是拼接、联系的意思。该函数有两个重要的参数需要你掌握。

第一个参数是 tensors,它很好理解,就是若干个我们准备进行拼接的 Tensor。

第二个参数是 dim,我们回忆一下 Tensor 的定义,Tensor 的维度(秩)是有多种情况
的。比如有两个 3 维的 Tensor,可以有几种不同的拼接方式(如下图),dim 参数就可
以对此作出约定。

 

看到这里,你可能觉得上面画的图是三维的,看起来比较晦涩,所以咱们先从简单的二维
的情况说起,我们先声明两个 3x3 的矩阵,代码如下:

>>> A=torch.ones(3,3)
>>> B=2*torch.ones(3,3)
>>> A
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
>>> B
tensor([[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])

我们先看看 dim=0 的情况,拼接的结果是怎样的:

>>> C=torch.cat((A,B),0)
>>> C
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[2., 2., 2.],
[2., 2., 2.],
[2., 2., 2.]])

你会发现,两个矩阵是按照“行”的方向拼接的。

我们接下来再看看,dim=1 的情况是怎样的:

>>> D=torch.cat((A,B),1)
>>> D
tensor([[1., 1., 1., 2., 2., 2.],
[1., 1., 1., 2., 2., 2.],
[1., 1., 1., 2., 2., 2.]])

显然,两个矩阵,是按照“列”的方向拼接的。那如果 Tensor 是三维甚至更高维度的呢?其实道理也是一样的,dim 的数值是多少,两个矩阵就会按照相应维度的方向链接两个 Tensor。

看到这里你可能会问了,cat 实际上是将多个 Tensor 在已有的维度上进行连接,那如果想增加新的维度进行连接,又该怎么做呢?这时候就需要 stack 函数登场了。

stack

为了让你加深理解,我们还是结合具体例子来看看。假设我们有两个二维矩阵 Tensor,把
它们“堆叠”放在一起,构成一个三维的 Tensor,如下图:

这相当于原来的维度(秩)是 2,现在变成了 3,变成了一个立体的结构,增加了一个维度。你需要注意的是,这跟前面的 cat 不同,cat 中示意图的例子,原来就是 3 维的,cat之后仍旧是 3 维的,而现在咱们是从 2 维变成了 3 维。

在实际图像算法开发中,咱们有时候需要将多个单通道 Tensor(2 维)合并,得到多通道
的结果(3 维)。而实现这种增加维度拼接的方法,我们把它叫做 stack。

stack 函数的定义如下:

torch.stack(inputs, dim=0)

其中,inputs 表示需要拼接的 Tensor,dim 表示新建立维度的方向。

那 stack 如何使用呢?我们一块来看一个例子:

>>> A=torch.arange(0,4)
>>> A
tensor([0, 1, 2, 3])
>>> B=torch.arange(5,9)
>>> B
tensor([5, 6, 7, 8])
>>> C=torch.stack((A,B),0)
>>> C
tensor([[0, 1, 2, 3],
[5, 6, 7, 8]])
>>> D=torch.stack((A,B),1)
>>> D
tensor([[0, 5],
[1, 6],
[2, 7],
[3, 8]])

结合代码,我们可以看到,首先我们构建了两个 4 元素向量 A 和 B,它们的维度是 1。然后,我们在 dim=0,也就是“行”的方向上新建一个维度,这样维度就成了 2,也就得到了 C。而对于 D,我们则是在 dim=1,也就是“列”的方向上新建维度。

Tensor 的切分操作

学完了连接操作之后,我们再来看看连接的逆操作:切分。
切分就是连接的逆过程,有了刚才的经验,你很容易就会想到,切分的操作也应该有很多
种,比如切片、切块等。没错,切分的操作主要分为三种类型:chunk、split、unbind。
乍一看有不少,其实是因为它们各有特点,适用于不同的使用情景,让我们一起看一下。

chunk

chunk 的作用就是将 Tensor 按照声明的 dim,进行尽可能平均的划分。

比如说,我们有一个 32channel 的特征,需要将其按照 channel 均匀分成 4 组,每组 8
个 channel,这个切分就可以通过 chunk 函数来实现。具体函数如下:

torch.chunk(input, chunks, dim=0)

我们挨个来看看函数中涉及到的三个参数:
首先是 input,它表示要做 chunk 操作的 Tensor。

接着,我们看下 chunks,它代表将要被划分的块的数量,而不是每组的数量。请注意,
chunks 必须是整型。

最后是 dim,想想这个参数是什么意思呢?对,就是按照哪个维度来进行 chunk。

还是跟之前一样,我们通过几个代码例子直观感受一下。我们从一个简单的一维向量开
始:

>>> A=torch.tensor([1,2,3,4,5,6,7,8,9,10])
>>> B = torch.chunk(A, 2, 0)
>>> B
(tensor([1, 2, 3, 4, 5]), tensor([ 6,7,8,9, 10]))

这里我们通过 chunk 函数,将原来 10 位长度的 Tensor A,切分成了两个一样 5 位长度的向量。(注意,B 是两个切分结果组成的 tuple)。

那如果 chunk 参数不能够整除的话,结果会是怎样的呢?我们接着往下看:

>>> B = torch.chunk(A, 3, 0)

>>> B
3 (tensor([1, 2, 3, 4]), tensor([5, 6, 7, 8]), tensor([ 9, 10]))

我们发现,10 位长度的 Tensor A,切分成了三个向量,长度分别是 4,4,2 位。这是怎
么分的呢,不应该是 3,3,4 这样更为平均的方式么?

想要解决问题,就得找到规律。让我们再来看一个更大一点的例子,将 A 改为 17 位长
度。

>>> A=torch.tensor([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17])
2 >>> B = torch.chunk(A, 4, 0)
3 >>> B
4 (tensor([1, 2, 3, 4, 5]), tensor([ 6,
7,
8,
9, 10]), tensor([11, 12, 13, 14,15]),tensor([16,17])

17 位长度的 Tensor A,切分成了四个分别为 5,5,5,2 位长度的向量。这时候你就会
发现,其实在计算每个结果元素个数的时候,chunk 函数是先做除法,然后再向上取整得
到每组的数量。

比如上面这个例子,17/4=4.25,向上取整就是 5,那就先逐个生成若干个长度为 5 的向
量,最后不够的就放在一块,作为最后一个向量(长度 2)。

那如果 chunk 参数大于 Tensor 可以切分的长度,又要怎么办呢?我们实际操作一下,代
码如下:

1 >>> A=torch.tensor([1,2,3])
2 >>> B = torch.chunk(A, 5, 0)
3 >>> B
4 (tensor([1]), tensor([2]), tensor([3]))

显然,被切分的 Tensor 只能分成若干个长度为 1 的向量。
由此可以推论出二维的情况,我们再举一个例子, 看看二维矩阵 Tensor 的情况 :

1 >>> A=torch.ones(4,4)
2 >>> A
3 tensor([[1., 1., 1., 1.],
4 [1., 1., 1., 1.],
5 [1., 1., 1., 1.],
6 [1., 1., 1., 1.]])
7 >>> B = torch.chunk(A, 2, 0)
8 >>> B
9 (tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.]]),
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.]]))

还是跟前面的 cat 一样,这里的 dim 参数,表示的是第 dim 维度方向上进行切分。


刚才介绍的 chunk 函数,是按照“切分成确定的份数”来进行切分的,那如果想按照“每份按照确定的大小”来进行切分,该怎样做呢?PyTorch 也提供了相应的方法,叫做split。

split

split 的函数定义如下,跟前面一样,我们还是分别看看这里涉及的参数。

torch.split(tensor, split_size_or_sections, dim=0)

首先是 tensor,也就是待切分的 Tensor。

然后是 split_size_or_sections 这个参数。当它为整数时,表示将 tensor 按照每块大小为
这个整数的数值来切割;当这个参数为列表时,则表示将此 tensor 切成和列表中元素一样
大小的块。

最后同样是 dim,它定义了要按哪个维度切分。

同样的,我们举几个例子来看一下 split 的具体操作。首先是 split_size_or_sections 是整
数的情况。

>>> A=torch.rand(4,4)
>>> A
tensor([[0.6418, 0.4171, 0.7372, 0.0733],
[0.0935, 0.2372, 0.6912, 0.8677],
[0.5263, 0.4145, 0.9292, 0.5671],
[0.2284, 0.6938, 0.0956, 0.3823]])
>>> B=torch.split(A, 2, 0)
>>> B
(tensor([[0.6418, 0.4171, 0.7372, 0.0733],
[0.0935, 0.2372, 0.6912, 0.8677]]),
tensor([[0.5263, 0.4145, 0.9292, 0.5671],
[0.2284, 0.6938, 0.0956, 0.3823]]))

在这个例子里,我们看到,原来 4x4 大小的 Tensor A,沿着第 0 维度,也就是
沿“行”的方向,按照每组 2“行”的大小进行切分,得到了两个 2x4 大小的 Tensor。

那么问题来了,如果 split_size_or_sections 不能整除对应方向的大小的话,会有怎样的结
果呢?我们将代码稍作修改就好了:

>>> C=torch.split(A, 3, 0)
>>> C
(tensor([[0.6418, 0.4171, 0.7372, 0.0733],
[0.0935, 0.2372, 0.6912, 0.8677],
[0.5263, 0.4145, 0.9292, 0.5671]]),
tensor([[0.2284, 0.6938, 0.0956, 0.3823]]))

根据刚才的代码我们就能发现,原来,PyTorch 会尽可能凑够每一个结果,使得其对应
dim 的数据大小等于 split_size_or_sections。如果最后剩下的不够,那就把剩下的内容放
到一块,作为最后一个结果。

接下来,我们再看一下 split_size_or_sections 是列表时的情况。刚才提到了,当
split_size_or_sections 为列表的时候,表示将此 tensor 切成和列表中元素大小一样的大
小的块,我们来看一段对应的代码:

>>> import torch
>>> A=torch.arange(0,16).view(4,4)
>>> A
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
>>> b=torch.unbind(A, 0)
>>> b
(tensor([0, 1, 2, 3]), tensor([4, 5, 6, 7]), tensor([ 8,  9, 10, 11]), tensor([12, 13, 14, 15]))

在这个例子中,我们首先创建了一个 4x4 的二维矩阵 Tensor,随后我们从第 0 维,也就
是“行”的方向进行切分 ,因为矩阵有 4 行,所以就会得到 4 个结果。

接下来,我们看一下:如果从第 1 维,也就是“列”的方向进行切分,会是怎样的结果
呢:

>>> b=torch.unbind(A, 1)
>>> b
(tensor([ 0,  4,  8, 12]), tensor([ 1,  5,  9, 13]), tensor([ 2,  6, 10, 14]), tensor([ 3,  7, 11, 15]))

不难发现,这里是按照“列”的方向进行拆解的。所以,unbind 是一种降维切分的方
式,相当于删除一个维度之后的结果。

Tensor 的索引操作

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

repinkply

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值