目录
按索引筛选数据,选取第0行和第2行
import torch
data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
aaa=[0,2]
print(data[aaa])
结果:
tensor([[1, 2, 3],
[7, 8, 9]])
index_select 按行索引筛选:
import torch
data = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
indices = torch.tensor([0, 2]) # 在轴上筛选坐标
result=torch.index_select(data,dim= 0, index=indices) # 指定筛选对象、轴、筛选坐标
print(result)
结果:
tensor([[1, 2, 3],
[7, 8, 9]])