在指定维度上取出指定索引数据。
import torch
a=torch.tensor([[0,1,2],
[3,4,5],
[6,7,8]
])
i=torch.tensor([0,0,1,2,2])
c=torch.index_select(a,0,i)
print(c)
a是三行三列的。
这里指定维度为行,也就是取
第0行
第0行
第1行
第2行
第2行
输出大小为5x3
输出:
tensor([[0, 1, 2],
[0, 1, 2],
[3, 4, 5],
[6, 7, 8],
[6, 7, 8]])
tenor类型转换为long
tensor = torch.randn(2, 2)
print(tensor.type())
# torch.long() 将tensor转换为long类型
long_tensor = tensor.long()
print(long_tensor.type())