本文按字典序排序,方便索引。
torch.cat()
torch.item():返回单元素tensor的元素值
x = torch.tensor([[1]])
print(x.item()) # 1
x = torch.tensor(2.5)
print(x.item()) # 2.5
torch.topk():返回tensor中前k大元素及下标
1.函数形式:
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None) -> (Tensor, LongTensor)
返回给定tensor的前k大元素和其下标
2.参数:
3.例子:
>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1., 2., 3., 4., 5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))