pytorch 笔记:应用:根据内积结果评估模型预测准确性

1 问题描述

  • 任务是评估一个预测模型的准确性,该模型通过计算embedding的内积进行位置预测
  • 具体来说,有以下数据:

    • 张量 a (大小为 40x10),包含 40 个位置的embedding(每个位置的embedding维度为10)。

    • 张量 b (大小为 4x10),包含一个batch中4个预测位置的embedding。

    • 张量 c (大小为 4x1),包含四个ground-truth位置的索引。

  • 任务是对于 b 中的每个位置,计算它与 a 中所有40个位置的embedding的内积,然后选取内积值最高的10个位置。

  • 接下来,检查 c 中的ground-truth位置索引是否在这10个最高内积位置中

2 python 实现

2.1 获取a,b,c

import torch

a = torch.randn(40, 10)  # 40个位置的embedding


b = torch.randn(4, 10)   # 4个预测位置的embedding
b
'''
tensor([[-1.7334e-02, -4.0604e-01,  9.0610e-01,  1.7381e+00, -1.1258e+00,
         -4.9452e-01, -2.1200e+00, -7.3516e-01, -1.6682e+00,  1.7613e-01],
        [-2.0892e-01,  8.8255e-01, -4.0820e-02, -1.4790e+00, -1.5859e+00,
          5.1649e-01, -3.2593e-01,  6.4271e-01,  8.9277e-01, -6.6575e-01],
        [ 1.1684e-01, -7.3740e-01,  1.0661e+00, -1.0934e+00, -6.1928e-01,
          1.2838e+00,  5.3154e-01,  2.1426e+00,  3.3756e-02, -7.1108e-01],
        [-2.0004e+00,  1.6285e+00,  1.5834e-01, -7.2439e-01, -4.5901e-01,
          9.4934e-01, -1.3431e+00, -2.1714e+00,  1.3512e-01,  5.9950e-04]])
'''


c = torch.randint(0, 40, (4, 1))  # ground-truth位置的索引
c
'''
tensor([[12],
        [27],
        [ 6],
        [12]])
'''

2.2 计算b与a的所有位置的内积

similarity = torch.matmul(b, a.t())  
similarity

2.3  对每一行(每个预测位置)取最大的10个值的索引

top10_indices = torch.topk(similarity, 10, dim=1).indices  
top10_indices 
'''
tensor([[25, 12,  1, 13, 11, 10, 19,  8, 37,  3],
        [26, 35, 20,  5, 30, 10, 22, 16,  0, 38],
        [27, 19, 22, 30, 23, 14, 35, 26,  2,  9],
        [26, 38, 11, 10,  1, 32, 12,  5,  6, 21]])
'''

2.4 判断是否在top10中

is_in_top10 = torch.zeros(4, dtype=torch.bool)
for i in range(4):
    is_in_top10[i] = c[i] in top10_indices[i]
is_in_top10
#tensor([ True, False, False,  True])

  • 9
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值