1.作用
取一个tensor的topk元素
2.使用方法
dim=0表示按照列求topn,dim=1表示按照行求topn,None情况下,dim=1.
任务一:
取top1:
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
print(pred)
values, indices = pred.topk(1, dim=0, largest=True, sorted=True)
print(indices)
print(values)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=0, keepdim=True)
print(indices_max)
print(indices_max == indices)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
tensor([[1, 1, 1, 1, 1]])
tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])
tensor([[1, 1, 1, 1, 1]])
tensor([[True, True, True, True, True]])
任务二:
按行取出topk,将小于topk的置为inf
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
print(pred)
top_k = 2
filter_value=-float('Inf')
indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]
print(indices_to_remove)
pred[indices_to_remove] = filter_value # 对于topk之外的其他元素的logits值设为负无穷
print(pred)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.0145, 0.4053],
[ 0.7265, 1.4164, 1.3443, 1.2035, 1.8823],
[-0.4451, 0.1673, 1.2590, -2.0757, 1.7255],
[ 0.2021, 0.3041, 0.1383, 0.3849, -1.6311]])
tensor([[4],
[4],
[4],
[3]])
tensor([[0.4053],
[1.8823],
[1.7255],
[0.3849]])
tensor([[ True, False, True, True, False],
[ True, False, True, True, False],
[ True, True, False, True, False],
[ True, False, True, False, True]])
tensor([[ -inf, -0.3873, -inf, -inf, 0.4053],
[ -inf, 1.4164, -inf, -inf, 1.8823],
[ -inf, -inf, 1.2590, -inf, 1.7255],
[ -inf, 0.3041, -inf, 0.3849, -inf]])
参考:
https://blog.csdn.net/qq_34914551/article/details/103738160
https://blog.csdn.net/u014264373/article/details/86525621
https://blog.csdn.net/weixin_45062709/article/details/102711885