1. 简单行、列索引
torch.manual_seed(0)
data1 = torch.randint(0,10,[4,5])
'''
tensor([[4, 9, 3, 0, 3],
[9, 7, 3, 7, 3],
[1, 6, 6, 9, 8],
[6, 6, 8, 4, 3]])
'''
# 获得某行的元素
data1_1 = data1[0] # 第一行 tensor([4, 9, 3, 0, 3])
# 获得某列的元素
data1_1_ = data1[:,0] # 第一列 tensor([4, 9, 1, 6])
# 获得某行某列元素
data1_11 = data1[0,0] # 第一行第一列的元素 tensor(4)
2. 列表索引
# 列表索引
data1_13_12 = data1[[0,2],[1,3]] # 一维列表,列表长度相同 获得(0,1)-第一行第二列,(2,3)-第三行第四列的元素 tensor([9, 9])
data1_13_12_ = data1[[[0],[2]],[1,3]] # 二维列表,列表长度相同 第一行与第三行的(第二列,第四列)的元素
'''
tensor([[9, 0],
[6, 9]])
'''
3. 布尔索引
data1_da5 = data1>5
'''
tensor([[False, True, False, False, False],
[ True, True, False, True, False],
[False, True, True, True, True],
[ True, True, True, False, False]])
'''
# 获得张量中大于5的所有元素
data1_dayu5 = data1[data1>5] # tensor([9, 9, 7, 7, 6, 6, 9, 8, 6, 6, 8])
# 获得第二列元素大于6的行
data1_2dayu6 = data1[data1[:,1]>6,:]
'''
tensor([[4, 9, 3, 0, 3],
[9, 7, 3, 7, 3]])
'''
# 获得第二行元素大于7的所有列
data1_2dayu7 = data1[:,data1[1,:]>7]
'''
tensor([[4],
[9],
[1],
[6]])
'''
4. 多维索引
data2 = torch.randint(0,10,[2,3,4])
'''
tensor([[[6, 9, 1, 4],
[4, 1, 9, 9],
[9, 0, 1, 2]],
[[3, 0, 5, 5],
[2, 9, 1, 8],
[8, 3, 6, 9]]])
'''
# 按照第1个维度选择(组)
data2_10 = data2[0,:,:] # 第1个维度的第0个数据
'''
tensor([[6, 9, 1, 4],
[4, 1, 9, 9],
[9, 0, 1, 2]])
'''
# 按照第2个维度选择(行)
data2_21 = data2[:,1,:] # 第2个维度的第1个数据
'''
tensor([[4, 1, 9, 9],
[2, 9, 1, 8]])
'''
# 按照第3个维度选择(列)
data2_32 = data2[:,:,2] # 第3个维度的第2个数据
'''
tensor([[1, 9, 1],
[5, 1, 6]])
'''