【Pytorch笔记】使用Tensor作为索引

只关心具体技术的读者,可以直接跳到section “使用Tensor作为索引”。
这篇文章只讲解索引二维矩阵的做法,更高维度请自行推广。

题外话1:为了写出高效运行的代码,应当尽力避免在pytorch框架下使用for循环。多利用pytorch的内置函数是好的,这能利用GPU的并行处理提高计算效率。这也就是记下这篇文章的主要目的。这篇文章对应的实际问题是,我在做一个子空间对抗训练的算法,需要对每个batch的data进行按类别的奇异值分解,并将得到的奇异值进行线性组合作为对抗扰动以便进行对抗训练。当然,这篇文章主要关注的是使用tensor作为索引的技术细节,而不是算法思想。还记得第一次进实验室的写深度学习代码的时候,由于还没有摆脱C++的思维,在同样是处理逐类别问题时,使用了巨量的for循环、还有各种诡异丑陋的dict、list的操作,被学长深深地吐槽。。从那次以后,我在写pytorch的时候都会试图找到最佳的实现方案,多用人家已经实现好的函数,少造轮子。

题外话2:我认为对于pytorch的学习,问题导向是效率最高的。即当工程实现面临问题时,针对这个具体问题来进行进行探索,这样既有趣味,效率又高。当然,我并非反对初学者从零开始系统学习pytorch,如果有充足的时间,系统学习还是有必要的。只是我个人觉得,跟着教程学习印象不深,也不见得能学到工程实践中哪个地方用什么方法实现最简单、最高效这种精髓。

所要解决的实际问题

为了讲解tensor索引技术,这里只给出问题的简化描述:现在我们有一个batch的label y y y,维度是[1, bsz]。每个元素是对应的类别号(Cifar10数据集,0~9)。现在我们希望构造一个0-1 mask,mask的维度是[bsz, num_classes*dim]。其中,bsz是batch size,num_classes是类别数量(也就是10),dim是子空间维度。这些变量的具体含义在这里不重要。这个mask形如这样:比如,bsz=3,dim=2,y=[2,0,3]时

[[0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
 [1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],
 [0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0]]

也就是希望mask的第i行,从y[i]算起的dim个元素为1,其他都是0.
继续对问题抽象,也就是我们拥有一个希望对某一二维矩阵的每一行的特定列进行索引。

使用Tensor作为索引

使用一个一维Tensor作为索引

得到的是取出对应行,输出为第1、0、2行。

import torch
a=torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
print(a[[1,0,2]])

#输出:tensor([[4., 5., 6.],
#             [1., 2., 3.],
#             [7., 8., 9.]])

使用多个一维Tensor作为索引

a[b, c]得到的是把b的每个元素作为dim 0的索引,把c的每个元素作为dim 1的索引,对应的a的元素值。这里就是依次索引到a[1,0]、a[0,0]、a[2,0]。
即a[b,c][i]=a[b[i], c[i]]

import torch
a=torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
b=torch.Tensor([1,0,2]).long() #tensor作为索引,数据类型必须是long
c=torch.Tensor([0,0,0]).long()
print(a[b,c])

#输出:tensor([4., 1., 7.])
使用多个二维Tensor作为索引

这种情况就是把上一种情况推广,每一行可以索引多个值,也就是c的第1维所指定的那些值。c的第1维维度是几,就访问几个列。比如,在这个例子中,b和c的第1维有两个维度,就是访问两个列。b和c第0维有三个向量,对应a的三行(第0维的三个向量)。索引了a[1,0] a[1,1], a[0,0], a[0,2], a[2,1], a[2,2]。

import torch
a=torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
b=torch.Tensor([[1,1],[0,0],[2,2]]).long() 
c=torch.Tensor([[0,1],[0,2],[1,2]]).long()
a[b,c]=99

#输出:tensor([[99.,  2., 99.],
#             [99., 99.,  6.],
#             [ 7., 99., 99.]])

其他的正确写法:

a=torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
b=torch.Tensor([[1],[0],[2]]).long() #可以广播成[[1,1],[0,0],[2,2]],就跟上个例子一样了
c=torch.Tensor([[0,1],[0,2],[1,2]]).long()
a[b,c]=99
print(a)

#输出:tensor([[99.,  2., 99.],
#             [99., 99.,  6.],
#    

错误写法:

a=torch.Tensor([[1,2,3],[4,5,6],[7,8,9]])
b=torch.Tensor([1,0,2]).long() 
c=torch.Tensor([[0,1],[0,2],[1,2]]).long()
a[b,c]=99
print(a)
'''
报错:
Traceback (most recent call last):
  File "/home/xxx.py", line 6, in <module>
    a[b,c]=99
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [3], [3, 2]
'''

原因是[1,0,2]不能广播到[[0,1],[0,2],[1,2]]。

使用多个二维Tensor作为索引,就可以解决最开始提出的问题了。

import torch

mask=torch.zeros(3,20)
y=torch.Tensor([2,0,3])
row_index=torch.tensor([[0],[1],[2]]).long()

column_index=(torch.arange(start=0,end=2,step=1).repeat(3,1).transpose(0,1)+y).transpose(0,1).long()
'''tensor([[2., 3.],
           [0., 1.],
           [3., 4.]])'''

mask[row_index, column_index]=1
'''tensor([[0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
           [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
           [0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])'''
  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值