PyTorch 索引与切片

本文详细介绍了PyTorch中如何进行索引和切片操作,包括从前往后或从后往前全取的方法,如使用冒号表示取全部,以及指定范围的切片。还探讨了特殊的选择区间,如使用mask选择和通过flatten index进行选取。同时警告了使用mask选择可能会导致数据被默认打平的情况。
摘要由CSDN通过智能技术生成

indexing

#从第0维往后排
a = torch.rand(4,3,28,28)
print(a[0].shape)
print(a[0,0].shape)
print(a[0,0,0].shape)

print(a[0,0,0,0])

从前或者后面全取 

#从第0维往后排
a = torch.rand(4,3,28,28)
#取最前面的
print(a[:2].shape)
print(a[:2,:1,:,:].shape)
#取最后面的
print(a[2:,1:,:,:].shape)
print(a[2:,-1:,:,:].shape)#-1表示倒数第一
print(a[:,:,::2,::2].shape)

1、:单独出现表示取全部

2、:n表示,从0到n

3、n:表示从n到最后

4、n:m,表示从n到m,不包括m

4、n:m:k,表示从n到m,不包括m,隔行采样,间隔k取一个

特殊的选择某区间

#从第0维往后排,第二个参数必须是tensor
a = torch.rand(4,3,28,28)
print(a.index_select(0,torch.tensor([0,2])).shape)
print(a.index_select(1,torch.tensor([0,2])).shape)

使用...

#从第0维往后排  ...表示剩余的任意长
a = torch.rand(4,3,28,28)
print(a[...].shape)
print(a[0,...].shape)
print(a[...,:2].shape)

select by mask,不建议使用,会把数据默认打平

x = torch.randn(3,4)
print(x)

mask = x.ge(0.5)#大于0.5处为true
print(mask)

print(torch.masked_select(x,mask))
print(torch.masked_select(x,mask).shape)

select by flatten index

也会进行打平,比如查找a[2][3]中最后一个用下标5

x = torch.tensor([ [4,3,5],[6,7,8] ])
print(torch.take(x,torch.tensor([0,2,5])))

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值