pytorch第03天 索引与切片

1 索引

看下面的程序,重点看注释,就能明白索引什么意思

import torch
a = torch.rand(2, 3, 4)
print('---------a------------')
print(a)
print(a.size())

print('--------a[0]----------')
print(a[0])				# 表示第一个维度取0,其他维度没有标明,表示全选
print(a[0].size())

print('--------a[0, 1]---------')
print(a[0, 1])					# 表示第一个维度取0,第二个维度取1
print(a[0, 1].size())

print('--------a[0, 1, 2]------')
print(a[0, 1, 2])			# 三个索引,能取到具体的元素,返回的是标量
print(a[0, 1, 2].size())

print(--------a[0,:,2]--------)
print(a[0, : , 2]).size())	
# 第一个维度为0,第3个维度为2,符合条件的元素全部提取出来组成一个新的张量

输出

---------a------------
tensor([[[0.8452, 0.1938, 0.8032, 0.5945],
         [0.3180, 0.9786, 0.5027, 0.6467],
         [0.8329, 0.2143, 0.5065, 0.9333]],

        [[0.7726, 0.8236, 0.0623, 0.3531],
         [0.6295, 0.9886, 0.8412, 0.3464],
         [0.7716, 0.9486, 0.9734, 0.4873]]])
torch.Size([2, 3, 4])
--------a[0]----------
tensor([[0.8452, 0.1938, 0.8032, 0.5945],
        [0.3180, 0.9786, 0.5027, 0.6467],
        [0.8329, 0.2143, 0.5065, 0.9333]])
torch.Size([3, 4])
--------a[0, 1]---------
tensor([0.3180, 0.9786, 0.5027, 0.6467])
torch.Size([4])
--------a[0, 1, 2]------
tensor(0.5027)
torch.Size([])
--------a[0,:,2]--------
torch.Size([3])

另外,a也可以用a[…]表示,a[1]也可以用a[1, …]表示,
a[1, :, 2]可以用a[1, …, 2]表示
即省略号表示后面的维度全选,省略号什么时候用呢?假如一个张量是5维的,我固定住第一个维度为1,第5个维度是3,其他不约束,那么可以用a[1, …, 3]来表示,这样可以代替a[1, :, :, :, 3]”。在维度比较高的时候,用冒号比较方便。
通过索引,可以从原来的张量中,截取一部分作为新的张量,切片的原理也是类似。

2 切片

切片的使用方法与numpy的切片、python中列表、字符串等容器的切片方法完全一致。
在提取a的切片的时候,对于某一个维度,可以使用start🔚step,例如

print(a[0,2, 0:4:2])  # 第3个维度,不包括4

也可以将 start🔚step 简化,比如,

:,表示取所有
3:,表示从3开始到结束,3以前的索引都不取
:2,表示从0开始取到22后面的元素都不取
3::2,表示从3开始,以2为步长取索引
:5:2,以5作为结束(不包括5),以2为步长取索引
::,表示取所有
::2,从0到最后,以2位步长取索引,能不能包括最后,要看步长

只要理解了start : end : step,就能看懂那么多的切片了,看懂就行,不需要记住

3 负索引与切片

a = torch.rand(2, 3, 4)
print(a)
print(a[1, 1, :-1])

输出

tensor([[[0.3718, 0.9764, 0.6455, 0.5590],
         [0.0086, 0.9308, 0.5603, 0.6590],
         [0.8180, 0.2191, 0.9816, 0.2556]],

        [[0.8498, 0.8754, 0.7551, 0.3621],
         [0.6529, 0.1800, 0.1788, 0.1257],
         [0.2079, 0.4583, 0.9558, 0.8517]]])
tensor([0.6529, 0.1800, 0.1788])

在a[1, 1, :-1]中,前两个索引定位到了向量[0.6529, 0.1800, 0.1788, 0.1257]
这四个元素的负索引为[-4, -3, -2, -1],那么 :-1的意思就是,从-4开始到-1,但不包括-1

4 选择特定的索引作为切片

index_select函数,第一个参数指定哪一个维度,0表示第一个维度,1表示第二个维度,因为程序是根据第一个参数从size中取值,而size是列表。第二个参数,用向量的形式表示在这个维度取哪些索引,它必须是tensor向量。

import torch
a = torch.rand(4, 6)
print(a)
index_tensor = torch.tensor([0, 2])
b = a.index_select(1, index_tensor) # 相当于 a[:, [0,2]],即在第2个维度选择0和2两个索引
print(b)
print(b.size())
print(a[:, [0,2]])

输出

tensor([[0.1224, 0.4920, 0.9343, 0.2703, 0.8450, 0.8305],
        [0.0208, 0.2733, 0.3877, 0.6427, 0.4784, 0.4948],
        [0.2023, 0.7746, 0.4477, 0.5680, 0.0073, 0.6333],
        [0.9234, 0.9261, 0.9253, 0.8811, 0.2696, 0.0863]])
tensor([[0.1224, 0.9343],
        [0.0208, 0.3877],
        [0.2023, 0.4477],
        [0.9234, 0.9253]])
torch.Size([4, 2])
tensor([[0.1224, 0.9343],
        [0.0208, 0.3877],
        [0.2023, 0.4477],
        [0.9234, 0.9253]])

5 mask编码

x = torch.rand(3,4)
print(x)
mask = x.ge(0.5)		# 把张量x中,大于0.5的元素筛选出来
print(mask)
print(mask.type())		# mask是一个布尔型矩阵,维度与x相同
r = torch.masked_select(x, mask)	# 将x中,符合mask条件的,都筛选出来组合成一个向量
print(r)
print(r.size())

输出

tensor([[0.1655, 0.2331, 0.1466, 0.2827],
        [0.8231, 0.8934, 0.5205, 0.2169],
        [0.9436, 0.9027, 0.9388, 0.2888]])
tensor([[False, False, False, False],
        [ True,  True,  True, False],
        [ True,  True,  True, False]])
torch.BoolTensor
tensor([0.8231, 0.8934, 0.5205, 0.9436, 0.9027, 0.9388])
torch.Size([6])

从上面的程序与输出结果可以看到,mask编码是把筛选之后得到的元素重新组合成一个向量,向量多长,要看原张量中有多少个元素符合条件。

6 今日重点

1 如果a是4维张量,那么a[0],a[0, 1],a[0, 1, 2],a[0, : , 2]各表示几维,它们分别代表什么意思?a[0, :, 2]和a[0, …, 2]有什么区别?
2 如果a的某个维度为2 : 5 : 2,代表什么意思?切片的通用公式是?
3 如何在a的第一个维度,选择索引号为0和1的切片作为一个新的张量
4 知道torch.masked_select的功能

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值