import torch
import numpy as np
#索引
a = torch.rand(4,3,28,28)
print(a[0].shape,a[0,0].shape,a[0,0,2,4])#torch.Size([3, 28, 28]) torch.Size([28, 28]) tensor(0.4030)
print(a[:2].shape)#torch.Size([2, 3, 28, 28])
print(a[:2,:1,:,:],a[:2,-1,:,:])#-1代表反向索引,一般我们会使用正向索引,-1代表反向的第一个元素;
print(a[:,:,0:28:2,0:28:2])#隔行采样
#特殊位置获取
print(a.index_select(0,torch.tensor([0,2]))) #采集第0个维度上的第0和第2个数据
print(a[2,...])#...表示后面三个维度默认
print(a[2,...,:4])#...是自动推测的并不是每个.代表一个维度
#mask掩码索引,问题是它会把数据打平
x = torch.randn(3,4)
mask = x.ge(0.5)#大于等于0.5
print(torch.masked_select(x,mask))#取出所有大于等于0.5的值或索引
#
#torch.take(x,mask)