indexing
#从第0维往后排
a = torch.rand(4,3,28,28)
print(a[0].shape)
print(a[0,0].shape)
print(a[0,0,0].shape)
print(a[0,0,0,0])
从前或者后面全取
#从第0维往后排
a = torch.rand(4,3,28,28)
#取最前面的
print(a[:2].shape)
print(a[:2,:1,:,:].shape)
#取最后面的
print(a[2:,1:,:,:].shape)
print(a[2:,-1:,:,:].shape)#-1表示倒数第一
print(a[:,:,::2,::2].shape)
1、:单独出现表示取全部
2、:n表示,从0到n
3、n:表示从n到最后
4、n:m,表示从n到m,不包括m
4、n:m:k,表示从n到m,不包括m,隔行采样,间隔k取一个
特殊的选择某区间
#从第0维往后排,第二个参数必须是tensor
a = torch.rand(4,3,28,28)
print(a.index_select(0,torch.tensor([0,2])).shape)
print(a.index_select(1,torch.tensor([0,2])).shape)
使用...
#从第0维往后排 ...表示剩余的任意长
a = torch.rand(4,3,28,28)
print(a[...].shape)
print(a[0,...].shape)
print(a[...,:2].shape)
select by mask,不建议使用,会把数据默认打平
x = torch.randn(3,4)
print(x)
mask = x.ge(0.5)#大于0.5处为true
print(mask)
print(torch.masked_select(x,mask))
print(torch.masked_select(x,mask).shape)
select by flatten index
也会进行打平,比如查找a[2][3]中最后一个用下标5
x = torch.tensor([ [4,3,5],[6,7,8] ])
print(torch.take(x,torch.tensor([0,2,5])))