关于函数torch.topk用法的思考


开始介绍之前先来点哲理性的思考,为什么函数 torch.topk,他的名字会叫 topk呢?

个人认为名称是来源于 “top k”,在这种情况下,它表示 “前 k 个最大值”。

假设我们有一个形状为 ( 2 , 3 , 4 ) (2, 3, 4) (2,3,4) 的三维张量 A A A,如下所示:

A = torch.tensor([[[ 1,  3,  5,  7],
                   [ 2,  4,  6,  8],
                   [ 9, 11, 13, 15]],
                  [[16, 18, 20, 22],
                   [17, 19, 21, 23],
                   [10, 12, 14, 24]]])

1. 沿着dim=0

沿着 dim=0(即在子矩阵之间进行比较):

k = 1
topk_values, topk_indices = torch.topk(A, k=k, dim=0)
A = torch.tensor([[[ 1,  3,  5,  7],
                   [ 2,  4,  6,  8],
                   [ 9, 11, 13, 15]],
                  [[16, 18, 20, 22],
                   [17, 19, 21, 23],
                   [10, 12, 14, 24]]])

沿着 dim=0(即在子矩阵之间进行比较):

topk_values = tensor([[[16, 18, 20, 22],
                        [17, 19, 21, 23],
                        [10, 12, 14, 24]]])

topk_indices = tensor([[[1, 1, 1, 1],
                         [1, 1, 1, 1],
                         [1, 1, 1, 1]]])

那此时我们令k = 3会发生什么?很显然,我们并没有三个子矩阵,所以此时程序会报错。给大家看一下程序的错误:

Traceback (most recent call last):
  File "E:\Learning_Material\Junior_Second_Semester\Adademic_Research\BusterNet_pytorch-master\test_2023_4_7.py", line 14, in <module>
    topk_values, topk_indices = torch.topk(A, k=k, dim=0)
RuntimeError: selected index k out of range

2. 沿着dim=1

沿着 dim=1(即在行之间进行比较):

k = 2
topk_values, topk_indices = torch.topk(A, k=k, dim=1)
topk_values = tensor([[[ 9, 11, 13, 15],
         			   [ 2,  4,  6,  8]],

        			   [[17, 19, 21, 24],
         				[16, 18, 20, 23]]])

topk_indices = tensor([[[2, 2, 2, 2],
         				[1, 1, 1, 1]],

        			   [[1, 1, 1, 2],
         				[0, 0, 0, 1]]])

在沿着行进行比较的情况下,比较是不会在子矩阵,也就是更高维度上发生的。仅仅在子矩阵的维度上比较一个子矩阵中最大的行数。

3. 沿着dim=2

沿着 dim=2(即在列之间进行比较):

k = 2
topk_values, topk_indices = torch.topk(A, k=k, dim=2)
topk_values = tensor([[[ 7,  5],
         			   [ 8,  6],
         			   [15, 13]],

        			  [[22, 20],
         		   	   [23, 21],
         			   [24, 14]]])

topk_indices = tensor([[[3, 2],
         				[3, 2],
         				[3, 2]],

        			   [[3, 2],
         				[3, 2],
         				[3, 2]]])

沿着dim=2比较,此时就不会牵扯到更高的两个维度,只会在最后一个维度之内进行排序比较

4. 总结

没有什么别的经验,希望对大家有用就好!

这段代码的作用是根据给定的top_k和top_p值,过滤掉logits分数较低的预测结果,从而生成更准确的预测结果。下面是每个变量的含义和每句代码的语法: 1. `logits`:一个张量,表示对应词汇表中单词的logits分数。 2. `top_k`:一个整数,表示要保留的最高可能性预测的数量。 3. `top_p`:一个浮点数,表示要保留的累积概率质量。 4. `filter_value`:一个浮点数,用于过滤掉不想要的预测。默认值为负无穷大。 5. `assert`:断言语句,用于判断logits张量的维度是否为1,如果维度不为1,程序将会报错并停止运行。 6. `logits.dim()`:张量的维度数。 7. `top_k = min(top_k, logits.size(-1))`:将top_k值与logits张量的最后一维大小进行比较,保证top_k值不会大于张量的维度。 8. `if top_k > 0:`:如果指定了top_k值,则进行以下操作。 9. `indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]`:返回logits张量中最后一维的最大值的top_k个元素,并将剩余元素的值设置为过滤值, 然后返回不需要的结果的索引。 10. `logits[indices_to_remove] = filter_value`:将logits张量中的索引为indices_to_remove的元素的值设置为过滤值。 11. `if top_p > 0.0:`:如果指定了top_p值,则进行以下操作。 12. `sorted_logits, sorted_indices = torch.sort(logits, descending=True)`:按照降序对logits张量进行排序,并返回排序后的结果和对应的索引。 13. `cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)`:计算softmax函数的累积概率值。 14. `sorted_indices_to_remove = cumulative_probs > top_p`:返回累积概率大于top_p的索引。 15. `sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()`:将索引向右移一位,保留第一个索引。 16. `sorted_indices_to_remove[..., 0] = 0`:将第一个索引设置为0。 17. `indices_to_remove = sorted_indices[sorted_indices_to_remove]`:返回不需要的结果的索引。 18. `logits[indices_to_remove] = filter_value`:将logits张量中的索引为indices_to_remove的元素的值设置为过滤值。 19. `return logits`:返回过滤后的logits张量。
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

No_one-_-2022

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值